/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.transform;

import edu.cmu.minorthird.classify.BasicFeatureIndex;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.SampleDatasets;
import edu.cmu.minorthird.classify.transform.ChiSquareInstanceTransform;
import edu.cmu.minorthird.classify.transform.ContingencyTable;
import edu.cmu.minorthird.classify.transform.InstanceTransform;
import edu.cmu.minorthird.classify.transform.InstanceTransformLearner;
import java.util.Iterator;

public class ChiSquareTransformLearner
implements InstanceTransformLearner {
    private String frequencyModel;

    public ChiSquareTransformLearner() {
        this.frequencyModel = "document";
    }

    public ChiSquareTransformLearner(String model) {
        this.frequencyModel = model;
    }

    public void setSchema(ExampleSchema schema) {
        if (!ExampleSchema.BINARY_EXAMPLE_SCHEMA.equals(schema)) {
            throw new IllegalStateException("can only learn binary example data");
        }
    }

    public InstanceTransform batchTrain(Dataset dataset) {
        ChiSquareInstanceTransform filter = new ChiSquareInstanceTransform();
        BasicFeatureIndex index = new BasicFeatureIndex(dataset);
        if (this.frequencyModel.equals("document")) {
            int totalNeg;
            int totalPos = index.size("POS");
            if (totalPos + (totalNeg = index.size("NEG")) != dataset.size()) {
                throw new IllegalStateException("ERROR - Dataset size and index size do not match");
            }
            Iterator<Feature> i = index.featureIterator();
            while (i.hasNext()) {
                Feature f = i.next();
                int a = index.size(f, "POS");
                int b = index.size(f, "NEG");
                int c = totalPos - a;
                int d = totalNeg - b;
                ContingencyTable ct = new ContingencyTable(a, b, c, d);
                double chiScore = ct.getChiSquared();
                filter.addFeature(chiScore, f);
            }
        } else if (this.frequencyModel.equals("word")) {
            System.out.println("warning: " + this.frequencyModel + " not implemented yet!");
            System.exit(1);
        } else {
            System.out.println("warning: " + this.frequencyModel + " is an unknown model for frequency!");
            System.exit(1);
        }
        return filter;
    }

    public static void main(String[] args) {
        Dataset dataset = SampleDatasets.sampleData("toy", false);
        System.out.println("old data:\n" + dataset);
        ChiSquareTransformLearner learner = new ChiSquareTransformLearner();
        ChiSquareInstanceTransform filter = (ChiSquareInstanceTransform)learner.batchTrain(dataset);
        filter.setNumberOfFeatures(10);
        dataset = filter.transform(dataset);
        System.out.println("new data:\n" + dataset);
        System.out.println("\n\n\n " + filter.toString(8));
    }
}

