/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.ranking;

import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.ranking.BatchRankingLearner;
import edu.cmu.minorthird.util.BasicCommandLineProcessor;
import edu.cmu.minorthird.util.Saveable;
import edu.cmu.minorthird.util.StringUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.LineCharter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.VanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.LineNumberReader;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import javax.swing.JComponent;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class RankingEvaluation
implements Visible,
Saveable {
    private static final int GRAPHS_PER_PAGE = 10;
    private static final int NUM_TOP_TO_SHOW = 50;
    private TreeMap<String, List<Example>> rankedListMap = new TreeMap();
    private TreeMap<String, Set<String>> unrankedMap = new TreeMap();
    private TreeMap<String, List<Double>> scoreMap = new TreeMap();
    private TreeMap<String, Integer> numPosExamples = new TreeMap();
    private boolean guiFlag = false;
    private String loadedFile = null;
    private static final String EVAL_FORMAT_NAME = "Graph Searcher Evaluation";
    private static final String EVAL_EXT = ".gsev";

    public void extend(String rankingId, List<Example> ranking, BinaryClassifier classifier) {
        this.extend(rankingId, ranking, classifier, Collections.EMPTY_SET);
    }

    public void extend(String rankingId, List<Example> ranking, BinaryClassifier classifier, Set<String> unrankedPos) {
        BatchRankingLearner.sortByScore(classifier, ranking);
        this.rankedListMap.put(rankingId, ranking);
        ArrayList<Double> scores = new ArrayList<Double>(ranking.size());
        int k = 0;
        for (Example ex : ranking) {
            if (ex.getLabel().isPositive()) {
                this.increment(this.numPosExamples, rankingId, 1);
            }
            scores.set(k++, classifier.score(ex));
        }
        this.scoreMap.put(rankingId, scores);
        this.unrankedMap.put(rankingId, unrankedPos);
        this.increment(this.numPosExamples, rankingId, unrankedPos.size());
    }

    private void increment(TreeMap<String, Integer> map, String key, int delta) {
        Integer i = map.get(key);
        if (i == null) {
            map.put(key, new Integer(delta));
        } else {
            map.put(key, new Integer(i + delta));
        }
    }

    private List<Example> getRanking(String rankingId) {
        return this.rankedListMap.get(rankingId);
    }

    private double getScore(String rankingId, int rank) {
        return this.scoreMap.get(rankingId).get(rank - 1);
    }

    private Iterator<String> getRankingIterator() {
        return this.rankedListMap.keySet().iterator();
    }

    private int numPosExamples(String rankingId) {
        return this.numPosExamples.get(rankingId);
    }

    private boolean isPositive(String rankingId, Example ex) {
        return ex.getLabel().isPositive();
    }

    private int numRankings() {
        return this.rankedListMap.keySet().size();
    }

    private Set<Example> getUnrankedPositives(String rankingId) {
        return Collections.EMPTY_SET;
    }

    private String[][] exampleGroups(int groupSize) {
        int remainder = this.numRankings() % groupSize;
        int numRemainderGroups = remainder > 0 ? 1 : 0;
        String[][] group = new String[this.numRankings() / groupSize + numRemainderGroups][];
        for (int i = 0; i < group.length - numRemainderGroups; ++i) {
            group[i] = new String[groupSize];
        }
        if (numRemainderGroups > 0) {
            group[group.length - 1] = new String[remainder];
        }
        int j = 0;
        int k = 0;
        Iterator<String> i = this.getRankingIterator();
        while (i.hasNext()) {
            String name = i.next();
            group[j][k++] = name;
            if (k < group[j].length) continue;
            ++j;
            k = 0;
        }
        return group;
    }

    private double[][] recallAndPrecisionForEachK(String rankingId) {
        List<Example> ranking = this.getRanking(rankingId);
        int totalPos = this.numPosExamples(rankingId);
        double[] recall = new double[ranking.size() + 1];
        double[] precision = new double[ranking.size() + 1];
        int rank = 0;
        double numPosAboveRank = 0.0;
        for (Example ex : ranking) {
            ++rank;
            if (this.isPositive(rankingId, ex)) {
                numPosAboveRank += 1.0;
            }
            if (totalPos > 0) {
                recall[rank] = numPosAboveRank / (double)totalPos;
                precision[rank] = numPosAboveRank / (double)rank;
                continue;
            }
            precision[rank] = 1.0;
            recall[rank] = 1.0;
        }
        double[][] result = new double[][]{recall, precision};
        return result;
    }

    public double averagePrecision(String rankingId) {
        if (this.numPosExamples(rankingId) == 0) {
            return 1.0;
        }
        double rank = 0.0;
        double numPosAboveRank = 0.0;
        double totPrec = 0.0;
        List<Example> ranking = this.getRanking(rankingId);
        for (Example ex : ranking) {
            rank += 1.0;
            if (!this.isPositive(rankingId, ex)) continue;
            totPrec += (numPosAboveRank += 1.0) / rank;
        }
        return totPrec / (double)this.numPosExamples(rankingId);
    }

    public double maxF1(String rankingId) {
        if (this.numPosExamples(rankingId) == 0) {
            return 1.0;
        }
        double rank = 0.0;
        double numPosAboveRank = 0.0;
        double maxF1 = 0.0;
        List<Example> ranking = this.getRanking(rankingId);
        for (Example id : ranking) {
            double recall;
            double precision;
            rank += 1.0;
            if (this.isPositive(rankingId, id)) {
                numPosAboveRank += 1.0;
            }
            if (!((precision = numPosAboveRank / rank) + (recall = numPosAboveRank / (double)this.numPosExamples(rankingId)) > 0.0)) continue;
            double f1 = 2.0 * precision * recall / (precision + recall);
            maxF1 = Math.max(maxF1, f1);
        }
        return maxF1;
    }

    public double maxRecall(String rankingId) {
        if (this.numPosExamples(rankingId) == 0) {
            return 1.0;
        }
        double numRanked = 0.0;
        List<Example> ranking = this.getRanking(rankingId);
        for (Example ex : ranking) {
            if (!ex.getLabel().isPositive()) continue;
            numRanked += 1.0;
        }
        return numRanked / (double)this.numPosExamples(rankingId);
    }

    public double[] averageElevenPointPrecision() {
        double[] averagePrecision = new double[11];
        Iterator<String> i = this.getRankingIterator();
        while (i.hasNext()) {
            String name = i.next();
            double[] precision = this.elevenPointPrecision(name);
            for (int j = 0; j <= 10; ++j) {
                int n = j;
                averagePrecision[n] = averagePrecision[n] + precision[j];
            }
        }
        int j = 0;
        while (j <= 10) {
            int n = j++;
            averagePrecision[n] = averagePrecision[n] / (double)this.numRankings();
        }
        return averagePrecision;
    }

    public double[] elevenPointPrecision(String rankingId) {
        double[][] a = this.recallAndPrecisionForEachK(rankingId);
        double[] recall = a[0];
        double[] precision = a[1];
        double[] interpolatedPrecision = new double[11];
        for (int k = 1; k < recall.length; ++k) {
            double r = recall[k];
            double p = precision[k];
            for (int j = 0; j <= 10; ++j) {
                if (!(r >= (double)j / 10.0)) continue;
                interpolatedPrecision[j] = Math.max(interpolatedPrecision[j], p);
            }
        }
        return interpolatedPrecision;
    }

    public String toTable() {
        if (this.rankedListMap.keySet().size() == 0) {
            return "no examples?\n";
        }
        StringBuffer buf = new StringBuffer();
        DecimalFormat fmt = new DecimalFormat("0.000");
        DecimalFormat fmt2 = new DecimalFormat("0.0");
        buf.append("avgPr\tmaxF1\tmaxRec\t#pos\n");
        double totMaxF1 = 0.0;
        double totAvgPrec = 0.0;
        double totPos = 0.0;
        double totMaxRec = 0.0;
        Iterator<String> i = this.getRankingIterator();
        while (i.hasNext()) {
            String name = i.next();
            double ap = this.averagePrecision(name);
            double maxf = this.maxF1(name);
            double maxr = this.maxRecall(name);
            int np = this.numPosExamples(name);
            buf.append(fmt.format(ap) + "\t");
            buf.append(fmt.format(maxf) + "\t");
            buf.append(fmt.format(maxr) + "\t");
            buf.append(np + "\t");
            buf.append(name + "\n");
            totAvgPrec += ap;
            totMaxF1 += maxf;
            totMaxRec += maxr;
            totPos += (double)np;
        }
        buf.append("\n");
        buf.append(fmt.format(totAvgPrec / (double)this.numRankings()) + "\t");
        buf.append(fmt.format(totMaxF1 / (double)this.numRankings()) + "\t");
        buf.append(fmt.format(totMaxRec / (double)this.numRankings()) + "\t");
        buf.append(fmt2.format(totPos / (double)this.numRankings()) + "\t");
        buf.append("average\n");
        return buf.toString();
    }

    private double[] averageRecallAtEachK() {
        int longestRankedList = 0;
        Iterator<String> i = this.getRankingIterator();
        while (i.hasNext()) {
            String name = i.next();
            longestRankedList = Math.max(this.getRanking(name).size(), longestRankedList);
        }
        double[] recall = new double[longestRankedList + 1];
        Iterator<String> i2 = this.getRankingIterator();
        while (i2.hasNext()) {
            String name = i2.next();
            List<Example> ranking = this.getRanking(name);
            int rank = 0;
            double numPosAboveRank = 0.0;
            for (Example id : ranking) {
                ++rank;
                if (this.isPositive(name, id)) {
                    numPosAboveRank += 1.0;
                }
                if (this.numPosExamples(name) > 0) {
                    int n = rank;
                    recall[n] = recall[n] + numPosAboveRank / (double)this.numPosExamples(name);
                    continue;
                }
                recall[rank] = 1.0;
            }
            for (int k = rank + 1; k < recall.length; ++k) {
                recall[k] = recall[rank];
            }
        }
        int k = 1;
        while (k < recall.length) {
            int n = k++;
            recall[n] = recall[n] / (double)this.numRankings();
        }
        recall[0] = -1.0;
        return recall;
    }

    public String averageRecallAsFunctionOfK() {
        DecimalFormat fmt = new DecimalFormat("0.000");
        StringBuffer buf = new StringBuffer("");
        buf.append("K\tAvgRecall\n");
        double[] recall = this.averageRecallAtEachK();
        for (int k = 1; k < recall.length; ++k) {
            if (recall[k] == recall[k - 1]) continue;
            buf.append(k + "\t" + fmt.format(recall[k]) + "\n");
        }
        return buf.toString();
    }

    public String toTable(String name, int numToShowAllEntries) {
        List<Example> ranking = this.getRanking(name);
        StringBuffer buf = new StringBuffer();
        DecimalFormat fmt = new DecimalFormat("0.000");
        int rank = 0;
        for (Example id : ranking) {
            String tag;
            double score = this.getScore(name, ++rank);
            String string2 = tag = this.isPositive(name, id) ? "+" : "-";
            if (rank >= numToShowAllEntries && !tag.startsWith("+")) continue;
            buf.append(rank + "\t" + fmt.format(score) + "\t" + tag + "\t" + id + "\n");
        }
        for (Example id : this.getUnrankedPositives(name)) {
            String tag = "+";
            buf.append(">" + rank + "\t0\t" + tag + "\t" + id + "\n");
        }
        return buf.toString();
    }

    @Override
    public Viewer toGUI() {
        ParallelViewer v = new ParallelViewer();
        v.addSubView("Summary Table", new ComponentViewer(){
            static final long serialVersionUID = 20080206L;

            public JComponent componentFor(Object o) {
                RankingEvaluation gsEval = (RankingEvaluation)o;
                return new VanillaViewer(gsEval.toTable());
            }
        });
        ParallelViewer v2 = new ParallelViewer();
        v.addSubView("11-Pt Precision", v2);
        v2.addSubView("Averaged", new ComponentViewer(){
            static final long serialVersionUID = 20080206L;

            public JComponent componentFor(Object o) {
                RankingEvaluation gsEval = (RankingEvaluation)o;
                double[] avgPrec = gsEval.averageElevenPointPrecision();
                LineCharter lc = new LineCharter();
                lc.startCurve("11-Pt Avg Prec");
                for (int j = 0; j <= 10; ++j) {
                    lc.addPoint((double)j / 10.0, avgPrec[j]);
                }
                return lc.getPanel("11-Pt Average Interpolated Precision", "Recall", "Precision");
            }
        });
        String[][] groups = this.exampleGroups(10);
        for (int i = 0; i < groups.length; ++i) {
            String tag = groups.length == 1 ? "Details" : "Details: Group " + (i + 1);
            final String[] group = groups[i];
            v2.addSubView(tag, new ComponentViewer(){
                static final long serialVersionUID = 20080206L;

                public JComponent componentFor(Object o) {
                    RankingEvaluation gsEval = (RankingEvaluation)o;
                    LineCharter lc = new LineCharter();
                    for (int i = 0; i < group.length; ++i) {
                        String name = group[i];
                        double[] avgPrec = gsEval.elevenPointPrecision(name);
                        lc.startCurve(name);
                        for (int j = 0; j <= 10; ++j) {
                            lc.addPoint((double)j / 10.0, avgPrec[j]);
                        }
                    }
                    return lc.getPanel("11-Pt Interpolated Precision", "Recall", "Precision");
                }
            });
        }
        v.addSubView("AvgRecall vs Rank", new ComponentViewer(){
            static final long serialVersionUID = 20080206L;

            public JComponent componentFor(Object o) {
                double[] avgRec = RankingEvaluation.this.averageRecallAtEachK();
                LineCharter lc = new LineCharter();
                lc.startCurve("Recall vs Rank");
                for (int i = 1; i < avgRec.length; ++i) {
                    lc.addPoint(i, avgRec[i]);
                }
                return lc.getPanel("AvgRecall vs Rank", "Rank", "AvgRecall");
            }
        });
        ParallelViewer v3 = new ParallelViewer();
        v3.putTabsOnLeft();
        v.addSubView("Details", v3);
        Iterator<String> i = this.getRankingIterator();
        while (i.hasNext()) {
            final String name = i.next();
            v3.addSubView(name, new ComponentViewer(){
                static final long serialVersionUID = 20080206L;

                public JComponent componentFor(Object o) {
                    return new VanillaViewer(RankingEvaluation.this.toTable(name, 50));
                }
            });
        }
        v.setContent(this);
        return v;
    }

    @Override
    public String[] getFormatNames() {
        return new String[]{EVAL_FORMAT_NAME};
    }

    @Override
    public String getExtensionFor(String format) {
        return EVAL_EXT;
    }

    @Override
    public void saveAs(File file, String formatName) throws IOException {
        this.save(file);
    }

    @Override
    public Object restore(File file) throws IOException {
        return RankingEvaluation.load(file);
    }

    private void save(File file) throws IOException {
        PrintStream out = new PrintStream(new FileOutputStream(file));
        Iterator<String> i = this.getRankingIterator();
        while (i.hasNext()) {
            String name = i.next();
            List<Example> ranking = this.getRanking(name);
            int rank = 0;
            for (Example id : ranking) {
                double weight = this.getScore(name, ++rank);
                out.println(name + "\t" + id.getSource() + "\t" + rank + "\t" + weight);
            }
            for (Example id : ranking) {
                if (!this.isPositive(name, id)) continue;
                out.println(name + "\t" + id.getSource());
            }
            Set<Example> pos = this.getUnrankedPositives(name);
            for (Example id : pos) {
                out.println(name + "\t" + id.getSource());
            }
        }
        out.close();
    }

    public static RankingEvaluation load(File file) throws IOException {
        RankingEvaluation eval = new RankingEvaluation();
        eval.loadFromFile(file);
        return eval;
    }

    private void loadFromFile(File file) throws IOException {
        TreeMap tempListMap = new TreeMap();
        LineNumberReader in = new LineNumberReader(new InputStreamReader(new FileInputStream(file)));
        String line = null;
        while ((line = in.readLine()) != null) {
            String[] parts = line.split("\t");
            if (parts.length == 2) {
                Set<String> pos = this.unrankedMap.get(parts[0]);
                if (pos == null) {
                    pos = new TreeSet<String>();
                    this.unrankedMap.put(parts[0], pos);
                }
                pos.add(parts[1]);
                this.increment(this.numPosExamples, parts[0], 1);
                continue;
            }
            if (parts.length == 4) {
                List<Double> scores;
                ArrayList<String> ranking = (ArrayList<String>)tempListMap.get(parts[0]);
                if (ranking == null) {
                    ranking = new ArrayList<String>();
                    tempListMap.put(parts[0], ranking);
                }
                if ((scores = this.scoreMap.get(parts[0])) == null) {
                    scores = new ArrayList<Double>();
                    this.scoreMap.put(parts[0], scores);
                }
                scores.add(new Double(StringUtil.atof(parts[3])));
                ranking.add(parts[1]);
                continue;
            }
            throw new IllegalArgumentException(file + " line " + in.getLineNumber() + ": illegal format");
        }
        Iterator<String> i = this.getRankingIterator();
        while (i.hasNext()) {
            String rankingId = i.next();
            Set pos = this.unrankedMap.get(rankingId);
            if (pos == null) {
                pos = Collections.EMPTY_SET;
            }
            List<Double> scores = this.scoreMap.get(rankingId);
            List tempRanking = (List)tempListMap.get(rankingId);
            ArrayList<Example> ranking = new ArrayList<Example>(tempRanking.size());
            double[] newScores = new double[scores.size()];
            for (int j = 0; j < tempRanking.size(); ++j) {
                String exId = (String)tempRanking.get(j);
                if (pos.contains(exId)) {
                    ranking.set(j, new Example(new MutableInstance(exId), ClassLabel.binaryLabel(1.0)));
                    pos.remove(exId);
                } else {
                    ranking.set(j, new Example(new MutableInstance(exId), ClassLabel.binaryLabel(-1.0)));
                }
                newScores[j] = scores.get(j);
            }
            ArrayList<Double> newScoresList = new ArrayList<Double>(scores.size());
            for (int j = 0; j < newScores.length; ++j) {
                newScoresList.add(j, newScores[j]);
            }
            this.scoreMap.put(rankingId, newScoresList);
        }
    }

    public void processArguments(String[] args) {
        new MyCLP().processArguments(args);
    }

    public static void main(String[] args) throws IOException {
        RankingEvaluation eval = new RankingEvaluation();
        eval.processArguments(args);
        if (eval.guiFlag) {
            new ViewerFrame(eval.loadedFile, eval.toGUI());
        } else {
            System.out.println(eval.toTable());
        }
    }

    public class MyCLP
    extends BasicCommandLineProcessor {
        public void gui() {
            RankingEvaluation.this.guiFlag = true;
        }

        public void loadFrom(String s) {
            RankingEvaluation.this.loadedFile = s;
            try {
                RankingEvaluation.this.loadFromFile(new File(s));
            }
            catch (IOException ex) {
                ex.printStackTrace();
            }
        }
    }
}

