/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.sequential.BeamSearcher;
import edu.cmu.minorthird.classify.sequential.CMM;
import edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner;
import edu.cmu.minorthird.classify.sequential.InstanceFromSequence;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.util.ProgressCounter;
import java.util.Iterator;
import java.util.Vector;

public class MarginPerceptronLearner
extends CollinsPerceptronLearner {
    float beta = 0.05f;
    int topK = 10;

    public MarginPerceptronLearner() {
        this(3, 5, 0.05f);
    }

    public MarginPerceptronLearner(int numberOfEpochs) {
        this(3, numberOfEpochs, 0.05f);
    }

    public MarginPerceptronLearner(int historySize, int numberOfEpochs, float beta) {
        this(historySize, numberOfEpochs, beta, 10);
    }

    public MarginPerceptronLearner(int historySize, int numberOfEpochs, float beta, int topK) {
        super(historySize, numberOfEpochs);
        this.beta = beta;
        this.topK = topK;
    }

    public SequenceClassifier batchTrain(SequenceDataset dataset) {
        ExampleSchema schema = dataset.getSchema();
        CollinsPerceptronLearner.MultiClassVPClassifier c = new CollinsPerceptronLearner.MultiClassVPClassifier(schema);
        ProgressCounter pc = new ProgressCounter("training sequence perceptron", "sequence", this.getNumberOfEpochs() * dataset.numberOfSequences());
        Vector<ClassLabel[]> viterbiS = new Vector<ClassLabel[]>();
        for (int epoch = 0; epoch < this.getNumberOfEpochs(); ++epoch) {
            int sequenceErrors = 0;
            int transitionErrors = 0;
            int transitions = 0;
            Iterator<Example[]> i = dataset.sequenceIterator();
            while (i.hasNext()) {
                Instance[] sequence = i.next();
                BeamSearcher beam = new BeamSearcher(c, this.getHistorySize(), schema);
                beam.doSearch(sequence);
                float corrScore = this.getScore((Example[])sequence, c);
                if (DEBUG) {
                    log.debug("corrScore: " + corrScore);
                }
                viterbiS.clear();
                int maxNum = Math.min(beam.getNumberOfSolutionsFound(), this.topK);
                for (int k = 0; k < maxNum; ++k) {
                    ClassLabel[] viterbi = beam.viterbi(k);
                    float thisScore = beam.score(k);
                    if (DEBUG) {
                        log.debug("viterbi: " + k + " score " + thisScore);
                    }
                    if (DEBUG) {
                        log.debug(this.sequenceToString(viterbi));
                    }
                    if (thisScore < corrScore * (1.0f - this.beta)) break;
                    if (this.isCorrect(viterbi, (Example[])sequence)) continue;
                    viterbiS.add(viterbi);
                }
                if (DEBUG) {
                    log.debug("added: " + viterbiS.size());
                }
                boolean errorOnThisSequence = false;
                if (viterbiS.size() > 0) {
                    for (int j = 0; j < sequence.length; ++j) {
                        boolean differenceAtJ = false;
                        for (int s = 0; s < viterbiS.size(); ++s) {
                            ClassLabel[] viterbi = (ClassLabel[])viterbiS.elementAt(s);
                            differenceAtJ = !viterbi[j].isCorrect(((Example)sequence[j]).getLabel());
                            for (int k = 1; j - k >= 0 && !differenceAtJ && k <= this.getHistorySize(); ++k) {
                                if (viterbi[j - k].isCorrect(((Example)sequence[j - k]).getLabel())) continue;
                                differenceAtJ = true;
                            }
                            if (differenceAtJ) break;
                        }
                        if (!differenceAtJ) continue;
                        ++transitionErrors;
                        errorOnThisSequence = true;
                        InstanceFromSequence.fillHistory(this.history, (Example[])sequence, j);
                        InstanceFromSequence correctXj = new InstanceFromSequence(sequence[j], this.history);
                        c.update(((Example)sequence[j]).getLabel().bestClassName(), correctXj, 1.0);
                        for (int s = 0; s < viterbiS.size(); ++s) {
                            ClassLabel[] viterbi = (ClassLabel[])viterbiS.elementAt(s);
                            InstanceFromSequence.fillHistory(this.history, viterbi, j);
                            InstanceFromSequence wrongXj = new InstanceFromSequence(sequence[j], this.history);
                            c.update(viterbi[j].bestClassName(), wrongXj, -1.0 / (double)viterbiS.size());
                        }
                    }
                }
                c.completeUpdate();
                if (errorOnThisSequence) {
                    ++sequenceErrors;
                }
                transitions += sequence.length;
                pc.progress();
            }
            System.out.println("Epoch " + epoch + ": sequenceErr=" + sequenceErrors + " transitionErrors=" + transitionErrors + "/" + transitions);
            if (transitionErrors == 0) break;
        }
        pc.finished();
        c.setVoteMode(true);
        return new CMM(c, this.getHistorySize(), schema);
    }

    float getScore(Example[] sequence, Classifier classifier) {
        float score = 0.0f;
        for (int j = 0; j < sequence.length; ++j) {
            InstanceFromSequence.fillHistory(this.history, sequence, j);
            InstanceFromSequence correctXj = new InstanceFromSequence(sequence[j], this.history);
            score = (float)((double)score + classifier.classification(correctXj).getWeight(sequence[j].getLabel().bestClassName()));
        }
        return score;
    }

    boolean isCorrect(ClassLabel[] viterbi, Example[] sequence) {
        for (int j = 0; j < sequence.length; ++j) {
            if (viterbi[j].isCorrect(sequence[j].getLabel())) continue;
            return false;
        }
        return true;
    }

    String sequenceToString(ClassLabel[] viterbi) {
        String path = "";
        for (int j = 0; j < viterbi.length; ++j) {
            path = path + viterbi[j].bestClassName() + " ";
        }
        return path;
    }
}

