/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.trees;

import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTree;
import java.util.Iterator;
import java.util.TreeMap;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

public class DecisionTreeLearner
extends BatchBinaryClassifierLearner {
    private static Logger log = Logger.getLogger(DecisionTreeLearner.class);
    private static final boolean DEBUG = log.getEffectiveLevel().isGreaterOrEqual(Level.DEBUG);
    private int maxDepth = 5;
    private int minSplitCount = 2;
    private double epsilon = 0.001;

    public DecisionTreeLearner(int maxDepth, int minSplitCount) {
        this.maxDepth = maxDepth;
        this.minSplitCount = minSplitCount;
    }

    public DecisionTreeLearner() {
        this(5, 2);
    }

    public int getMaxDepth() {
        return this.maxDepth;
    }

    public void setMaxDepth(int d) {
        this.maxDepth = d;
    }

    public int getMinSplitCount() {
        return this.minSplitCount;
    }

    public void setMinSplitCount(int c) {
        this.minSplitCount = c;
    }

    public Classifier batchTrain(Dataset dataset) {
        this.epsilon = 0.5 / (double)dataset.size();
        DecisionTree c = this.batchTrain(dataset, 0);
        log.info("built tree: " + c);
        return c;
    }

    private DecisionTree batchTrain(Dataset dataset, int depth) {
        double v;
        Object s;
        double posWeight = 0.0;
        double negWeight = 0.0;
        Iterator<Example> i = dataset.iterator();
        while (i.hasNext()) {
            Example example = i.next();
            if (example.getLabel().numericLabel() > 0.0) {
                posWeight += example.getWeight();
                continue;
            }
            negWeight += example.getWeight();
        }
        log.info("build (sub)tree with posWeight: " + posWeight + " negWeight: " + negWeight);
        if (dataset.size() < this.minSplitCount || depth >= this.maxDepth || negWeight == 0.0 || posWeight == 0.0) {
            log.debug("leaf");
            return new DecisionTree.Leaf(0.5 * Math.log((posWeight + this.epsilon) / (negWeight + this.epsilon)));
        }
        double totalPosWeight = 0.0;
        double totalNegWeight = 0.0;
        TreeMap<Feature, BinaryFeatureStats> binaryMap = new TreeMap<Feature, BinaryFeatureStats>();
        TreeMap<Feature, NumericFeatureStats> numericMap = new TreeMap<Feature, NumericFeatureStats>();
        Iterator<Example> i2 = dataset.iterator();
        while (i2.hasNext()) {
            Object s2;
            Feature f;
            Example example = i2.next();
            if (example.getLabel().numericLabel() > 0.0) {
                totalPosWeight += example.getWeight();
            } else {
                totalNegWeight += example.getWeight();
            }
            Iterator<Feature> j = example.binaryFeatureIterator();
            while (j.hasNext()) {
                f = j.next();
                s2 = (BinaryFeatureStats)binaryMap.get(f);
                if (s2 == null) {
                    s2 = new BinaryFeatureStats();
                    binaryMap.put(f, (BinaryFeatureStats)s2);
                }
                ((BinaryFeatureStats)s2).update(example);
            }
            j = example.numericFeatureIterator();
            while (j.hasNext()) {
                f = j.next();
                s2 = (NumericFeatureStats)numericMap.get(f);
                if (s2 == null) {
                    s2 = new NumericFeatureStats();
                    numericMap.put(f, (NumericFeatureStats)s2);
                }
                ((NumericFeatureStats)s2).update(example, example.getWeight(f));
            }
        }
        double bestValue = Double.MAX_VALUE;
        double bestThreshold = -9999.0;
        Feature bestFeature = null;
        for (Feature f : binaryMap.keySet()) {
            s = (BinaryFeatureStats)binaryMap.get(f);
            v = ((BinaryFeatureStats)s).value(totalPosWeight, totalNegWeight);
            if (DEBUG) {
                log.debug("feature " + f + " stats: " + s + " val: " + v);
            }
            if (!(v < bestValue)) continue;
            bestValue = v;
            bestFeature = f;
            bestThreshold = 0.5;
            if (!DEBUG) continue;
            log.debug(" ==> BEST");
        }
        for (Feature f : numericMap.keySet()) {
            s = (NumericFeatureStats)numericMap.get(f);
            v = ((NumericFeatureStats)s).value(totalPosWeight, totalNegWeight);
            double th = ((NumericFeatureStats)s).getBestThreshold();
            if (DEBUG) {
                log.debug("feature " + f + "<" + th + " stats: " + s + " val: " + v);
            }
            if (!(v < bestValue)) continue;
            bestValue = v;
            bestFeature = f;
            bestThreshold = th;
            if (!DEBUG) continue;
            log.debug(" ==> BEST");
        }
        if (bestFeature == null) {
            log.debug("no good split found - leaf");
            return new DecisionTree.Leaf(0.5 * Math.log((posWeight + this.epsilon) / (negWeight + this.epsilon)));
        }
        log.info("split on " + bestFeature + ">" + bestThreshold);
        BasicDataset trueData = new BasicDataset();
        BasicDataset falseData = new BasicDataset();
        Iterator<Example> i3 = dataset.iterator();
        while (i3.hasNext()) {
            Example example = i3.next();
            if (example.getWeight(bestFeature) > bestThreshold) {
                trueData.add(example, false);
                continue;
            }
            falseData.add(example, false);
        }
        DecisionTree trueBranch = this.batchTrain(trueData, depth + 1);
        DecisionTree falseBranch = this.batchTrain(falseData, depth + 1);
        return new DecisionTree.InternalNode(bestFeature, bestThreshold, trueBranch, falseBranch);
    }

    private static double schapireSingerValue(double pos, double neg, double totalPos, double totalNeg) {
        double totalWeight = totalPos + totalNeg;
        double wp1 = pos / totalWeight;
        double wp0 = (totalPos - pos) / totalWeight;
        double wn1 = neg / totalWeight;
        double wn0 = (totalNeg - neg) / totalWeight;
        log.debug("pos, neg, total = " + pos + ", " + neg + ", " + totalWeight);
        log.debug("wp1,wp0,wn1,wn0 = " + wp1 + "," + wp0 + "," + wn1 + "," + wn0);
        return 2.0 * (Math.sqrt(wp1 * wn1) + Math.sqrt(wp0 * wn0));
    }

    private class NumericFeatureStats {
        private TreeMap<Double, BinaryFeatureStats> map = new TreeMap();
        private double posNonZero = 0.0;
        private double negNonZero = 0.0;
        private double bestThreshold;
        private double bestThresholdValue;

        public void update(Example example, double featureWeight) {
            Double key = new Double(featureWeight);
            BinaryFeatureStats bfs = this.map.get(key);
            if (bfs == null) {
                bfs = new BinaryFeatureStats();
                this.map.put(key, bfs);
            }
            bfs.update(example);
            if (example.getLabel().numericLabel() > 0.0) {
                this.posNonZero += example.getWeight();
            } else {
                this.negNonZero += example.getWeight();
            }
        }

        public double value(double totalPosWeight, double totalNegWeight) {
            if (totalPosWeight + totalPosWeight > this.posNonZero + this.negNonZero) {
                BinaryFeatureStats s = new BinaryFeatureStats();
                s.pos = totalPosWeight - this.posNonZero;
                s.neg = totalNegWeight - this.negNonZero;
                this.map.put(new Double(0.0), s);
            }
            Double lastKey = null;
            double posGT = totalPosWeight;
            double negGT = totalNegWeight;
            this.bestThresholdValue = Double.MAX_VALUE;
            for (Double key : this.map.keySet()) {
                double threshold = -1.0;
                if (lastKey != null) {
                    threshold = lastKey + 0.5 * (key - lastKey);
                    double value = DecisionTreeLearner.schapireSingerValue(posGT, negGT, totalPosWeight, totalNegWeight);
                    if (value < this.bestThresholdValue) {
                        this.bestThreshold = threshold;
                        this.bestThresholdValue = value;
                    }
                }
                lastKey = key;
                BinaryFeatureStats bfs = this.map.get(key);
                posGT -= bfs.pos;
                negGT -= bfs.neg;
            }
            return this.bestThresholdValue;
        }

        double getBestThreshold() {
            return this.bestThreshold;
        }

        public String toString() {
            return "[pos: " + this.posNonZero + " neg: " + this.negNonZero + " map: " + this.map + "]";
        }
    }

    private class BinaryFeatureStats {
        private double pos = 0.0;
        private double neg = 0.0;

        private BinaryFeatureStats() {
        }

        public void update(Example example) {
            if (example.getLabel().numericLabel() > 0.0) {
                this.pos += example.getWeight();
            } else {
                this.neg += example.getWeight();
            }
        }

        public double value(double totalPosWeight, double totalNegWeight) {
            return DecisionTreeLearner.schapireSingerValue(this.pos, this.neg, totalPosWeight, totalNegWeight);
        }

        public String toString() {
            return "[pos:" + this.pos + " neg:" + this.neg + "]";
        }
    }
}

