/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner;
import edu.cmu.minorthird.classify.sequential.CMMLearner;
import edu.cmu.minorthird.classify.sequential.DatasetSequenceClassifierTeacher;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.classify.transform.AugmentedInstance;
import edu.cmu.minorthird.util.MathUtil;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.util.Iterator;
import org.apache.log4j.Logger;

public class StackedSequenceLearner
implements BatchSequenceClassifierLearner {
    private static Logger log = Logger.getLogger(StackedSequenceLearner.class);
    private SequenceClassifierLearner baseLearner = new CMMLearner(new VotedPerceptron(), 0);
    private StackingParams params = new StackingParams();

    public int getHistorySize() {
        return this.params.historySize;
    }

    public void setHistorySize(int newHistorySize) {
        this.params.setHistorySize(newHistorySize);
    }

    public StackingParams getParams() {
        return this.params;
    }

    public StackedSequenceLearner() {
    }

    public StackedSequenceLearner(SequenceClassifierLearner baseLearner, int depth) {
        this();
        this.baseLearner = baseLearner;
        this.params.setStackingDepth(depth);
    }

    public StackedSequenceLearner(ClassifierLearner baseLearner, int depth) {
        this();
        this.baseLearner = new CMMLearner(baseLearner, 0);
        this.params.setStackingDepth(depth);
    }

    public StackedSequenceLearner(SequenceClassifierLearner baseLearner, int depth, int windowSize) {
        this();
        this.baseLearner = baseLearner;
        this.params.setStackingDepth(depth);
        this.params.setHistorySize(windowSize);
        this.params.setFutureSize(windowSize);
    }

    public StackedSequenceLearner(ClassifierLearner baseLearner, int depth, int windowSize) {
        this();
        this.baseLearner = new CMMLearner(baseLearner, 0);
        this.params.setStackingDepth(depth);
        this.params.setHistorySize(windowSize);
        this.params.setFutureSize(windowSize);
    }

    public void setSchema(ExampleSchema schema) {
    }

    public SequenceClassifier batchTrain(SequenceDataset dataset) {
        SequenceClassifier[] m = new SequenceClassifier[this.params.stackingDepth + 1];
        SequenceDataset stackedDataset = dataset;
        stackedDataset.setHistorySize(0);
        ProgressCounter pc = new ProgressCounter("training stacked learner", "stacking level", this.params.stackingDepth + 1);
        for (int d = 0; d <= this.params.stackingDepth; ++d) {
            m[d] = new DatasetSequenceClassifierTeacher(stackedDataset).train(this.baseLearner);
            if (d + 1 <= this.params.stackingDepth) {
                stackedDataset = this.stackDataset(stackedDataset);
            }
            pc.progress();
        }
        pc.finished();
        return new StackedSequenceClassifier(m, this.params);
    }

    public SequenceDataset stackDataset(SequenceDataset dataset) {
        SequenceDataset result = new SequenceDataset();
        Dataset.Split s = dataset.splitSequence(this.params.splitter);
        ProgressCounter pc = new ProgressCounter("labeling for stacking", "fold", s.getNumPartitions());
        for (int k = 0; k < s.getNumPartitions(); ++k) {
            SequenceDataset trainData = (SequenceDataset)s.getTrain(k);
            SequenceDataset testData = (SequenceDataset)s.getTest(k);
            log.info("splitting with " + this.params.splitter + ", preparing to train on " + trainData.size() + " and test on " + testData.size());
            SequenceClassifier c = new DatasetSequenceClassifierTeacher(trainData).train(this.baseLearner);
            Iterator<Example[]> i = testData.sequenceIterator();
            while (i.hasNext()) {
                Instance[] seq = i.next();
                ClassLabel[] labels = c.classification(seq);
                Example[] stackSeq = new Example[seq.length];
                for (int j = 0; j < seq.length; ++j) {
                    Instance stackInstance = StackedSequenceLearner.stackInstance(j, ((Example)seq[j]).asInstance(), labels, this.params);
                    stackSeq[j] = new Example(stackInstance, ((Example)seq[j]).getLabel());
                }
                result.addSequence(stackSeq);
            }
            log.info("splitting with " + this.params.splitter + ", stored classified dataset");
            pc.progress();
        }
        pc.finished();
        result.setHistorySize(0);
        return result;
    }

    private static Instance stackInstance(int j, Instance instancej, ClassLabel[] labels, StackingParams params) {
        int numNewFeatures = params.historySize + params.futureSize + (params.useTargetPrediction ? 1 : 0);
        String[] features = new String[numNewFeatures];
        double[] values = new double[numNewFeatures];
        int index = 0;
        for (int m = j - params.historySize; m <= j + params.futureSize; ++m) {
            if (m == j && !params.useTargetPrediction) continue;
            if (m >= 0 && m < labels.length) {
                features[index] = StackedSequenceLearner.stackFeatureName(m - j, labels[m].bestClassName());
                values[index] = 1.0;
                if (params.useConfidence) {
                    double w = labels[m].bestWeight();
                    values[index] = params.useLogistic ? MathUtil.logistic(w) : w;
                }
            } else {
                features[index] = StackedSequenceLearner.stackFeatureName(m - j, "NULL");
                values[index] = 1.0;
            }
            ++index;
        }
        return new AugmentedInstance(instancej, features, values);
    }

    private static String stackFeatureName(int offsetFromTarget, String predictedClassName) {
        if (offsetFromTarget < 0) {
            return "pred.prev." + -offsetFromTarget + "." + predictedClassName;
        }
        if (offsetFromTarget > 0) {
            return "pred.next." + offsetFromTarget + "." + predictedClassName;
        }
        return "pred.here." + predictedClassName;
    }

    private class StackedSequenceClassifier
    implements SequenceClassifier,
    Visible {
        private SequenceClassifier[] m;
        private StackingParams params;

        public StackedSequenceClassifier(SequenceClassifier[] m, StackingParams params) {
            this.m = m;
            this.params = params;
        }

        public ClassLabel[] classification(Instance[] sequence) {
            ClassLabel[] labels = this.m[0].classification(sequence);
            Instance[] augmentedSequence = new Instance[sequence.length];
            for (int d = 1; d < this.m.length; ++d) {
                for (int j = 0; j < sequence.length; ++j) {
                    augmentedSequence[j] = StackedSequenceLearner.stackInstance(j, sequence[j], labels, this.params);
                }
                labels = this.m[d].classification(augmentedSequence);
            }
            return labels;
        }

        public String explain(Instance[] sequence) {
            return "not implemented";
        }

        public Explanation getExplanation(Instance[] sequence) {
            Explanation ex = new Explanation(this.explain(sequence));
            return ex;
        }

        public Viewer toGUI() {
            ParallelViewer v = new ParallelViewer();
            int i = 0;
            while (i < this.m.length) {
                final int k = i++;
                v.addSubView("Level " + k + " classifier", new TransformedViewer(new SmartVanillaViewer(this.m[k])){
                    static final long serialVersionUID = 20080207L;

                    public Object transform(Object o) {
                        StackedSequenceClassifier s = (StackedSequenceClassifier)o;
                        return s.m[k];
                    }
                });
            }
            v.setContent(this);
            return v;
        }
    }

    public static class StackingParams {
        public int historySize = 5;
        public int futureSize = 5;
        public int stackingDepth = 1;
        public boolean useLogistic = true;
        public boolean useTargetPrediction = true;
        public boolean useConfidence = true;
        public Splitter<Example[]> splitter = new CrossValSplitter<Example[]>(5);
        int crossValSplits = 5;

        public int getHistorySize() {
            return this.historySize;
        }

        public void setHistorySize(int newHistorySize) {
            this.historySize = newHistorySize;
        }

        public int getFutureSize() {
            return this.futureSize;
        }

        public void setFutureSize(int newFutureSize) {
            this.futureSize = newFutureSize;
        }

        public boolean getUseLogisticOnConfidences() {
            return this.useLogistic;
        }

        public void setUseLogisticOnConfidences(boolean flag) {
            this.useLogistic = flag;
        }

        public boolean getUseConfidences() {
            return this.useConfidence;
        }

        public void setUseConfidences(boolean flag) {
            this.useConfidence = flag;
        }

        public boolean getUseTargetPrediction() {
            return this.useTargetPrediction;
        }

        public void setUseTargetPrediction(boolean flag) {
            this.useTargetPrediction = flag;
        }

        public int getStackingDepth() {
            return this.stackingDepth;
        }

        public void setStackingDepth(int newStackingDepth) {
            this.stackingDepth = newStackingDepth;
        }

        public int getCrossValSplits() {
            return this.crossValSplits;
        }

        public void setCrossValSplits(int newCrossValSplits) {
            this.splitter = new CrossValSplitter<Example[]>(newCrossValSplits);
            this.crossValSplits = newCrossValSplits;
        }
    }
}

