/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.sequential.HMM;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceConstants;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.classify.sequential.Viterbi;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.Iterator;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;

public class MultiClassHMMClassifier
implements SequenceClassifier,
SequenceConstants,
Visible,
Serializable {
    static final long serialVersionUID = 20080207L;
    private ExampleSchema schema;
    public HMM hmmModel;
    private int numStates;
    private int numEmissions;
    String[] state;
    double[][] aprob;
    double[][] eprob;
    ArrayList<String[]> training_seq;
    private Hashtable<String, String> dict_tok;

    public MultiClassHMMClassifier(SequenceDataset dataset) {
        this.schema = dataset.getSchema();
        this.numStates = this.schema.getNumberOfClasses();
        this.state = new String[this.numStates];
        for (int i = 0; i < this.schema.getNumberOfClasses(); ++i) {
            this.state[i] = this.schema.getClassName(i);
        }
        this.dict_tok = new Hashtable();
        this.training_seq = new ArrayList();
        Iterator<Example[]> i = dataset.sequenceIterator();
        while (i.hasNext()) {
            Example[] sequence = i.next();
            String[] tok = new String[sequence.length];
            for (int j = 0; j < sequence.length; ++j) {
                String token;
                int size = sequence[j].numericFeatureIterator().next().size();
                tok[j] = token = sequence[j].numericFeatureIterator().next().getPart(size - 1);
                if (this.dict_tok.containsKey(token)) {
                    int cnt = Integer.parseInt(this.dict_tok.get(token));
                    this.dict_tok.put(token, String.valueOf(++cnt));
                    continue;
                }
                this.dict_tok.put(token, "1");
            }
            this.training_seq.add(tok);
        }
        this.dict_tok.put("UNSEEN", "1");
        this.numEmissions = this.dict_tok.size();
        this.aprob = new double[this.numStates][this.numStates];
        this.eprob = new double[this.numStates][this.numEmissions];
        this.hmmModel = new HMM(this.state, this.aprob, this.dict_tok, this.eprob);
    }

    public void baumwelch(double threshold) {
        ArrayList<String[]> training_data = new ArrayList<String[]>(this.training_seq.size());
        for (int i = 0; i < this.training_seq.size(); ++i) {
            training_data.add(this.hmmModel.convert_Ob_seq(this.training_seq.get(i)));
        }
        this.hmmModel = HMM.baumwelch(training_data, this.state, this.dict_tok, threshold);
    }

    public ClassLabel[] classification(Instance[] sequence) {
        ClassLabel[] label = new ClassLabel[sequence.length];
        String[] ob_seq = new String[sequence.length];
        for (int i = 0; i < sequence.length; ++i) {
            int size = sequence[i].numericFeatureIterator().next().size();
            ob_seq[i] = sequence[i].numericFeatureIterator().next().getPart(size - 1);
            System.out.println("ob_seq[" + i + "] is " + ob_seq[i]);
        }
        String[] seq = this.hmmModel.convert_Ob_seq(ob_seq);
        Viterbi vit = new Viterbi(this.hmmModel, seq);
        String[] tag_seq = vit.getPath();
        for (int i = 0; i < tag_seq.length; ++i) {
            label[i] = new ClassLabel(tag_seq[i]);
            System.out.println("tag_seq[" + i + "] is " + tag_seq[i]);
        }
        return label;
    }

    public String explain(Instance[] instance) {
        StringBuffer buf = new StringBuffer("");
        for (int i = 0; i < this.numStates; ++i) {
            buf.append("Hyperplane for class " + this.schema.getClassName(i) + ":\n");
            buf.append("\n");
        }
        return buf.toString();
    }

    public Explanation getExplanation(Instance[] instance) {
        Explanation.Node top = new Explanation.Node("MultiClassHMM Explanation");
        for (int i = 0; i < this.numStates; ++i) {
            Explanation.Node classEx = new Explanation.Node("Hyperplane for class " + this.schema.getClassName(i) + ":\n");
            top.add(classEx);
        }
        Explanation ex = new Explanation(top);
        return ex;
    }

    public Viewer toGUI() {
        ComponentViewer gui = new ComponentViewer(){
            static final long serialVersionUID = 20080207L;

            public JComponent componentFor(Object o) {
                MultiClassHMMClassifier c = (MultiClassHMMClassifier)o;
                JPanel main = new JPanel();
                for (int i = 0; i < MultiClassHMMClassifier.this.numStates; ++i) {
                    JPanel classPanel = new JPanel();
                    classPanel.setBorder(new TitledBorder("Class " + c.schema.getClassName(i)));
                    main.add(classPanel);
                }
                return new JScrollPane(main);
            }
        };
        gui.setContent(this);
        return gui;
    }

    public String toString() {
        return "[MultiClassHMMClassifier:";
    }
}

