/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.sequential.InstanceFromSequence;
import edu.cmu.minorthird.classify.sequential.SequenceConstants;
import edu.cmu.minorthird.util.MathUtil;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;

public class BeamSearcher
implements SequenceConstants,
Serializable {
    private static final long serialVersionUID = 20080207L;
    private static boolean OLD_VERSION = false;
    private static Logger log = Logger.getLogger(BeamSearcher.class);
    private static final boolean DEBUG = false;
    private int historySize;
    private String[] possibleClassLabels;
    private Classifier classifier;
    private int beamSize = 10;
    private transient Beam beam = new Beam();
    private transient Instance[] instances;
    private transient String[] history;

    public BeamSearcher(Classifier classifier, int historySize, ExampleSchema schema) {
        this.classifier = classifier;
        this.historySize = historySize;
        this.possibleClassLabels = schema.validClassNames();
        if (this.possibleClassLabels.length < 2) {
            throw new IllegalArgumentException("possibleClassLabels.length=" + this.possibleClassLabels.length + " <2 ???");
        }
    }

    public int getMaxBeamSize() {
        return this.beamSize;
    }

    public void setMaxBeamSize(int n) {
        this.beamSize = n;
    }

    public ClassLabel[] bestLabelSequence(Instance[] instances) {
        this.doSearch(instances);
        return this.viterbi(0);
    }

    public static Instance getBeamInstance(Instance instance, int historySize) {
        String[] history = new String[historySize];
        InstanceFromSequence.fillHistory(history, new String[0], 0);
        return new InstanceFromSequence(instance, history);
    }

    public void doSearch(Instance[] sequence) {
        this.instances = sequence;
        if (this.possibleClassLabels.length < 2) {
            throw new IllegalStateException("possibleClassLabels.length=" + this.possibleClassLabels.length + " <2 ???");
        }
        this.history = new String[this.historySize];
        this.beam = new Beam();
        this.beam.add(new BeamEntry());
        for (int i = 0; i < this.instances.length; ++i) {
            Beam nextBeam = new Beam();
            for (int j = 0; j < Math.min(this.beam.size(), this.beamSize); ++j) {
                BeamEntry entry = this.beam.get(j);
                Instance beamInstance = entry.getBeamInstance(this.instances[i]);
                ClassLabel label = this.classifier.classification(beamInstance);
                for (int el = 0; el < this.possibleClassLabels.length; ++el) {
                    double w = label.getWeight(this.possibleClassLabels[el]);
                    nextBeam.add(entry.extend(this.possibleClassLabels[el], w));
                }
            }
            nextBeam.sort();
            this.beam = nextBeam;
        }
    }

    public void doSearch(Instance[] sequence, ClassLabel[] template) {
        this.instances = sequence;
        if (this.possibleClassLabels.length < 2) {
            throw new IllegalStateException("possibleClassLabels.length=" + this.possibleClassLabels.length + " <2 ???");
        }
        this.history = new String[this.historySize];
        this.beam = new Beam();
        this.beam.add(new BeamEntry());
        for (int i = 0; i < this.instances.length; ++i) {
            Beam nextBeam = new Beam();
            for (int j = 0; j < Math.min(this.beam.size(), this.beamSize); ++j) {
                BeamEntry entry = this.beam.get(j);
                Instance beamInstance = entry.getBeamInstance(this.instances[i]);
                ClassLabel label = this.classifier.classification(beamInstance);
                for (int el = 0; el < this.possibleClassLabels.length; ++el) {
                    if (template.length >= i + 1 && template[i] != null && !template[i].bestClassName().equals(this.possibleClassLabels[el])) continue;
                    double w = label.getWeight(this.possibleClassLabels[el]);
                    nextBeam.add(entry.extend(this.possibleClassLabels[el], w));
                }
            }
            nextBeam.sort();
            this.beam = nextBeam;
        }
    }

    public int getNumberOfSolutionsFound() {
        return this.beam.size();
    }

    public ClassLabel[] viterbi(int k) {
        ClassLabel[] result = new ClassLabel[this.instances.length];
        BeamEntry entry = this.beam.get(k);
        for (int i = 0; i < this.instances.length; ++i) {
            result[i] = entry.toClassLabel(i);
        }
        return result;
    }

    public float score(int k) {
        return (float)this.beam.get((int)k).score;
    }

    public String explain(Instance[] sequence) {
        StringBuffer buf = new StringBuffer("");
        this.doSearch(sequence);
        BeamEntry targetEntry = this.beam.get(0);
        BeamEntry entry = new BeamEntry();
        for (int i = 0; i < sequence.length; ++i) {
            buf.append("Classification for instance " + i + " is " + targetEntry.labels[i] + " (score " + targetEntry.scores[i] + "):\n");
            buf.append(this.classifier.explain(entry.getBeamInstance(sequence[i])));
            entry = entry.extend(targetEntry.labels[i], targetEntry.scores[i]);
            buf.append("\nRunning total score: " + entry.score + "\n\n");
        }
        return buf.toString();
    }

    public Explanation getExplanation(Instance[] sequence) {
        this.doSearch(sequence);
        BeamEntry targetEntry = this.beam.get(0);
        BeamEntry entry = new BeamEntry();
        Explanation.Node top = new Explanation.Node("BeamSearcher Classification");
        for (int i = 0; i < sequence.length; ++i) {
            Explanation.Node seqEx = new Explanation.Node("Classification for instance " + i + " is " + targetEntry.labels[i] + " (score " + targetEntry.scores[i] + "):\n");
            Explanation.Node explan = this.classifier.getExplanation(sequence[i]).getTopNode();
            if (explan == null) {
                explan = new Explanation.Node(this.classifier.explain(entry.getBeamInstance(sequence[i])));
            }
            seqEx.add(explan);
            entry = entry.extend(targetEntry.labels[i], targetEntry.scores[i]);
            Explanation.Node score = new Explanation.Node("\nRunning total score: " + entry.score + "\n\n");
            seqEx.add(score);
            top.add(seqEx);
        }
        Explanation ex = new Explanation(top);
        return ex;
    }

    private class BeamKey {
        private String[] keyHistory;

        public BeamKey(BeamEntry entry) {
            this.keyHistory = new String[BeamSearcher.this.historySize];
            entry.fillHistory(this.keyHistory);
        }

        public int hashCode() {
            int h = 73643674;
            for (int i = 0; i < this.keyHistory.length; ++i) {
                if (OLD_VERSION) {
                    h ^= this.keyHistory.hashCode();
                    continue;
                }
                h ^= this.keyHistory[i].hashCode();
            }
            return h;
        }

        public boolean equals(Object o) {
            if (!(o instanceof BeamKey)) {
                return false;
            }
            BeamKey b = (BeamKey)o;
            if (b.keyHistory.length != this.keyHistory.length) {
                return false;
            }
            for (int i = 0; i < b.keyHistory.length; ++i) {
                if (this.keyHistory[i].equals(b.keyHistory[i])) continue;
                return false;
            }
            return true;
        }

        public String toString() {
            String path = "[Key ";
            for (int i = 0; i < this.keyHistory.length; ++i) {
                path = path + this.keyHistory[i] + " ";
            }
            path = path + "]";
            return path;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class BeamEntry
    implements Comparable<BeamEntry> {
        public String[] labels = new String[0];
        public double[] scores = new double[0];
        public double score = 0.0;

        private BeamEntry() {
        }

        @Override
        public int compareTo(BeamEntry other) {
            return MathUtil.sign(other.score - this.score);
        }

        public ClassLabel toClassLabel(int i) {
            return new ClassLabel(this.labels[i], this.scores[i]);
        }

        public BeamEntry extend(String label, double labelScore) {
            BeamEntry result = new BeamEntry();
            result.labels = new String[this.labels.length + 1];
            result.scores = new double[this.labels.length + 1];
            for (int i = 0; i < this.labels.length; ++i) {
                result.labels[i] = this.labels[i];
                result.scores[i] = this.scores[i];
            }
            result.labels[this.labels.length] = label;
            result.scores[this.labels.length] = labelScore;
            result.score = this.score + labelScore;
            return result;
        }

        public Instance getBeamInstance(Instance instance) {
            this.fillHistory(BeamSearcher.this.history);
            return new InstanceFromSequence(instance, BeamSearcher.this.history);
        }

        public void fillHistory(String[] history) {
            InstanceFromSequence.fillHistory(history, this.labels, this.labels.length);
        }

        public String toString() {
            return "[entry: " + this.labels + ";" + this.scores + "; score:" + this.score + "]";
        }
    }

    private class Beam {
        private List<BeamEntry> list = new ArrayList<BeamEntry>();
        private Map<BeamKey, BeamEntry> keyMap = new HashMap<BeamKey, BeamEntry>();

        private Beam() {
        }

        public BeamEntry get(int i) {
            return this.list.get(i);
        }

        public void add(BeamEntry entry) {
            BeamKey key = new BeamKey(entry);
            BeamEntry existingEntry = this.keyMap.get(key);
            if (existingEntry == null || existingEntry.score < entry.score) {
                if (existingEntry != null) {
                    this.list.remove(existingEntry);
                }
                this.list.add(entry);
                this.keyMap.put(key, entry);
            }
        }

        public int size() {
            return this.list.size();
        }

        public void sort() {
            Collections.sort(this.list);
        }
    }
}

