/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.transform;

import edu.cmu.minorthird.classify.BasicFeatureIndex;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.transform.AbstractInstanceTransform;
import edu.cmu.minorthird.classify.transform.InstanceTransform;
import edu.cmu.minorthird.classify.transform.InstanceTransformLearner;
import edu.cmu.minorthird.classify.transform.MaskedInstance;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;

public class InfoGainTransformLearner2
implements InstanceTransformLearner {
    private String frequencyModel;
    private int numFeatures;
    private ExampleSchema schema;

    public InfoGainTransformLearner2() {
        this(100, "document");
    }

    public InfoGainTransformLearner2(int numFeatures) {
        this(numFeatures, "document");
    }

    public InfoGainTransformLearner2(int numFeatures, String frequencyModel) {
        this.frequencyModel = frequencyModel;
        this.numFeatures = numFeatures;
    }

    public void setSchema(ExampleSchema schema) {
    }

    public InstanceTransform batchTrain(Dataset dataset) {
        this.schema = dataset.getSchema();
        int N = this.schema.getNumberOfClasses();
        BasicFeatureIndex index = new BasicFeatureIndex(dataset);
        ArrayList<IGPair> igValues = new ArrayList<IGPair>();
        if (this.frequencyModel.equals("document")) {
            double[] classCnt = new double[N];
            double totalCnt = 0.0;
            for (int c = 0; c < N; ++c) {
                classCnt[c] = index.size(this.schema.getClassName(c));
                totalCnt += classCnt[c];
            }
            double totalEntropy = this.Entropy(classCnt, totalCnt);
            Iterator<Feature> i = index.featureIterator();
            while (i.hasNext()) {
                Feature f = i.next();
                double[] featureCntWithF = new double[N];
                double[] featureCntWithoutF = new double[N];
                double totalCntWithF = 0.0;
                double totalCntWithoutF = 0.0;
                for (int c = 0; c < N; ++c) {
                    featureCntWithF[c] = index.size(f, this.schema.getClassName(c));
                    featureCntWithoutF[c] = classCnt[c] - featureCntWithF[c];
                    totalCntWithF += featureCntWithF[c];
                    totalCntWithoutF += featureCntWithoutF[c];
                }
                double entropyWithF = this.Entropy(featureCntWithF, totalCntWithF);
                double entropyWithoutF = this.Entropy(featureCntWithoutF, totalCntWithoutF);
                double wf = totalCntWithF / totalCnt;
                double infoGain = totalEntropy - wf * entropyWithF - (1.0 - wf) * entropyWithoutF;
                igValues.add(new IGPair(infoGain, f));
            }
        } else {
            if (this.frequencyModel.equals("word")) {
                throw new UnsupportedOperationException("not implemented");
            }
            System.out.println("warning: " + this.frequencyModel + " is an unknown model for frequency!");
            System.exit(1);
        }
        Collections.sort(igValues);
        final HashSet<Feature> activeFeatureSet = new HashSet<Feature>();
        for (int i = 0; i < this.numFeatures; ++i) {
            activeFeatureSet.add(((IGPair)igValues.get((int)i)).feature);
        }
        return new AbstractInstanceTransform(){

            public Instance transform(Instance instance) {
                return new MaskedInstance(instance, activeFeatureSet);
            }

            public String toString() {
                return "[InstanceTransform: model = " + InfoGainTransformLearner2.this.frequencyModel + ", top " + InfoGainTransformLearner2.this.numFeatures + " by InfoGain]";
            }
        };
    }

    public double entropy(double P1, double P2) {
        double entropy = P1 == 0.0 | P2 == 0.0 ? 0.0 : -P1 * Math.log(P1) / Math.log(2.0) - P2 * Math.log(P2) / Math.log(2.0);
        return entropy;
    }

    public double Entropy(double[] p, double tot) {
        double entropy = 0.0;
        for (int i = 0; i < p.length; ++i) {
            if (!(p[i] > 0.0)) continue;
            entropy += -p[i] / tot * Math.log(p[i] / tot) / Math.log(2.0);
        }
        return entropy;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class IGPair
    implements Comparable<IGPair> {
        double value;
        Feature feature;

        public IGPair(double v, Feature f) {
            this.value = v;
            this.feature = f;
        }

        @Override
        public int compareTo(IGPair ig2) {
            if (this.value < ig2.value) {
                return 1;
            }
            if (this.value > ig2.value) {
                return -1;
            }
            return this.feature.compareTo(ig2.feature);
        }

        public String toString() {
            return "[ " + this.value + "," + this.feature + " ]";
        }
    }
}

