/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.trees;

import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.StringUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.VanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.Component;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import org.apache.log4j.Logger;

public class AdaBoost
extends BatchBinaryClassifierLearner {
    private static Logger log = Logger.getLogger(AdaBoost.class);
    private BatchClassifierLearner baseLearner;
    private int maxRounds = 100;

    public AdaBoost() {
        this(new DecisionTreeLearner(), 10);
    }

    public AdaBoost(BatchClassifierLearner baseLearner, int maxRounds) {
        this.baseLearner = baseLearner;
        this.maxRounds = maxRounds;
    }

    public int getMaxRounds() {
        return this.maxRounds;
    }

    public void setMaxRounds(int n) {
        this.maxRounds = n;
    }

    public BatchClassifierLearner getBaseLearner() {
        return this.baseLearner;
    }

    public void setBaseLearner(BatchClassifierLearner learner) {
        this.baseLearner = learner;
    }

    public Classifier batchTrain(Dataset dataset) {
        BasicDataset weightedData = new BasicDataset();
        Iterator<Example> i = dataset.iterator();
        while (i.hasNext()) {
            Example e = i.next();
            weightedData.add(new Example(e.asInstance(), e.getLabel()));
        }
        ArrayList<Classifier> classifiers = new ArrayList<Classifier>(this.maxRounds);
        ProgressCounter pc = new ProgressCounter("boosting", "round", this.maxRounds);
        for (int t = 0; t < this.maxRounds; ++t) {
            log.info("Adaboost is starting round " + (t + 1) + "/" + this.maxRounds);
            log.info("Learning classifier with " + this.baseLearner);
            BinaryClassifier c = (BinaryClassifier)this.baseLearner.batchTrain(weightedData);
            classifiers.add(c);
            if (log.isDebugEnabled()) {
                log.debug("classifier is " + c);
            }
            log.info("Generating new distribution");
            double z = 0.0;
            Iterator<Example> k = weightedData.iterator();
            while (k.hasNext()) {
                Example xk = k.next();
                double yk = xk.getLabel().numericLabel();
                double yhatk = c.score(xk);
                xk.setWeight(xk.getWeight() / this.discountFactor(yk, yhatk));
                z += xk.getWeight();
            }
            Iterator<Example> i2 = weightedData.iterator();
            while (i2.hasNext()) {
                Example e = i2.next();
                e.setWeight(e.getWeight() / z);
            }
            pc.progress();
        }
        pc.finished();
        return new BoostedClassifier(classifiers);
    }

    protected double discountFactor(double y, double yhat) {
        return Math.exp(y * yhat);
    }

    private static class BoostedClassifierViewer
    extends ComponentViewer {
        static final long serialVersionUID = 20080609L;

        private BoostedClassifierViewer() {
        }

        public JComponent componentFor(Object o) {
            BoostedClassifier bc = (BoostedClassifier)o;
            JPanel panel = new JPanel();
            panel.setLayout(new GridBagLayout());
            int ypos = 0;
            for (Classifier c : bc.classifiers) {
                GridBagConstraints gbc = new GridBagConstraints();
                gbc.fill = 2;
                gbc.weighty = 0.0;
                gbc.weightx = 0.0;
                gbc.gridx = 0;
                gbc.gridy = ypos++;
                Viewer subview = c instanceof Visible ? ((Visible)((Object)c)).toGUI() : new VanillaViewer(c);
                subview.setSuperView(this);
                panel.add((Component)subview, gbc);
            }
            JScrollPane scroller = new JScrollPane(panel);
            scroller.setHorizontalScrollBarPolicy(30);
            return scroller;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private static class BoostedClassifier
    extends BinaryClassifier
    implements Serializable,
    Visible {
        static final long serialVersionUID = 20080609L;
        private List<Classifier> classifiers;

        public BoostedClassifier(List<Classifier> classifiers) {
            this.classifiers = classifiers;
        }

        @Override
        public double score(Instance instance) {
            double totalScore = 0.0;
            for (BinaryClassifier binaryClassifier : this.classifiers) {
                totalScore += binaryClassifier.score(instance);
            }
            return totalScore;
        }

        @Override
        public String explain(Instance instance) {
            StringBuffer buf = new StringBuffer("");
            double totalScore = 0.0;
            for (BinaryClassifier binaryClassifier : this.classifiers) {
                totalScore += binaryClassifier.score(instance);
                buf.append("score of " + binaryClassifier + ": " + binaryClassifier.score(instance) + "\n");
                buf.append(StringUtil.indent(1, binaryClassifier.explain(instance)) + "\n");
            }
            buf.append("total score: " + totalScore);
            return buf.toString();
        }

        @Override
        public Explanation getExplanation(Instance instance) {
            Explanation.Node top = new Explanation.Node("AdaBoost Explanation");
            double totalScore = 0.0;
            for (BinaryClassifier binaryClassifier : this.classifiers) {
                totalScore += binaryClassifier.score(instance);
                Explanation.Node score = new Explanation.Node("score of " + binaryClassifier);
                Explanation.Node scoreEx = new Explanation.Node(binaryClassifier.score(instance) + " ");
                score.add(scoreEx);
                Explanation.Node childEx = binaryClassifier.getExplanation(instance).getTopNode();
                score.add(childEx);
                top.add(score);
            }
            Explanation.Node total = new Explanation.Node("total score: " + totalScore);
            top.add(total);
            Explanation explanation = new Explanation(top);
            return explanation;
        }

        public String toString() {
            StringBuffer buf = new StringBuffer("[boosted classifier:\n");
            for (BinaryClassifier binaryClassifier : this.classifiers) {
                buf.append(binaryClassifier.toString() + "\n");
            }
            buf.append("]");
            return buf.toString();
        }

        @Override
        public Viewer toGUI() {
            BoostedClassifierViewer v = new BoostedClassifierViewer();
            v.setContent(this);
            return v;
        }
    }

    public static class L
    extends AdaBoost {
        public L() {
        }

        public L(BatchClassifierLearner baseLearner, int maxRounds) {
            super(baseLearner, maxRounds);
        }

        protected double discountFactor(double y, double yhat) {
            return 1.0 + Math.exp(y * yhat);
        }
    }
}

