/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.WeightedSet;
import edu.cmu.minorthird.classify.algorithms.random.Arithmetic;
import edu.cmu.minorthird.classify.algorithms.random.Estimate;
import edu.cmu.minorthird.util.MathUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.Controllable;
import edu.cmu.minorthird.util.gui.ControlledViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerControls;
import edu.cmu.minorthird.util.gui.Visible;
import gnu.trove.TObjectDoubleHashMap;
import gnu.trove.TObjectDoubleIterator;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import javax.swing.ButtonGroup;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JRadioButton;
import javax.swing.JScrollPane;
import javax.swing.JTable;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MultinomialClassifier
implements Classifier,
Visible,
Serializable {
    static final long serialVersionUID = 20080128L;
    private double SCALE;
    private List<String> classNames = new ArrayList<String>();
    private List<Double> classParameters = new ArrayList<Double>();
    private Map<Feature, String> featureModels = new HashMap<Feature, String>();
    private List<Object> featureGivenClassParameters = new ArrayList<Object>();
    private double featurePrior;
    private String unseenModel;

    public MultinomialClassifier() {
        this.featureGivenClassParameters.add(new WeightedSet());
        this.featurePrior = 0.0;
        this.unseenModel = null;
    }

    @Override
    public ClassLabel classification(Instance instance) {
        double[] score = this.score(instance);
        int maxIndex = 0;
        for (int i = 0; i < score.length; ++i) {
            if (!(score[i] >= score[maxIndex])) continue;
            maxIndex = i;
        }
        return new ClassLabel(this.classNames.get(maxIndex));
    }

    public double[] score(Instance instance) {
        double exampleWeight = 0.0;
        Iterator<Feature> j = instance.featureIterator();
        while (j.hasNext()) {
            Feature f = j.next();
            exampleWeight += instance.getWeight(f);
        }
        double[] score = new double[this.classNames.size()];
        for (int i = 0; i < this.classNames.size(); ++i) {
            double classProb = this.classParameters.get(i);
            score[i] = Math.log(classProb);
        }
        Iterator<Feature> j2 = instance.featureIterator();
        while (j2.hasNext()) {
            Feature f = j2.next();
            double featureCounts = instance.getWeight(f);
            for (int i = 0; i < this.classNames.size(); ++i) {
                double lambda;
                SortedMap<String, Double> pms;
                String parameterization;
                Estimate featureProb = (Estimate)((Map)this.featureGivenClassParameters.get(i)).get(f);
                String model = "";
                try {
                    model = featureProb.getModel();
                }
                catch (NullPointerException e) {
                    model = "unseen";
                }
                if (model.equals("Poisson")) {
                    parameterization = featureProb.getParameterization();
                    if (parameterization.equals("weighted-lambda")) {
                        pms = featureProb.getPms();
                        lambda = (Double)pms.get("lambda");
                        int n = i;
                        score[n] = score[n] + (-lambda * exampleWeight / this.SCALE + featureCounts * Math.log(lambda));
                        continue;
                    }
                    if (!parameterization.equals("lambda")) continue;
                    pms = featureProb.getPms();
                    lambda = (Double)pms.get("lambda");
                    int n = i;
                    score[n] = score[n] + (-lambda * exampleWeight / this.SCALE + featureCounts * Math.log(lambda));
                    continue;
                }
                if (model.equals("Naive-Bayes")) {
                    double mean;
                    parameterization = featureProb.getParameterization();
                    if (parameterization.equals("weighted-mean")) {
                        pms = featureProb.getPms();
                        mean = (Double)pms.get("mean");
                        int n = i;
                        score[n] = score[n] + featureCounts * Math.log(mean);
                        continue;
                    }
                    if (!parameterization.equals("mean")) continue;
                    pms = featureProb.getPms();
                    mean = (Double)pms.get("mean");
                    int n = i;
                    score[n] = score[n] + featureCounts * Math.log(mean);
                    continue;
                }
                if (model.equals("Negative-Binomial")) {
                    parameterization = featureProb.getParameterization();
                    if (!parameterization.equals("mu/delta")) continue;
                    pms = featureProb.getPms();
                    int n = i;
                    score[n] = score[n] + this.logProbNegativeBinomialMuDelta(featureCounts, exampleWeight / this.SCALE, pms);
                    continue;
                }
                if (model.equals("Binomial")) {
                    parameterization = featureProb.getParameterization();
                    if (parameterization.equals("p/N")) {
                        pms = featureProb.getPms();
                        int n = i;
                        score[n] = score[n] + this.logProbBinomialPN(featureCounts, exampleWeight / this.SCALE, pms);
                        continue;
                    }
                    if (!parameterization.equals("mu/delta")) continue;
                    pms = featureProb.getPms();
                    int n = i;
                    score[n] = score[n] + this.logProbBinomialMuDelta(featureCounts, exampleWeight / this.SCALE, pms);
                    continue;
                }
                if (model.equals("Dirichlet-Poisson MCMC")) {
                    parameterization = featureProb.getParameterization();
                    if (parameterization.equals("weighted-lambda")) {
                        pms = featureProb.getPms();
                        lambda = (Double)pms.get("lambda");
                        int n = i;
                        score[n] = score[n] + (-lambda * exampleWeight / this.SCALE + featureCounts * Math.log(lambda));
                        continue;
                    }
                    if (!parameterization.equals("lambda")) continue;
                    pms = featureProb.getPms();
                    lambda = (Double)pms.get("lambda");
                    int n = i;
                    score[n] = score[n] + (-lambda * exampleWeight / this.SCALE + featureCounts * Math.log(lambda));
                    continue;
                }
                if (model.equals("unseen")) {
                    int n = i;
                    score[n] = score[n] + 0.0;
                    continue;
                }
                System.out.println("error: model " + model + " not found!");
                System.exit(1);
            }
        }
        return score;
    }

    private double logProbNegativeBinomialMuDelta(double x, double w, SortedMap<String, Double> pms) {
        double logProb;
        try {
            double m = (Double)pms.get("mu");
            double d = (Double)pms.get("delta");
            logProb = d == 0.0 ? x * Math.log(m) - w * m : Arithmetic.logGamma(x + m / d) - Arithmetic.logGamma(m / d) + x * Math.log(d) - x * Math.log(1.0 + w * d);
        }
        catch (Exception e) {
            logProb = 0.0;
        }
        return logProb;
    }

    private double logProbBinomialPN(double x, double w, SortedMap<String, Double> pms) {
        double logProb = 0.0;
        try {
            double p = (Double)pms.get("p");
            double N = (Double)pms.get("N");
            logProb = N == 0.0 ? x * Math.log(p) - w * p : Arithmetic.logFactorial((int)N) - Arithmetic.logFactorial((int)N - (int)x) + x * Math.log(p) + (N - x) * Math.log(1.0 - p);
        }
        catch (Exception e) {
            logProb = 0.0;
        }
        return logProb;
    }

    private double logProbBinomialMuDelta(double x, double w, SortedMap<String, Double> pms) {
        double logProb = 0.0;
        try {
            double m = (Double)pms.get("mu");
            double d = (Double)pms.get("delta");
            if (d == 0.0) {
                logProb = x * Math.log(m) - w * m;
            } else {
                double N = Math.round(Math.max(m / d, x));
                double p = Math.min(Math.max(1.0E-7, w * d), 0.9999999);
                logProb = Arithmetic.logGamma(N + 1.0) - Arithmetic.logGamma(N - x + 1.0) + x * Math.log(d) - x * Math.log(1.0 - p) + N * Math.log(1.0 - p);
            }
        }
        catch (Exception e) {
            logProb = 0.0;
        }
        return logProb;
    }

    @Override
    public String explain(Instance instance) {
        StringBuffer buf = new StringBuffer("");
        Iterator<Feature> j = instance.featureIterator();
        while (j.hasNext()) {
            if (buf.length() > 0) {
                buf.append("\n + ");
                continue;
            }
            buf.append("   ");
        }
        buf.append("\n = " + this.score(instance));
        return buf.toString();
    }

    @Override
    public Explanation getExplanation(Instance instance) {
        Explanation.Node top = new Explanation.Node("MultinomialClassifier Explanation");
        Explanation.Node features = new Explanation.Node("Features");
        Iterator<Feature> j = instance.featureIterator();
        while (j.hasNext()) {
            Feature f = j.next();
            Explanation.Node featureEx = new Explanation.Node(f + "<" + instance.getWeight(f));
            features.add(featureEx);
        }
        Explanation.Node bias = new Explanation.Node("bias");
        features.add(bias);
        top.add(features);
        Explanation.Node score = new Explanation.Node("\n = " + this.score(instance));
        top.add(score);
        Explanation ex = new Explanation(top);
        return ex;
    }

    public void setScale(double value) {
        this.SCALE = value;
    }

    public void setPrior(double pi) {
        this.featurePrior = pi;
    }

    public double getPrior() {
        return this.featurePrior;
    }

    public void setUnseenModel(String str) {
        this.unseenModel = str;
    }

    public String getUnseenModel() {
        return this.unseenModel;
    }

    public double getLogLikelihood(Example example) {
        int idx = -1;
        for (int i = 0; i < this.classNames.size(); ++i) {
            if (!this.classNames.get(i).equals(example.getLabel().bestClassName())) continue;
            idx = i;
            break;
        }
        Instance instance = example.asInstance();
        double loglik = 0.0;
        Iterator<Feature> j = instance.featureIterator();
        while (j.hasNext()) {
            Feature f = j.next();
            double featureCounts = instance.getWeight(f);
            double featureProb = ((WeightedSet)this.featureGivenClassParameters.get(idx)).getWeight(f);
            String model = this.getFeatureModel(f);
            if (model.equals("Poisson")) {
                loglik += -featureProb + featureCounts * Math.log(featureProb);
                continue;
            }
            if (model.equals("Naive-Bayes")) {
                loglik += featureCounts * Math.log(featureProb);
                continue;
            }
            if (model.equals("unseen")) {
                System.out.println("unseen: " + f);
                continue;
            }
            System.out.println("error: model " + model + " not found!");
            System.exit(1);
        }
        return loglik;
    }

    public void reset() {
        this.classParameters = new ArrayList<Double>();
        this.featureGivenClassParameters = new ArrayList<Object>();
    }

    public boolean isPresent(ClassLabel label) {
        boolean isPresent = false;
        for (int i = 0; i < this.classNames.size(); ++i) {
            if (!this.classNames.get(i).equals(label.bestClassName())) continue;
            isPresent = true;
        }
        return isPresent;
    }

    public void addValidLabel(ClassLabel label) {
        this.classNames.add(label.bestClassName());
    }

    public ClassLabel getLabel(int i) {
        return new ClassLabel(this.classNames.get(i));
    }

    public int indexOf(ClassLabel label) {
        return this.classNames.indexOf(label.bestClassName());
    }

    public void setFeatureGivenClassParameter(Feature f, int j, Estimate pms) {
        try {
            Map hmap = (Map)this.featureGivenClassParameters.get(j);
            hmap.put(f, pms);
            this.featureGivenClassParameters.set(j, hmap);
        }
        catch (Exception NoHashMapforClassJ) {
            HashMap<Feature, Estimate> hmap = null;
            hmap = new HashMap<Feature, Estimate>();
            hmap.put(f, pms);
            this.featureGivenClassParameters.add(j, hmap);
        }
    }

    public void setFeatureGivenClassParameter(Feature f, int j, double probabilityOfOccurrence) {
        System.out.println("Should not happen!");
    }

    public void setClassParameter(int j, double probabilityOfOccurrence) {
        try {
            this.classParameters.get(j);
        }
        catch (Exception x) {
            this.classParameters.add(j, new Double(probabilityOfOccurrence));
        }
    }

    public void setFeatureModel(Feature feature, String model) {
        this.featureModels.put(feature, model);
    }

    public String getFeatureModel(Feature feature) {
        try {
            String model = this.featureModels.get(feature).toString();
            return model;
        }
        catch (NullPointerException x) {
            return "unseen";
        }
    }

    public Iterator<Feature> featureIterator() {
        TObjectDoubleHashMap map = new TObjectDoubleHashMap();
        for (int i = 0; i < this.classNames.size(); ++i) {
            Map hmap = (Map)this.featureGivenClassParameters.get(i);
            for (Feature f : hmap.keySet()) {
                double w = 0.0;
                map.put(f, w);
            }
        }
        final TObjectDoubleIterator ti = map.iterator();
        Iterator<Feature> i = new Iterator<Feature>(){

            @Override
            public boolean hasNext() {
                return ti.hasNext();
            }

            @Override
            public Feature next() {
                ti.advance();
                return (Feature)ti.key();
            }

            @Override
            public void remove() {
                ti.remove();
            }
        };
        return i;
    }

    public Object[] keys() {
        TObjectDoubleHashMap map = new TObjectDoubleHashMap();
        for (int i = 0; i < this.classNames.size(); ++i) {
            Map hmap = (Map)this.featureGivenClassParameters.get(i);
            for (Feature f : hmap.keySet()) {
                double w = 0.0;
                map.put(f, w);
            }
        }
        return map.keys();
    }

    @Override
    public Viewer toGUI() {
        ControlledViewer gui = new ControlledViewer(new MyViewer(), new MultinomialClassifierControls());
        gui.setContent(this);
        return gui;
    }

    public String toString() {
        return null;
    }

    private static class MyViewer
    extends ComponentViewer
    implements Controllable {
        static final long serialVersionUID = 20080128L;
        private MultinomialClassifierControls controls = null;
        private MultinomialClassifier h = null;

        private MyViewer() {
        }

        public void applyControls(ViewerControls controls) {
            this.controls = (MultinomialClassifierControls)controls;
            this.setContent(this.h, true);
            this.revalidate();
        }

        public boolean canReceive(Object o) {
            return o instanceof MultinomialClassifier;
        }

        public JComponent componentFor(Object o) {
            this.h = (MultinomialClassifier)o;
            Object[] keys = this.h.keys();
            Object[][] tableData = new Object[keys.length][this.h.classNames.size() + 1];
            int k = 0;
            Iterator<Feature> i = this.h.featureIterator();
            while (i.hasNext()) {
                Feature f = i.next();
                tableData[k][0] = f;
                for (int l = 0; l < this.h.classNames.size(); ++l) {
                    String content = ((Estimate)((Map)this.h.featureGivenClassParameters.get(l)).get(f)).toTableInViewer();
                    tableData[k][l + 1] = content;
                }
                ++k;
            }
            if (this.controls != null) {
                Arrays.sort(tableData, new Comparator<Object[]>(){

                    @Override
                    public int compare(Object[] ra, Object[] rb) {
                        if (MyViewer.this.controls.nameButton.isSelected()) {
                            return ra[0].toString().compareTo(rb[0].toString());
                        }
                        Double da = (Double)ra[1];
                        Double db = (Double)rb[1];
                        if (MyViewer.this.controls.valueButton.isSelected()) {
                            return MathUtil.sign(db - da);
                        }
                        return MathUtil.sign(Math.abs(db) - Math.abs(da));
                    }
                });
            }
            Object[] columnNames = new String[this.h.classNames.size() + 1];
            columnNames[0] = "Feature Name";
            for (int i2 = 0; i2 < this.h.classNames.size(); ++i2) {
                columnNames[i2 + 1] = "Class " + (String)this.h.classNames.get(i2);
            }
            JTable table = new JTable(tableData, columnNames);
            this.monitorSelections(table, 0);
            return new JScrollPane(table);
        }
    }

    private static class MultinomialClassifierControls
    extends ViewerControls {
        static final long serialVersionUID = 20080128L;
        private JRadioButton valueButton;
        private JRadioButton nameButton;

        private MultinomialClassifierControls() {
        }

        public void initialize() {
            this.add(new JLabel("Sort by"));
            ButtonGroup group = new ButtonGroup();
            this.nameButton = this.addButton("name", group, true);
            this.valueButton = this.addButton("weight", group, false);
        }

        private JRadioButton addButton(String s, ButtonGroup group, boolean selected) {
            JRadioButton button = new JRadioButton(s, selected);
            group.add(button);
            this.add(button);
            button.addActionListener(this);
            return button;
        }
    }
}

