/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.sequential.BatchSegmenterLearner;
import edu.cmu.minorthird.classify.sequential.CandidateSegmentGroup;
import edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner;
import edu.cmu.minorthird.classify.sequential.InstanceFromSequence;
import edu.cmu.minorthird.classify.sequential.SegmentDataset;
import edu.cmu.minorthird.classify.sequential.Segmentation;
import edu.cmu.minorthird.classify.sequential.Segmenter;
import edu.cmu.minorthird.classify.sequential.SequenceConstants;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.BorderLayout;
import java.awt.Component;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class SegmentCollinsPerceptronLearner
implements BatchSegmenterLearner,
SequenceConstants {
    private static Logger log = Logger.getLogger(SegmentCollinsPerceptronLearner.class);
    private static final boolean DEBUG = log.isDebugEnabled();
    private int numberOfEpochs;
    private boolean updatedViterbi = false;

    public SegmentCollinsPerceptronLearner(int epochs) {
        this.numberOfEpochs = epochs;
    }

    public SegmentCollinsPerceptronLearner(int epochs, boolean updatedViterbi) {
        this(epochs);
        this.updatedViterbi = updatedViterbi;
    }

    public SegmentCollinsPerceptronLearner() {
        this.numberOfEpochs = 5;
    }

    @Override
    public void setSchema(ExampleSchema schema) {
    }

    public int getNumberOfEpochs() {
        return this.numberOfEpochs;
    }

    public void setNumberOfEpochs(int newNumberOfEpochs) {
        this.numberOfEpochs = newNumberOfEpochs;
    }

    public int getHistorySize() {
        return 1;
    }

    @Override
    public Segmenter batchTrain(SegmentDataset dataset) {
        int maxSegmentSize = dataset.getMaxWindowSize();
        ExampleSchema schema = dataset.getSchema();
        if (DEBUG) {
            log.debug("schema: " + schema);
        }
        CollinsPerceptronLearner.MultiClassVPClassifier c = new CollinsPerceptronLearner.MultiClassVPClassifier(schema);
        ProgressCounter pc = new ProgressCounter("training semi-markov voted-perceptron", "sequence", this.numberOfEpochs * dataset.getNumberOfSegmentGroups());
        if (this.updatedViterbi) {
            c.setVoteMode(true);
        }
        for (int epoch = 0; epoch < this.numberOfEpochs; ++epoch) {
            int sequenceErrors = 0;
            int transitionErrors = 0;
            int transitions = 0;
            Iterator<CandidateSegmentGroup> i = dataset.candidateSegmentGroupIterator();
            while (i.hasNext()) {
                int fn;
                CandidateSegmentGroup g = i.next();
                if (DEBUG) {
                    log.debug("classifier is: " + c);
                }
                Segmentation viterbi = new ViterbiSearcher(c, schema, maxSegmentSize).bestSegments(g);
                if (DEBUG) {
                    log.debug("viterbi:\n" + viterbi);
                }
                Segmentation correct = this.correctSegments(g, schema, maxSegmentSize);
                if (DEBUG) {
                    log.debug("correct segments:\n" + correct);
                }
                boolean errorOnThisSequence = false;
                int fp = this.compareSegmentsAndRevise(c, schema, viterbi, correct, -1.0, g);
                if (fp > 0) {
                    errorOnThisSequence = true;
                }
                if ((fn = this.compareSegmentsAndRevise(c, schema, correct, viterbi, 1.0, g)) > 0) {
                    errorOnThisSequence = true;
                }
                if (errorOnThisSequence) {
                    ++sequenceErrors;
                }
                transitionErrors += fp + fn;
                transitions += correct.size();
                c.completeUpdate();
                pc.progress();
            }
            System.out.println("Epoch " + epoch + ": sequenceErr=" + sequenceErrors + " transitionErrors=" + transitionErrors + "/" + transitions);
            if (transitionErrors == 0) break;
        }
        pc.finished();
        c.setVoteMode(true);
        return new ViterbiSegmenter(c, schema, maxSegmentSize);
    }

    private int compareSegmentsAndRevise(CollinsPerceptronLearner.MultiClassVPClassifier classifier, ExampleSchema schema, Segmentation segments, Segmentation otherSegments, double delta, CandidateSegmentGroup g) {
        int errors = 0;
        Map<Segmentation.Segment, String> map = this.previousClassMap(segments, schema);
        Map<Segmentation.Segment, String> otherMap = this.previousClassMap(otherSegments, schema);
        String[] history = new String[1];
        Iterator<Segmentation.Segment> j = segments.iterator();
        while (j.hasNext()) {
            Segmentation.Segment seg = j.next();
            String previousClass = map.get(seg);
            if (seg.lo < 0 || otherSegments.contains(seg) && otherMap.get(seg).equals(previousClass)) continue;
            ++errors;
            history[0] = previousClass;
            InstanceFromSequence instance = new InstanceFromSequence(g.getSubsequenceExample(seg.lo, seg.hi), history);
            if (DEBUG) {
                log.debug("update " + delta + " for: " + instance.getSource());
            }
            classifier.update(schema.getClassName(seg.y), instance, delta);
        }
        return errors;
    }

    private Map<Segmentation.Segment, String> previousClassMap(Segmentation segments, ExampleSchema schema) {
        TreeMap<Segmentation.Segment, String> map = new TreeMap<Segmentation.Segment, String>();
        Segmentation.Segment previousSeg = null;
        Iterator<Segmentation.Segment> j = segments.iterator();
        while (j.hasNext()) {
            Segmentation.Segment seg = j.next();
            String previousClassName = previousSeg == null ? "null" : schema.getClassName(previousSeg.y);
            map.put(seg, previousClassName);
            previousSeg = seg;
        }
        return map;
    }

    private Segmentation correctSegments(CandidateSegmentGroup g, ExampleSchema schema, int maxSegmentSize) {
        Segmentation result = new Segmentation(schema);
        int pos = 0;
        while (pos < g.getSequenceLength()) {
            boolean addedASegmentStartingAtPos = false;
            for (int len = 1; !addedASegmentStartingAtPos && len <= maxSegmentSize; ++len) {
                Instance inst = g.getSubsequenceInstance(pos, pos + len);
                ClassLabel label = g.getSubsequenceLabel(pos, pos + len);
                if (inst == null || label.isNegative()) continue;
                result.add(new Segmentation.Segment(pos, pos + len, schema.getClassIndex(label.bestClassName())));
                addedASegmentStartingAtPos = true;
                pos += len;
            }
            if (addedASegmentStartingAtPos) continue;
            result.add(new Segmentation.Segment(pos, pos + 1, schema.getClassIndex("NEG")));
            ++pos;
        }
        return result;
    }

    private static void dumpStuff(CandidateSegmentGroup g, double[][] fty, BackPointer[][] trace) {
        Example nullExample = new Example(new MutableInstance("*NULL*"), new ClassLabel("*NULL*"));
        DecimalFormat format = new DecimalFormat("####.###");
        System.out.println("t.y\tf(t,y)\tt'.y'\tspan");
        for (int t = 0; t < fty.length; ++t) {
            for (int y = 0; y < fty[t].length; ++y) {
                Example ex;
                BackPointer bp = trace[t][y];
                Example example = ex = bp == null ? nullExample : g.getSubsequenceExample(bp.lastT, bp.t);
                if (bp == null) {
                    bp = new BackPointer(-1, -1, -1);
                }
                String marker = bp.onBestPath ? "<==" : "";
                System.out.println(t + "." + y + "\t" + format.format(fty[t][y]) + "\t" + bp.lastT + "." + bp.lastY + "\t'" + ex.getSource() + "' " + marker);
            }
        }
    }

    public static class ViterbiSegmenter
    implements Segmenter,
    Visible,
    Serializable {
        private static final long serialVersionUID = 20080207L;
        private Classifier c;
        private ExampleSchema schema;
        private int maxSegSize;

        public ViterbiSegmenter(Classifier c, ExampleSchema schema, int maxSegSize) {
            this.c = c;
            this.schema = schema;
            this.maxSegSize = maxSegSize;
        }

        public Segmentation segmentation(CandidateSegmentGroup g) {
            return new ViterbiSearcher(this.c, this.schema, this.maxSegSize).bestSegments(g);
        }

        public String explain(CandidateSegmentGroup g) {
            return "not implemented yet";
        }

        public Viewer toGUI() {
            ComponentViewer v = new ComponentViewer(){
                static final long serialVersionUID = 20080207L;

                public JComponent componentFor(Object o) {
                    ViterbiSegmenter vs = (ViterbiSegmenter)o;
                    JPanel mainPanel = new JPanel();
                    mainPanel.setLayout(new BorderLayout());
                    mainPanel.add((Component)new JLabel("ViterbiSegmenter: maxSegSize=" + vs.maxSegSize), "North");
                    SmartVanillaViewer subView = new SmartVanillaViewer(vs.c);
                    subView.setSuperView(this);
                    mainPanel.add((Component)subView, "South");
                    mainPanel.setBorder(new TitledBorder("ViterbiSegmenter"));
                    return new JScrollPane(mainPanel);
                }
            };
            v.setContent(this);
            return v;
        }
    }

    private static class BackPointer {
        public int lastT;
        public int t;
        public int lastY;
        public boolean onBestPath;

        public BackPointer(int lastT, int t, int lastY) {
            this.lastT = lastT;
            this.t = t;
            this.lastY = lastY;
            this.onBestPath = false;
        }
    }

    public static class ViterbiSearcher {
        private Classifier classifier;
        private ExampleSchema schema;
        private int maxSegmentSize;

        public ViterbiSearcher(Classifier classifier, ExampleSchema schema, int maxSegmentSize) {
            this.classifier = classifier;
            this.schema = schema;
            this.maxSegmentSize = maxSegmentSize;
        }

        public Segmentation bestSegments(CandidateSegmentGroup g) {
            int y;
            int t;
            String[] history = new String[1];
            int seqLen = g.getSequenceLength();
            int ny = this.schema.getNumberOfClasses();
            int backgroundClass = this.schema.getClassIndex("NEG");
            double[][] fty = new double[seqLen + 1][ny];
            BackPointer[][] trace = new BackPointer[seqLen + 1][ny];
            for (t = 0; t < seqLen + 1; ++t) {
                for (y = 0; y < ny; ++y) {
                    fty[t][y] = -99999.0;
                    trace[t][y] = null;
                }
            }
            for (int y2 = 0; y2 < ny; ++y2) {
                fty[0][y2] = 0.0;
            }
            for (t = 0; t < seqLen + 1; ++t) {
                for (y = 0; y < ny; ++y) {
                    for (int lastY = 0; lastY < ny; ++lastY) {
                        int maxSegSizeForY = y == backgroundClass ? 1 : this.maxSegmentSize;
                        for (int lastT = Math.max(0, t - maxSegSizeForY); lastT < t; ++lastT) {
                            Instance segmentInstance = g.getSubsequenceInstance(lastT, t);
                            if (segmentInstance == null) continue;
                            history[0] = this.schema.getClassName(lastY);
                            InstanceFromSequence seqSegmentInstance = new InstanceFromSequence(segmentInstance, history);
                            double segmentScore = this.classifier.classification(seqSegmentInstance).getWeight(this.schema.getClassName(y));
                            if (!(segmentScore + fty[lastT][lastY] > fty[t][y])) continue;
                            fty[t][y] = segmentScore + fty[lastT][lastY];
                            trace[t][y] = new BackPointer(lastT, t, lastY);
                        }
                    }
                }
            }
            int bestEndY = -1;
            double bestEndYScore = -1.7976931348623157E308;
            for (int y3 = 0; y3 < ny; ++y3) {
                if (!(fty[seqLen][y3] > bestEndYScore)) continue;
                bestEndYScore = fty[seqLen][y3];
                bestEndY = y3;
            }
            Segmentation result = new Segmentation(this.schema);
            int y4 = bestEndY;
            BackPointer bp = trace[seqLen][y4];
            while (bp != null) {
                bp.onBestPath = true;
                result.add(new Segmentation.Segment(bp.lastT, bp.t, y4));
                y4 = bp.lastY;
                bp = trace[bp.lastT][bp.lastY];
            }
            if (DEBUG) {
                SegmentCollinsPerceptronLearner.dumpStuff(g, fty, trace);
            }
            return result;
        }
    }
}

