/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.experiments;

import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.DatasetIndex;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.SampleDatasets;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner;
import edu.cmu.minorthird.classify.experiments.ClassifiedDataset;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import edu.cmu.minorthird.util.gui.Visible;
import java.text.DecimalFormat;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class CrossValidatedDataset
implements Visible {
    private static Logger log = Logger.getLogger(CrossValidatedDataset.class);
    private ClassifiedDataset[] cds;
    private ClassifiedDataset[] trainCds;
    private Evaluation v;

    public CrossValidatedDataset(ClassifierLearner learner, Dataset d, Splitter<Example> splitter) {
        this(learner, d, splitter, false);
    }

    public CrossValidatedDataset(ClassifierLearner learner, Dataset d, Splitter<Example> splitter, boolean saveTrainPartitions) {
        Dataset.Split s = d.split(splitter);
        this.cds = new ClassifiedDataset[s.getNumPartitions()];
        this.trainCds = saveTrainPartitions ? new ClassifiedDataset[s.getNumPartitions()] : null;
        this.v = new Evaluation(d.getSchema());
        ProgressCounter pc = new ProgressCounter("train/test", "fold", s.getNumPartitions());
        for (int k = 0; k < s.getNumPartitions(); ++k) {
            Dataset trainData = s.getTrain(k);
            Dataset testData = s.getTest(k);
            log.info("splitting with " + splitter + ", preparing to train on " + trainData.size() + " and test on " + testData.size());
            Classifier c = new DatasetClassifierTeacher(trainData).train(learner);
            DatasetIndex testIndex = new DatasetIndex(testData);
            this.cds[k] = new ClassifiedDataset(c, testData, testIndex);
            if (this.trainCds != null) {
                this.trainCds[k] = new ClassifiedDataset(c, trainData, testIndex);
            }
            this.v.extend(this.cds[k].getClassifier(), testData, k);
            this.v.setProperty("classesInFold" + (k + 1), "train: " + this.classDistributionString(trainData.getSchema(), new DatasetIndex(trainData)) + "     test: " + this.classDistributionString(testData.getSchema(), testIndex));
            log.info("splitting with " + splitter + ", stored classified dataset");
            pc.progress();
        }
        pc.finished();
    }

    private String classDistributionString(ExampleSchema schema, DatasetIndex index) {
        StringBuffer buf = new StringBuffer("");
        DecimalFormat fmt = new DecimalFormat("#####");
        for (int i = 0; i < schema.getNumberOfClasses(); ++i) {
            if (buf.length() > 0) {
                buf.append("; ");
            }
            String label = schema.getClassName(i);
            buf.append(fmt.format(index.size(label)) + " " + label);
        }
        return buf.toString();
    }

    @Override
    public Viewer toGUI() {
        int k;
        int i;
        ParallelViewer main = new ParallelViewer();
        for (i = 0; i < this.cds.length; ++i) {
            k = i;
            main.addSubView("Test Partition " + (i + 1), new TransformedViewer(this.cds[0].toGUI()){
                static final long serialVersionUID = 20080130L;

                public Object transform(Object o) {
                    return CrossValidatedDataset.this.cds[k];
                }
            });
        }
        if (this.trainCds != null) {
            for (i = 0; i < this.trainCds.length; ++i) {
                k = i;
                main.addSubView("Train Partition " + (i + 1), new TransformedViewer(this.cds[0].toGUI()){
                    static final long serialVersionUID = 20080130L;

                    public Object transform(Object o) {
                        return CrossValidatedDataset.this.trainCds[k];
                    }
                });
            }
        }
        main.addSubView("Overall Evaluation", new TransformedViewer(this.v.toGUI()){
            static final long serialVersionUID = 20080130L;

            public Object transform(Object o) {
                CrossValidatedDataset cvd = (CrossValidatedDataset)o;
                return cvd.v;
            }
        });
        main.setContent(this);
        return main;
    }

    public Evaluation getEvaluation() {
        return this.v;
    }

    public static void main(String[] args) {
        Dataset train = SampleDatasets.sampleData("toy", false);
        DecisionTreeLearner learner = new DecisionTreeLearner();
        CrossValidatedDataset cd = new CrossValidatedDataset(learner, train, new CrossValSplitter<Example>(3), true);
        new ViewerFrame("CrossValidatedDataset", cd.toGUI());
    }
}

