/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify;

import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.BinaryClassifierLearner;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.util.MathUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.BorderLayout;
import java.awt.Component;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;

public class TweakedLearner
extends BatchBinaryClassifierLearner {
    private BinaryClassifierLearner innerLearner;
    private double beta;
    private Dataset m_dataset;
    private ExampleSchema schema;
    private boolean isBinary = true;
    private static final int ILLEGAL_VALUE = -1;
    private static final double UNINITIALIZED = -1.0;
    private List<Row> tweakingTable = new ArrayList<Row>();
    Evaluation.Matrix cm = null;
    private static Logger log = Logger.getLogger(TweakedLearner.class);

    public TweakedLearner(BinaryClassifierLearner innerLearner, double beta) {
        this.beta = beta;
        this.innerLearner = innerLearner;
    }

    public Classifier batchTrain(Dataset dataset) {
        this.schema = dataset.getSchema();
        this.isBinary = this.schema.equals(ExampleSchema.BINARY_EXAMPLE_SCHEMA);
        if (!this.isBinary) {
            throw new IllegalArgumentException("Dataset given to TweakedLearner::batchTrain must be a binary dataset");
        }
        if (dataset.size() == 0) {
            throw new IllegalArgumentException("Dataset given to TweakedLearner::batchTrain is empty");
        }
        this.m_dataset = dataset;
        BinaryClassifier bc = (BinaryClassifier)new DatasetClassifierTeacher(this.m_dataset).train(this.innerLearner);
        this.initializeTable();
        double threshold = this.executeTweaking();
        return new TweakedClassifier(bc, threshold);
    }

    public double getBeta() {
        return this.beta;
    }

    public void setBeta(double beta) {
        this.beta = beta;
    }

    public BinaryClassifierLearner getInnerLearner() {
        return this.innerLearner;
    }

    public void setInnerLearner(BinaryClassifierLearner learner) {
        this.innerLearner = learner;
    }

    private void initializeTable() {
        int counter = 0;
        Iterator<Example> i = this.m_dataset.iterator();
        while (i.hasNext()) {
            Example ex = i.next();
            ClassLabel predicted = this.innerLearner.getBinaryClassifier().classification(ex);
            this.tweakingTable.add(new Row(ex.asInstance(), ex.getLabel(), predicted, ClassLabel.negativeLabel(-1.0)));
            ++counter;
        }
        this.sortByScore();
    }

    private double executeTweaking() {
        double threshold = -1.0;
        this.initConfusionMatrix();
        for (int i = 0; i < this.tweakingTable.size(); ++i) {
            this.getRow((int)i).tweak_predicted = ClassLabel.positiveLabel(1.0);
            this.updateConfusionMatrix(i);
            this.getRow((int)i).precision = this.getCurrentPrecision();
            this.getRow((int)i).recall = this.getCurrentRecall();
            this.getRow((int)i).F_beta = this.calculateFBeta(this.getRow((int)i).precision, this.getRow((int)i).recall);
        }
        int index = this.maxFBetaEntry();
        if (index + 1 == this.tweakingTable.size()) {
            threshold = this.getRow((int)index).orig_predicted.posWeight();
        } else {
            double maxRowScore = this.getRow((int)index).orig_predicted.posWeight();
            double nextRowScore = this.getRow((int)(index + 1)).orig_predicted.posWeight();
            threshold = (maxRowScore + nextRowScore) / 2.0;
        }
        log.debug("Threshold found: " + threshold + " (in row " + index + ")");
        return threshold;
    }

    private void initConfusionMatrix() {
        String[] classes = this.getClasses();
        double[][] confused = new double[classes.length][classes.length];
        for (int i = 0; i < this.tweakingTable.size(); ++i) {
            Row row = this.getRow(i);
            double[] dArray = confused[this.classIndexOf(row.actual)];
            int n = this.classIndexOf(row.tweak_predicted);
            dArray[n] = dArray[n] + 1.0;
        }
        this.cm = new Evaluation.Matrix(confused);
    }

    private void updateConfusionMatrix(int index) {
        Row row = this.getRow(index);
        int actual = this.classIndexOf(row.actual);
        int p = this.classIndexOf("POS");
        int n = this.classIndexOf("NEG");
        double[] dArray = this.cm.values[actual];
        int n2 = p;
        dArray[n2] = dArray[n2] + 1.0;
        double[] dArray2 = this.cm.values[actual];
        int n3 = n;
        dArray2[n3] = dArray2[n3] - 1.0;
    }

    private double calculateFBeta(double precision, double recall) {
        double divisor = this.beta * precision + recall;
        if (divisor == 0.0) {
            log.warn("TweakedLearner::calculateFBeta, divisor of F_beta is zero !!!");
            return 0.0;
        }
        if (new Double(divisor).isNaN()) {
            log.warn("TweakedLearner::calculateFBeta, divisor of F_beta is a NaN !!!");
            return 0.0;
        }
        return (this.beta + 1.0) * precision * recall / divisor;
    }

    private double getCurrentPrecision() {
        if (!this.isBinary) {
            return -1.0;
        }
        int p = this.classIndexOf("POS");
        int n = this.classIndexOf("NEG");
        return this.cm.values[p][p] / (this.cm.values[p][p] + this.cm.values[n][p]);
    }

    private double getCurrentRecall() {
        if (!this.isBinary) {
            return -1.0;
        }
        int p = this.classIndexOf("POS");
        int n = this.classIndexOf("NEG");
        return this.cm.values[p][p] / (this.cm.values[p][p] + this.cm.values[p][n]);
    }

    private void sortByScore() {
        Collections.sort(this.tweakingTable, new Comparator<Row>(){

            @Override
            public int compare(Row a, Row b) {
                return MathUtil.sign(b.orig_predicted.posWeight() - a.orig_predicted.posWeight());
            }
        });
    }

    private int maxFBetaEntry() {
        double maxFBeta = -1.0;
        int maxIndex = -1;
        for (int i = 0; i < this.tweakingTable.size(); ++i) {
            if (!(this.getRow((int)i).F_beta > maxFBeta)) continue;
            maxFBeta = this.getRow((int)i).F_beta;
            maxIndex = i;
        }
        if (maxFBeta == -1.0) {
            log.error("In TweakedLearner::maxFBetaEntry, maxFBeta has an illegal value");
        }
        return maxIndex;
    }

    private Row getRow(int i) {
        return this.tweakingTable.get(i);
    }

    private String[] getClasses() {
        return this.schema.validClassNames();
    }

    private int classIndexOf(ClassLabel classLabel) {
        return this.classIndexOf(classLabel.bestClassName());
    }

    private int classIndexOf(String classLabelName) {
        return this.schema.getClassIndex(classLabelName);
    }

    public static void main(String[] args) {
        System.out.println("Started the test program for TweakedLearner");
        NaiveBayes nb = new NaiveBayes();
        new TweakedLearner(nb, 3.0);
        System.out.println("Created a TweakedLearner");
    }

    public static class TweakedClassifier
    extends BinaryClassifier
    implements Serializable,
    Visible {
        private static final long serialVersionUID = 20080128L;
        private double m_threshold;
        private BinaryClassifier m_classifier;

        public TweakedClassifier(BinaryClassifier classifier, double threshold) {
            this.m_classifier = classifier;
            this.m_threshold = threshold;
        }

        public double score(Instance instance) {
            return this.m_classifier.score(instance) - this.m_threshold;
        }

        public Viewer toGUI() {
            ComponentViewer v = new ComponentViewer(){
                static final long serialVersionUID = 20080128L;

                public JComponent componentFor(Object o) {
                    TweakedClassifier c = (TweakedClassifier)o;
                    JPanel mainPanel = new JPanel();
                    mainPanel.setLayout(new BorderLayout());
                    mainPanel.add((Component)new JLabel("Optimal threshold for TweakedClassifier=" + c.m_threshold), "North");
                    mainPanel.add((Component)new JLabel("Original classifier before tweaking:"), "Center");
                    SmartVanillaViewer subView = new SmartVanillaViewer(c.m_classifier);
                    subView.setSuperView(this);
                    mainPanel.add((Component)subView, "South");
                    mainPanel.setBorder(new TitledBorder("TweakedClassifier class"));
                    return new JScrollPane(mainPanel);
                }
            };
            v.setContent(this);
            return v;
        }

        public String explain(Instance instance) {
            StringBuffer buf = new StringBuffer("");
            buf.append("Explanation of original untweaked classifier:\n");
            buf.append(this.m_classifier.explain(instance));
            buf.append("\nAdjusted score after tweaking = " + this.score(instance));
            return buf.toString();
        }

        public Explanation getExplanation(Instance instance) {
            Explanation.Node top = new Explanation.Node("TweakedLearner Explanation");
            Explanation.Node orig = new Explanation.Node("Explanation of original untweaked classifier");
            Explanation.Node origEx = this.m_classifier.getExplanation(instance).getTopNode();
            orig.add(origEx);
            top.add(orig);
            Explanation.Node adjusted = new Explanation.Node("\nAdjusted score after tweaking = " + this.score(instance));
            top.add(adjusted);
            Explanation ex = new Explanation(top);
            return ex;
        }
    }

    private static class Row
    implements Serializable {
        private static final long serialVersionUID = -4069980043842319180L;
        public transient Instance instance = null;
        public ClassLabel actual;
        public ClassLabel orig_predicted;
        public ClassLabel tweak_predicted;
        public double precision = -1.0;
        public double recall = -1.0;
        public double F_beta = -1.0;

        public Row(Instance i, ClassLabel a, ClassLabel orig_p, ClassLabel tweak_p) {
            this.instance = i;
            this.actual = a;
            this.orig_predicted = orig_p;
            this.tweak_predicted = tweak_p;
        }

        public String toString() {
            return this.orig_predicted + "\t" + this.actual + "\t" + this.instance;
        }
    }
}

