/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.BasicFeatureIndex;
import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.algorithms.linear.PoissonClassifier;
import java.util.Iterator;

public class PoissonLearner
extends BatchBinaryClassifierLearner {
    private static final boolean LOG = true;
    private double SCALE;

    public PoissonLearner() {
        this.SCALE = 10.0;
        this.reset();
    }

    public PoissonLearner(double scale) {
        this.SCALE = scale;
        this.reset();
    }

    public Classifier batchTrain(Dataset data) {
        BasicFeatureIndex index = new BasicFeatureIndex(data);
        PoissonClassifier c = new PoissonClassifier();
        c.setScale(this.SCALE);
        double numPos = 0.0;
        double numNeg = 0.0;
        Iterator<Feature> floo = index.featureIterator();
        while (floo.hasNext()) {
            Feature f = floo.next();
            for (int j = 0; j < index.size(f); ++j) {
                Example ex = index.getExample(f, j);
                boolean isPos = ex.getLabel().isPositive();
                if (isPos) {
                    numPos += ex.getWeight(f);
                    continue;
                }
                numNeg += ex.getWeight(f);
            }
        }
        double featurePrior = 1.0 / (double)index.numberOfFeatures();
        Iterator<Feature> i = index.featureIterator();
        while (i.hasNext()) {
            Feature f = i.next();
            double ngp = index.getCounts(f, "POS");
            double ngn = index.getCounts(f, "NEG");
            double pweight = this.estimatedProb(ngp, numPos / c.getScale(), featurePrior, 1.0 / c.getScale());
            double nweight = this.estimatedProb(ngn, numNeg / c.getScale(), featurePrior, 1.0 / c.getScale());
            c.increment(f, -pweight + nweight);
            pweight = this.estimatedProb(ngp, numPos / c.getScale(), featurePrior, 1.0 / c.getScale(), true);
            nweight = this.estimatedProb(ngn, numNeg / c.getScale(), featurePrior, 1.0 / c.getScale(), true);
            c.increment(f, pweight - nweight, true);
        }
        c.incrementBias(this.estimatedProb(numPos, numPos + numNeg, 0.5, 1.0, true));
        c.incrementBias(-this.estimatedProb(numNeg, numPos + numNeg, 0.5, 1.0, true));
        return c;
    }

    private double estimatedProb(double k, double n, double prior, double pseudoCounts) {
        return (k + prior * pseudoCounts) / (n + pseudoCounts);
    }

    private double estimatedProb(double k, double n, double prior, double pseudoCounts, boolean log) {
        return Math.log((k + prior * pseudoCounts) / (n + pseudoCounts));
    }
}

