/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.semisupervised;

import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.BasicFeatureIndex;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.SampleDatasets;
import edu.cmu.minorthird.classify.semisupervised.MultinomialClassifier;
import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedBatchClassifierLearner;
import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedDataset;
import java.util.Iterator;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class SemiSupervisedNaiveBayesLearner
extends SemiSupervisedBatchClassifierLearner {
    private int MAX_ITER = 1000;
    private Iterator<Instance> iteratorOverUnlabeled;

    public SemiSupervisedNaiveBayesLearner() {
    }

    public SemiSupervisedNaiveBayesLearner(int iterations) {
        this.MAX_ITER = iterations;
    }

    @Override
    public void setSchema(ExampleSchema schema) {
    }

    @Override
    public void setInstancePool(Iterator<Instance> i) {
        this.iteratorOverUnlabeled = i;
    }

    @Override
    public ExampleSchema getSchema() {
        return null;
    }

    @Override
    public ClassifierLearner copy() {
        ClassifierLearner learner = null;
        try {
            learner = (ClassifierLearner)this.clone();
            learner.reset();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return learner;
    }

    @Override
    public Classifier batchTrain(SemiSupervisedDataset dataset) {
        Iterator<Instance> il;
        MultinomialClassifier mc = new MultinomialClassifier();
        int numberOfClasses = 0;
        Iterator<Example> i = dataset.iterator();
        while (i.hasNext()) {
            Example ex = i.next();
            if (mc.isPresent(ex.getLabel())) continue;
            mc.addValidLabel(ex.getLabel());
            ++numberOfClasses;
        }
        BasicFeatureIndex index = new BasicFeatureIndex(dataset);
        double[] countsGivenClass = new double[numberOfClasses];
        double[] examplesGivenClass = new double[numberOfClasses];
        double numberOfExamples = dataset.size();
        double numberOfFeatures = index.numberOfFeatures();
        Iterator<Example> i2 = dataset.iterator();
        while (i2.hasNext()) {
            int classIndex;
            Example ex = i2.next();
            int n = classIndex = mc.indexOf(ex.getLabel());
            examplesGivenClass[n] = examplesGivenClass[n] + 1.0;
            Iterator<Feature> j = index.featureIterator();
            while (j.hasNext()) {
                Feature f = j.next();
                int n2 = classIndex;
                countsGivenClass[n2] = countsGivenClass[n2] + ex.getWeight(f);
            }
        }
        for (int j = 0; j < numberOfClasses; ++j) {
            double probabilityOfOccurrence = this.estimateClassProbMLE(1.0, numberOfClasses, examplesGivenClass[j], numberOfExamples);
            mc.setClassParameter(j, probabilityOfOccurrence);
        }
        Iterator<Feature> i22 = index.featureIterator();
        while (i22.hasNext()) {
            int j;
            Feature f = i22.next();
            double[] countsFeatureGivenClass = new double[numberOfClasses];
            for (j = 0; j < index.size(f); ++j) {
                int classIndex;
                Example ex = index.getExample(f, j);
                int n = classIndex = mc.indexOf(ex.getLabel());
                countsFeatureGivenClass[n] = countsFeatureGivenClass[n] + ex.getWeight(f);
            }
            for (j = 0; j < numberOfClasses; ++j) {
                double probabilityOfOccurrence = this.estimateFeatureProbMLE(1.0, numberOfFeatures, countsFeatureGivenClass[j], countsGivenClass[j]);
                mc.setFeatureGivenClassParameter(f, j, probabilityOfOccurrence);
            }
            mc.setFeatureModel(f, "Binomial");
        }
        BasicDataset unlabeledDataset = new BasicDataset();
        Iterator<Instance> i3 = il = this.iteratorOverUnlabeled;
        while (i3.hasNext()) {
            Instance mi = i3.next();
            System.out.println(mi);
            ClassLabel estimatedClassLabel = mc.classification(mi);
            unlabeledDataset.add(new Example(mi, estimatedClassLabel));
        }
        double logLik = Double.NEGATIVE_INFINITY;
        int iter = 0;
        boolean hasConverged = false;
        while (iter < this.MAX_ITER & !hasConverged) {
            double previousLogLik = logLik;
            logLik = 0.0;
            BasicDataset combinedDataset = new BasicDataset();
            Iterator<Object> i4 = dataset.iterator();
            while (i4.hasNext()) {
                combinedDataset.add(i4.next());
            }
            i4 = unlabeledDataset.iterator();
            while (i4.hasNext()) {
                combinedDataset.add(i4.next());
            }
            mc.reset();
            index = new BasicFeatureIndex(combinedDataset);
            countsGivenClass = new double[numberOfClasses];
            examplesGivenClass = new double[numberOfClasses];
            numberOfExamples = combinedDataset.size();
            numberOfFeatures = index.numberOfFeatures();
            i4 = dataset.iterator();
            while (i4.hasNext()) {
                int classIndex;
                Example ex = i4.next();
                int n = classIndex = mc.indexOf(ex.getLabel());
                examplesGivenClass[n] = examplesGivenClass[n] + 1.0;
                Iterator<Feature> j = index.featureIterator();
                while (j.hasNext()) {
                    Feature f = j.next();
                    int n3 = classIndex;
                    countsGivenClass[n3] = countsGivenClass[n3] + ex.getWeight(f);
                }
            }
            for (int j = 0; j < numberOfClasses; ++j) {
                double probabilityOfOccurrence = this.estimateClassProbMLE(1.0, numberOfClasses, examplesGivenClass[j], numberOfExamples);
                mc.setClassParameter(j, probabilityOfOccurrence);
            }
            i4 = index.featureIterator();
            while (i4.hasNext()) {
                int j;
                Feature f = (Feature)i4.next();
                double[] countsFeatureGivenClass = new double[numberOfClasses];
                for (j = 0; j < index.size(f); ++j) {
                    int classIndex;
                    Example ex = index.getExample(f, j);
                    int n = classIndex = mc.indexOf(ex.getLabel());
                    countsFeatureGivenClass[n] = countsFeatureGivenClass[n] + ex.getWeight(f);
                }
                for (j = 0; j < numberOfClasses; ++j) {
                    double probabilityOfOccurrence = this.estimateFeatureProbMLE(1.0, numberOfFeatures, countsFeatureGivenClass[j], countsGivenClass[j]);
                    mc.setFeatureGivenClassParameter(f, j, probabilityOfOccurrence);
                }
                mc.setFeatureModel(f, "Binomial");
            }
            il = this.iteratorOverUnlabeled;
            i4 = il;
            while (i4.hasNext()) {
                Instance mi = (Instance)i4.next();
                System.out.println(mi);
                ClassLabel estimatedClassLabel = mc.classification(mi);
                unlabeledDataset.add(new Example(mi, estimatedClassLabel));
            }
            logLik = 0.0;
            Iterator<Example> eloo = combinedDataset.iterator();
            while (eloo.hasNext()) {
                Example example = eloo.next();
                logLik += mc.getLogLikelihood(example);
            }
            if (this.EMconverged(logLik, previousLogLik, 1.0E-7, true)) {
                hasConverged = true;
                System.out.println("EM converged!");
            } else {
                System.out.println("iteration=" + (iter + 1) + " log-likelihood=" + logLik);
            }
            ++iter;
        }
        return mc;
    }

    private double estimateClassProbMLE(double classPrior, double numberOfClasses, double observedCounts, double totalCounts) {
        return (classPrior + observedCounts) / (numberOfClasses + totalCounts);
    }

    private double estimateFeatureProbMLE(double featurePrior, double numberOfFeatures, double observedCounts, double totalCounts) {
        return (featurePrior + observedCounts) / (numberOfFeatures + totalCounts);
    }

    private boolean EMconverged(double loglik, double previousLoglik, double threshold, boolean checkIncreased) {
        double avgLoglik;
        double deltaLoglik;
        double epsilon = 2.2204E-16;
        boolean converged = false;
        if (checkIncreased && loglik - previousLoglik < -0.001) {
            System.out.println("******likelihood decreased from " + previousLoglik + " to " + loglik);
        }
        if ((deltaLoglik = Math.abs(loglik - previousLoglik)) / (avgLoglik = (Math.abs(loglik) + Math.abs(previousLoglik) + epsilon) / 2.0) < threshold) {
            converged = true;
        }
        return converged;
    }

    public static void main(String[] args) {
        Dataset dataset = new BasicDataset();
        dataset = SampleDatasets.sampleData("bayesUnlabeled", false);
        System.out.println("SampleDatasets (bayesUnlabeled):\n" + dataset);
    }
}

