/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.trees;

import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTree;
import edu.cmu.minorthird.classify.algorithms.trees.FastRandomTreeLearner;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.StringUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.VanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.Component;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.Vector;
import java.util.concurrent.Semaphore;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class RandomForests
extends BatchBinaryClassifierLearner {
    private static Logger log = Logger.getLogger(RandomForests.class);
    private FastRandomTreeLearner baseLearner;
    private int numComponents;
    private int selectSize = 1;
    private boolean selectSizeLog = true;
    private boolean collectStats = false;
    private boolean isThreaded = true;
    private int threadCount = 4;
    private Random rand;
    private boolean scaleWeights = true;

    public RandomForests() {
        this(101);
    }

    public RandomForests(int numComponents) {
        this(new FastRandomTreeLearner(), numComponents);
    }

    public RandomForests(FastRandomTreeLearner baseLearner, int numComponents) {
        this.baseLearner = baseLearner;
        this.numComponents = numComponents;
        this.threadCount = Runtime.getRuntime().availableProcessors();
        this.rand = new Random();
        log.info("setting number of random forest threads to " + this.threadCount);
    }

    public RandomForests setThreaded(boolean b) {
        this.isThreaded = b;
        return this;
    }

    public RandomForests setThreadCount(int c) {
        this.threadCount = c;
        return this;
    }

    public RandomForests setCollectStats(boolean b) {
        this.collectStats = b;
        return this;
    }

    public RandomForests setScaleWeights(boolean b) {
        this.scaleWeights = b;
        return this;
    }

    public RandomForests setSelectionSizeLog() {
        this.selectSizeLog = true;
        return this;
    }

    public RandomForests setSelectionSize(int c) {
        this.selectSize = c;
        this.selectSizeLog = false;
        return this;
    }

    public static RandomForests RepeatableForest() {
        RandomForests rf = new RandomForests(new FastRandomTreeLearner().setRandomSeed(0L), 101);
        rf.rand = new Random(0L);
        rf.setThreaded(false);
        return rf;
    }

    @Override
    public Classifier batchTrain(Dataset dataset) {
        Vector<Example> examples = new Vector<Example>(dataset.size());
        Vector<Feature> allFeatures = RandomForests.getDatasetFeatures(dataset);
        int eSize = dataset.size();
        Iterator<Example> it = dataset.iterator();
        double pos = 0.0;
        double neg = 0.0;
        for (int i = 0; i < eSize; ++i) {
            Example e = it.next();
            examples.addElement(new Example(e.asInstance(), e.getLabel(), e.getWeight()));
            if (e.getLabel().numericLabel() > 0.0) {
                pos += e.getWeight();
                continue;
            }
            neg += e.getWeight();
        }
        if (this.scaleWeights) {
            double pRatio = pos / (pos + neg) + 1.0E-4;
            double nRatio = neg / (pos + neg) + 1.0E-4;
            for (Example e : examples) {
                if (e.getLabel().numericLabel() > 0.0) {
                    e.setWeight(e.getWeight() / pRatio);
                    continue;
                }
                e.setWeight(e.getWeight() / nRatio);
            }
        }
        if (this.selectSizeLog) {
            this.baseLearner.setSubsetSize(Math.max((int)Math.floor(Math.log(allFeatures.size()) / Math.log(2.0)), 1));
        } else {
            this.baseLearner.setSubsetSize(this.selectSize);
        }
        Hashtable<Classifier, Set<Example>> oobMap = new Hashtable<Classifier, Set<Example>>();
        ArrayList<Classifier> classifiers = new ArrayList<Classifier>(this.numComponents);
        ProgressCounter pc = new ProgressCounter("RandomForest", "treecount", this.numComponents);
        int numThreads = this.isThreaded ? this.threadCount : 1;
        Semaphore s = new Semaphore(numThreads);
        log.info("Random forests starting with " + dataset.size() + " elements, " + allFeatures.size() + " features");
        log.info("example size: " + examples.size());
        log.info("Learning classifier with " + this.baseLearner);
        for (int t = 0; t < this.numComponents; ++t) {
            if (this.isThreaded) {
                s.acquireUninterruptibly();
            }
            LearnerThread runnerT = new LearnerThread(examples, new Vector<Feature>(allFeatures), classifiers, oobMap, s);
            if (this.isThreaded) {
                runnerT.start();
            } else {
                ((Thread)runnerT).run();
            }
            pc.progress();
        }
        if (this.isThreaded) {
            s.acquireUninterruptibly(numThreads);
            s.release(numThreads);
        }
        pc.finished();
        if (this.collectStats) {
            this.printSomeStats(examples, oobMap);
        }
        return new VotingClassifier(classifiers);
    }

    private void printSomeStats(Vector<Example> examples, Hashtable<Classifier, Set<Example>> oobMap) {
        this.printTreeShapeInfo(oobMap);
        this.printOobErrorEstimate(examples, oobMap);
    }

    private void printTreeShapeInfo(Hashtable<Classifier, Set<Example>> oobmap) {
        int[] maxDepth = new int[oobmap.size()];
        int[] numNodes = new int[oobmap.size()];
        int i = 0;
        Enumeration<Classifier> e = oobmap.keys();
        while (e.hasMoreElements()) {
            DecisionTree t = (DecisionTree)e.nextElement();
            maxDepth[i] = this.maxDepth(t);
            numNodes[i] = this.numNodes(t);
            ++i;
        }
        int avgNumNodes = 0;
        int avgMaxDepth = 0;
        int maxMaxDepth = 0;
        for (i = 0; i < oobmap.size(); ++i) {
            avgNumNodes += numNodes[i];
            avgMaxDepth += maxDepth[i];
            maxMaxDepth = maxDepth[i] > maxMaxDepth ? maxDepth[i] : maxMaxDepth;
        }
        avgNumNodes = (int)Math.round((double)avgNumNodes / (double)i);
        avgMaxDepth = (int)Math.round((double)avgMaxDepth / (double)i);
        log.info("Average Number of nodes: " + avgNumNodes);
        log.info("Average Max depth of tree: " + avgMaxDepth);
        log.info("Max Max depth of tree: " + maxMaxDepth);
    }

    private int maxDepth(DecisionTree t) {
        int fb;
        if (t instanceof DecisionTree.Leaf) {
            return 1;
        }
        DecisionTree.InternalNode n = (DecisionTree.InternalNode)t;
        int tb = this.maxDepth(n.getTrueBranch());
        return (tb > (fb = this.maxDepth(n.getFalseBranch())) ? tb : fb) + 1;
    }

    private int numNodes(DecisionTree t) {
        if (t instanceof DecisionTree.Leaf) {
            return 1;
        }
        DecisionTree.InternalNode n = (DecisionTree.InternalNode)t;
        int tb = this.maxDepth(n.getTrueBranch());
        int fb = this.maxDepth(n.getFalseBranch());
        return tb + fb + 1;
    }

    private void printOobErrorEstimate(Vector<Example> examples, Hashtable<Classifier, Set<Example>> oobMap) {
        int numCorrect = 0;
        int numIncorrect = 0;
        for (Example e : examples) {
            double score = 0.0;
            Enumeration<Classifier> trees = oobMap.keys();
            while (trees.hasMoreElements()) {
                DecisionTree t = (DecisionTree)trees.nextElement();
                Set<Example> oobData = oobMap.get(t);
                if (!oobData.contains(e)) continue;
                score += t.score(e.asInstance());
            }
            if (e.getLabel().numericLabel() > 0.0 && score > 0.0 || e.getLabel().numericLabel() < 0.0 && score < 0.0) {
                ++numCorrect;
                continue;
            }
            ++numIncorrect;
        }
        log.info("out of bag num correct: " + numCorrect);
        log.info("out of bag num inCorrect: " + numIncorrect);
        log.info("out of bag estimated error: " + (double)numIncorrect / (double)(numCorrect + numIncorrect));
    }

    public static Vector<Feature> getDatasetFeatures(Dataset dataset) {
        Iterator<Example> it = dataset.iterator();
        HashSet<Feature> allFeatures = new HashSet<Feature>();
        while (it.hasNext()) {
            Feature f;
            Example example = it.next();
            Iterator<Feature> j = example.binaryFeatureIterator();
            while (j.hasNext()) {
                f = j.next();
                allFeatures.add(f);
            }
            j = example.numericFeatureIterator();
            while (j.hasNext()) {
                f = j.next();
                allFeatures.add(f);
            }
        }
        return new Vector<Feature>(allFeatures);
    }

    private static class VotingClassifierViewer
    extends ComponentViewer {
        static final long serialVersionUID = 20080609L;

        private VotingClassifierViewer() {
        }

        public JComponent componentFor(Object o) {
            VotingClassifier bc = (VotingClassifier)o;
            JPanel panel = new JPanel();
            panel.setLayout(new GridBagLayout());
            int ypos = 0;
            for (Classifier c : bc.classifiers) {
                GridBagConstraints gbc = new GridBagConstraints();
                gbc.fill = 2;
                gbc.weighty = 0.0;
                gbc.weightx = 0.0;
                gbc.gridx = 0;
                gbc.gridy = ypos++;
                Viewer subview = c instanceof Visible ? ((Visible)((Object)c)).toGUI() : new VanillaViewer(c);
                subview.setSuperView(this);
                panel.add((Component)subview, gbc);
            }
            JScrollPane scroller = new JScrollPane(panel);
            scroller.setHorizontalScrollBarPolicy(30);
            return scroller;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class VotingClassifier
    extends BinaryClassifier
    implements Serializable,
    Visible {
        static final long serialVersionUID = 20080609L;
        private List<Classifier> classifiers;

        public VotingClassifier(List<Classifier> classifiers) {
            this.classifiers = classifiers;
        }

        public List<Classifier> getClassifiers() {
            return this.classifiers;
        }

        @Override
        public double score(Instance instance) {
            double totalScore = 0.0;
            for (BinaryClassifier binaryClassifier : this.classifiers) {
                totalScore += binaryClassifier.score(instance);
            }
            return totalScore;
        }

        @Override
        public String explain(Instance instance) {
            StringBuffer buf = new StringBuffer("");
            double totalScore = 0.0;
            for (BinaryClassifier binaryClassifier : this.classifiers) {
                totalScore += binaryClassifier.score(instance);
                buf.append("score of " + binaryClassifier + ": " + binaryClassifier.score(instance) + "\n");
                buf.append(StringUtil.indent(1, binaryClassifier.explain(instance)) + "\n");
            }
            buf.append("total score: " + totalScore);
            return buf.toString();
        }

        @Override
        public Explanation getExplanation(Instance instance) {
            Explanation.Node top = new Explanation.Node("Random Forest Explanation");
            double totalScore = 0.0;
            for (BinaryClassifier binaryClassifier : this.classifiers) {
                Explanation.Node score = new Explanation.Node("score of " + binaryClassifier);
                Explanation.Node scoreEx = new Explanation.Node((totalScore += binaryClassifier.score(instance)) + " ");
                Explanation.Node childEx = binaryClassifier.getExplanation(instance).getTopNode();
                score.add(scoreEx);
                score.add(childEx);
                top.add(score);
            }
            Explanation.Node total = new Explanation.Node("total score: " + totalScore);
            top.add(total);
            Explanation explanation = new Explanation(top);
            return explanation;
        }

        public String toString() {
            StringBuffer buf = new StringBuffer("[voting classifiers:\n");
            for (BinaryClassifier binaryClassifier : this.classifiers) {
                buf.append(binaryClassifier.toString() + "\n");
            }
            buf.append("]");
            return buf.toString();
        }

        @Override
        public Viewer toGUI() {
            VotingClassifierViewer v = new VotingClassifierViewer();
            v.setContent(this);
            return v;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class LearnerThread
    extends Thread {
        Vector<Example> examples;
        Vector<Feature> features;
        List<Classifier> classifiers;
        Hashtable<Classifier, Set<Example>> results;
        Semaphore s;

        public LearnerThread(Vector<Example> examples, Vector<Feature> features, List<Classifier> classifiers, Hashtable<Classifier, Set<Example>> results, Semaphore s) {
            this.examples = examples;
            this.features = features;
            this.classifiers = classifiers;
            this.results = results;
            this.s = s;
        }

        @Override
        public void run() {
            LinkedList<Example> newData = new LinkedList<Example>();
            HashSet<Example> oobData = new HashSet<Example>();
            HashSet<Example> duplicates = new HashSet<Example>();
            for (int i = 0; i < this.examples.size(); ++i) {
                Example e = this.examples.elementAt((int)Math.floor(RandomForests.this.rand.nextDouble() * (double)this.examples.size()));
                if (!duplicates.add(e)) continue;
                newData.add(e);
            }
            for (Example e : this.examples) {
                if (duplicates.contains(e) || !RandomForests.this.collectStats) continue;
                oobData.add(e);
            }
            log.debug("RandomForest is building tree  with " + newData.size() + " elements");
            BinaryClassifier c = (BinaryClassifier)RandomForests.this.baseLearner.batchTrain(newData, this.features);
            this.classifiers.add(c);
            if (RandomForests.this.collectStats) {
                this.results.put(c, oobData);
            }
            if (RandomForests.this.isThreaded) {
                this.s.release();
            }
        }
    }
}

