/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.trees;

import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.algorithms.trees.CompactDecisionTree;
import edu.cmu.minorthird.classify.algorithms.trees.RandomForests;
import edu.cmu.minorthird.classify.algorithms.trees.RandomTreeLearner;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.TreeMap;
import java.util.Vector;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class FastRandomTreeLearner
extends RandomTreeLearner {
    private static Logger log = Logger.getLogger(FastRandomTreeLearner.class);
    private Random rand = new Random();
    public int subsetSize;

    public FastRandomTreeLearner setRandomSeed(long seed) {
        this.rand = new Random(seed);
        return this;
    }

    public FastRandomTreeLearner setSubsetSize(int subsetSize) {
        this.subsetSize = subsetSize;
        return this;
    }

    @Override
    public Classifier batchTrain(List<Example> dataset, Vector<Feature> allFeatures) {
        Classifier c = this.batchTrain(new Vector<Example>(dataset), 0, allFeatures, allFeatures.size() - 1, 0, dataset.size());
        log.info("built tree: " + c);
        return c;
    }

    @Override
    public Classifier batchTrain(Dataset dataset) {
        LinkedList<Example> newData = new LinkedList<Example>();
        Iterator<Example> it = dataset.iterator();
        while (it.hasNext()) {
            newData.add(it.next());
        }
        return this.batchTrain(newData, RandomForests.getDatasetFeatures(dataset));
    }

    public Classifier batchTrain(Vector<Example> dataset, int depth, Vector<Feature> unusedFeatures, int lastFeature, int from, int to) {
        CompactDecisionTree tree = new CompactDecisionTree();
        tree.setRoot(this.batchTrain(dataset, depth, unusedFeatures, lastFeature, from, to, tree));
        tree.compactStorage();
        return tree;
    }

    private Object[] getSplit(Vector<Example> dataset, int from, int to, Vector<Feature> unusedFeatures, int lastFeature, double posWeight, double negWeight) {
        int i;
        HashMap<Feature, Integer> features = new HashMap<Feature, Integer>();
        HashMap<Feature, NumericFeatureStats> stats = new HashMap<Feature, NumericFeatureStats>();
        for (i = 0; i < this.subsetSize; ++i) {
            int featureIndex = (int)Math.floor(this.rand.nextDouble() * (double)lastFeature);
            Feature f = unusedFeatures.get(featureIndex);
            features.put(f, featureIndex);
            stats.put(f, new NumericFeatureStats());
        }
        for (i = from; i < to; ++i) {
            Example example = dataset.get(i);
            for (Feature f : features.keySet()) {
                NumericFeatureStats s = (NumericFeatureStats)stats.get(f);
                s.update(example, example.getWeight(f));
            }
        }
        double bestValue = Double.MAX_VALUE;
        double bestThreshold = -9999.0;
        Feature bestFeature = null;
        int bestFeatureIndex = -1;
        for (Feature f : features.keySet()) {
            NumericFeatureStats s = (NumericFeatureStats)stats.get(f);
            double v = s.value(posWeight, negWeight);
            double th = s.getBestThreshold();
            if (!(v < bestValue)) continue;
            bestValue = v;
            bestFeature = f;
            bestThreshold = th;
            bestFeatureIndex = (Integer)features.get(f);
        }
        if (bestFeature == null) {
            bestFeature = (Feature)features.keySet().iterator().next();
            bestFeatureIndex = (Integer)features.get(bestFeature);
        }
        return new Object[]{bestFeatureIndex, bestThreshold};
    }

    public int batchTrain(Vector<Example> dataset, int depth, Vector<Feature> unusedFeatures, int lastFeature, int from, int to, CompactDecisionTree tree) {
        double posWeight = 0.0;
        double negWeight = 0.0;
        for (int i = from; i < to; ++i) {
            Example example = dataset.get(i);
            if (example.getLabel().numericLabel() > 0.0) {
                posWeight += example.getWeight();
                continue;
            }
            negWeight += example.getWeight();
        }
        double weight = posWeight - negWeight;
        if (negWeight == 0.0 || posWeight == 0.0 || lastFeature < 0 || depth > 500) {
            return tree.addLeafNode(weight);
        }
        Object[] result = this.getSplit(dataset, from, to, unusedFeatures, lastFeature, posWeight, negWeight);
        int featureIndex = (Integer)result[0];
        double bestThreshold = (Double)result[1];
        Feature bestFeature = unusedFeatures.get(featureIndex);
        if (depth > 1000) {
            log.warn("Pos Weight: " + posWeight);
            log.warn("Neg Weight: " + negWeight);
            log.warn("last Feature: " + lastFeature);
            log.warn("from: " + from + ", to: " + to);
            log.warn("split on: " + bestFeature + " with threshold " + bestThreshold);
        }
        int storeIndex = from;
        for (int i = from; i < to; ++i) {
            if (!(dataset.get(i).getWeight(bestFeature) >= bestThreshold)) continue;
            Example tmp = dataset.get(storeIndex);
            dataset.setElementAt(dataset.get(i), storeIndex);
            dataset.setElementAt(tmp, i);
            ++storeIndex;
        }
        if (depth > 1000) {
            log.warn("Pos Weight: " + posWeight);
            log.warn("Neg Weight: " + negWeight);
            log.warn("last Feature: " + lastFeature);
            log.warn("from: " + from + ", to: " + to);
            log.warn("storeIndex: " + storeIndex);
            log.warn("split on: " + bestFeature + " with threshold " + bestThreshold);
        }
        unusedFeatures.setElementAt(unusedFeatures.get(lastFeature), featureIndex);
        unusedFeatures.setElementAt(bestFeature, lastFeature);
        if (storeIndex == from || storeIndex == to) {
            log.debug("didn't split data with this feature");
            return this.batchTrain(dataset, depth + 1, unusedFeatures, lastFeature - 1, from, to, tree);
        }
        int trueBranch = this.batchTrain(dataset, depth + 1, unusedFeatures, lastFeature, from, storeIndex, tree);
        int falseBranch = this.batchTrain(dataset, depth + 1, unusedFeatures, lastFeature, storeIndex, to, tree);
        return tree.addInternalNode(bestFeature, bestThreshold, trueBranch, falseBranch);
    }

    private static double schapireSingerValue(double pos, double neg, double totalPos, double totalNeg) {
        double totalWeight = totalPos + totalNeg;
        double wp1 = pos / totalWeight;
        double wp0 = (totalPos - pos) / totalWeight;
        double wn1 = neg / totalWeight;
        double wn0 = (totalNeg - neg) / totalWeight;
        log.debug("pos, neg, total = " + pos + ", " + neg + ", " + totalWeight);
        log.debug("wp1,wp0,wn1,wn0 = " + wp1 + "," + wp0 + "," + wn1 + "," + wn0);
        return 2.0 * (Math.sqrt(wp1 * wn1) + Math.sqrt(wp0 * wn0));
    }

    private static class NumericFeatureStats {
        private TreeMap<Double, BinaryFeatureStats> map = new TreeMap();
        private double posNonZero = 0.0;
        private double negNonZero = 0.0;
        private double bestThreshold;
        private double bestThresholdValue;

        public void update(Example example, double featureWeight) {
            Double key = new Double(featureWeight);
            BinaryFeatureStats bfs = this.map.get(key);
            if (bfs == null) {
                bfs = new BinaryFeatureStats();
                this.map.put(key, bfs);
            }
            bfs.update(example);
            if (example.getLabel().numericLabel() > 0.0) {
                this.posNonZero += example.getWeight();
            } else {
                this.negNonZero += example.getWeight();
            }
        }

        public double value(double totalPosWeight, double totalNegWeight) {
            if (totalPosWeight + totalPosWeight > this.posNonZero + this.negNonZero) {
                BinaryFeatureStats s = new BinaryFeatureStats();
                s.pos = totalPosWeight - this.posNonZero;
                s.neg = totalNegWeight - this.negNonZero;
                this.map.put(new Double(0.0), s);
            }
            Double lastKey = null;
            double posGT = totalPosWeight;
            double negGT = totalNegWeight;
            this.bestThresholdValue = Double.MAX_VALUE;
            for (Double key : this.map.keySet()) {
                double threshold = -1.0;
                if (lastKey != null) {
                    threshold = lastKey + 0.5 * (key - lastKey);
                    double value = FastRandomTreeLearner.schapireSingerValue(posGT, negGT, totalPosWeight, totalNegWeight);
                    if (value < this.bestThresholdValue) {
                        this.bestThreshold = threshold;
                        this.bestThresholdValue = value;
                    }
                }
                lastKey = key;
                BinaryFeatureStats bfs = this.map.get(key);
                posGT -= bfs.pos;
                negGT -= bfs.neg;
            }
            return this.bestThresholdValue;
        }

        double getBestThreshold() {
            return this.bestThreshold;
        }

        public String toString() {
            return "[pos: " + this.posNonZero + " neg: " + this.negNonZero + " map: " + this.map + "]";
        }
    }

    private static class BinaryFeatureStats {
        private double pos = 0.0;
        private double neg = 0.0;

        private BinaryFeatureStats() {
        }

        public void update(Example example) {
            if (example.getLabel().numericLabel() > 0.0) {
                this.pos += example.getWeight();
            } else {
                this.neg += example.getWeight();
            }
        }

        public double value(double totalPosWeight, double totalNegWeight) {
            return FastRandomTreeLearner.schapireSingerValue(this.pos, this.neg, totalPosWeight, totalNegWeight);
        }

        public String toString() {
            return "[pos:" + this.pos + " neg:" + this.neg + "]";
        }
    }
}

