/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.relational;

import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.SGMExample;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.relational.RealRelationalDataset;
import edu.cmu.minorthird.classify.relational.StackedBatchClassifierLearner;
import edu.cmu.minorthird.classify.transform.AugmentedInstance;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class StackedGraphicalLearner
extends StackedBatchClassifierLearner {
    private static Logger log = Logger.getLogger(StackedGraphicalLearner.class);
    private ExampleSchema schema;
    private BatchClassifierLearner baseLearner = new MaxEntLearner();
    private StackingParams params = new StackingParams();

    public StackingParams getParams() {
        return this.params;
    }

    public StackedGraphicalLearner() {
    }

    public StackedGraphicalLearner(BatchClassifierLearner baseLearner) {
        this();
        this.baseLearner = baseLearner;
        this.params.setStackingDepth(1);
    }

    public StackedGraphicalLearner(BatchClassifierLearner baseLearner, int depth) {
        this();
        this.baseLearner = baseLearner;
        this.params.setStackingDepth(depth);
    }

    public StackedGraphicalLearner(int depth) {
        this();
        this.params.setStackingDepth(depth);
    }

    @Override
    public final void setSchema(ExampleSchema schema) {
        this.schema = schema;
    }

    @Override
    public final ExampleSchema getSchema() {
        return this.schema;
    }

    @Override
    public Classifier batchTrain(RealRelationalDataset dataset) {
        Classifier[] m = new Classifier[this.params.stackingDepth + 1];
        RealRelationalDataset stackedDataset = dataset;
        ProgressCounter pc = new ProgressCounter("training stacked learner", "stacking level", this.params.stackingDepth + 1);
        for (int d = 0; d <= this.params.stackingDepth; ++d) {
            m[d] = new DatasetClassifierTeacher(stackedDataset).train(this.baseLearner);
            if (d + 1 <= this.params.stackingDepth) {
                stackedDataset = this.stackDataset(stackedDataset);
            }
            pc.progress();
        }
        pc.finished();
        return new StackedGraphicalClassifier(m, this.params, dataset);
    }

    public RealRelationalDataset stackDataset(RealRelationalDataset dataset) {
        RealRelationalDataset result = new RealRelationalDataset();
        Dataset.Split s = dataset.split(this.params.splitter);
        this.schema = dataset.getSchema();
        ProgressCounter pc = new ProgressCounter("stack-labeling", "fold", s.getNumPartitions());
        HashMap<String, ClassLabel> rlt = new HashMap<String, ClassLabel>();
        for (int k = 0; k < s.getNumPartitions(); ++k) {
            RealRelationalDataset trainData = (RealRelationalDataset)s.getTrain(k);
            RealRelationalDataset testData = (RealRelationalDataset)s.getTest(k);
            log.info("splitting with " + this.params.splitter + ", preparing to train on " + trainData.size() + " and test on " + testData.size());
            Classifier c = new DatasetClassifierTeacher(trainData).train(this.baseLearner);
            Iterator<Example> i = testData.iterator();
            while (i.hasNext()) {
                SGMExample ex = (SGMExample)i.next();
                ClassLabel p = c.classification(ex);
                rlt.put(ex.getExampleID(), p);
            }
            log.info("splitting with " + this.params.splitter + ", stored classified dataset");
            pc.progress();
        }
        Map<String, Map<String, Set<String>>> LinksMap = RealRelationalDataset.getLinksMap();
        Map<String, Set<String>> Aggregators = RealRelationalDataset.getAggregators();
        Iterator<Example> i = dataset.iterator();
        while (i.hasNext()) {
            SGMExample ex = (SGMExample)i.next();
            SGMExample AugmentEx = this.AugmentExample(ex, LinksMap, Aggregators, rlt);
            result.add(AugmentEx);
        }
        pc.finished();
        return result;
    }

    private SGMExample AugmentExample(SGMExample ex, Map<String, Map<String, Set<String>>> LinksMap, Map<String, Set<String>> Aggregators, Map<String, ClassLabel> PredictedRlt) {
        int numNewFeatures = 0;
        Iterator<String> iter = Aggregators.keySet().iterator();
        while (iter.hasNext()) {
            numNewFeatures += Aggregators.get(iter.next()).size() * this.schema.getNumberOfClasses();
        }
        String[] features = new String[numNewFeatures];
        double[] values = new double[numNewFeatures];
        int index = 0;
        String egID = ex.getExampleID();
        if (LinksMap.containsKey(egID)) {
            for (String type : Aggregators.keySet()) {
                if (!LinksMap.get(egID).containsKey(type)) continue;
                Set<String> oper = Aggregators.get(type);
                for (String Agr : oper) {
                    int[] temval = new int[this.schema.getNumberOfClasses()];
                    Set<String> ngb = LinksMap.get(egID).get(type);
                    for (String ngbID : ngb) {
                        int idx;
                        if (PredictedRlt.get(ngbID) == null) continue;
                        String pre = PredictedRlt.get(ngbID).bestClassName();
                        int n = idx = this.schema.getClassIndex(pre);
                        temval[n] = temval[n] + 1;
                    }
                    for (int i = 0; i < this.schema.getNumberOfClasses(); ++i) {
                        features[index] = StackedGraphicalLearner.stackFeatureName(type, Agr, this.schema.getClassName(i));
                        if (Agr.equals("COUNT")) {
                            values[index] = temval[i];
                        }
                        if (Agr.equals("EXISTS") && temval[i] > 0) {
                            values[index] = 1.0;
                        }
                        ++index;
                    }
                }
            }
            String[] truefeatures = new String[index];
            double[] truevalues = new double[index];
            for (int i = 0; i < index; ++i) {
                truefeatures[i] = features[i];
                truevalues[i] = values[i];
            }
            AugmentedInstance stackedInstance = new AugmentedInstance(ex.asInstance(), truefeatures, truevalues);
            return new SGMExample((Instance)stackedInstance, ex.getLabel(), ex.getExampleID());
        }
        return ex;
    }

    private static String stackFeatureName(String agr, String type, String predictedClassName) {
        return "pred." + agr + "." + type + "." + predictedClassName;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public class StackedGraphicalClassifier
    implements Classifier,
    Visible {
        private Classifier[] m;
        private StackingParams params;

        public StackedGraphicalClassifier(Classifier[] m, StackingParams params, RealRelationalDataset ds) {
            this.m = m;
            this.params = params;
        }

        @Override
        public ClassLabel classification(Instance instance) {
            return this.m[0].classification(instance);
        }

        public Map<String, ClassLabel> classification(RealRelationalDataset dataset) {
            HashMap<String, ClassLabel> rlt = new HashMap<String, ClassLabel>();
            RealRelationalDataset testData = dataset;
            for (int d = 0; d <= this.params.stackingDepth; ++d) {
                Iterator<Example> i = testData.iterator();
                while (i.hasNext()) {
                    SGMExample ex = (SGMExample)i.next();
                    ClassLabel p = this.m[d].classification(ex);
                    rlt.put(ex.getExampleID(), p);
                }
                if (d + 1 > this.params.stackingDepth) continue;
                testData = this.stackTestDataset(testData, rlt);
            }
            return rlt;
        }

        public RealRelationalDataset stackTestDataset(RealRelationalDataset dataset, Map<String, ClassLabel> predictions) {
            RealRelationalDataset result = new RealRelationalDataset();
            Map<String, Map<String, Set<String>>> LinksMap = RealRelationalDataset.getLinksMap();
            Map<String, Set<String>> Aggregators = RealRelationalDataset.getAggregators();
            Iterator<Example> i = dataset.iterator();
            while (i.hasNext()) {
                SGMExample ex = (SGMExample)i.next();
                SGMExample AugmentEx = StackedGraphicalLearner.this.AugmentExample(ex, LinksMap, Aggregators, predictions);
                result.addSGM(AugmentEx);
            }
            return result;
        }

        public double score(Instance instance, String classLabelName) {
            return this.classification(instance).getWeight(classLabelName);
        }

        @Override
        public String explain(Instance instance) {
            return "sorry, not implemented yet";
        }

        @Override
        public Explanation getExplanation(Instance instance) {
            Explanation ex = new Explanation(this.explain(instance));
            return ex;
        }

        @Override
        public Viewer toGUI() {
            ParallelViewer v = new ParallelViewer();
            int i = 0;
            while (i < this.m.length) {
                final int k = i++;
                v.addSubView("Level " + k + " classifier", new TransformedViewer(new SmartVanillaViewer(this.m[k])){
                    static final long serialVersionUID = 20080202L;

                    public Object transform(Object o) {
                        StackedGraphicalClassifier s = (StackedGraphicalClassifier)o;
                        return s.m[k];
                    }
                });
            }
            v.setContent(this);
            return v;
        }
    }

    public static class StackingParams {
        public int stackingDepth = 1;
        public boolean useLogistic = true;
        public boolean useTargetPrediction = true;
        public boolean useConfidence = true;
        public Splitter<Example> splitter = new CrossValSplitter<Example>(5);
        int crossValSplits = 5;

        public boolean getUseLogisticOnConfidences() {
            return this.useLogistic;
        }

        public void setUseLogisticOnConfidences(boolean flag) {
            this.useLogistic = flag;
        }

        public boolean getUseConfidences() {
            return this.useConfidence;
        }

        public void setUseConfidences(boolean flag) {
            this.useConfidence = flag;
        }

        public boolean getUseTargetPrediction() {
            return this.useTargetPrediction;
        }

        public void setUseTargetPrediction(boolean flag) {
            this.useTargetPrediction = flag;
        }

        public int getStackingDepth() {
            return this.stackingDepth;
        }

        public void setStackingDepth(int newStackingDepth) {
            this.stackingDepth = newStackingDepth;
        }

        public int getCrossValSplits() {
            return this.crossValSplits;
        }

        public void setCrossValSplits(int newCrossValSplits) {
            this.splitter = new CrossValSplitter<Example>(newCrossValSplits);
            this.crossValSplits = newCrossValSplits;
        }
    }
}

