/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;

public class OneVsAllClassifier
implements Classifier,
Visible,
Serializable {
    private static final long serialVersionUID = 1L;
    private String[] classNames;
    private Classifier[] binaryClassifiers;

    public OneVsAllClassifier(String[] classNames, Classifier[] binaryClassifiers) {
        if (classNames.length != binaryClassifiers.length) {
            throw new IllegalArgumentException("arrays must be parallel");
        }
        this.classNames = classNames;
        this.binaryClassifiers = binaryClassifiers;
    }

    public Classifier[] getBinaryClassifiers() {
        return this.binaryClassifiers;
    }

    public ClassLabel classification(Instance instance) {
        ClassLabel classLabel = new ClassLabel();
        for (int i = 0; i < this.classNames.length; ++i) {
            classLabel.add(this.classNames[i], this.binaryClassifiers[i].classification(instance).posWeight());
        }
        return classLabel;
    }

    public String explain(Instance instance) {
        StringBuffer buf = new StringBuffer("");
        for (int i = 0; i < this.binaryClassifiers.length; ++i) {
            buf.append("score for " + this.classNames[i] + ": ");
            buf.append(this.binaryClassifiers[i].explain(instance));
            buf.append("\n");
        }
        buf.append("classification = " + this.classification(instance).toString());
        return buf.toString();
    }

    public Explanation getExplanation(Instance instance) {
        Explanation.Node top = new Explanation.Node("OneVsAll Explanation");
        for (int i = 0; i < this.binaryClassifiers.length; ++i) {
            Explanation.Node binClassifierNode = new Explanation.Node(this.classNames[i] + " Tree");
            Explanation.Node explanation = this.binaryClassifiers[i].getExplanation(instance).getTopNode();
            binClassifierNode.add(explanation);
            top.add(binClassifierNode);
        }
        Explanation ex = new Explanation(top);
        return ex;
    }

    public String[] getClassNames() {
        return this.classNames;
    }

    public String toString() {
        StringBuffer buf = new StringBuffer("[OneVsAllClassifier:\n");
        for (int i = 0; i < this.classNames.length; ++i) {
            buf.append(this.classNames[i] + ": " + this.binaryClassifiers[i] + "\n");
        }
        buf.append("end OneVsAllClassifier]\n");
        return buf.toString();
    }

    public Viewer toGUI() {
        ComponentViewer v = new ComponentViewer(){
            static final long serialVersionUID = 20071015L;

            public JComponent componentFor(Object o) {
                OneVsAllClassifier c = (OneVsAllClassifier)o;
                JPanel panel = new JPanel();
                for (int i = 0; i < c.classNames.length; ++i) {
                    panel.add(new JLabel(c.classNames[i]));
                    SmartVanillaViewer subView = new SmartVanillaViewer();
                    subView.setContent(c.binaryClassifiers[i]);
                    subView.setSuperView(this);
                    panel.add(subView);
                }
                return new JScrollPane(panel);
            }
        };
        v.setContent(this);
        return v;
    }
}

