/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;
import java.util.TreeSet;
import org.apache.log4j.Logger;

public class BBMira
extends OnlineBinaryClassifierLearner {
    private static Logger log = Logger.getLogger(BBMira.class);
    private List<WeightedExample> cache;
    private Hyperplane w_t;
    private double minimalMargin = 1.0;
    private boolean useBudget = true;
    private Set<Feature> usedFeatures;

    public BBMira(boolean useBudget, double minimalMargin) {
        this.useBudget = useBudget;
        this.minimalMargin = minimalMargin;
        this.reset();
    }

    public BBMira() {
        this(true, 1.0);
    }

    public void reset() {
        this.cache = new LinkedList<WeightedExample>();
        this.w_t = new Hyperplane();
        this.usedFeatures = new TreeSet<Feature>();
    }

    public void addExample(Example example) {
        double y = example.getLabel().numericLabel();
        Instance x = example.asInstance();
        Iterator<Feature> i = x.featureIterator();
        while (i.hasNext()) {
            Feature f = i.next();
            if (!this.usedFeatures.add(f)) continue;
            this.w_t.increment(f, 1.0);
        }
        double s = this.w_t.score(x);
        if (log.isDebugEnabled()) {
            log.debug("y=" + y + " s=" + s + " for " + x);
        }
        if (y * s <= this.minimalMargin) {
            double tau_t = this.truncateG(-y * s / BBMira.kernel(x, x));
            log.debug("update: y*s = " + y * s + " ||x||^2 = " + BBMira.kernel(x, x) + " tau_t = " + tau_t);
            if (tau_t != 0.0) {
                this.w_t.increment(example, y * tau_t);
                this.cache.add(new WeightedExample(example, tau_t));
                if (log.isDebugEnabled()) {
                    log.debug("into cache, useBudget=" + this.useBudget + " tau=" + tau_t + " :" + x);
                }
                if (this.useBudget) {
                    this.distillCache();
                }
            }
        }
    }

    public Classifier getClassifier() {
        return this.w_t;
    }

    private static double kernel(Instance x1, Instance x2) {
        double result = 0.0;
        Iterator<Feature> i = x1.featureIterator();
        while (i.hasNext()) {
            Feature f = i.next();
            result += x1.getWeight(f) * x2.getWeight(f);
        }
        return result + 1.0;
    }

    private double truncateG(double z) {
        if (z < 0.0) {
            return 0.0;
        }
        if (z > 1.0) {
            return 1.0;
        }
        return z;
    }

    private void distillCache() {
        log.info("distilling cache, size=" + this.cache.size());
        boolean somethingRemoved = true;
        while (somethingRemoved) {
            somethingRemoved = false;
            ListIterator<WeightedExample> i = this.cache.listIterator();
            while (i.hasNext()) {
                double wxContribution;
                WeightedExample wx = i.next();
                double y = wx.example.getLabel().numericLabel();
                Instance x = wx.example.asInstance();
                double currentPrediction = this.w_t.score(x);
                if (!((currentPrediction - (wxContribution = BBMira.kernel(x, x) * y * wx.alpha)) * y >= this.minimalMargin)) continue;
                i.remove();
                somethingRemoved = true;
                this.w_t.increment(x, -y * wx.alpha);
                log.info("reduced cache to " + this.cache.size() + " entries");
            }
        }
    }

    public String toString() {
        return "[BBMira " + this.useBudget + ";" + this.minimalMargin + "]";
    }

    private static class WeightedExample {
        public Example example;
        public double alpha;

        public WeightedExample(Example example, double alpha) {
            this.example = example;
            this.alpha = alpha;
        }

        public String toString() {
            return "[WX: " + this.example + " alpha=" + this.alpha + "]";
        }
    }
}

