/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
import edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner;
import edu.cmu.minorthird.classify.sequential.BeamSearcher;
import edu.cmu.minorthird.classify.sequential.CMM;
import edu.cmu.minorthird.classify.sequential.InstanceFromSequence;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceConstants;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.StringUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;
import java.util.Iterator;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;

public class CollinsPerceptronLearner
implements BatchSequenceClassifierLearner,
SequenceConstants {
    protected static Logger log = Logger.getLogger(CollinsPerceptronLearner.class);
    protected static final boolean DEBUG = log.isDebugEnabled();
    protected int historySize;
    protected int numberOfEpochs;
    protected String[] history;

    public CollinsPerceptronLearner() {
        this(3, 5);
    }

    public CollinsPerceptronLearner(int numberOfEpochs) {
        this(3, numberOfEpochs);
    }

    public CollinsPerceptronLearner(int historySize, int numberOfEpochs) {
        this.historySize = historySize;
        this.numberOfEpochs = numberOfEpochs;
        this.history = new String[historySize];
    }

    public int getNumberOfEpochs() {
        return this.numberOfEpochs;
    }

    public void setNumberOfEpochs(int newNumberOfEpochs) {
        this.numberOfEpochs = newNumberOfEpochs;
    }

    public int getHistorySize() {
        return this.historySize;
    }

    public void setHistorySize(int newHistorySize) {
        this.historySize = newHistorySize;
    }

    public String getHistorySizeHelp() {
        return "Number of tokens to look back on. <br>The predicted labels for the history are used as features to help classify the current token.";
    }

    public void setSchema(ExampleSchema schema) {
    }

    public SequenceClassifier batchTrain(SequenceDataset dataset) {
        ExampleSchema schema = dataset.getSchema();
        MultiClassVPClassifier c = new MultiClassVPClassifier(schema);
        ProgressCounter pc = new ProgressCounter("training sequence perceptron", "sequence", this.numberOfEpochs * dataset.numberOfSequences());
        for (int epoch = 0; epoch < this.numberOfEpochs; ++epoch) {
            dataset.shuffle();
            int sequenceErrors = 0;
            int transitionErrors = 0;
            int transitions = 0;
            Iterator<Example[]> i = dataset.sequenceIterator();
            while (i.hasNext()) {
                Instance[] sequence = i.next();
                Object[] viterbi = new BeamSearcher(c, this.historySize, schema).bestLabelSequence(sequence);
                if (DEBUG) {
                    log.debug("classifier: " + c);
                }
                if (DEBUG) {
                    log.debug("viterbi:\n" + StringUtil.toString(viterbi));
                }
                boolean errorOnThisSequence = false;
                for (int j = 0; j < sequence.length; ++j) {
                    boolean differenceAtJ = !((ClassLabel)viterbi[j]).isCorrect(((Example)sequence[j]).getLabel());
                    for (int k = 1; j - k >= 0 && !differenceAtJ && k <= this.historySize; ++k) {
                        if (((ClassLabel)viterbi[j - k]).isCorrect(((Example)sequence[j - k]).getLabel())) continue;
                        differenceAtJ = true;
                    }
                    if (!differenceAtJ) continue;
                    ++transitionErrors;
                    errorOnThisSequence = true;
                    InstanceFromSequence.fillHistory(this.history, (Example[])sequence, j);
                    InstanceFromSequence correctXj = new InstanceFromSequence(sequence[j], this.history);
                    c.update(((Example)sequence[j]).getLabel().bestClassName(), correctXj, 1.0);
                    if (DEBUG) {
                        log.debug("+ update " + ((Example)sequence[j]).getLabel().bestClassName() + " " + correctXj.getSource());
                    }
                    InstanceFromSequence.fillHistory(this.history, (ClassLabel[])viterbi, j);
                    InstanceFromSequence wrongXj = new InstanceFromSequence(sequence[j], this.history);
                    c.update(((ClassLabel)viterbi[j]).bestClassName(), wrongXj, -1.0);
                    if (!DEBUG) continue;
                    log.debug("- update " + ((ClassLabel)viterbi[j]).bestClassName() + " " + wrongXj.getSource());
                }
                c.completeUpdate();
                if (errorOnThisSequence) {
                    ++sequenceErrors;
                }
                transitions += sequence.length;
                pc.progress();
            }
            System.out.println("Epoch " + epoch + ": sequenceErr=" + sequenceErrors + " transitionErrors=" + transitionErrors + "/" + transitions);
            if (transitionErrors == 0) break;
        }
        pc.finished();
        c.setVoteMode(true);
        return new CMM(c, this.historySize, schema);
    }

    public static class MultiClassVPClassifier
    implements Classifier,
    Visible,
    Serializable {
        private static final long serialVersionUID = 1L;
        private ExampleSchema schema;
        private Hyperplane[] s_t;
        private Hyperplane[] w_t;
        private int numClasses;
        private boolean voteMode = false;

        public MultiClassVPClassifier(ExampleSchema schema) {
            this.schema = schema;
            this.numClasses = schema.getNumberOfClasses();
            this.reset();
        }

        public void setVoteMode(boolean flag) {
            this.voteMode = flag;
        }

        public Hyperplane[] getHyperplanes() {
            return this.voteMode ? this.s_t : this.w_t;
        }

        public ExampleSchema getSchema() {
            return this.schema;
        }

        public void update(String className, Instance instance, double delta) {
            int index = this.schema.getClassIndex(className);
            this.w_t[index].increment(instance, delta);
        }

        public void completeUpdate() {
            for (int i = 0; i < this.numClasses; ++i) {
                this.s_t[i].increment(this.w_t[i], 1.0);
            }
        }

        public ClassLabel classification(Instance instance) {
            Hyperplane[] h = this.voteMode ? this.s_t : this.w_t;
            ClassLabel label = new ClassLabel();
            for (int i = 0; i < this.numClasses; ++i) {
                label.add(this.schema.getClassName(i), h[i].score(instance));
            }
            return label;
        }

        public String explain(Instance instance) {
            Hyperplane[] h = this.voteMode ? this.s_t : this.w_t;
            StringBuffer buf = new StringBuffer("");
            for (int i = 0; i < this.numClasses; ++i) {
                buf.append("Hyperplane for class " + this.schema.getClassName(i) + ":\n");
                buf.append(h[i].explain(instance));
                buf.append("\n");
            }
            return buf.toString();
        }

        public Explanation getExplanation(Instance instance) {
            Hyperplane[] h = this.voteMode ? this.s_t : this.w_t;
            Explanation.Node top = new Explanation.Node("CollinsPerceptron Explanation");
            for (int i = 0; i < this.numClasses; ++i) {
                Explanation.Node hyp = new Explanation.Node("Hyperplane for class " + this.schema.getClassName(i) + ":\n");
                Explanation.Node explanation = h[i].getExplanation(instance).getTopNode();
                hyp.add(explanation);
                top.add(hyp);
            }
            Explanation ex = new Explanation(top);
            return ex;
        }

        public Viewer toGUI() {
            ComponentViewer gui = new ComponentViewer(){
                static final long serialVersionUID = 20080207L;

                public JComponent componentFor(Object o) {
                    MultiClassVPClassifier c = (MultiClassVPClassifier)o;
                    JPanel main = new JPanel();
                    for (int i = 0; i < MultiClassVPClassifier.this.numClasses; ++i) {
                        JPanel classPanel = new JPanel();
                        classPanel.setBorder(new TitledBorder("Class " + c.schema.getClassName(i)));
                        Viewer subviewer = MultiClassVPClassifier.this.voteMode ? MultiClassVPClassifier.this.s_t[i].toGUI() : MultiClassVPClassifier.this.w_t[i].toGUI();
                        subviewer.setSuperView(this);
                        classPanel.add(subviewer);
                        main.add(classPanel);
                    }
                    return new JScrollPane(main);
                }
            };
            gui.setContent(this);
            return gui;
        }

        public void reset() {
            this.s_t = new Hyperplane[this.numClasses];
            this.w_t = new Hyperplane[this.numClasses];
            for (int i = 0; i < this.numClasses; ++i) {
                this.s_t[i] = new Hyperplane();
                this.w_t[i] = new Hyperplane();
            }
        }

        public String toString() {
            return "[MultiClassVPClassifier:" + StringUtil.toString(this.w_t, "\n", "\n]", "\n - ");
        }
    }
}

