/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.multi;

import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.experiments.CrossValidatedDataset;
import edu.cmu.minorthird.classify.multi.MultiClassifiedDataset;
import edu.cmu.minorthird.classify.multi.MultiClassifier;
import edu.cmu.minorthird.classify.multi.MultiDataset;
import edu.cmu.minorthird.classify.multi.MultiDatasetClassifierTeacher;
import edu.cmu.minorthird.classify.multi.MultiDatasetIndex;
import edu.cmu.minorthird.classify.multi.MultiEvaluation;
import edu.cmu.minorthird.classify.multi.MultiExample;
import edu.cmu.minorthird.classify.transform.PredictedClassTransform;
import edu.cmu.minorthird.classify.transform.TransformingMultiClassifier;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MultiCrossValidatedDataset
implements Visible {
    private static Logger log = Logger.getLogger(CrossValidatedDataset.class);
    private MultiClassifiedDataset[] cds;
    private MultiClassifiedDataset[] trainCds;
    private MultiEvaluation v;

    public MultiCrossValidatedDataset(ClassifierLearner learner, MultiDataset d, Splitter<MultiExample> splitter) {
        this(learner, d, splitter, false, false);
    }

    public MultiCrossValidatedDataset(ClassifierLearner learner, MultiDataset d, Splitter<MultiExample> splitter, boolean saveTrainPartitions) {
        this(learner, d, splitter, saveTrainPartitions, false);
    }

    public MultiCrossValidatedDataset(ClassifierLearner learner, MultiDataset d, Splitter<MultiExample> splitter, boolean saveTrainPartitions, boolean cross) {
        MultiDataset.MultiSplit s = d.MultiSplit(splitter);
        this.cds = new MultiClassifiedDataset[s.getNumPartitions()];
        this.trainCds = saveTrainPartitions ? new MultiClassifiedDataset[s.getNumPartitions()] : null;
        this.v = new MultiEvaluation(d.getMultiSchema());
        ProgressCounter pc = new ProgressCounter("train/test", "fold", s.getNumPartitions());
        for (int k = 0; k < s.getNumPartitions(); ++k) {
            MultiDataset trainData = s.getTrain(k);
            if (cross) {
                trainData = trainData.annotateData();
            }
            MultiDataset testData = s.getTest(k);
            log.info("splitting with " + splitter + ", preparing to train on " + trainData.size() + " and test on " + testData.size());
            MultiClassifier c = new MultiDatasetClassifierTeacher(trainData).train(learner);
            if (cross) {
                PredictedClassTransform transformer = new PredictedClassTransform(c);
                c = new TransformingMultiClassifier(c, transformer);
            }
            MultiDatasetIndex testIndex = new MultiDatasetIndex(testData);
            this.cds[k] = new MultiClassifiedDataset(c, testData, testIndex);
            if (this.trainCds != null) {
                this.trainCds[k] = new MultiClassifiedDataset(c, trainData, testIndex);
            }
            this.v.extend(c, testData);
            log.info("splitting with " + splitter + ", stored classified dataset");
            pc.progress();
        }
        pc.finished();
    }

    @Override
    public Viewer toGUI() {
        int k;
        int i;
        ParallelViewer main = new ParallelViewer();
        for (i = 0; i < this.cds.length; ++i) {
            k = i;
            System.out.println(i);
            main.addSubView("Test Partition " + (i + 1), new TransformedViewer(this.cds[0].toGUI()){
                static final long serialVersionUID = 20080130L;

                public Object transform(Object o) {
                    return MultiCrossValidatedDataset.this.cds[k];
                }
            });
        }
        if (this.trainCds != null) {
            for (i = 0; i < this.trainCds.length; ++i) {
                k = i;
                main.addSubView("Train Partition " + (i + 1), new TransformedViewer(this.cds[0].toGUI()){
                    static final long serialVersionUID = 20080130L;

                    public Object transform(Object o) {
                        return MultiCrossValidatedDataset.this.trainCds[k];
                    }
                });
            }
        }
        main.addSubView("Overall Evaluation", new TransformedViewer(this.v.toGUI()){
            static final long serialVersionUID = 20080130L;

            public Object transform(Object o) {
                MultiCrossValidatedDataset cvd = (MultiCrossValidatedDataset)o;
                return cvd.v;
            }
        });
        main.setContent(this);
        return main;
    }

    public MultiEvaluation getEvaluation() {
        return this.v;
    }

    public static void main(String[] args) {
        System.out.println("CrossValidatedDataset");
    }
}

