/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.knn;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.DatasetIndex;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.util.MathUtil;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.TreeSet;
import org.apache.log4j.Logger;

class KnnClassifier
implements Classifier,
Serializable {
    private static final long serialVersionUID = 1L;
    private static Logger log = Logger.getLogger(KnnClassifier.class);
    private static final boolean DEBUG = log.isDebugEnabled();
    private DatasetIndex index;
    private ExampleSchema schema;
    private int k;

    public KnnClassifier(DatasetIndex index, ExampleSchema schema, int k) {
        this.index = index;
        this.schema = schema;
        this.k = k;
        if (DEBUG) {
            log.debug("knn classifier for index:\n" + index);
        }
    }

    public ClassLabel classification(Instance instance) {
        if (DEBUG) {
            log.debug("classifying: " + instance);
        }
        TreeSet<Neighbor> set = new TreeSet<Neighbor>();
        Iterator<Example> i = this.index.getNeighbors(instance);
        while (i.hasNext()) {
            Example e = i.next();
            double sim = this.computeSimilarity(instance, e);
            set.add(new Neighbor(e, sim));
        }
        double tot = 0.0;
        HashMap<String, Double> classCounts = new HashMap<String, Double>();
        int num = 0;
        Iterator j = set.iterator();
        while (num++ < this.k && j.hasNext()) {
            Neighbor n = (Neighbor)j.next();
            String s = n.e.getLabel().bestClassName();
            double w = n.e.getWeight() * n.sim;
            Double d = (Double)classCounts.get(s);
            if (d == null) {
                d = new Double(0.0);
                classCounts.put(s, d);
            }
            classCounts.put(s, new Double(d + w));
            tot += w;
            if (!DEBUG) continue;
            log.debug("neighbor: " + n.e + " distance: " + n.sim + " weight: " + w + " count[" + s + "]: " + classCounts.get(s));
        }
        if (tot == 0.0 || Double.isNaN(tot)) {
            if (Double.isNaN(tot)) {
                log.warn("total similarity to neighbors is not defined for: " + instance);
            }
            tot = 0.0;
            for (int i2 = 0; i2 < this.schema.getNumberOfClasses(); ++i2) {
                String s = this.schema.getClassName(i2);
                double d = this.index.size(s);
                classCounts.put(s, new Double(d));
                tot += d;
            }
        }
        ClassLabel result = new ClassLabel();
        for (String s : classCounts.keySet()) {
            double d = (Double)classCounts.get(s);
            result.add(s, Math.log(d / tot + 0.001) - Math.log((tot - d) / tot + 0.001));
        }
        return result;
    }

    public String explain(Instance instance) {
        return "not implemented";
    }

    public Explanation getExplanation(Instance instance) {
        Explanation ex = new Explanation(this.explain(instance));
        return ex;
    }

    private double computeSimilarity(Instance a, Instance b) {
        double bw;
        double aNorm = 0.0;
        double dotProd = 0.0;
        Iterator<Feature> i = a.featureIterator();
        while (i.hasNext()) {
            Feature f = i.next();
            double aw = a.getWeight(f);
            bw = b.getWeight(f);
            aNorm += aw * aw;
            dotProd += aw * bw;
        }
        double bNorm = 0.0;
        Iterator<Feature> i2 = b.featureIterator();
        while (i2.hasNext()) {
            Feature f = i2.next();
            bw = b.getWeight(f);
            bNorm += bw * bw;
        }
        return dotProd / (Math.sqrt(aNorm) * Math.sqrt(bNorm));
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private static class Neighbor
    implements Comparable<Neighbor> {
        Example e;
        double sim;

        public Neighbor(Example e, double sim) {
            this.e = e;
            this.sim = sim;
        }

        @Override
        public int compareTo(Neighbor n) {
            return MathUtil.sign(n.sim - this.sim);
        }
    }
}

