/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.svm;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.FeatureFactory;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.algorithms.svm.SVMUtils;
import edu.cmu.minorthird.classify.algorithms.svm.VisibleSVM;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;
import javax.swing.JComponent;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;

public class SVMClassifier
implements Classifier,
Serializable,
Visible {
    static final long serialVersionUID = 20071130L;
    private svm_model model;
    private ExampleSchema schema;
    private FeatureFactory featureFactory;

    public SVMClassifier(svm_model model, ExampleSchema schema, FeatureFactory featureFactory) {
        this.model = model;
        this.schema = schema;
        this.featureFactory = featureFactory;
    }

    public String explain(Instance instance) {
        return "None";
    }

    public Explanation getExplanation(Instance instance) {
        return new Explanation(this.explain(instance));
    }

    public svm_model getSVMModel() {
        return this.model;
    }

    public ExampleSchema getSchema() {
        return this.schema;
    }

    public FeatureFactory getFeatureFactory() {
        return this.featureFactory;
    }

    public ClassLabel classification(Instance instance) {
        instance = this.featureFactory.compress(instance);
        svm_node[] nodeArray = SVMUtils.instanceToNodeArray(instance);
        ClassLabel label = new ClassLabel();
        if (svm.svm_check_probability_model(this.model) > 0) {
            if (this.schema.equals(ExampleSchema.BINARY_EXAMPLE_SCHEMA)) {
                double[] probs = new double[2];
                double prediction = svm.svm_predict_probability(this.model, nodeArray, probs);
                prediction = probs[0] > probs[1] ? (prediction *= Math.log(probs[0] / (1.0 - probs[0]))) : (prediction *= Math.log(probs[1] / (1.0 - probs[1])));
                label = prediction >= 0.0 ? ClassLabel.positiveLabel(prediction) : ClassLabel.negativeLabel(prediction);
            } else {
                double[] probs = new double[svm.svm_get_nr_class(this.model)];
                svm.svm_predict_probability(this.model, nodeArray, probs);
                int[] labels = new int[svm.svm_get_nr_class(this.model)];
                svm.svm_get_labels(this.model, labels);
                for (int i = 0; i < labels.length; ++i) {
                    double logOdds = probs[i];
                    label.add(this.schema.getClassName(labels[i]), logOdds);
                }
            }
        } else {
            double prediction = svm.svm_predict(this.model, nodeArray);
            if (this.schema.equals(ExampleSchema.BINARY_EXAMPLE_SCHEMA)) {
                if (prediction < 0.0) {
                    label.add("NEG", 1.0);
                } else {
                    label.add("POS", 1.0);
                }
            } else {
                label.add(this.schema.getClassName((int)prediction), 1.0);
            }
        }
        return label;
    }

    public Viewer toGUI() {
        SVMViewer svmViewer = new SVMViewer();
        svmViewer.setContent(this);
        return svmViewer;
    }

    private static class SVMViewer
    extends ComponentViewer {
        static final long serialVersionUID = 20071130L;

        private SVMViewer() {
        }

        public boolean canReceive(Object o) {
            return o instanceof SVMClassifier;
        }

        public JComponent componentFor(Object o) {
            SVMClassifier svmClassifier = (SVMClassifier)o;
            VisibleSVM vsSVMtemp = new VisibleSVM(svmClassifier.getSVMModel(), svmClassifier.getFeatureFactory());
            return vsSVMtemp.toGUI();
        }
    }
}

