/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify;

import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner;
import edu.cmu.minorthird.classify.algorithms.trees.AdaBoost;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.BorderLayout;
import java.awt.Component;
import java.util.Iterator;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class StackedLearner
extends BatchClassifierLearner {
    private static Logger log = Logger.getLogger(StackedLearner.class);
    private static final boolean DEBUG = false;
    private ExampleSchema schema;
    private BatchClassifierLearner[] innerLearners;
    private BatchClassifierLearner finalLearner;
    private Splitter<Example> splitter;

    public StackedLearner(BatchClassifierLearner innerLearner, Splitter<Example> splitter) {
        this(new BatchClassifierLearner[]{innerLearner}, new MaxEntLearner(), splitter);
    }

    public StackedLearner(BatchClassifierLearner innerLearner) {
        this(new BatchClassifierLearner[]{innerLearner}, new MaxEntLearner(), new CrossValSplitter<Example>(3));
    }

    public StackedLearner() {
        this(new BatchClassifierLearner[]{new AdaBoost()}, new MaxEntLearner(), new CrossValSplitter<Example>(3));
    }

    public StackedLearner(BatchClassifierLearner[] innerLearners, BatchClassifierLearner finalLearner, Splitter<Example> splitter) {
        this.innerLearners = innerLearners;
        this.finalLearner = finalLearner;
        this.splitter = splitter;
    }

    public Splitter<Example> getSplitter() {
        return this.splitter;
    }

    public void setSplitter(Splitter<Example> splitter) {
        this.splitter = splitter;
    }

    public void setInnerLearner(BatchClassifierLearner learner) {
        this.innerLearners = new BatchClassifierLearner[]{learner};
    }

    public BatchClassifierLearner getInnerLearner() {
        if (this.innerLearners.length != 1) {
            throw new IllegalStateException("multiple inner learners");
        }
        return this.innerLearners[0];
    }

    @Override
    public final void setSchema(ExampleSchema schema) {
        this.schema = schema;
        for (int i = 0; i < this.innerLearners.length; ++i) {
            this.innerLearners[i].setSchema(schema);
        }
        this.finalLearner.setSchema(schema);
    }

    @Override
    public final ExampleSchema getSchema() {
        return this.schema;
    }

    @Override
    public Classifier batchTrain(Dataset dataset) {
        BasicDataset stackedData = new BasicDataset();
        Classifier[] innerClassifiers = new Classifier[this.innerLearners.length];
        Dataset.Split split = dataset.split(this.splitter);
        for (int k = 0; k < split.getNumPartitions(); ++k) {
            Dataset trainData = split.getTrain(k);
            for (int i = 0; i < this.innerLearners.length; ++i) {
                this.innerLearners[i].reset();
                log.info("training inner learner " + (i + 1) + "/" + this.innerLearners.length + " on fold " + (k + 1) + "/" + split.getNumPartitions());
                innerClassifiers[i] = this.innerLearners[i].batchTrain(trainData);
            }
            Dataset testData = split.getTest(k);
            log.info("transforming test examples of fold " + (k + 1) + "/" + split.getNumPartitions());
            Iterator<Example> j = testData.iterator();
            while (j.hasNext()) {
                Example e = j.next();
                stackedData.add(new Example(StackedLearner.transformInstance(this.schema, e, innerClassifiers), e.getLabel()));
            }
        }
        log.info("training level-1 learner");
        Classifier finalClassifier = this.finalLearner.batchTrain(stackedData);
        log.info("result is " + finalClassifier);
        for (int i = 0; i < this.innerLearners.length; ++i) {
            log.info("training inner learner " + (i + 1) + "/" + this.innerLearners.length + " on full dataset");
            innerClassifiers[i] = this.innerLearners[i].batchTrain(dataset);
        }
        this.classifier = new StackedClassifier(this.schema, innerClassifiers, finalClassifier);
        return this.classifier;
    }

    private static Instance transformInstance(ExampleSchema schema, Instance oldInstance, Classifier[] innerClassifiers) {
        MutableInstance newInstance = new MutableInstance();
        for (int i = 0; i < innerClassifiers.length; ++i) {
            ClassLabel ithPrediction = innerClassifiers[i].classification(oldInstance);
            String learner = "learner_" + i;
            for (int h = 0; h < schema.getNumberOfClasses(); ++h) {
                String className = schema.getClassName(h);
                double w = ithPrediction.getWeight(className);
                newInstance.addNumeric(new Feature(new String[]{learner, "class_" + className}), w);
            }
        }
        return newInstance;
    }

    private static String explainTransformedInstance(ExampleSchema schema, Instance oldInstance, Classifier[] innerClassifiers) {
        StringBuffer buf = new StringBuffer("");
        MutableInstance newInstance = new MutableInstance();
        for (int i = 0; i < innerClassifiers.length; ++i) {
            ClassLabel ithPrediction = innerClassifiers[i].classification(oldInstance);
            String learner = "learner_" + i;
            for (int h = 0; h < schema.getNumberOfClasses(); ++h) {
                String className = schema.getClassName(h);
                double w = ithPrediction.getWeight(className);
                newInstance.addNumeric(new Feature(new String[]{learner, "class_" + className}), w);
                buf.append("learner#" + (i + 1) + " predicts " + className + ":\n" + innerClassifiers[i].explain(oldInstance) + "\n");
            }
        }
        return buf.toString();
    }

    private static class StackedClassifier
    implements Classifier,
    Visible {
        private ExampleSchema schema;
        private Classifier[] innerClassifiers;
        private Classifier finalClassifier;

        public StackedClassifier(ExampleSchema schema, Classifier[] innerClassifiers, Classifier finalClassifier) {
            this.schema = schema;
            this.innerClassifiers = innerClassifiers;
            this.finalClassifier = finalClassifier;
        }

        public ClassLabel classification(Instance instance) {
            Instance newInstance = StackedLearner.transformInstance(this.schema, instance, this.innerClassifiers);
            return this.finalClassifier.classification(newInstance);
        }

        public double score(Instance instance, String classLabelName) {
            return this.classification(instance).getWeight(classLabelName);
        }

        public String explain(Instance instance) {
            StringBuffer buf = new StringBuffer("");
            buf.append(StackedLearner.explainTransformedInstance(this.schema, instance, this.innerClassifiers));
            Instance newInstance = StackedLearner.transformInstance(this.schema, instance, this.innerClassifiers);
            buf.append("final classifier:\n");
            buf.append(this.finalClassifier.explain(newInstance));
            return buf.toString();
        }

        public Explanation getExplanation(Instance instance) {
            Explanation ex = new Explanation(this.explain(instance));
            return ex;
        }

        public Viewer toGUI() {
            ComponentViewer v = new ComponentViewer(){
                static final long serialVersionUID = 20071015L;

                public JComponent componentFor(Object o) {
                    StackedClassifier sc = (StackedClassifier)o;
                    JPanel mainPanel = new JPanel();
                    mainPanel.setLayout(new BorderLayout());
                    mainPanel.setBorder(new TitledBorder("Stacked Classifier"));
                    JPanel finalPanel = new JPanel();
                    finalPanel.setBorder(new TitledBorder("Final classifier"));
                    SmartVanillaViewer w = new SmartVanillaViewer(sc.finalClassifier);
                    finalPanel.add(w);
                    w.setSuperView(this);
                    mainPanel.add((Component)finalPanel, "North");
                    JPanel innerPanel = new JPanel();
                    innerPanel.setBorder(new TitledBorder("Inner classifier(s)"));
                    for (int i = 0; i < StackedClassifier.this.innerClassifiers.length; ++i) {
                        SmartVanillaViewer u = new SmartVanillaViewer(StackedClassifier.this.innerClassifiers[i]);
                        innerPanel.add(u);
                        u.setSuperView(this);
                    }
                    mainPanel.add((Component)innerPanel, "South");
                    return new JScrollPane(mainPanel);
                }
            };
            return v;
        }
    }
}

