/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.OnlineClassifierLearner;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
import edu.cmu.minorthird.classify.algorithms.linear.MarginPerceptron;
import edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner;
import edu.cmu.minorthird.classify.sequential.BeamSearcher;
import edu.cmu.minorthird.classify.sequential.CMM;
import edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner;
import edu.cmu.minorthird.classify.sequential.HyperplaneInstance;
import edu.cmu.minorthird.classify.sequential.InstanceFromSequence;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceConstants;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.classify.sequential.SequenceUtils;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.StringUtil;
import java.util.Iterator;
import org.apache.log4j.Logger;

public class GenericCollinsLearner
implements BatchSequenceClassifierLearner,
SequenceConstants {
    private static Logger log = Logger.getLogger(CollinsPerceptronLearner.class);
    private static final boolean DEBUG = log.isDebugEnabled();
    private OnlineClassifierLearner innerLearnerPrototype;
    private OnlineClassifierLearner[] innerLearner;
    private int historySize;
    private int numberOfEpochs;
    private String[] history;

    public GenericCollinsLearner() {
        this(new MarginPerceptron(0.0, false, true));
    }

    public GenericCollinsLearner(OnlineClassifierLearner innerLearner) {
        this(innerLearner, 5);
    }

    public GenericCollinsLearner(int epochs) {
        this(new MarginPerceptron(0.0, false, true), epochs);
    }

    public GenericCollinsLearner(OnlineClassifierLearner innerLearner, int epochs) {
        this(innerLearner, 3, epochs);
    }

    public GenericCollinsLearner(OnlineClassifierLearner innerLearner, int historySize, int epochs) {
        this.historySize = historySize;
        this.innerLearnerPrototype = innerLearner;
        this.numberOfEpochs = epochs;
        this.history = new String[historySize];
    }

    public void setSchema(ExampleSchema schema) {
    }

    public OnlineClassifierLearner getInnerLearner() {
        return this.innerLearnerPrototype;
    }

    public void setInnerLearner(OnlineClassifierLearner newInnerLearner) {
        this.innerLearnerPrototype = newInnerLearner;
    }

    public int getHistorySize() {
        return this.historySize;
    }

    public void setHistorySize(int newHistorySize) {
        this.historySize = newHistorySize;
    }

    public int getNumberOfEpochs() {
        return this.numberOfEpochs;
    }

    public void setNumberOfEpochs(int newNumberOfEpochs) {
        this.numberOfEpochs = newNumberOfEpochs;
    }

    public SequenceClassifier batchTrain(SequenceDataset dataset) {
        ExampleSchema schema = dataset.getSchema();
        this.innerLearner = SequenceUtils.duplicatePrototypeLearner(this.innerLearnerPrototype, schema.getNumberOfClasses());
        ProgressCounter pc = new ProgressCounter("training sequential " + this.innerLearnerPrototype.toString(), "sequence", this.numberOfEpochs * dataset.numberOfSequences());
        for (int epoch = 0; epoch < this.numberOfEpochs; ++epoch) {
            dataset.shuffle();
            int sequenceErrors = 0;
            int transitionErrors = 0;
            int transitions = 0;
            Iterator<Example[]> i = dataset.sequenceIterator();
            while (i.hasNext()) {
                int k;
                Instance[] sequence = i.next();
                SequenceUtils.MultiClassClassifier c = new SequenceUtils.MultiClassClassifier(schema, this.innerLearner);
                Object[] viterbi = new BeamSearcher(c, this.historySize, schema).bestLabelSequence(sequence);
                if (DEBUG) {
                    log.debug("classifier: " + c);
                }
                if (DEBUG) {
                    log.debug("viterbi:\n" + StringUtil.toString(viterbi));
                }
                boolean errorOnThisSequence = false;
                Hyperplane[] accumPos = new Hyperplane[schema.getNumberOfClasses()];
                Hyperplane[] accumNeg = new Hyperplane[schema.getNumberOfClasses()];
                for (int k2 = 0; k2 < schema.getNumberOfClasses(); ++k2) {
                    accumPos[k2] = new Hyperplane();
                    accumNeg[k2] = new Hyperplane();
                }
                for (int j = 0; j < sequence.length; ++j) {
                    boolean differenceAtJ = !((ClassLabel)viterbi[j]).isCorrect(((Example)sequence[j]).getLabel());
                    for (k = 1; j - k >= 0 && !differenceAtJ && k <= this.historySize; ++k) {
                        if (((ClassLabel)viterbi[j - k]).isCorrect(((Example)sequence[j - k]).getLabel())) continue;
                        differenceAtJ = true;
                    }
                    if (!differenceAtJ) continue;
                    ++transitionErrors;
                    errorOnThisSequence = true;
                    InstanceFromSequence.fillHistory(this.history, (Example[])sequence, j);
                    InstanceFromSequence correctXj = new InstanceFromSequence(sequence[j], this.history);
                    int correctClassIndex = schema.getClassIndex(((Example)sequence[j]).getLabel().bestClassName());
                    accumPos[correctClassIndex].increment(correctXj, 1.0);
                    accumNeg[correctClassIndex].increment(correctXj, -1.0);
                    if (DEBUG) {
                        log.debug("+ update " + ((Example)sequence[j]).getLabel().bestClassName() + " " + correctXj.getSource() + ";" + correctXj);
                    }
                    InstanceFromSequence.fillHistory(this.history, (ClassLabel[])viterbi, j);
                    InstanceFromSequence wrongXj = new InstanceFromSequence(sequence[j], this.history);
                    int wrongClassIndex = schema.getClassIndex(((ClassLabel)viterbi[j]).bestClassName());
                    accumPos[wrongClassIndex].increment(wrongXj, -1.0);
                    accumNeg[wrongClassIndex].increment(wrongXj, 1.0);
                    if (!DEBUG) continue;
                    log.debug("- update " + ((ClassLabel)viterbi[j]).bestClassName() + " " + wrongXj.getSource());
                }
                if (errorOnThisSequence) {
                    ++sequenceErrors;
                    String subPopId = ((Example)sequence[0]).getSubpopulationId();
                    String source = "no source";
                    for (k = 0; k < schema.getNumberOfClasses(); ++k) {
                        this.innerLearner[k].addExample(new Example(new HyperplaneInstance(accumPos[k], subPopId, source), ClassLabel.positiveLabel(1.0)));
                        this.innerLearner[k].addExample(new Example(new HyperplaneInstance(accumNeg[k], subPopId, source), ClassLabel.negativeLabel(-1.0)));
                    }
                }
                transitions += sequence.length;
                pc.progress();
            }
            System.out.println("Epoch " + epoch + ": sequenceErr=" + sequenceErrors + " transitionErrors=" + transitionErrors + "/" + transitions);
            if (transitionErrors == 0) break;
        }
        pc.finished();
        for (int k = 0; k < schema.getNumberOfClasses(); ++k) {
            this.innerLearner[k].completeTraining();
        }
        SequenceUtils.MultiClassClassifier c = new SequenceUtils.MultiClassClassifier(schema, this.innerLearner);
        return new CMM(c, this.historySize, schema);
    }
}

