/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.sequential.BeamSearcher;
import edu.cmu.minorthird.classify.sequential.CMM;
import edu.cmu.minorthird.classify.sequential.CRFLearner;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;
import java.util.Iterator;

public class MaxEntLearner
extends BatchClassifierLearner {
    private CRFLearner crfLearner;
    private boolean scaleScores = false;
    public boolean logSpace = true;

    public MaxEntLearner() {
        this.crfLearner = new CRFLearner("", 1);
    }

    public MaxEntLearner(String args) {
        this.crfLearner = new CRFLearner(args, 1);
        if (args.indexOf("scaleScores 1") >= 0) {
            this.scaleScores = true;
            System.out.println("scaleScores => true");
        }
    }

    public void setLogSpace(boolean b) {
        if (b) {
            this.crfLearner.setLogSpaceOption();
        } else {
            this.crfLearner.removeLogSpaceOption();
        }
        this.logSpace = b;
    }

    public boolean getLogSpace() {
        return this.logSpace;
    }

    public void setSchema(ExampleSchema schema) {
        this.crfLearner.setSchema(schema);
    }

    public ExampleSchema getSchema() {
        return this.crfLearner.getSchema();
    }

    public Classifier batchTrain(Dataset dataset) {
        SequenceDataset seqData = new SequenceDataset();
        Iterator<Example> i = dataset.iterator();
        while (i.hasNext()) {
            Example e = i.next();
            seqData.addSequence(new Example[]{e});
        }
        CMM c = (CMM)this.crfLearner.batchTrain(seqData);
        return new MyClassifier(c.getClassifier(), seqData.getSchema(), this.scaleScores);
    }

    public static class MyClassifier
    implements Classifier,
    Serializable,
    Visible {
        private static final long serialVersionUID = 20080128L;
        private Classifier c;
        private ExampleSchema schema;
        private boolean scaleScores;

        public MyClassifier(Classifier c, ExampleSchema schema, boolean scaleScores) {
            this.c = c;
            this.schema = schema;
            this.scaleScores = scaleScores;
        }

        public ClassLabel classification(Instance instance) {
            ClassLabel label = this.c.classification(BeamSearcher.getBeamInstance(instance, 1));
            return this.scaleScores ? this.transformScores(label) : label;
        }

        public String explain(Instance instance) {
            Instance augmentedInstance = BeamSearcher.getBeamInstance(instance, 1);
            if (this.scaleScores) {
                return "Augmented instance: " + augmentedInstance + "\n" + this.c.explain(augmentedInstance) + "\nTransformed score: " + this.classification(instance);
            }
            return "Augmented instance: " + augmentedInstance + "\n" + this.c.explain(augmentedInstance);
        }

        public Explanation getExplanation(Instance instance) {
            Explanation.Node ai;
            Explanation.Node top = new Explanation.Node("MaxEntClassifier Explanation");
            Instance augmentedInstance = BeamSearcher.getBeamInstance(instance, 1);
            if (this.scaleScores) {
                ai = new Explanation.Node("Augmented instance: " + augmentedInstance);
                String augmentedEx = this.c.explain(augmentedInstance);
                String[] split = augmentedEx.split("\n");
                Explanation.Node curTopNode = ai;
                for (int i = 0; i < split.length; ++i) {
                    Explanation.Node exNode = new Explanation.Node(split[i]);
                    if (split[i].charAt(0) != ' ') {
                        curTopNode = exNode;
                        ai.add(exNode);
                        continue;
                    }
                    curTopNode.add(exNode);
                }
                top.add(ai);
                Explanation.Node ts = new Explanation.Node("\nTransformed score: " + this.classification(instance));
                top.add(ts);
            } else {
                ai = new Explanation.Node("Augmented instance: " + augmentedInstance);
                String augmentedEx = this.c.explain(augmentedInstance);
                String[] split = augmentedEx.split("\n");
                Explanation.Node curTopNode = ai;
                for (int i = 0; i < split.length; ++i) {
                    Explanation.Node exNode = new Explanation.Node(split[i]);
                    if (split[i].charAt(0) != ' ') {
                        curTopNode = exNode;
                        ai.add(exNode);
                        continue;
                    }
                    curTopNode.add(exNode);
                }
                top.add(ai);
            }
            Explanation ex = new Explanation(top);
            return ex;
        }

        private ClassLabel transformScores(ClassLabel label) {
            double[] pseudoProb = new double[this.schema.getNumberOfClasses()];
            double normalizer = 0.0;
            for (int i = 0; i < this.schema.getNumberOfClasses(); ++i) {
                String yi = this.schema.getClassName(i);
                pseudoProb[i] = Math.exp(label.getWeight(yi));
                normalizer += pseudoProb[i];
            }
            ClassLabel transformed = new ClassLabel();
            for (int i = 0; i < this.schema.getNumberOfClasses(); ++i) {
                String yi = this.schema.getClassName(i);
                double p = pseudoProb[i] / normalizer;
                transformed.add(yi, Math.log(p / (1.0 - p)));
            }
            return transformed;
        }

        public Classifier getRawClassifier() {
            return this.c;
        }

        public Viewer toGUI() {
            TransformedViewer v = new TransformedViewer(new SmartVanillaViewer()){
                static final long serialVersionUID = 20080128L;

                public Object transform(Object o) {
                    MyClassifier mycl = (MyClassifier)o;
                    return mycl.c;
                }
            };
            v.setContent(this);
            return v;
        }
    }
}

