/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.ranking;

import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
import edu.cmu.minorthird.classify.ranking.BatchRankingLearner;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class RankingBoosted
extends BatchRankingLearner {
    private int numEpochs;
    private int exampleSize = 20;
    private Map<Feature, Set<Index>> A_pos = new HashMap<Feature, Set<Index>>();
    private Map<Feature, Set<Index>> A_neg = new HashMap<Feature, Set<Index>>();
    private Set<Feature> features = new HashSet<Feature>();
    private double SMOOTH_PARAM = 0.005;
    private double[][] margins;
    private Feature score = new Feature("walkerScore");

    public RankingBoosted() {
        this(500, 20);
    }

    public RankingBoosted(int numEpochs, int exampleSize) {
        this.numEpochs = numEpochs;
        this.exampleSize = exampleSize;
    }

    @Override
    public Classifier batchTrain(Dataset data) {
        Map<String, List<Example>> rankingMap = RankingBoosted.splitIntoRankings(data);
        Example[][] rankedExamples = new Example[rankingMap.size()][this.exampleSize];
        int index = 0;
        for (String subpop : rankingMap.keySet()) {
            List<Example> ranking = this.orderExamplesList(rankingMap.get(subpop));
            for (int j = 0; j < this.exampleSize; ++j) {
                rankedExamples[index][j] = ranking.get(j);
            }
            ++index;
        }
        Hyperplane s = this.populate_A(rankedExamples, new Hyperplane());
        s.increment(this.score, this.best_w0(rankedExamples));
        this.margins = this.initializeMargins(rankedExamples, s);
        ProgressCounter pc = new ProgressCounter("boosted perceptron training", "epoch", this.numEpochs);
        for (int e = 0; e < this.numEpochs; ++e) {
            s = this.batchTrain(s);
            pc.progress();
        }
        pc.finished();
        new ViewerFrame("hyperplane", s.toGUI());
        return s;
    }

    private Hyperplane populate_A(Example[][] rankedExamples, Hyperplane s) {
        for (int i = 0; i < rankedExamples.length; ++i) {
            Example correctEx = rankedExamples[i][0];
            HashSet<Feature> correctFtrs = new HashSet<Feature>();
            Iterator<Feature> it = correctEx.binaryFeatureIterator();
            while (it.hasNext()) {
                correctFtrs.add(it.next());
            }
            for (int j = 1; j < this.exampleSize; ++j) {
                Feature ftr;
                Example ex = rankedExamples[i][j];
                HashSet<Feature> actualFtrs = new HashSet<Feature>();
                Iterator<Feature> it2 = ex.binaryFeatureIterator();
                while (it2.hasNext()) {
                    ftr = it2.next();
                    if (!correctFtrs.contains(ftr)) {
                        this.update_A(this.A_neg, ftr, i, j);
                    }
                    actualFtrs.add(ftr);
                    this.features.add(ftr);
                }
                it2 = correctEx.binaryFeatureIterator();
                while (it2.hasNext()) {
                    ftr = it2.next();
                    if (!actualFtrs.contains(ftr)) {
                        this.update_A(this.A_pos, ftr, i, j);
                    }
                    this.features.add(ftr);
                }
            }
        }
        s.multiply(0.0);
        return s;
    }

    private Map<Feature, Set<Index>> update_A(Map<Feature, Set<Index>> map, Feature ftr, int i, int j) {
        Set<Index> set = new HashSet<Index>();
        if (map.containsKey(ftr)) {
            set = map.get(ftr);
        }
        set.add(new Index(i, j));
        map.put(ftr, set);
        return map;
    }

    private double best_w0(Example[][] rankedExamples) {
        double w0 = 0.001;
        double minExpLoss = 1.0E8;
        for (double w = 0.001; w < 10.0; w += 0.001) {
            double expLoss = this.initialExpLoss(w, rankedExamples);
            if (!(expLoss < minExpLoss)) continue;
            w0 = w;
            minExpLoss = expLoss;
        }
        return w0;
    }

    public double initialExpLoss(double w0, Example[][] rankedExamples) {
        double expLoss = 0.0;
        for (int i = 0; i < rankedExamples.length; ++i) {
            for (int j = 0; j < this.exampleSize; ++j) {
                if (!rankedExamples[i][j].getLabel().toString().endsWith("NEG 1.0]")) continue;
                expLoss += Math.exp(-w0 * (Math.log(rankedExamples[i][0].getWeight(this.score)) - Math.log(rankedExamples[i][j].getWeight(this.score))));
            }
        }
        return expLoss;
    }

    private double expLoss(double[][] margins) {
        double expLoss = 0.0;
        for (int i = 0; i < margins.length; ++i) {
            for (int j = 0; j < this.exampleSize; ++j) {
                expLoss += Math.exp(-1.0 * margins[i][j]);
            }
        }
        return expLoss;
    }

    private double[][] initializeMargins(Example[][] rankedExamples, Hyperplane s) {
        double[][] margins = new double[rankedExamples.length][this.exampleSize];
        for (int i = 0; i < margins.length; ++i) {
            for (int j = 0; j < this.exampleSize; ++j) {
                margins[i][j] = s.featureScore(this.score) * (Math.log(rankedExamples[i][0].getWeight(this.score)) - Math.log(rankedExamples[i][j].getWeight(this.score)));
                System.out.println("margins: " + i + " " + j + " " + margins[i][j]);
            }
        }
        return margins;
    }

    private Hyperplane batchTrain(Hyperplane s) {
        Feature bestFeature = null;
        double maxGain = 0.0;
        double W_Pos = 0.0;
        double W_Neg = 0.0;
        for (Feature ftr : this.features) {
            double gain;
            double cur_W_Pos = 0.0;
            double cur_W_Neg = 0.0;
            if (this.A_pos.containsKey(ftr)) {
                for (Index index : this.A_pos.get(ftr)) {
                    cur_W_Pos += Math.exp(-1.0 * this.margins[index.i][index.j]);
                }
            }
            if (this.A_neg.containsKey(ftr)) {
                for (Index index : this.A_neg.get(ftr)) {
                    cur_W_Neg += Math.exp(-1.0 * this.margins[index.i][index.j]);
                }
            }
            if (!((gain = Math.abs(Math.sqrt(cur_W_Pos) - Math.sqrt(cur_W_Neg))) > maxGain)) continue;
            maxGain = gain;
            bestFeature = ftr;
            W_Pos = cur_W_Pos;
            W_Neg = cur_W_Neg;
        }
        if (bestFeature != null) {
            double Z = this.expLoss(this.margins);
            double delta = 0.5 * Math.log((W_Pos + this.SMOOTH_PARAM * Z) / (W_Neg + this.SMOOTH_PARAM * Z));
            this.updateMargins(bestFeature, delta);
            s.increment(bestFeature, delta);
        }
        return s;
    }

    private void updateMargins(Feature feature, double delta) {
        Set<Index> pos = this.A_pos.get(feature);
        Set<Index> neg = this.A_neg.get(feature);
        if (pos != null) {
            for (Index ij : pos) {
                double[] dArray = this.margins[ij.i];
                int n = ij.j;
                dArray[n] = dArray[n] + delta;
            }
        }
        if (neg != null) {
            for (Index ij : neg) {
                double[] dArray = this.margins[ij.i];
                int n = ij.j;
                dArray[n] = dArray[n] - delta;
            }
        }
    }

    private List<Example> orderExamplesList(List<Example> ranking) {
        HashSet<Example> correct = new HashSet<Example>();
        HashSet<Example> incorrect = new HashSet<Example>();
        for (int i = 0; i < ranking.size(); ++i) {
            Example ex = ranking.get(i);
            if (ex.getLabel().toString().endsWith("POS 1.0]")) {
                correct.add(ex);
                continue;
            }
            incorrect.add(ex);
        }
        LinkedList<Example> ordered = new LinkedList<Example>();
        Iterator it = correct.iterator();
        while (it.hasNext()) {
            ordered.add((Example)it.next());
        }
        it = incorrect.iterator();
        while (it.hasNext()) {
            ordered.add((Example)it.next());
        }
        return ordered;
    }

    private class Index {
        int i;
        int j;

        public Index(int i, int j) {
            this.i = i;
            this.j = j;
        }
    }
}

