/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.BasicFeatureIndex;
import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.algorithms.linear.NegativeBinomialClassifier;
import java.util.Iterator;
import java.util.SortedMap;
import java.util.TreeMap;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class NegativeBinomialLearner
extends BatchBinaryClassifierLearner {
    private double SCALE;

    public NegativeBinomialLearner() {
        this.SCALE = 10.0;
        this.reset();
    }

    public NegativeBinomialLearner(double scale) {
        this.SCALE = scale;
        this.reset();
    }

    @Override
    public Classifier batchTrain(Dataset data) {
        Feature f;
        Iterator<Feature> floo;
        BasicFeatureIndex index = new BasicFeatureIndex(data);
        NegativeBinomialClassifier c = new NegativeBinomialClassifier();
        c.setScale(this.SCALE);
        int JNeg = index.size("NEG");
        int JPos = index.size("POS");
        double[] wgtNeg = new double[JNeg];
        double[] wgtPos = new double[JPos];
        int exNeg = 0;
        int exPos = 0;
        double numPos = 0.0;
        double numNeg = 0.0;
        Iterator<Example> eloo = data.iterator();
        while (eloo.hasNext()) {
            double wgtTot;
            Example e = eloo.next();
            if (e.getLabel().bestClassName().equals("POS")) {
                wgtTot = 0.0;
                floo = e.featureIterator();
                while (floo.hasNext()) {
                    f = floo.next();
                    wgtTot += e.getWeight(f);
                }
                wgtPos[exPos++] = wgtTot / this.SCALE;
                numPos += wgtTot;
                continue;
            }
            if (e.getLabel().bestClassName().equals("NEG")) {
                wgtTot = 0.0;
                floo = e.featureIterator();
                while (floo.hasNext()) {
                    f = floo.next();
                    wgtTot += e.getWeight(f);
                }
                wgtNeg[exNeg++] = wgtTot / this.SCALE;
                numNeg += wgtTot;
                continue;
            }
            System.out.println("error: no class found for example!\n " + e);
            System.exit(1);
        }
        double featurePrior = 1.0 / (double)index.numberOfFeatures();
        c.setPriorPos(numPos, numPos + numNeg, 0.5, 1.0);
        c.setPriorNeg(numNeg, numPos + numNeg, 0.5, 1.0);
        double[] vNeg = new double[JNeg];
        double[] vPos = new double[JPos];
        floo = index.featureIterator();
        while (floo.hasNext()) {
            f = floo.next();
            exNeg = 0;
            exPos = 0;
            Iterator<Example> eloo2 = data.iterator();
            while (eloo2.hasNext()) {
                Example e = eloo2.next();
                if (e.getLabel().bestClassName().equals("POS")) {
                    vPos[exPos++] = e.getWeight(f);
                    continue;
                }
                if (e.getLabel().bestClassName().equals("NEG")) {
                    vNeg[exNeg++] = e.getWeight(f);
                    continue;
                }
                System.out.println("error: no class found for example!\n " + e);
                System.exit(1);
            }
            SortedMap<String, Double> mudeltaNeg = this.estimateNegBinMOME(vNeg, wgtNeg, featurePrior);
            SortedMap<String, Double> mudeltaPos = this.estimateNegBinMOME(vPos, wgtPos, featurePrior);
            c.setPmsNeg(f, mudeltaNeg);
            c.setPmsPos(f, mudeltaPos);
        }
        return c;
    }

    private SortedMap<String, Double> estimateNegBinMOME(double[] vCnt, double[] vWgt, double prior) {
        double r;
        double m = 0.0;
        double d = 0.0;
        int N = vCnt.length;
        double sumX = 0.0;
        double sumWgt = 0.0;
        double sumWgt2 = 0.0;
        for (int i = 0; i < N; ++i) {
            sumX += vCnt[i];
            sumWgt += vWgt[i];
            sumWgt2 += Math.pow(vWgt[i], 2.0);
        }
        m = (sumX + prior * 1.0 / this.SCALE) / (sumWgt + 1.0 / this.SCALE);
        double v = 0.0;
        if ((double)N <= 1.0) {
            r = 0.0;
            v = 0.0;
        } else {
            r = (sumWgt - sumWgt2 / sumWgt) / ((double)N - 1.0);
            for (int i = 0; i < N; ++i) {
                v += vWgt[i] * Math.pow(vCnt[i] / vWgt[i] - m, 2.0) / ((double)N - 1.0);
            }
        }
        d = Math.max(0.0, (v - m) / (r * m));
        if (new Double(d).isNaN()) {
            d = 1.0E-7;
        }
        TreeMap<String, Double> mudelta = new TreeMap<String, Double>();
        mudelta.put("mu", new Double(m));
        mudelta.put("delta", new Double(d));
        return mudelta;
    }
}

