/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.OnlineClassifierLearner;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
import edu.cmu.minorthird.classify.algorithms.linear.MarginPerceptron;
import edu.cmu.minorthird.classify.sequential.BatchSegmenterLearner;
import edu.cmu.minorthird.classify.sequential.CandidateSegmentGroup;
import edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner;
import edu.cmu.minorthird.classify.sequential.HyperplaneInstance;
import edu.cmu.minorthird.classify.sequential.InstanceFromSequence;
import edu.cmu.minorthird.classify.sequential.SegmentCollinsPerceptronLearner;
import edu.cmu.minorthird.classify.sequential.SegmentDataset;
import edu.cmu.minorthird.classify.sequential.Segmentation;
import edu.cmu.minorthird.classify.sequential.Segmenter;
import edu.cmu.minorthird.classify.sequential.SequenceConstants;
import edu.cmu.minorthird.classify.sequential.SequenceUtils;
import edu.cmu.minorthird.util.ProgressCounter;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class SegmentGenericCollinsLearner
implements BatchSegmenterLearner,
SequenceConstants {
    private static Logger log = Logger.getLogger(CollinsPerceptronLearner.class);
    private static final boolean DEBUG = log.isDebugEnabled();
    private OnlineClassifierLearner innerLearnerPrototype;
    private OnlineClassifierLearner[] innerLearner;
    private int numberOfEpochs;
    private int maxSegmentSize;

    public SegmentGenericCollinsLearner() {
        this(new MarginPerceptron(0.0, false, true));
    }

    public SegmentGenericCollinsLearner(OnlineClassifierLearner innerLearner) {
        this(innerLearner, 5);
    }

    public SegmentGenericCollinsLearner(int epochs) {
        this(new MarginPerceptron(0.0, false, true), epochs);
    }

    public SegmentGenericCollinsLearner(OnlineClassifierLearner innerLearner, int epochs) {
        this(innerLearner, 4, epochs);
    }

    public SegmentGenericCollinsLearner(OnlineClassifierLearner innerLearner, int maxSegmentSize, int epochs) {
        this.maxSegmentSize = maxSegmentSize;
        this.innerLearnerPrototype = innerLearner;
        this.numberOfEpochs = epochs;
    }

    @Override
    public void setSchema(ExampleSchema schema) {
    }

    public OnlineClassifierLearner getInnerLearner() {
        return this.innerLearnerPrototype;
    }

    public void setInnerLearner(OnlineClassifierLearner newInnerLearner) {
        this.innerLearnerPrototype = newInnerLearner;
    }

    public int getHistorySize() {
        return 1;
    }

    public int getNumberOfEpochs() {
        return this.numberOfEpochs;
    }

    public void setNumberOfEpochs(int newNumberOfEpochs) {
        this.numberOfEpochs = newNumberOfEpochs;
    }

    @Override
    public Segmenter batchTrain(SegmentDataset dataset) {
        ExampleSchema schema = dataset.getSchema();
        this.innerLearner = SequenceUtils.duplicatePrototypeLearner(this.innerLearnerPrototype, schema.getNumberOfClasses());
        ProgressCounter pc = new ProgressCounter("training segments " + this.innerLearnerPrototype.toString(), "sequence", this.numberOfEpochs * dataset.getNumberOfSegmentGroups());
        for (int epoch = 0; epoch < this.numberOfEpochs; ++epoch) {
            int sequenceErrors = 0;
            int transitionErrors = 0;
            int transitions = 0;
            Iterator<CandidateSegmentGroup> i = dataset.candidateSegmentGroupIterator();
            while (i.hasNext()) {
                int fn;
                SequenceUtils.MultiClassClassifier c = new SequenceUtils.MultiClassClassifier(schema, this.innerLearner);
                if (DEBUG) {
                    log.debug("classifier is: " + c);
                }
                CandidateSegmentGroup g = i.next();
                Segmentation viterbi = new SegmentCollinsPerceptronLearner.ViterbiSearcher(c, schema, this.maxSegmentSize).bestSegments(g);
                if (DEBUG) {
                    log.debug("viterbi " + this.maxSegmentSize + "\n" + viterbi);
                }
                Segmentation correct = this.correctSegments(g, schema, this.maxSegmentSize);
                if (DEBUG) {
                    log.debug("correct segments:\n" + correct);
                }
                boolean errorOnThisSequence = false;
                Hyperplane[] accumPos = new Hyperplane[schema.getNumberOfClasses()];
                Hyperplane[] accumNeg = new Hyperplane[schema.getNumberOfClasses()];
                for (int k = 0; k < schema.getNumberOfClasses(); ++k) {
                    accumPos[k] = new Hyperplane();
                    accumNeg[k] = new Hyperplane();
                }
                int fp = this.compareSegmentsAndIncrement(schema, viterbi, correct, accumNeg, 1.0, g);
                if (fp > 0) {
                    errorOnThisSequence = true;
                }
                if ((fn = this.compareSegmentsAndIncrement(schema, correct, viterbi, accumPos, 1.0, g)) > 0) {
                    errorOnThisSequence = true;
                }
                if (errorOnThisSequence) {
                    ++sequenceErrors;
                }
                transitionErrors += fp + fn;
                if (errorOnThisSequence) {
                    ++sequenceErrors;
                    String subPopId = g.getSubpopulationId();
                    String source = "no source";
                    for (int k = 0; k < schema.getNumberOfClasses(); ++k) {
                        this.innerLearner[k].addExample(new Example(new HyperplaneInstance(accumPos[k], subPopId, source), ClassLabel.positiveLabel(1.0)));
                        this.innerLearner[k].addExample(new Example(new HyperplaneInstance(accumNeg[k], subPopId, source), ClassLabel.negativeLabel(-1.0)));
                    }
                }
                transitions += correct.size();
                pc.progress();
            }
            System.out.println("Epoch " + epoch + ": sequenceErr=" + sequenceErrors + " transitionErrors=" + transitionErrors + "/" + transitions);
            if (transitionErrors == 0) break;
        }
        pc.finished();
        for (int k = 0; k < schema.getNumberOfClasses(); ++k) {
            this.innerLearner[k].completeTraining();
        }
        SequenceUtils.MultiClassClassifier c = new SequenceUtils.MultiClassClassifier(schema, this.innerLearner);
        return new SegmentCollinsPerceptronLearner.ViterbiSegmenter(c, schema, this.maxSegmentSize);
    }

    private int compareSegmentsAndIncrement(ExampleSchema schema, Segmentation segments, Segmentation otherSegments, Hyperplane[] accum, double delta, CandidateSegmentGroup g) {
        int errors = 0;
        Map<Segmentation.Segment, String> map = this.previousClassMap(segments, schema);
        Map<Segmentation.Segment, String> otherMap = this.previousClassMap(otherSegments, schema);
        String[] history = new String[1];
        Iterator<Segmentation.Segment> j = segments.iterator();
        while (j.hasNext()) {
            Segmentation.Segment seg = j.next();
            String previousClass = map.get(seg);
            if (seg.lo < 0 || otherSegments.contains(seg) && otherMap.get(seg).equals(previousClass)) continue;
            ++errors;
            history[0] = previousClass;
            InstanceFromSequence instance = new InstanceFromSequence(g.getSubsequenceExample(seg.lo, seg.hi), history);
            if (DEBUG) {
                log.debug("class " + schema.getClassName(seg.y) + " update " + delta + " for: " + instance.getSource());
            }
            accum[seg.y].increment(instance, delta);
        }
        return errors;
    }

    private Map<Segmentation.Segment, String> previousClassMap(Segmentation segments, ExampleSchema schema) {
        TreeMap<Segmentation.Segment, String> map = new TreeMap<Segmentation.Segment, String>();
        Segmentation.Segment previousSeg = null;
        Iterator<Segmentation.Segment> j = segments.iterator();
        while (j.hasNext()) {
            Segmentation.Segment seg = j.next();
            String previousClassName = previousSeg == null ? "null" : schema.getClassName(previousSeg.y);
            map.put(seg, previousClassName);
            previousSeg = seg;
        }
        return map;
    }

    private Segmentation correctSegments(CandidateSegmentGroup g, ExampleSchema schema, int maxSegmentSize) {
        Segmentation result = new Segmentation(schema);
        int pos = 0;
        while (pos < g.getSequenceLength()) {
            boolean addedASegmentStartingAtPos = false;
            for (int len = 1; !addedASegmentStartingAtPos && len <= maxSegmentSize; ++len) {
                Instance inst = g.getSubsequenceInstance(pos, pos + len);
                ClassLabel label = g.getSubsequenceLabel(pos, pos + len);
                if (inst == null || label.isNegative()) continue;
                result.add(new Segmentation.Segment(pos, pos + len, schema.getClassIndex(label.bestClassName())));
                addedASegmentStartingAtPos = true;
                pos += len;
            }
            if (addedASegmentStartingAtPos) continue;
            result.add(new Segmentation.Segment(pos, pos + 1, schema.getClassIndex("NEG")));
            ++pos;
        }
        return result;
    }
}

