/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.text.learn;

import com.wcohen.ss.BasicStringWrapper;
import com.wcohen.ss.DistanceLearnerFactory;
import com.wcohen.ss.api.StringDistance;
import com.wcohen.ss.api.StringDistanceLearner;
import com.wcohen.ss.api.StringWrapper;
import com.wcohen.ss.lookup.SoftDictionary;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.sequential.InstanceFromSequence;
import edu.cmu.minorthird.text.AbstractAnnotator;
import edu.cmu.minorthird.text.Annotator;
import edu.cmu.minorthird.text.EmptyLabels;
import edu.cmu.minorthird.text.MonotonicTextLabels;
import edu.cmu.minorthird.text.Span;
import edu.cmu.minorthird.text.TextLabels;
import edu.cmu.minorthird.text.learn.AnnotationExample;
import edu.cmu.minorthird.text.learn.AnnotatorLearner;
import edu.cmu.minorthird.text.learn.ExtractorAnnotator;
import edu.cmu.minorthird.text.learn.SampleFE;
import edu.cmu.minorthird.text.learn.SpanFeatureExtractor;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.BorderLayout;
import java.awt.Component;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;

public class ConditionalSemiMarkovModel {
    private static Logger log = Logger.getLogger(ConditionalSemiMarkovModel.class);
    private static final boolean DEBUG = log.isDebugEnabled();

    public static Segments bestSegments(Span documentSpan, TextLabels labels, SpanFeatureExtractor fe, BinaryClassifier classifier, int maxSegSize) {
        int y;
        int t;
        double[][] fty = new double[documentSpan.size() + 1][2];
        BackPointer[][] trace = new BackPointer[documentSpan.size() + 1][2];
        for (t = 0; t < documentSpan.size() + 1; ++t) {
            for (y = 0; y < 2; ++y) {
                fty[t][y] = -99999.0;
                trace[t][y] = null;
            }
        }
        fty[0][1] = 0.0;
        fty[0][0] = 0.0;
        for (t = 0; t < documentSpan.size() + 1; ++t) {
            for (y = 0; y < 2; ++y) {
                for (int lastY = 0; lastY < 2; ++lastY) {
                    int maxSegSizeForY = y == 0 ? 1 : maxSegSize;
                    for (int lastT = Math.max(0, t - maxSegSizeForY); lastT < t; ++lastT) {
                        Span segment = documentSpan.subSpan(lastT, t - lastT);
                        double segmentScore = ConditionalSemiMarkovModel.score(labels, lastY, y, lastT, t, segment, fe, classifier);
                        if (!(segmentScore + fty[lastT][lastY] > fty[t][y])) continue;
                        fty[t][y] = segmentScore + fty[lastT][lastY];
                        trace[t][y] = new BackPointer(segment, lastT, lastY);
                    }
                }
            }
        }
        int y2 = fty[documentSpan.size()][1] > fty[documentSpan.size()][0] ? 1 : 0;
        TreeSet<Span> result = new TreeSet<Span>();
        BackPointer bp = trace[documentSpan.size()][y2];
        while (bp != null) {
            bp.onBestPath = true;
            if (y2 == 1) {
                result.add(bp.span);
            }
            y2 = bp.lastY;
            bp = trace[bp.lastT][bp.lastY];
        }
        if (DEBUG) {
            ConditionalSemiMarkovModel.dumpStuff(fty, trace);
        }
        return new Segments(result);
    }

    private static void dumpStuff(double[][] fty, BackPointer[][] trace) {
        DecimalFormat format = new DecimalFormat("####.###");
        System.out.println("t.y\tf(t,y)\tt'.y'\tspan");
        for (int t = 0; t < fty.length; ++t) {
            for (int y = 0; y < 2; ++y) {
                String spanText;
                BackPointer bp = trace[t][y];
                String string2 = spanText = bp == null ? "*NULL*" : bp.span.asString();
                if (bp == null) {
                    bp = new BackPointer(null, -1, -1);
                }
                String marker = bp.onBestPath ? "<==" : "";
                System.out.println(t + "." + y + "\t" + format.format(fty[t][y]) + "\t" + bp.lastT + "." + bp.lastY + "  '" + spanText + "' " + marker);
            }
        }
    }

    private static double score(TextLabels labels, int lastY, int y, int lastT, int t, Span segment, SpanFeatureExtractor fe, BinaryClassifier cls) {
        if (y == 0) {
            return 0.0;
        }
        String prevLabel = lastY == 1 ? "POS" : "NEG";
        InstanceFromSequence segmentInstance = new InstanceFromSequence(fe.extractInstance(labels, segment), new String[]{prevLabel});
        if (DEBUG) {
            log.debug("score: " + cls.score(segmentInstance) + "\t" + segment);
        }
        return cls.score(segmentInstance);
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class Segments {
        private Set<Span> spanSet;

        public Segments(Set<Span> spanSet) {
            this.spanSet = spanSet;
        }

        public Iterator<Span> iterator() {
            return this.spanSet.iterator();
        }

        public boolean contains(Span span) {
            return this.spanSet.contains(span);
        }

        public String toString() {
            return "[Segments: " + this.spanSet.toString() + "]";
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class CSMMWithDictionarySpanFE
    extends CSMMSpanFE {
        static final long serialVersionUID = 20080306L;
        boolean addTrainingSegsToDictionary;
        boolean useCrossVal;
        SoftDictionary dictionary;
        StringDistance[] distances;
        Feature[] features;

        public CSMMWithDictionarySpanFE(String dictionaryFile, String distanceNames) {
            this(dictionaryFile, distanceNames, false, false);
        }

        public CSMMWithDictionarySpanFE(String dictionaryFile, String distanceNames, boolean addTraining, boolean useCrossValArg) {
            try {
                this.addTrainingSegsToDictionary = addTraining;
                this.useCrossVal = useCrossValArg;
                this.dictionary = new SoftDictionary();
                this.distances = DistanceLearnerFactory.buildArray(distanceNames);
                if (dictionaryFile.length() > 0) {
                    this.dictionary.load(new File(dictionaryFile));
                    this.trainDistances();
                }
                this.features = new Feature[this.distances.length];
                for (int d = 0; d < this.distances.length; ++d) {
                    this.features[d] = new Feature(this.distances[d].toString());
                }
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }

        public void trainDistances() {
            for (int d = 0; d < this.distances.length; ++d) {
                if (!(this.distances[d] instanceof StringDistanceLearner)) continue;
                this.distances[d] = this.dictionary.getTeacher().train((StringDistanceLearner)((Object)this.distances[d]));
            }
        }

        public void train(Iterator<AnnotationExample> iter) {
            if (!this.addTrainingSegsToDictionary) {
                return;
            }
            int numAdded = 0;
            while (iter.hasNext()) {
                AnnotationExample example = iter.next();
                String id = example.getDocumentSpan().getDocumentId();
                String type = example.getInputType();
                Iterator<Span> i = example.getLabels().instanceIterator(type, id);
                while (i.hasNext()) {
                    String thisSeg = i.next().asString();
                    ++numAdded;
                    this.dictionary.put(id, thisSeg, null);
                }
            }
            this.trainDistances();
        }

        @Override
        public void extractFeatures(TextLabels labels, Span span) {
            super.extractFeatures(labels, span);
            BasicStringWrapper spanString = new BasicStringWrapper(span.asString());
            String id = this.addTrainingSegsToDictionary && this.useCrossVal ? span.getDocumentId() : null;
            Object closestMatch = this.dictionary.lookup(id, spanString);
            if (closestMatch != null) {
                for (int d = 0; d < this.distances.length; ++d) {
                    double score = this.distances[d].score(spanString, (StringWrapper)closestMatch);
                    if (score == 0.0) continue;
                    this.instance.addNumeric(this.features[d], score);
                }
            }
        }
    }

    public static class CSMMSpanFE
    extends SampleFE.ExtractionFE {
        static final long serialVersionUID = 20080306L;

        public CSMMSpanFE() {
        }

        public CSMMSpanFE(int windowSize) {
            super(windowSize);
        }

        public CSMMSpanFE(String mixupFile) {
            this.setRequiredAnnotation(mixupFile, mixupFile + ".mixup");
            this.setTokenPropertyFeatures("*");
        }

        public void extractFeatures(Span span) {
            this.extractFeatures(new EmptyLabels(), span);
        }

        public void extractFeatures(TextLabels labels, Span span) {
            super.extractFeatures(labels, span);
            this.from(span).eq().lc().emit();
            if (this.useCharType) {
                this.from(span).eq().charTypes().emit();
            }
            if (this.useCompressedCharType) {
                this.from(span).eq().charTypePattern().emit();
            }
            this.from(span).size().emit();
            this.from(span).exactSize().emit();
            this.from(span).token(0).eq().lc().emit();
            this.from(span).token(-1).eq().lc().emit();
            if (this.useCharType) {
                this.from(span).token(0).eq().charTypes().lc().emit();
                this.from(span).token(-1).eq().charTypes().lc().emit();
            }
            if (this.useCompressedCharType) {
                this.from(span).token(0).eq().charTypePattern().lc().emit();
                this.from(span).token(-1).eq().charTypePattern().lc().emit();
            }
            for (int i = 0; i < this.tokenPropertyFeatures.length; ++i) {
                String p = this.tokenPropertyFeatures[i];
                this.from(span).token(0).prop(p).emit();
                this.from(span).token(-1).prop(p).emit();
                this.from(span).subSpan(1, span.size() - 2).tokens().prop(p).emit();
            }
        }
    }

    private static class BackPointer {
        public Span span;
        public int lastT;
        public int lastY;
        public boolean onBestPath;

        public BackPointer(Span span, int lastT, int lastY) {
            this.span = span;
            this.lastT = lastT;
            this.lastY = lastY;
            this.onBestPath = false;
        }
    }

    public static class CSMMAnnotator
    extends AbstractAnnotator
    implements Visible,
    ExtractorAnnotator,
    Serializable {
        private static final long serialVersionUID = 20080306L;
        private SpanFeatureExtractor fe;
        private BinaryClassifier classifier;
        private String annotationType;
        private int maxSegSize;

        public Viewer toGUI() {
            ComponentViewer v = new ComponentViewer(){
                static final long serialVersionUID = 20080306L;

                public JComponent componentFor(Object o) {
                    CSMMAnnotator ann = (CSMMAnnotator)o;
                    JPanel mainPanel = new JPanel();
                    mainPanel.setLayout(new BorderLayout());
                    mainPanel.add((Component)new JLabel("CSMM: segsize " + CSMMAnnotator.this.maxSegSize), "North");
                    SmartVanillaViewer subView = new SmartVanillaViewer(ann.classifier);
                    subView.setSuperView(this);
                    mainPanel.add((Component)subView, "South");
                    mainPanel.setBorder(new TitledBorder("Conditional Semi-Markov-Model"));
                    return new JScrollPane(mainPanel);
                }
            };
            v.setContent(this);
            return v;
        }

        public CSMMAnnotator(SpanFeatureExtractor fe, BinaryClassifier classifier, String annotationType, int maxSegSize) {
            this.fe = fe;
            this.classifier = classifier;
            this.annotationType = annotationType;
            this.maxSegSize = maxSegSize;
        }

        public String getSpanType() {
            return this.annotationType;
        }

        public void doAnnotate(MonotonicTextLabels labels) {
            ProgressCounter pc = new ProgressCounter("annotating", "document", labels.getTextBase().size());
            Iterator<Span> i = labels.getTextBase().documentSpanIterator();
            while (i.hasNext()) {
                Span doc = i.next();
                Segments viterbi = ConditionalSemiMarkovModel.bestSegments(doc, labels, this.fe, this.classifier, this.maxSegSize);
                Iterator<Span> j = viterbi.iterator();
                while (j.hasNext()) {
                    Span span = j.next();
                    labels.addToType(span, this.annotationType);
                }
                pc.progress();
            }
            pc.finished();
        }

        public String explainAnnotation(TextLabels labels, Span documentSpan) {
            return "not implemented";
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class CSMMLearner
    extends AnnotatorLearner {
        private SpanFeatureExtractor fe;
        private OnlineBinaryClassifierLearner classifierLearner;
        private int epochs;
        private int maxSegmentSize = 5;
        private Iterator<Span> documentLooper;
        private List<AnnotationExample> exampleList;
        private String annotationType;

        public CSMMLearner() {
            this(new CSMMSpanFE(), new VotedPerceptron(), 5, 5, "");
        }

        public CSMMLearner(int epochs) {
            this(new CSMMSpanFE(), new VotedPerceptron(), epochs, 5, "");
        }

        public CSMMLearner(int epochs, int maxSegmentSize) {
            this(new CSMMSpanFE(), new VotedPerceptron(), epochs, maxSegmentSize, "");
        }

        public CSMMLearner(String annotation) {
            this(new CSMMSpanFE(), new VotedPerceptron(), 5, 5, annotation);
        }

        public CSMMLearner(String dictionaryFile, String distanceNames, int maxSegmentSize) {
            this(dictionaryFile, distanceNames, 5, maxSegmentSize);
        }

        public CSMMLearner(String dictionaryFile, String distanceNames, int epoch, int maxSegmentSize) {
            this(dictionaryFile, distanceNames, epoch, maxSegmentSize, "");
        }

        public CSMMLearner(String dictionaryFile, String distanceNames, int epoch, int maxSegmentSize, String mixFile) {
            this(dictionaryFile, distanceNames, epoch, maxSegmentSize, false, mixFile);
        }

        public CSMMLearner(String dictionaryFile, String distanceNames, int epochSize, int maxSegmentSize, boolean addTraining, boolean doCrossVal, String mixFile) {
            this(new CSMMWithDictionarySpanFE(dictionaryFile, distanceNames, addTraining, doCrossVal), new VotedPerceptron(), epochSize, maxSegmentSize, mixFile);
        }

        public CSMMLearner(String dictionaryFile, String distanceNames, int epoch, int maxSegmentSize, boolean addTraining, String mixFile) {
            this(dictionaryFile, distanceNames, epoch, maxSegmentSize, addTraining, true, mixFile);
        }

        public CSMMLearner(SpanFeatureExtractor fe, OnlineBinaryClassifierLearner classifierLearner, int epochs, int maxSegSz, String annotation) {
            this.fe = fe;
            if (annotation.length() > 0) {
                System.out.println("Reading annotations");
                ((CSMMSpanFE)fe).setRequiredAnnotation(annotation, annotation + ".mixup");
                ((CSMMSpanFE)fe).setTokenPropertyFeatures("*");
            }
            this.classifierLearner = classifierLearner;
            this.epochs = epochs;
            this.maxSegmentSize = maxSegSz;
            this.reset();
        }

        public OnlineBinaryClassifierLearner getLearner() {
            return this.classifierLearner;
        }

        public void setLearner(OnlineBinaryClassifierLearner newLearner) {
            this.classifierLearner = newLearner;
        }

        public int getEpochs() {
            return this.epochs;
        }

        public void setEpochs(int newEpochs) {
            this.epochs = newEpochs;
        }

        public int getMaxSegmentSize() {
            return this.maxSegmentSize;
        }

        public void setMaxSegmentSize(int newMaxSize) {
            this.maxSegmentSize = newMaxSize;
        }

        @Override
        public SpanFeatureExtractor getSpanFeatureExtractor() {
            return this.fe;
        }

        @Override
        public void setSpanFeatureExtractor(SpanFeatureExtractor fe) {
            this.fe = fe;
        }

        @Override
        public void reset() {
            this.exampleList = new ArrayList<AnnotationExample>();
        }

        @Override
        public void setDocumentPool(Iterator<Span> documentLooper) {
            this.documentLooper = documentLooper;
        }

        @Override
        public boolean hasNextQuery() {
            return this.documentLooper.hasNext();
        }

        @Override
        public Span nextQuery() {
            return this.documentLooper.next();
        }

        @Override
        public void setAnswer(AnnotationExample answeredQuery) {
            this.exampleList.add(answeredQuery);
        }

        @Override
        public void setAnnotationType(String s) {
            this.annotationType = s;
        }

        @Override
        public String getAnnotationType() {
            return this.annotationType;
        }

        @Override
        public Annotator getAnnotator() {
            this.classifierLearner.reset();
            log.debug("processing " + this.exampleList.size() + " examples for " + this.epochs + " epochs");
            ProgressCounter pc = new ProgressCounter("training CSMM", "document", this.epochs * this.exampleList.size());
            if (this.fe.getClass().getName().endsWith("CSMMWithDictionarySpanFE")) {
                ((CSMMWithDictionarySpanFE)this.fe).train(this.exampleList.iterator());
            }
            for (int i = 0; i < this.epochs; ++i) {
                for (AnnotationExample example : this.exampleList) {
                    Span span;
                    Span doc = example.getDocumentSpan();
                    if (DEBUG) {
                        log.debug("updating from " + doc);
                    }
                    Segments viterbi = ConditionalSemiMarkovModel.bestSegments(doc, example.getLabels(), this.fe, this.classifierLearner.getBinaryClassifier(), this.maxSegmentSize);
                    if (DEBUG) {
                        log.debug("viterbi solution:\n" + viterbi);
                    }
                    Segments correct = this.correctSegments(example);
                    if (DEBUG) {
                        log.debug("correct spans:\n" + correct);
                    }
                    Span previousSpan = null;
                    Iterator<Span> k = viterbi.iterator();
                    while (k.hasNext()) {
                        span = k.next();
                        if (!correct.contains(span)) {
                            if (DEBUG) {
                                log.debug("false pos: " + span);
                            }
                            this.classifierLearner.addExample(this.exampleFor(example, span, previousSpan, -1.0));
                        }
                        previousSpan = span;
                    }
                    previousSpan = null;
                    k = correct.iterator();
                    while (k.hasNext()) {
                        span = k.next();
                        if (!viterbi.contains(span)) {
                            if (DEBUG) {
                                log.debug("false neg: " + span);
                            }
                            this.classifierLearner.addExample(this.exampleFor(example, span, previousSpan, 1.0));
                        }
                        previousSpan = span;
                    }
                    pc.progress();
                }
                if (DEBUG) {
                    new ViewerFrame("classifier after epoch " + i, new SmartVanillaViewer(this.classifierLearner.getBinaryClassifier()));
                }
                pc.finished();
            }
            return new CSMMAnnotator(this.fe, this.classifierLearner.getBinaryClassifier(), this.annotationType, this.maxSegmentSize);
        }

        private Example exampleFor(AnnotationExample example, Span span, Span prevSpan, double numberLabel) {
            Instance instance = this.fe.extractInstance(example.getLabels(), span);
            String prevLabel = prevSpan != null && prevSpan.getRightBoundary().equals(span.getLeftBoundary()) ? "POS" : "NEG";
            InstanceFromSequence instanceFromSeq = new InstanceFromSequence(instance, new String[]{prevLabel});
            if (DEBUG) {
                log.debug("example for " + span + ": " + instanceFromSeq);
            }
            return new Example(instanceFromSeq, ClassLabel.binaryLabel(numberLabel));
        }

        private Segments correctSegments(AnnotationExample example) {
            TreeSet<Span> set = new TreeSet<Span>();
            String id = example.getDocumentSpan().getDocumentId();
            String type = example.getInputType();
            Iterator<Span> i = example.getLabels().instanceIterator(type, id);
            while (i.hasNext()) {
                set.add(i.next());
            }
            return new Segments(set);
        }
    }
}

