/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
import edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner;
import edu.cmu.minorthird.classify.sequential.CMM;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceConstants;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.classify.sequential.SequenceUtils;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import iitb.CRF.CRF;
import iitb.CRF.DataIter;
import iitb.CRF.DataSequence;
import iitb.Model.EdgeFeatures;
import iitb.Model.EdgeLinearHistFeatures;
import iitb.Model.FeatureGenImpl;
import iitb.Model.FeatureImpl;
import iitb.Model.FeatureTypes;
import iitb.Model.StartFeatures;
import java.awt.BorderLayout;
import java.awt.Component;
import java.io.Serializable;
import java.util.Iterator;
import java.util.Properties;
import java.util.StringTokenizer;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;

public class CRFLearner
implements BatchSequenceClassifierLearner,
SequenceConstants,
SequenceClassifier,
Visible,
Serializable {
    private static final long serialVersionUID = 1L;
    int histsize = 1;
    ExampleSchema schema;
    CRF crfModel;
    Properties defaults;
    Properties options;
    private static final boolean CONVERT_TO_MINORTHIRD_HYPERPLANE = true;
    public String maxItersHelp = new String("Number of training iterations over the training set; default set to 100");
    public String useHighPrecisionArithmeticHelp = new String("Make the learner use high precision arithmetic.");
    FeatureGenImpl featureGen;
    SequenceClassifier cmmClassifier = null;
    double[] crfWs;

    public CRFLearner() {
        this.defaults = new Properties();
        this.defaults.setProperty("modelGraph", "naive");
        this.defaults.setProperty("debugLvl", "1");
        this.defaults.setProperty("trainer", "ll");
        this.options = this.defaults;
    }

    public CRFLearner(String args) {
        this(args, 1);
    }

    public CRFLearner(String args, int histsize) {
        this();
        this.histsize = histsize;
        StringTokenizer argTok = new StringTokenizer(args, " ");
        this.options = new Properties(this.defaults);
        while (argTok.hasMoreTokens()) {
            this.options.setProperty(argTok.nextToken(), argTok.nextToken());
        }
    }

    public CRFLearner(String[] args) {
        this();
        this.options = new Properties(this.defaults);
        for (int i = 0; i < args.length - 1; i += 2) {
            this.options.setProperty(args[i], args[i + 1]);
        }
    }

    public void setLogSpaceOption() {
        this.options.setProperty("trainer", "ll");
    }

    public void removeLogSpaceOption() {
        this.options.remove("trainer");
    }

    public void setSchema(ExampleSchema schema) {
        this.schema = schema;
    }

    public ExampleSchema getSchema() {
        return this.schema;
    }

    public int getHistorySize() {
        return this.histsize;
    }

    public void setMaxIters(int newMaxIters) {
        this.defaults.setProperty("maxIters", Integer.toString(newMaxIters));
    }

    public int getMaxIters() {
        String maxIters = this.defaults.getProperty("maxIters");
        if (maxIters != null) {
            return Integer.parseInt(maxIters);
        }
        return 100;
    }

    public String getMaxItersHelp() {
        return this.maxItersHelp;
    }

    public boolean getUseHighPrecisionArithmetic() {
        String value = this.defaults.getProperty("trainer");
        return value != null && value.equals("ll");
    }

    public void setUseHighPrecisionArithmetic(boolean newUseHighPrecisionArithmetic) {
        if (newUseHighPrecisionArithmetic) {
            this.setLogSpaceOption();
        } else {
            this.removeLogSpaceOption();
        }
    }

    public String getUseHighPrecisionArithmeticHelp() {
        return this.useHighPrecisionArithmeticHelp;
    }

    DataIter allocModel(SequenceDataset dataset) throws Exception {
        this.featureGen = new MTFeatureGenImpl(this.options.getProperty("modelGraph"), this.schema.getNumberOfClasses(), this.schema.validClassNames());
        System.out.println("Property: " + this.options.getProperty("trainer"));
        this.crfModel = new CRF(this.featureGen.numStates(), this.histsize, this.featureGen, this.options);
        return new CRFDataIter(dataset);
    }

    public SequenceClassifier batchTrain(SequenceDataset dataset) {
        try {
            this.schema = dataset.getSchema();
            return this.doTrain(this.allocModel(dataset));
        }
        catch (Exception e) {
            e.printStackTrace();
            throw new IllegalStateException("error in CRF: " + e);
        }
    }

    SequenceClassifier doTrain(DataIter trainData) throws Exception {
        this.featureGen.train(trainData);
        ProgressCounter pc = new ProgressCounter("training CRF", "iteration");
        this.crfWs = this.crfModel.train(trainData);
        pc.finished();
        return this.toMinorthirdClassifier();
    }

    private SequenceClassifier toMinorthirdClassifier() {
        int numClasses = this.schema.getNumberOfClasses();
        Classifier[] w_t = new Hyperplane[numClasses];
        for (int i = 0; i < numClasses; ++i) {
            w_t[i] = new Hyperplane();
            w_t[i].setBias(0.0);
        }
        for (int fIndex = 0; fIndex < this.crfWs.length; ++fIndex) {
            Feature feature = (Feature)this.featureGen.featureIdentifier((int)fIndex).name;
            int classIndex = this.featureGen.featureIdentifier((int)fIndex).stateId;
            ((Hyperplane)w_t[classIndex]).increment(feature, this.crfWs[fIndex]);
        }
        return new CMM(new SequenceUtils.MultiClassClassifier(this.schema, w_t), this.histsize, this.schema);
    }

    public ClassLabel[] classification(Instance[] sequence) {
        TestDataSequenceC seq = new TestDataSequenceC(sequence);
        this.crfModel.apply(seq);
        this.featureGen.mapStatesToLabels(seq);
        return seq.getLabels();
    }

    public String explain(Instance[] sequence) {
        if (this.cmmClassifier == null) {
            this.cmmClassifier = this.toMinorthirdClassifier();
        }
        return this.cmmClassifier.explain(sequence);
    }

    public Explanation getExplanation(Instance[] sequence) {
        if (this.cmmClassifier == null) {
            this.cmmClassifier = this.toMinorthirdClassifier();
        }
        Explanation.Node top = new Explanation.Node("CRF Explanation");
        Explanation.Node cmmEx = this.cmmClassifier.getExplanation(sequence).getTopNode();
        if (cmmEx == null) {
            cmmEx = new Explanation.Node(this.cmmClassifier.explain(sequence));
        }
        top.add(cmmEx);
        Explanation ex = new Explanation(top);
        return ex;
    }

    public Viewer toGUI() {
        ComponentViewer v = new ComponentViewer(){
            static final long serialVersionUID = 20080207L;

            public JComponent componentFor(Object o) {
                JPanel mainPanel = new JPanel();
                mainPanel.setLayout(new BorderLayout());
                mainPanel.add((Component)new JLabel("CRFLearner: historySize=1"), "North");
                SmartVanillaViewer subView = new SmartVanillaViewer(CRFLearner.this.toMinorthirdClassifier());
                subView.setSuperView(this);
                mainPanel.add((Component)subView, "South");
                mainPanel.setBorder(new TitledBorder("CRFLearner"));
                return new JScrollPane(mainPanel);
            }
        };
        v.setContent(this);
        return v;
    }

    public class MTFeatureGenImpl
    extends FeatureGenImpl {
        static final long serialVersionUID = 20080207L;

        public MTFeatureGenImpl(String modelSpecs, int numLabels, String[] labelNames) throws Exception {
            super(modelSpecs, numLabels, false);
            Object[] features = new Feature[labelNames.length];
            for (int i = 0; i < labelNames.length; ++i) {
                features[i] = new Feature(new String[]{"previousLabel", "1", labelNames[i]});
            }
            this.addFeature(new EdgeFeatures(this, features));
            this.addFeature(new StartFeatures(this, new Feature(new String[]{"previousLabel", "1", "null"})));
            if (CRFLearner.this.histsize > 1) {
                Object[][] histFeatures = new Feature[CRFLearner.this.histsize][labelNames.length];
                for (int k = 1; k < CRFLearner.this.histsize; ++k) {
                    for (int i = 0; i < labelNames.length; ++i) {
                        histFeatures[k][i] = new Feature(new String[]{"previousLabel", Integer.toString(k + 1), labelNames[i]});
                    }
                }
                this.addFeature(new EdgeLinearHistFeatures(this, histFeatures, CRFLearner.this.histsize));
            }
            this.addFeature(new MTFeatureTypes(this));
        }
    }

    class MTFeatureTypes
    extends FeatureTypes {
        static final long serialVersionUID = 20080207L;
        Iterator<Feature> featureLooper;
        Feature feature;
        int numStates;
        Instance example;
        int stateId;

        MTFeatureTypes(FeatureGenImpl gen) {
            super(gen);
            this.numStates = this.model.numStates();
        }

        void advance() {
            ++this.stateId;
            if (this.stateId < this.numStates) {
                return;
            }
            if (this.featureLooper.hasNext()) {
                this.feature = this.featureLooper.next();
                this.stateId = 0;
            } else {
                this.feature = null;
                this.featureLooper = null;
            }
        }

        boolean startScan() {
            this.stateId = -1;
            if (!this.featureLooper.hasNext()) {
                this.feature = null;
                return false;
            }
            this.feature = this.featureLooper.next();
            this.advance();
            return true;
        }

        public boolean startScanFeaturesAt(DataSequence data, int prevPos, int pos) {
            this.example = (Instance)data.x(pos);
            this.featureLooper = this.example.featureIterator();
            return this.startScan();
        }

        public boolean hasNext() {
            return this.stateId < this.numStates && this.feature != null;
        }

        public void next(FeatureImpl f) {
            f.yend = this.stateId;
            f.ystart = -1;
            f.val = (float)this.example.getWeight(this.feature);
            this.setFeatureIdentifier(this.feature.getID() * this.numStates + this.stateId, this.stateId, this.feature, f);
            this.advance();
        }
    }

    class CRFDataIter
    implements DataIter {
        Iterator<Example[]> iter;
        SequenceDataset dataset;
        TrainDataSequenceC sequence;
        int dataSize;

        CRFDataIter(SequenceDataset ds) {
            this.dataset = ds;
            this.dataSize = ds.size();
            this.sequence = new TrainDataSequenceC();
        }

        public void startScan() {
            this.iter = this.dataset.sequenceIterator();
        }

        public boolean hasNext() {
            return this.iter.hasNext();
        }

        public DataSequence next() {
            this.sequence.init(this.iter.next());
            return this.sequence;
        }
    }

    class TestDataSequenceC
    extends DataSequenceC {
        TestDataSequenceC(Instance[] tokens) {
            this.init(tokens);
        }

        ClassLabel[] getLabels() {
            ClassLabel[] clabels = new ClassLabel[this.sequence.length];
            for (int i = 0; i < this.sequence.length; ++i) {
                clabels[i] = new ClassLabel(CRFLearner.this.schema.getClassName(this.labels[i]));
            }
            return clabels;
        }
    }

    class TrainDataSequenceC
    extends DataSequenceC {
        TrainDataSequenceC() {
        }

        void init(Example[] tokens) {
            super.init(tokens);
            if (tokens != null) {
                for (int i = 0; i < this.sequence.length; ++i) {
                    this.labels[i] = CRFLearner.this.schema.getClassIndex(tokens[i].getLabel().bestClassName());
                }
            }
        }
    }

    class DataSequenceC
    implements DataSequence {
        Instance[] sequence;
        int[] labels;

        DataSequenceC() {
        }

        void init(Instance[] tokens) {
            this.sequence = tokens;
            if (tokens != null && (this.labels == null || tokens.length > this.labels.length)) {
                this.labels = new int[tokens.length];
            }
        }

        public int length() {
            return this.sequence.length;
        }

        public int y(int i) {
            return this.labels[i];
        }

        public Object x(int i) {
            return this.sequence[i];
        }

        public void set_y(int i, int label) {
            this.labels[i] = label;
        }
    }
}

