/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner;
import edu.cmu.minorthird.classify.sequential.BeamSearcher;
import edu.cmu.minorthird.classify.sequential.CMM;
import edu.cmu.minorthird.classify.sequential.InstanceFromSequence;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceConstants;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;
import java.util.Iterator;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;

public class GenericCollinsLearnerV1
implements BatchSequenceClassifierLearner,
SequenceConstants {
    private OnlineBinaryClassifierLearner innerLearnerPrototype;
    private OnlineBinaryClassifierLearner[] innerLearner;
    private int historySize;
    private int numberOfEpochs;
    private String[] history;

    public GenericCollinsLearnerV1() {
        this(3, 5);
    }

    public GenericCollinsLearnerV1(OnlineBinaryClassifierLearner innerLearner, int historySize) {
        this(innerLearner, historySize, 5);
    }

    public GenericCollinsLearnerV1(int historySize, int epochs) {
        this(new VotedPerceptron(), historySize, epochs);
    }

    public GenericCollinsLearnerV1(OnlineBinaryClassifierLearner innerLearner, int historySize, int epochs) {
        this.historySize = historySize;
        this.innerLearnerPrototype = innerLearner;
        this.numberOfEpochs = epochs;
        this.history = new String[historySize];
    }

    public void setSchema(ExampleSchema schema) {
    }

    public OnlineBinaryClassifierLearner getInnerLearner() {
        return this.innerLearnerPrototype;
    }

    public void setInnerLearner(OnlineBinaryClassifierLearner newInnerLearner) {
        this.innerLearnerPrototype = newInnerLearner;
    }

    public int getHistorySize() {
        return this.historySize;
    }

    public void setHistorySize(int newHistorySize) {
        this.historySize = newHistorySize;
    }

    public int getNumberOfEpochs() {
        return this.numberOfEpochs;
    }

    public void setNumberOfEpochs(int newNumberOfEpochs) {
        this.numberOfEpochs = newNumberOfEpochs;
    }

    public SequenceClassifier batchTrain(SequenceDataset dataset) {
        ExampleSchema schema = dataset.getSchema();
        try {
            this.innerLearner = new OnlineBinaryClassifierLearner[schema.getNumberOfClasses()];
            for (int i = 0; i < schema.getNumberOfClasses(); ++i) {
                this.innerLearner[i] = (OnlineBinaryClassifierLearner)this.innerLearnerPrototype.copy();
                this.innerLearner[i].reset();
            }
        }
        catch (Exception ex) {
            throw new IllegalArgumentException("innerLearner must be cloneable");
        }
        ProgressCounter pc = new ProgressCounter("training sequential " + this.innerLearnerPrototype.toString(), "sequence", this.numberOfEpochs * dataset.numberOfSequences());
        for (int epoch = 0; epoch < this.numberOfEpochs; ++epoch) {
            int sequenceErrors = 0;
            int transitionErrors = 0;
            int transitions = 0;
            Iterator<Example[]> i = dataset.sequenceIterator();
            while (i.hasNext()) {
                Instance[] sequence = i.next();
                MultiClassClassifier c = new MultiClassClassifier(schema, this.innerLearner);
                ClassLabel[] viterbi = new BeamSearcher(c, this.historySize, schema).bestLabelSequence(sequence);
                boolean errorOnThisSequence = false;
                for (int j = 0; j < sequence.length; ++j) {
                    boolean differenceAtJ = !viterbi[j].isCorrect(((Example)sequence[j]).getLabel());
                    for (int k = 1; j - k >= 0 && !differenceAtJ && k <= this.historySize; ++k) {
                        if (viterbi[j - k].isCorrect(((Example)sequence[j - k]).getLabel())) continue;
                        differenceAtJ = true;
                    }
                    if (!differenceAtJ) continue;
                    ++transitionErrors;
                    errorOnThisSequence = true;
                    InstanceFromSequence.fillHistory(this.history, (Example[])sequence, j);
                    InstanceFromSequence correctXj = new InstanceFromSequence(sequence[j], this.history);
                    int correctClassIndex = schema.getClassIndex(((Example)sequence[j]).getLabel().bestClassName());
                    this.innerLearner[correctClassIndex].addExample(new Example(correctXj, ClassLabel.binaryLabel(1.0)));
                    InstanceFromSequence.fillHistory(this.history, viterbi, j);
                    InstanceFromSequence wrongXj = new InstanceFromSequence(sequence[j], this.history);
                    int wrongClassIndex = schema.getClassIndex(viterbi[j].bestClassName());
                    this.innerLearner[wrongClassIndex].addExample(new Example(wrongXj, ClassLabel.binaryLabel(-1.0)));
                }
                if (errorOnThisSequence) {
                    ++sequenceErrors;
                }
                transitions += sequence.length;
                pc.progress();
            }
            System.out.println("Epoch " + epoch + ": sequenceErr=" + sequenceErrors + " transitionErrors=" + transitionErrors + "/" + transitions);
            if (transitionErrors == 0) break;
        }
        pc.finished();
        MultiClassClassifier c = new MultiClassClassifier(schema, this.innerLearner);
        return new CMM(c, this.historySize, schema);
    }

    public static class MultiClassClassifier
    implements Classifier,
    Visible,
    Serializable {
        private static final long serialVersionUID = 1L;
        private ExampleSchema schema;
        private BinaryClassifier[] innerClassifier;
        private int numClasses;

        public MultiClassClassifier(ExampleSchema schema, BinaryClassifier[] learners) {
            this.schema = schema;
            this.numClasses = schema.getNumberOfClasses();
            this.innerClassifier = learners;
        }

        public MultiClassClassifier(ExampleSchema schema, OnlineBinaryClassifierLearner[] innerLearner) {
            this.schema = schema;
            this.numClasses = schema.getNumberOfClasses();
            this.innerClassifier = new BinaryClassifier[this.numClasses];
            for (int i = 0; i < this.numClasses; ++i) {
                this.innerClassifier[i] = innerLearner[i].getBinaryClassifier();
            }
        }

        public ClassLabel classification(Instance instance) {
            ClassLabel label = new ClassLabel();
            for (int i = 0; i < this.numClasses; ++i) {
                label.add(this.schema.getClassName(i), this.innerClassifier[i].score(instance));
            }
            return label;
        }

        public String explain(Instance instance) {
            StringBuffer buf = new StringBuffer("");
            for (int i = 0; i < this.numClasses; ++i) {
                buf.append("Classifier for class " + this.schema.getClassName(i) + ":\n");
                buf.append(this.innerClassifier[i].explain(instance));
                buf.append("\n");
            }
            return buf.toString();
        }

        public Explanation getExplanation(Instance instance) {
            Explanation.Node top = new Explanation.Node("GenericCollins Explanation");
            for (int i = 0; i < this.numClasses; ++i) {
                Explanation.Node classifier = new Explanation.Node("Classifier for class " + this.schema.getClassName(i));
                Explanation.Node classEx = this.innerClassifier[i].getExplanation(instance).getTopNode();
                classifier.add(classEx);
                top.add(classifier);
            }
            Explanation ex = new Explanation(top);
            return ex;
        }

        public Viewer toGUI() {
            ComponentViewer gui = new ComponentViewer(){
                static final long serialVersionUID = 20080207L;

                public JComponent componentFor(Object o) {
                    MultiClassClassifier c = (MultiClassClassifier)o;
                    JPanel main = new JPanel();
                    for (int i = 0; i < MultiClassClassifier.this.numClasses; ++i) {
                        JPanel classPanel = new JPanel();
                        classPanel.setBorder(new TitledBorder("Class " + c.schema.getClassName(i)));
                        SmartVanillaViewer subviewer = new SmartVanillaViewer(c.innerClassifier[i]);
                        subviewer.setSuperView(this);
                        classPanel.add(subviewer);
                        main.add(classPanel);
                    }
                    return new JScrollPane(main);
                }
            };
            gui.setContent(this);
            return gui;
        }
    }
}

