/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.semisupervised;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.WeightedSet;
import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedClassifier;
import edu.cmu.minorthird.util.MathUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.Controllable;
import edu.cmu.minorthird.util.gui.ControlledViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerControls;
import edu.cmu.minorthird.util.gui.Visible;
import gnu.trove.TObjectDoubleHashMap;
import gnu.trove.TObjectDoubleIterator;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.swing.ButtonGroup;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JRadioButton;
import javax.swing.JScrollPane;
import javax.swing.JTable;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MultinomialClassifier
implements SemiSupervisedClassifier,
Classifier,
Visible,
Serializable {
    static final long serialVersionUID = 20080207L;
    static Logger log = Logger.getLogger(MultinomialClassifier.class);
    private double SCALE;
    private List<String> classNames = new ArrayList<String>();
    private List<Double> classParameters = new ArrayList<Double>();
    private Map<Feature, String> featureModels = new HashMap<Feature, String>();
    private List<WeightedSet<Feature>> featureGivenClassParameters = new ArrayList<WeightedSet<Feature>>();
    private double featurePrior;
    private String unseenModel;

    public MultinomialClassifier() {
        this.featureGivenClassParameters.add(new WeightedSet());
        this.featurePrior = 0.0;
        this.unseenModel = null;
    }

    @Override
    public ClassLabel classification(Instance instance) {
        double[] score = this.score(instance);
        int maxIndex = 0;
        for (int i = 0; i < score.length; ++i) {
            if (!(score[i] > score[maxIndex])) continue;
            maxIndex = i;
        }
        return new ClassLabel(this.classNames.get(maxIndex));
    }

    public double[] score(Instance instance) {
        double total = 0.0;
        Iterator<Feature> j = instance.featureIterator();
        while (j.hasNext()) {
            Feature f = j.next();
            total += instance.getWeight(f);
        }
        double[] score = new double[this.classNames.size()];
        int i = 0;
        while (i < this.classNames.size()) {
            score[i] = 0.0;
            Iterator<Feature> j2 = instance.featureIterator();
            while (j2.hasNext()) {
                Feature f = j2.next();
                double featureCounts = instance.getWeight(f);
                double featureProb = this.featureGivenClassParameters.get(i).getWeight(f);
                String model = this.getFeatureModel(f);
                if (model.equals("Poisson")) {
                    int n = i;
                    score[n] = score[n] + (-featureProb * total / this.SCALE + featureCounts * Math.log(featureProb));
                    continue;
                }
                if (model.equals("Binomial")) {
                    int n = i;
                    score[n] = score[n] + featureCounts * Math.log(featureProb);
                    continue;
                }
                if (model.equals("unseen")) {
                    int n = i;
                    score[n] = score[n] + 0.0;
                    continue;
                }
                System.out.println("error: model " + model + " not found!");
                System.exit(1);
            }
            double classProb = this.classParameters.get(i);
            int n = i++;
            score[n] = score[n] + Math.log(classProb);
        }
        return score;
    }

    @Override
    public String explain(Instance instance) {
        StringBuffer buf = new StringBuffer("");
        Iterator<Feature> j = instance.featureIterator();
        while (j.hasNext()) {
            if (buf.length() > 0) {
                buf.append("\n + ");
                continue;
            }
            buf.append("   ");
        }
        buf.append("\n = " + this.score(instance));
        return buf.toString();
    }

    @Override
    public Explanation getExplanation(Instance instance) {
        Explanation.Node top = new Explanation.Node("MultinomialClassifier Explanation");
        Explanation.Node features = new Explanation.Node("Features");
        Iterator<Feature> j = instance.featureIterator();
        while (j.hasNext()) {
            Feature f = j.next();
            Explanation.Node featureEx = new Explanation.Node(f + "<" + instance.getWeight(f));
            features.add(featureEx);
        }
        Explanation.Node bias = new Explanation.Node("bias");
        features.add(bias);
        top.add(features);
        Explanation.Node score = new Explanation.Node("\n = " + this.score(instance));
        top.add(score);
        Explanation ex = new Explanation(top);
        return ex;
    }

    public void setScale(double value) {
        this.SCALE = value;
    }

    public double getPrior() {
        return this.featurePrior;
    }

    public void setPrior(double pi) {
        this.featurePrior = pi;
    }

    public String getUnseenModel() {
        return this.unseenModel;
    }

    public void setUnseenModel(String str) {
        this.unseenModel = str;
    }

    public double getLogLikelihood(Example example) {
        int idx = -1;
        for (int i = 0; i < this.classNames.size(); ++i) {
            if (!this.classNames.get(i).equals(example.getLabel().bestClassName())) continue;
            idx = i;
            break;
        }
        Instance instance = example.asInstance();
        double loglik = 0.0;
        Iterator<Feature> j = instance.featureIterator();
        while (j.hasNext()) {
            Feature f = j.next();
            double featureCounts = instance.getWeight(f);
            double featureProb = this.featureGivenClassParameters.get(idx).getWeight(f);
            String model = this.getFeatureModel(f);
            if (model.equals("Poisson")) {
                loglik += -featureProb + featureCounts * Math.log(featureProb);
                continue;
            }
            if (model.equals("Binomial")) {
                loglik += featureCounts * Math.log(featureProb);
                continue;
            }
            if (model.equals("unseen")) {
                System.out.println("unseen: " + f);
                continue;
            }
            System.out.println("error: model " + model + " not found!");
            System.exit(1);
        }
        return loglik;
    }

    public void reset() {
        this.classParameters = new ArrayList<Double>();
        this.featureGivenClassParameters = new ArrayList<WeightedSet<Feature>>();
    }

    public boolean isPresent(ClassLabel label) {
        boolean isPresent = false;
        for (int i = 0; i < this.classNames.size(); ++i) {
            if (!this.classNames.get(i).equals(label.bestClassName())) continue;
            isPresent = true;
        }
        return isPresent;
    }

    public void addValidLabel(ClassLabel label) {
        this.classNames.add(label.bestClassName());
    }

    public ClassLabel getLabel(int i) {
        return new ClassLabel(this.classNames.get(i));
    }

    public int indexOf(ClassLabel label) {
        return this.classNames.indexOf(label.bestClassName());
    }

    public void setFeatureGivenClassParameter(Feature f, int j, double probabilityOfOccurrence) {
        try {
            WeightedSet<Feature> wset = this.featureGivenClassParameters.get(j);
            wset.add(f, probabilityOfOccurrence);
            this.featureGivenClassParameters.set(j, wset);
        }
        catch (Exception t) {
            WeightedSet<Feature> wset = null;
            wset = new WeightedSet<Feature>();
            wset.add(f, probabilityOfOccurrence);
            this.featureGivenClassParameters.add(j, wset);
        }
    }

    public void setClassParameter(int j, double probabilityOfOccurrence) {
        this.classParameters.add(j, new Double(probabilityOfOccurrence));
    }

    public void setFeatureModel(Feature feature, String model) {
        this.featureModels.put(feature, model);
    }

    public String getFeatureModel(Feature feature) {
        try {
            String model = this.featureModels.get(feature).toString();
            return model;
        }
        catch (NullPointerException x) {
            return "unseen";
        }
    }

    public Iterator<Feature> featureIterator() {
        TObjectDoubleHashMap map = new TObjectDoubleHashMap();
        for (int i = 0; i < this.classNames.size(); ++i) {
            WeightedSet<Feature> wset = this.featureGivenClassParameters.get(i);
            Iterator<Feature> j = wset.iterator();
            while (j.hasNext()) {
                Feature f = j.next();
                double w = wset.getWeight(f);
                map.put(f, w);
            }
        }
        final TObjectDoubleIterator ti = map.iterator();
        Iterator<Feature> i = new Iterator<Feature>(){

            @Override
            public boolean hasNext() {
                return ti.hasNext();
            }

            @Override
            public Feature next() {
                ti.advance();
                return (Feature)ti.key();
            }

            @Override
            public void remove() {
                ti.remove();
            }
        };
        return i;
    }

    public Object[] keys() {
        TObjectDoubleHashMap map = new TObjectDoubleHashMap();
        for (int i = 0; i < this.classNames.size(); ++i) {
            WeightedSet<Feature> wset = this.featureGivenClassParameters.get(i);
            Iterator<Feature> j = wset.iterator();
            while (j.hasNext()) {
                Feature f = j.next();
                double w = wset.getWeight(f);
                map.put(f, w);
            }
        }
        return map.keys();
    }

    @Override
    public Viewer toGUI() {
        ControlledViewer gui = new ControlledViewer(new MyViewer(), new MultinomialClassifierControls());
        gui.setContent(this);
        return gui;
    }

    public String toString() {
        return null;
    }

    private static class MyViewer
    extends ComponentViewer
    implements Controllable {
        static final long serialVersionUID = 20080207L;
        private MultinomialClassifierControls controls = null;
        private MultinomialClassifier h = null;

        private MyViewer() {
        }

        public void applyControls(ViewerControls controls) {
            this.controls = (MultinomialClassifierControls)controls;
            this.setContent(this.h, true);
            this.revalidate();
        }

        public boolean canReceive(Object o) {
            return o instanceof MultinomialClassifier;
        }

        public JComponent componentFor(Object o) {
            this.h = (MultinomialClassifier)o;
            Object[] keys = this.h.keys();
            Object[][] tableData = new Object[keys.length][this.h.classNames.size() + 1];
            int k = 0;
            Iterator<Feature> i = this.h.featureIterator();
            while (i.hasNext()) {
                Feature f = i.next();
                tableData[k][0] = f;
                for (int l = 0; l < this.h.classNames.size(); ++l) {
                    tableData[k][l + 1] = new Double(((WeightedSet)this.h.featureGivenClassParameters.get(l)).getWeight(f));
                }
                ++k;
            }
            if (this.controls != null) {
                Arrays.sort(tableData, new Comparator<Object[]>(){

                    @Override
                    public int compare(Object[] ra, Object[] rb) {
                        if (MyViewer.this.controls.nameButton.isSelected()) {
                            return ra[0].toString().compareTo(rb[0].toString());
                        }
                        Double da = (Double)ra[1];
                        Double db = (Double)rb[1];
                        if (MyViewer.this.controls.valueButton.isSelected()) {
                            return MathUtil.sign(db - da);
                        }
                        return MathUtil.sign(Math.abs(db) - Math.abs(da));
                    }
                });
            }
            Object[] columnNames = new String[this.h.classNames.size() + 1];
            columnNames[0] = "Feature Name";
            for (int i2 = 0; i2 < this.h.classNames.size(); ++i2) {
                columnNames[i2 + 1] = "Wgt " + (String)this.h.classNames.get(i2);
            }
            JTable table = new JTable(tableData, columnNames);
            this.monitorSelections(table, 0);
            return new JScrollPane(table);
        }
    }

    private static class MultinomialClassifierControls
    extends ViewerControls {
        static final long serialVersionUID = 20080207L;
        private JRadioButton valueButton;
        private JRadioButton nameButton;

        private MultinomialClassifierControls() {
        }

        public void initialize() {
            this.add(new JLabel("Sort by"));
            ButtonGroup group = new ButtonGroup();
            this.nameButton = this.addButton("name", group, true);
            this.valueButton = this.addButton("weight", group, false);
        }

        private JRadioButton addButton(String s, ButtonGroup group, boolean selected) {
            JRadioButton button = new JRadioButton(s, selected);
            group.add(button);
            this.add(button);
            button.addActionListener(this);
            return button;
        }
    }
}

