/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.trees;

import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTree;
import edu.cmu.minorthird.classify.algorithms.trees.RandomForests;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Vector;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class RandomTreeLearner
extends BatchBinaryClassifierLearner {
    private static Logger log = Logger.getLogger(RandomTreeLearner.class);
    private TreeSplitter splitter;

    public RandomTreeLearner() {
        this.splitter = new RandomTreeSplitter();
    }

    public RandomTreeLearner(TreeSplitter b) {
        this.splitter = b;
    }

    public Classifier batchTrain(List<Example> dataset, Vector<Feature> allFeatures) {
        DecisionTree c = this.batchTrain(dataset, 0, allFeatures);
        log.info("built tree: " + c);
        return c;
    }

    @Override
    public Classifier batchTrain(Dataset dataset) {
        LinkedList<Example> newData = new LinkedList<Example>();
        Iterator<Example> it = dataset.iterator();
        while (it.hasNext()) {
            newData.add(it.next());
        }
        return this.batchTrain(newData, RandomForests.getDatasetFeatures(dataset));
    }

    public DecisionTree batchTrain(List<Example> dataset, int depth, Vector<Feature> unusedFeatures) {
        double posWeight = 0.0;
        double negWeight = 0.0;
        for (Example example : dataset) {
            if (example.getLabel().numericLabel() > 0.0) {
                posWeight += example.getWeight();
                continue;
            }
            negWeight += example.getWeight();
        }
        log.debug("build (sub)tree with posWeight: " + posWeight + " negWeight: " + negWeight);
        if (negWeight == 0.0 || posWeight == 0.0 || unusedFeatures.size() == 0) {
            int weight = 0;
            weight = posWeight > negWeight ? 1 : (posWeight == negWeight ? 0 : -1);
            log.debug("leaf");
            return new DecisionTree.Leaf(weight);
        }
        Object[] result = this.splitter.getSplit(dataset, depth, unusedFeatures);
        Feature bestFeature = (Feature)result[0];
        double bestThreshold = (Double)result[1];
        List<Example> trueData = null;
        List<Example> falseData = null;
        if (result.length == 4) {
            trueData = (List)result[2];
            falseData = (List)result[3];
        } else {
            trueData = new LinkedList();
            falseData = new LinkedList();
            for (Example example : dataset) {
                if (example.getWeight(bestFeature) >= bestThreshold) {
                    trueData.add(example);
                    continue;
                }
                falseData.add(example);
            }
        }
        log.debug("split on: " + bestFeature + " with threshold " + bestThreshold);
        log.debug("trueData size: " + trueData.size() + " falseData size: " + falseData.size());
        Vector<Feature> newUnusedFeatures = new Vector<Feature>(unusedFeatures);
        newUnusedFeatures.removeElement(bestFeature);
        if (falseData.size() == 0 || trueData.size() == 0) {
            log.debug("didn't split data with this feature");
            return this.batchTrain(dataset, depth, newUnusedFeatures);
        }
        DecisionTree trueBranch = this.batchTrain(trueData, depth + 1, newUnusedFeatures);
        DecisionTree falseBranch = this.batchTrain(falseData, depth + 1, newUnusedFeatures);
        return new DecisionTree.InternalNode(bestFeature, bestThreshold, trueBranch, falseBranch);
    }

    private static double entropy(double pos, double neg, double totalPos, double totalNeg) {
        double tot = totalPos + totalNeg;
        double epsilon = 0.1 / tot;
        double w11 = pos / tot + epsilon;
        double w10 = neg / tot + epsilon;
        double w01 = (tot - pos) / tot + epsilon;
        double w00 = (tot - neg) / tot + epsilon;
        log.debug("pos, neg, total = " + pos + ", " + neg + ", " + tot);
        log.debug("w11,w10,w01,w00 = " + w11 + "," + w10 + "," + w01 + "," + w00);
        return -w11 * Math.log(w11) - w10 * Math.log(w10) - w01 * Math.log(w01) - w00 * Math.log(w00);
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class BestOfNRandomTreeSplitter
    implements TreeSplitter {
        int featureCount = 1;

        public BestOfNRandomTreeSplitter(int fc) {
            this.featureCount = fc;
        }

        @Override
        public Object[] getSplit(List<Example> dataset, int depth, Vector<Feature> unusedFeatures) {
            Feature bestFeature = null;
            double bestEntropy = Double.MIN_VALUE;
            double bestThreshold = 0.0;
            LinkedList<Example> bestTrueData = null;
            LinkedList<Example> bestFalseData = null;
            for (int i = 0; i < this.featureCount && i < unusedFeatures.size(); ++i) {
                int featureIndex = (int)Math.floor(Math.random() * (double)unusedFeatures.size());
                Feature f = unusedFeatures.get(featureIndex);
                LinkedList<Example> trueData = new LinkedList<Example>();
                LinkedList<Example> falseData = new LinkedList<Example>();
                double minValue = Double.MAX_VALUE;
                double maxValue = Double.MIN_VALUE;
                for (Example example : dataset) {
                    double val = example.getWeight(f);
                    if (val < minValue) {
                        minValue = val;
                    }
                    if (!(val > maxValue)) continue;
                    maxValue = val;
                }
                double threshold = Math.random() * (maxValue - minValue) + minValue;
                for (Example example : dataset) {
                    if (example.getWeight(f) >= threshold) {
                        trueData.add(example);
                        continue;
                    }
                    falseData.add(example);
                }
                double i_gain = RandomTreeLearner.entropy(trueData.size(), falseData.size(), trueData.size(), falseData.size());
                if (!(i_gain > bestEntropy)) continue;
                bestEntropy = i_gain;
                bestTrueData = trueData;
                bestFalseData = falseData;
                bestFeature = f;
                bestThreshold = threshold;
            }
            return new Object[]{bestFeature, new Double(bestThreshold), bestTrueData, bestFalseData};
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class RandomTreeSplitter
    implements TreeSplitter {
        @Override
        public Object[] getSplit(List<Example> dataset, int depth, Vector<Feature> unusedFeatures) {
            int featureIndex = (int)Math.floor(Math.random() * (double)unusedFeatures.size());
            Feature bestFeature = unusedFeatures.get(featureIndex);
            double minValue = Double.MAX_VALUE;
            double maxValue = Double.MIN_VALUE;
            for (Example example : dataset) {
                double val = example.getWeight(bestFeature);
                if (val < minValue) {
                    minValue = val;
                }
                if (!(val > maxValue)) continue;
                maxValue = val;
            }
            double bestThreshold = Math.random() * (maxValue - minValue) + minValue;
            return new Object[]{bestFeature, new Double(bestThreshold)};
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static interface TreeSplitter {
        public Object[] getSplit(List<Example> var1, int var2, Vector<Feature> var3);
    }
}

