/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.classify.sequential.ClassifiedSequenceDataset;
import edu.cmu.minorthird.classify.sequential.DatasetSequenceClassifierTeacher;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class CrossValidatedSequenceDataset
implements Visible {
    private static Logger log = Logger.getLogger(CrossValidatedSequenceDataset.class);
    private ClassifiedSequenceDataset[] cds;
    private ClassifiedSequenceDataset[] trainCds;
    private Evaluation v;

    public CrossValidatedSequenceDataset(SequenceClassifierLearner learner, SequenceDataset d, Splitter<Example[]> splitter) {
        this(learner, d, splitter, false);
    }

    public CrossValidatedSequenceDataset(SequenceClassifierLearner learner, SequenceDataset d, Splitter<Example[]> splitter, boolean saveTrainPartitions) {
        Dataset.Split s = d.splitSequence(splitter);
        this.cds = new ClassifiedSequenceDataset[s.getNumPartitions()];
        this.trainCds = saveTrainPartitions ? new ClassifiedSequenceDataset[s.getNumPartitions()] : null;
        this.v = new Evaluation(d.getSchema());
        ProgressCounter pc = new ProgressCounter("train/test", "fold", s.getNumPartitions());
        for (int k = 0; k < s.getNumPartitions(); ++k) {
            SequenceDataset trainData = (SequenceDataset)s.getTrain(k);
            SequenceDataset testData = (SequenceDataset)s.getTest(k);
            log.info("splitting with " + splitter + ", preparing to train on " + trainData.size() + " and test on " + testData.size());
            SequenceClassifier c = new DatasetSequenceClassifierTeacher(trainData).train(learner);
            this.cds[k] = new ClassifiedSequenceDataset(c, testData);
            if (this.trainCds != null) {
                this.trainCds[k] = new ClassifiedSequenceDataset(c, trainData);
            }
            this.v.extend(this.cds[k].getClassifier(), testData, 0);
            log.info("splitting with " + splitter + ", stored classified dataset");
            pc.progress();
        }
        pc.finished();
    }

    public Evaluation getEvaluation() {
        return this.v;
    }

    @Override
    public Viewer toGUI() {
        int k;
        int i;
        ParallelViewer main = new ParallelViewer();
        for (i = 0; i < this.cds.length; ++i) {
            k = i;
            main.addSubView("Test Partition " + (i + 1), new TransformedViewer(this.cds[0].toGUI()){
                static final long serialVersionUID = 20080207L;

                public Object transform(Object o) {
                    return CrossValidatedSequenceDataset.this.cds[k];
                }
            });
        }
        if (this.trainCds != null) {
            for (i = 0; i < this.trainCds.length; ++i) {
                k = i;
                main.addSubView("Train Partition " + (i + 1), new TransformedViewer(this.cds[0].toGUI()){
                    static final long serialVersionUID = 20080207L;

                    public Object transform(Object o) {
                        return CrossValidatedSequenceDataset.this.trainCds[k];
                    }
                });
            }
        }
        main.addSubView("Overall Evaluation", new TransformedViewer(this.v.toGUI()){
            static final long serialVersionUID = 20080207L;

            public Object transform(Object o) {
                CrossValidatedSequenceDataset cvd = (CrossValidatedSequenceDataset)o;
                return cvd.v;
            }
        });
        main.setContent(this);
        return main;
    }

    public static void main(String[] args) {
    }
}

