/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.ranking;

import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.ranking.BatchRankingLearner;
import edu.cmu.minorthird.util.ProgressCounter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ListNet
extends BatchRankingLearner {
    private int numEpochs;
    private int maxNumBadSteps = 6;
    private double learnRate;
    private double INITLearnRate = 0.05;
    private final double minCEImprovement = 0.005;
    private final double RELEVANT = 1.0;
    private final double NON_RELEVANT = -1.0;
    private Dataset devData = null;
    private double[] pz;
    private double[] py;

    public ListNet() {
        this(15, 0.05);
    }

    public ListNet(int numEpochs) {
        this(numEpochs, 0.05);
    }

    public ListNet(int epochs, double rate) {
        this.numEpochs = epochs;
        this.INITLearnRate = rate;
    }

    public void setDevData(Dataset data) {
        this.devData = data;
    }

    @Override
    public Classifier batchTrain(Dataset data) {
        double x = 0.0;
        double smallestCE = Double.MAX_VALUE;
        int outcount = 0;
        Dataset traindata = this.devData == null ? this.separateDevData(data) : data;
        ArrayList<Hyperplane> ar = new ArrayList<Hyperplane>();
        Hyperplane w = new Hyperplane();
        ar.add(w);
        Map<String, List<Example>> queryMap = ListNet.splitIntoRankings(traindata);
        ProgressCounter pc = new ProgressCounter("ListNet training", "epoch", this.numEpochs);
        for (int e = 0; e < this.numEpochs; ++e) {
            this.setLearnRate();
            this.learnStep(queryMap, w);
            double cur_ce = this.calculateLoss(ListNet.splitIntoRankings(this.devData), w);
            x = smallestCE - cur_ce;
            if (e == 0 || x > 0.0) {
                Hyperplane tmp = new Hyperplane();
                tmp.increment(w);
                ar.add(tmp);
                smallestCE = cur_ce;
            }
            if (x < 0.005) {
                if (++outcount > this.maxNumBadSteps) {
                    return (Hyperplane)ar.get(ar.size() - 1);
                }
                int count = 0;
                while (x < 0.0 && count++ < this.maxNumBadSteps) {
                    this.learnRate /= 5.0;
                    Hyperplane hii = new Hyperplane();
                    hii.increment((Hyperplane)ar.get(ar.size() - 1));
                    this.learnStep(queryMap, hii);
                    cur_ce = this.calculateLoss(ListNet.splitIntoRankings(this.devData), hii);
                    x = smallestCE - cur_ce;
                    if (!(x > 0.0)) continue;
                    w = hii;
                    ar.add(hii);
                    smallestCE = cur_ce;
                }
            } else {
                outcount = 0;
            }
            pc.progress();
        }
        pc.finished();
        return (Classifier)ar.get(ar.size() - 1);
    }

    private void learnStep(Map<String, List<Example>> queryMap, Hyperplane w) {
        for (String subpop : queryMap.keySet()) {
            List<Example> ranking = queryMap.get(subpop);
            this.batchTrainSubPop(w, ranking);
        }
    }

    private void batchTrainSubPop(Hyperplane w, List<Example> ranking) {
        this.initialize(ranking, w);
        Hyperplane deltaW = this.calculateGradient(ranking);
        w.increment(deltaW, -this.learnRate);
    }

    private Hyperplane calculateGradient(List<Example> list) {
        Hyperplane hyp = new Hyperplane();
        for (int i = 0; i < list.size(); ++i) {
            Instance ins = list.get(i);
            Iterator<Feature> loop = ins.featureIterator();
            while (loop.hasNext()) {
                Feature f = loop.next();
                double term1 = this.py[i] * ins.getWeight(f);
                hyp.increment(f, -term1);
                double term2 = this.pz[i] * ins.getWeight(f);
                hyp.increment(f, term2);
            }
        }
        return hyp;
    }

    public double crossEntropy(double[] base, double[] b) {
        if (base.length != b.length) {
            throw new IllegalArgumentException("Probability distributions of different sizes!");
        }
        double sum = 0.0;
        for (int i = 0; i < base.length; ++i) {
            sum += base[i] * Math.log(b[i]);
        }
        return -sum;
    }

    private void initialize(List<Example> list, Hyperplane w) {
        Example ex;
        int i;
        this.pz = new double[list.size()];
        this.py = new double[list.size()];
        double sumY = 0.0;
        double sumZ = 0.0;
        for (i = 0; i < list.size(); ++i) {
            ex = list.get(i);
            sumY = ex.getLabel().isPositive() ? (sumY += Math.exp(1.0)) : (sumY += Math.exp(-1.0));
            sumZ += Math.exp(w.score(ex));
        }
        for (i = 0; i < list.size(); ++i) {
            ex = list.get(i);
            double tmp = ex.getLabel().isPositive() ? Math.exp(1.0) : Math.exp(-1.0);
            this.py[i] = tmp / sumY;
            this.pz[i] = Math.exp(w.score(ex)) / sumZ;
        }
    }

    public void setLearnRate() {
        this.learnRate = this.INITLearnRate;
    }

    private Dataset separateDevData(Dataset data) {
        Dataset.Split split = data.split(new CrossValSplitter<Example>(5));
        this.devData = split.getTest(1);
        return split.getTrain(1);
    }

    double calculateLoss(Map<String, List<Example>> queryMap, Hyperplane w) {
        double ce = 0.0;
        for (String subpop : queryMap.keySet()) {
            List<Example> ranking = queryMap.get(subpop);
            this.initialize(ranking, w);
            ce += this.crossEntropy(this.py, this.pz);
        }
        return ce;
    }
}

