/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.BasicFeatureIndex;
import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.SampleDatasets;
import edu.cmu.minorthird.classify.algorithms.linear.MultinomialClassifier;
import edu.cmu.minorthird.classify.algorithms.random.Estimate;
import edu.cmu.minorthird.classify.algorithms.random.Estimators;
import java.util.ArrayList;
import java.util.Iterator;

public class KWayMixtureLearner
extends BatchClassifierLearner {
    private double SCALE;
    private String MODEL;
    private String PARAMETERIZATION;

    public KWayMixtureLearner() {
        this.SCALE = 10.0;
        this.MODEL = "Poisson";
        this.PARAMETERIZATION = "default";
    }

    public KWayMixtureLearner(String model) {
        this.SCALE = 10.0;
        this.MODEL = model;
        this.PARAMETERIZATION = "default";
    }

    public KWayMixtureLearner(String model, String parameterization) {
        this.SCALE = 10.0;
        this.MODEL = model;
        this.PARAMETERIZATION = parameterization;
    }

    public KWayMixtureLearner(String model, String parameterization, double scale) {
        this.SCALE = scale;
        this.MODEL = model;
        this.PARAMETERIZATION = parameterization;
    }

    public void setSchema(ExampleSchema schema) {
        if (ExampleSchema.BINARY_EXAMPLE_SCHEMA.equals(schema)) {
            throw new IllegalStateException("can only learn non-binary example data");
        }
    }

    public ExampleSchema getSchema() {
        return ExampleSchema.BINARY_EXAMPLE_SCHEMA;
    }

    public Classifier batchTrain(Dataset dataset) {
        MultinomialClassifier mc = new MultinomialClassifier();
        mc.setScale(this.SCALE);
        ExampleSchema schema = dataset.getSchema();
        BasicFeatureIndex index = new BasicFeatureIndex(dataset);
        int numberOfClasses = schema.getNumberOfClasses();
        String[] classLabels = new String[numberOfClasses];
        int[] classSizes = new int[numberOfClasses];
        ArrayList<double[]> featureMatrix = new ArrayList<double[]>();
        ArrayList<double[]> exampleWeightMatrix = new ArrayList<double[]>();
        for (int i = 0; i < numberOfClasses; ++i) {
            classLabels[i] = schema.getClassName(i);
            mc.addValidLabel(new ClassLabel(classLabels[i]));
            classSizes[i] = index.size(classLabels[i]);
            double[] featureCounts = new double[classSizes[i]];
            double[] exampleWeights = new double[classSizes[i]];
            featureMatrix.add(featureCounts);
            exampleWeightMatrix.add(exampleWeights);
        }
        double numberOfExamples = dataset.size();
        double numberOfFeatures = index.numberOfFeatures();
        double[] countsGivenClass = new double[numberOfClasses];
        double[] examplesGivenClass = new double[numberOfClasses];
        int[] excounter = new int[numberOfClasses];
        Iterator<Example> i = dataset.iterator();
        while (i.hasNext()) {
            int classIndex;
            Example ex = i.next();
            int idx = schema.getClassIndex(ex.getLabel().bestClassName().toString());
            if (idx != (classIndex = mc.indexOf(ex.getLabel()))) {
                System.out.println("Buzz! Error: incompatible class indeces ...");
                System.exit(1);
            }
            int n = idx;
            examplesGivenClass[n] = examplesGivenClass[n] + 1.0;
            Iterator<Feature> j = index.featureIterator();
            while (j.hasNext()) {
                Feature f = j.next();
                int n2 = idx;
                countsGivenClass[n2] = countsGivenClass[n2] + ex.getWeight(f);
                double[] dArray = (double[])exampleWeightMatrix.get(idx);
                int n3 = excounter[idx];
                dArray[n3] = dArray[n3] + ex.getWeight(f);
            }
            int n4 = idx;
            excounter[n4] = excounter[n4] + 1;
        }
        Iterator<Feature> floo = index.featureIterator();
        while (floo.hasNext()) {
            double[] countsGivenExample;
            Estimate mudelta;
            int[] counter = new int[numberOfClasses];
            Feature ft = floo.next();
            Iterator<Example> eloo = dataset.iterator();
            while (eloo.hasNext()) {
                Example ex = eloo.next();
                int idx = schema.getClassIndex(ex.getLabel().bestClassName().toString());
                if (this.MODEL.equals("Naive-Bayes")) {
                    int n = idx;
                    int n5 = counter[n];
                    counter[n] = n5 + 1;
                    ((double[])featureMatrix.get((int)idx))[n5] = Math.min(1.0, ex.getWeight(ft));
                    continue;
                }
                int n = idx;
                int n6 = counter[n];
                counter[n] = n6 + 1;
                ((double[])featureMatrix.get((int)idx))[n6] = ex.getWeight(ft);
            }
            if (this.MODEL.equals("Naive-Bayes")) {
                mc.setPrior(1.0 / numberOfFeatures);
                mc.setUnseenModel("Naive-Bayes");
                for (int j = 0; j < numberOfClasses; ++j) {
                    double probabilityOfOccurrence = this.estimateClassProbMLE(1.0, numberOfClasses, examplesGivenClass[j], numberOfExamples);
                    mc.setClassParameter(j, probabilityOfOccurrence);
                    if (this.PARAMETERIZATION.equals("default") | this.PARAMETERIZATION.equals("mean")) {
                        Estimate mean = Estimators.estimateNaiveBayesMean(1.0, numberOfFeatures, this.sum((double[])featureMatrix.get(j)), examplesGivenClass[j]);
                        mc.setFeatureGivenClassParameter(ft, j, mean);
                        continue;
                    }
                    if (!this.PARAMETERIZATION.equals("weighted-mean")) continue;
                    double[] countsFeatureGivenClass = (double[])featureMatrix.get(j);
                    double[] countsGivenExample2 = (double[])exampleWeightMatrix.get(j);
                    Estimate mean = Estimators.estimateNaiveBayesWeightedMean(countsFeatureGivenClass, countsGivenExample2, 1.0 / numberOfFeatures, this.SCALE);
                    mc.setFeatureGivenClassParameter(ft, j, mean);
                }
                mc.setFeatureModel(ft, "Naive-Bayes");
                continue;
            }
            if (this.MODEL.equals("Binomial")) {
                mc.setPrior(1.0 / numberOfFeatures);
                mc.setUnseenModel("Binomial");
                for (int j = 0; j < numberOfClasses; ++j) {
                    double[] countsGivenExample3;
                    double[] countsFeatureGivenClass;
                    double probabilityOfOccurrence = this.estimateClassProbMLE(1.0, numberOfClasses, examplesGivenClass[j], numberOfExamples);
                    mc.setClassParameter(j, probabilityOfOccurrence);
                    if (this.PARAMETERIZATION.equals("default") | this.PARAMETERIZATION.equals("p/N")) {
                        countsFeatureGivenClass = (double[])featureMatrix.get(j);
                        countsGivenExample3 = (double[])exampleWeightMatrix.get(j);
                        Estimate pn = Estimators.estimateBinomialPN(countsFeatureGivenClass, countsGivenExample3, 1.0 / numberOfFeatures, this.SCALE);
                        mc.setFeatureGivenClassParameter(ft, j, pn);
                        continue;
                    }
                    if (!this.PARAMETERIZATION.equals("mu/delta")) continue;
                    countsFeatureGivenClass = (double[])featureMatrix.get(j);
                    countsGivenExample3 = (double[])exampleWeightMatrix.get(j);
                    mudelta = Estimators.estimateBinomialMuDelta(countsFeatureGivenClass, countsGivenExample3, 1.0 / numberOfFeatures, this.SCALE);
                    mc.setFeatureGivenClassParameter(ft, j, mudelta);
                }
                mc.setFeatureModel(ft, "Binomial");
                continue;
            }
            if (this.MODEL.equals("Poisson")) {
                mc.setPrior(1.0 / numberOfFeatures);
                mc.setUnseenModel("Poisson");
                for (int j = 0; j < numberOfClasses; ++j) {
                    double probabilityOfOccurrence = this.estimateClassProbMLE(1.0, numberOfClasses, examplesGivenClass[j], numberOfExamples);
                    mc.setClassParameter(j, probabilityOfOccurrence);
                    if (this.PARAMETERIZATION.equals("default") | this.PARAMETERIZATION.equals("weighted-lambda")) {
                        double[] countsFeatureGivenClass = (double[])featureMatrix.get(j);
                        double[] countsGivenExample4 = (double[])exampleWeightMatrix.get(j);
                        Estimate lambda = Estimators.estimatePoissonWeightedLambda(countsFeatureGivenClass, countsGivenExample4, 1.0 / numberOfFeatures, this.SCALE);
                        mc.setFeatureGivenClassParameter(ft, j, lambda);
                        continue;
                    }
                    if (!this.PARAMETERIZATION.equals("lambda")) continue;
                    Estimate lambda = Estimators.estimatePoissonLambda(1.0 / this.SCALE, numberOfFeatures, this.sum((double[])featureMatrix.get(j)), countsGivenClass[j] / this.SCALE);
                    mc.setFeatureGivenClassParameter(ft, j, lambda);
                }
                mc.setFeatureModel(ft, "Poisson");
                continue;
            }
            if (this.MODEL.equals("Negative-Binomial")) {
                mc.setPrior(1.0 / numberOfFeatures);
                mc.setUnseenModel("Negative-Binomial");
                for (int j = 0; j < numberOfClasses; ++j) {
                    double probabilityOfOccurrence = this.estimateClassProbMLE(1.0, numberOfClasses, examplesGivenClass[j], numberOfExamples);
                    mc.setClassParameter(j, probabilityOfOccurrence);
                    if (!(this.PARAMETERIZATION.equals("default") | this.PARAMETERIZATION.equals("mu/delta"))) continue;
                    double[] countsFeatureGivenClass = (double[])featureMatrix.get(j);
                    double[] countsGivenExample5 = (double[])exampleWeightMatrix.get(j);
                    mudelta = Estimators.estimateNegativeBinomialMuDelta(countsFeatureGivenClass, countsGivenExample5, 1.0 / numberOfFeatures, this.SCALE);
                    mc.setFeatureGivenClassParameter(ft, j, mudelta);
                }
                mc.setFeatureModel(ft, "Negative-Binomial");
                continue;
            }
            if (this.MODEL.equals("Mixture")) {
                mc.setPrior(1.0 / numberOfFeatures);
                mc.setUnseenModel("Mixture");
                for (int j = 0; j < numberOfClasses; ++j) {
                    Estimate mudelta2;
                    double[] countsGivenExample6;
                    double probabilityOfOccurrence = this.estimateClassProbMLE(1.0, numberOfClasses, examplesGivenClass[j], numberOfExamples);
                    mc.setClassParameter(j, probabilityOfOccurrence);
                    double[] countsFeatureGivenClass = (double[])featureMatrix.get(j);
                    double mean = Estimators.estimateMean(countsFeatureGivenClass);
                    double var = Estimators.estimateVar(countsFeatureGivenClass);
                    String model = "";
                    if (mean > var) {
                        model = "Binomial";
                    } else if (mean <= var) {
                        model = "Negative-Binomial";
                    }
                    mc.setFeatureModel(ft, model);
                    if (model.equals("Naive-Bayes")) {
                        countsGivenExample6 = (double[])exampleWeightMatrix.get(j);
                        Estimate m = Estimators.estimateNaiveBayesWeightedMean(countsFeatureGivenClass, countsGivenExample6, 1.0 / numberOfFeatures, this.SCALE);
                        mc.setFeatureGivenClassParameter(ft, j, m);
                        continue;
                    }
                    if (model.equals("Binomial")) {
                        countsGivenExample6 = (double[])exampleWeightMatrix.get(j);
                        mudelta2 = Estimators.estimateBinomialMuDelta(countsFeatureGivenClass, countsGivenExample6, 1.0 / numberOfFeatures, this.SCALE);
                        mc.setFeatureGivenClassParameter(ft, j, mudelta2);
                        continue;
                    }
                    if (model.equals("Poisson")) {
                        countsGivenExample6 = (double[])exampleWeightMatrix.get(j);
                        Estimate lambda = Estimators.estimatePoissonWeightedLambda(countsFeatureGivenClass, countsGivenExample6, 1.0 / numberOfFeatures, this.SCALE);
                        mc.setFeatureGivenClassParameter(ft, j, lambda);
                        continue;
                    }
                    if (!model.equals("Negative-Binomial")) continue;
                    countsGivenExample6 = (double[])exampleWeightMatrix.get(j);
                    mudelta2 = Estimators.estimateNegativeBinomialMuDelta(countsFeatureGivenClass, countsGivenExample6, 1.0 / numberOfFeatures, this.SCALE);
                    mc.setFeatureGivenClassParameter(ft, j, mudelta2);
                }
                continue;
            }
            if (!this.MODEL.equals("Dirichlet-Poisson MCMC")) continue;
            mc.setPrior(1.0 / numberOfFeatures);
            mc.setUnseenModel("Dirichlet-Poisson MCMC");
            double[] sumCountsFeatureGivenClass = new double[numberOfClasses];
            double[] sumCountsGivenExample = new double[numberOfClasses];
            Estimate[] lambda = new Estimate[numberOfClasses];
            if (this.PARAMETERIZATION.equals("default") | this.PARAMETERIZATION.equals("weighted-lambda")) {
                for (int j = 0; j < numberOfClasses; ++j) {
                    double probabilityOfOccurrence = this.estimateClassProbMLE(1.0, numberOfClasses, examplesGivenClass[j], numberOfExamples);
                    mc.setClassParameter(j, probabilityOfOccurrence);
                    double[] countsFeatureGivenClass = (double[])featureMatrix.get(j);
                    countsGivenExample = (double[])exampleWeightMatrix.get(j);
                    lambda[j] = Estimators.estimatePoissonWeightedLambda(countsFeatureGivenClass, countsGivenExample, 1.0 / numberOfFeatures, this.SCALE);
                    sumCountsFeatureGivenClass[j] = this.sum(countsFeatureGivenClass);
                    sumCountsGivenExample[j] = this.sum(countsGivenExample);
                }
                double sigSD = (Double)lambda[0].getPms().get("lambda") + (Double)lambda[1].getPms().get("lambda");
                Estimate[] postLambdas = Estimators.mcmcEstimateDirichletPoissonTauSigma(lambda, new double[]{1.0E-7, 1.0E-7}, new double[]{1.0, 150.0}, sumCountsFeatureGivenClass[0], sumCountsFeatureGivenClass[1], sumCountsGivenExample[0], sumCountsGivenExample[1], new double[]{2.0, 1.0}, 0.075, sigSD / 10.0, 100);
                for (int j = 0; j < numberOfClasses; ++j) {
                    mc.setFeatureGivenClassParameter(ft, j, postLambdas[j]);
                }
            } else if (this.PARAMETERIZATION.equals("lambda")) {
                for (int j = 0; j < numberOfClasses; ++j) {
                    double probabilityOfOccurrence = this.estimateClassProbMLE(1.0, numberOfClasses, examplesGivenClass[j], numberOfExamples);
                    mc.setClassParameter(j, probabilityOfOccurrence);
                    double[] countsFeatureGivenClass = (double[])featureMatrix.get(j);
                    countsGivenExample = (double[])exampleWeightMatrix.get(j);
                    lambda[j] = Estimators.estimatePoissonLambda(1.0, numberOfFeatures, this.sum((double[])featureMatrix.get(j)), countsGivenClass[j]);
                    sumCountsFeatureGivenClass[j] = this.sum(countsFeatureGivenClass);
                    sumCountsGivenExample[j] = this.sum(countsGivenExample);
                }
                Estimate[] postLambdas = Estimators.mcmcEstimateDirichletPoissonTauSigma(lambda, new double[]{1.0E-7, 1.0E-7}, new double[]{1.0E-7, 150.0}, sumCountsFeatureGivenClass[0], sumCountsFeatureGivenClass[1], sumCountsGivenExample[0], sumCountsGivenExample[1], new double[]{2.0, 1.0}, 0.1, 0.5, 100);
                for (int j = 0; j < numberOfClasses; ++j) {
                    mc.setFeatureGivenClassParameter(ft, 0, postLambdas[j]);
                }
            }
            mc.setFeatureModel(ft, "Dirichlet-Poisson MCMC");
        }
        return mc;
    }

    private double estimateClassProbMLE(double classPrior, double numberOfClasses, double observedCounts, double totalCounts) {
        return (classPrior / numberOfClasses + observedCounts) / (1.0 + totalCounts);
    }

    private double sum(double[] vec) {
        double sum = 0.0;
        for (int i = 0; i < vec.length; ++i) {
            sum += vec[i];
        }
        return sum;
    }

    public static void main(String[] args) {
        Dataset dataset = new BasicDataset();
        dataset = SampleDatasets.sampleData("bayesUnlabeled", false);
        System.out.println("SampleDatasets (bayesUnlabeled):\n" + dataset);
    }
}

