/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify;

import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.ClassifierLearnerFactory;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.OneVsAllClassifier;
import edu.cmu.minorthird.classify.OneVsAllLearner;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.classify.experiments.Tester;
import java.util.ArrayList;
import java.util.List;

public class CascadingBinaryLearner
extends OneVsAllLearner {
    public String[] sortedClassNames;
    private List<Dataset> data = null;
    private List<Evaluation> eval = null;

    public CascadingBinaryLearner() {
    }

    public CascadingBinaryLearner(ClassifierLearnerFactory learnerFactory) {
        super(learnerFactory);
    }

    public CascadingBinaryLearner(String l) {
        super(l);
    }

    public CascadingBinaryLearner(BatchClassifierLearner learner) {
        this.learner = learner;
        this.learnerName = learner.toString();
        this.learnerFactory = new ClassifierLearnerFactory(this.learnerName);
    }

    public void setSchema(ExampleSchema schema) {
        this.schema = schema;
        this.innerLearner = new ArrayList();
        this.data = new ArrayList<Dataset>();
        for (int i = 0; i < schema.getNumberOfClasses(); ++i) {
            this.innerLearner.add(this.learner.copy());
            ((ClassifierLearner)this.innerLearner.get(i)).setSchema(ExampleSchema.BINARY_EXAMPLE_SCHEMA);
            this.data.add(new BasicDataset());
        }
    }

    private void createRankings() {
        CrossValSplitter<Example> splitter = new CrossValSplitter<Example>(9);
        this.eval = new ArrayList<Evaluation>();
        for (int i = 0; i < this.innerLearner.size(); ++i) {
            Evaluation evaluation = Tester.evaluate((ClassifierLearner)this.innerLearner.get(i), this.data.get(i), splitter);
            this.eval.add(evaluation);
        }
    }

    private void sortLearners() {
        ArrayList<BatchClassifierLearner> unsortedLearners = new ArrayList<BatchClassifierLearner>();
        String[] classNames = this.schema.validClassNames();
        ArrayList<String> unsortedClassNames = new ArrayList<String>();
        this.sortedClassNames = new String[this.schema.getNumberOfClasses()];
        for (int i = 0; i < this.innerLearner.size(); ++i) {
            unsortedLearners.add((BatchClassifierLearner)this.innerLearner.get(i));
            unsortedClassNames.add(classNames[i]);
        }
        this.innerLearner.clear();
        int position = 0;
        while (!unsortedLearners.isEmpty()) {
            String className;
            double maxKappa = -10.0;
            int learnerIndex = -1;
            for (int j = 0; j < unsortedLearners.size(); ++j) {
                try {
                    Evaluation evaluation = this.eval.get(j);
                    double kappa = evaluation.kappa();
                    if (!(kappa >= maxKappa)) continue;
                    maxKappa = kappa;
                    learnerIndex = j;
                    continue;
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
            ClassifierLearner learner = (ClassifierLearner)unsortedLearners.remove(learnerIndex);
            this.innerLearner.add(learner);
            this.sortedClassNames[position] = className = (String)unsortedClassNames.remove(learnerIndex);
            ++position;
        }
    }

    public void addExample(Example answeredQuery) {
        int classIndex = this.schema.getClassIndex(answeredQuery.getLabel().bestClassName());
        for (int i = 0; i < this.innerLearner.size(); ++i) {
            ClassLabel label = classIndex == i ? ClassLabel.positiveLabel(1.0) : ClassLabel.negativeLabel(-1.0);
            Example example = new Example(answeredQuery.asInstance(), label);
            ((ClassifierLearner)this.innerLearner.get(i)).addExample(example);
            this.data.get(i).add(example);
        }
    }

    public void completeTraining() {
        for (int i = 0; i < this.innerLearner.size(); ++i) {
            ((ClassifierLearner)this.innerLearner.get(i)).completeTraining();
        }
        this.createRankings();
        this.sortLearners();
    }

    public Classifier getClassifier() {
        Classifier[] classifiers = new Classifier[this.innerLearner.size()];
        for (int i = 0; i < this.innerLearner.size(); ++i) {
            classifiers[i] = ((ClassifierLearner)this.innerLearner.get(i)).getClassifier();
        }
        return new OneVsAllClassifier(this.sortedClassNames, classifiers);
    }
}

