/*
 * Decompiled with CFR 0.152.
 */
package iitb.CRF;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import iitb.CRF.CRF;
import iitb.CRF.DataSequence;
import iitb.CRF.Soln;
import iitb.CRF.Trainer;
import java.io.Serializable;

public class Viterbi
implements Serializable {
    private static final long serialVersionUID = 8122L;
    protected CRF model;
    protected int beamsize;
    Entry[][] winningLabel;
    protected Entry finalSoln;
    protected DoubleMatrix2D Mi;
    protected DoubleMatrix1D Ri;
    static final /* synthetic */ boolean $assertionsDisabled;

    Viterbi(CRF model, int bs) {
        this.model = model;
        this.beamsize = bs;
        if (model.params.miscOptions.getProperty("beamSize") != null) {
            this.beamsize = Integer.parseInt(model.params.miscOptions.getProperty("beamSize"));
        }
    }

    void allocateScratch(int numY) {
        this.Mi = new DenseDoubleMatrix2D(numY, numY);
        this.Ri = new DenseDoubleMatrix1D(numY);
        this.winningLabel = new Entry[numY][];
        this.finalSoln = new Entry(this.beamsize, 0, 0);
    }

    double fillArray(DataSequence dataSeq, double[] lambda, boolean calcScore) {
        double corrScore = 0.0;
        int numY = this.model.numY;
        for (int i = 0; i < dataSeq.length(); ++i) {
            int yi;
            Trainer.computeLogMi(this.model.featureGenerator, lambda, dataSeq, i, this.Mi, this.Ri, false);
            for (yi = 0; yi < numY; ++yi) {
                this.winningLabel[yi][i].clear();
                this.winningLabel[yi][i].valid = true;
            }
            yi = this.model.edgeGen.firstY(i);
            while (yi < numY) {
                if (i > 0) {
                    int yp = this.model.edgeGen.first(yi);
                    while (yp < numY) {
                        double val = this.Mi.get(yp, yi) + this.Ri.get(yi);
                        this.winningLabel[yi][i].add(this.winningLabel[yp][i - 1], (float)val);
                        yp = this.model.edgeGen.next(yi, yp);
                    }
                } else {
                    this.winningLabel[yi][i].add((float)this.Ri.get(yi));
                }
                yi = this.model.edgeGen.nextY(yi, i);
            }
            if (!calcScore) continue;
            corrScore += this.Ri.get(dataSeq.y(i)) + (i > 0 ? this.Mi.get(dataSeq.y(i - 1), dataSeq.y(i)) : 0.0);
        }
        return corrScore;
    }

    public double viterbiSearchBackward(DataSequence dataSeq, double[] lambda, DoubleMatrix2D[] Mis, DoubleMatrix1D[] Ris, boolean calcCorrectScore) {
        if (this.Mi == null) {
            this.allocateScratch(this.model.numY);
        }
        if (this.winningLabel[0] == null || this.winningLabel[0].length < dataSeq.length()) {
            for (int yi = 0; yi < this.winningLabel.length; ++yi) {
                this.winningLabel[yi] = new Entry[dataSeq.length()];
                for (int l = 0; l < dataSeq.length(); ++l) {
                    this.winningLabel[yi][l] = new Entry(this.beamsize, yi, l);
                }
            }
        }
        Entry[] firstEntries = new Entry[this.model.numY];
        for (int yi = 0; yi < this.winningLabel.length; ++yi) {
            firstEntries[yi] = new Entry(1, yi, 0);
        }
        double corrScore = this.fillArrayBackward(dataSeq, lambda, firstEntries, Mis, Ris, calcCorrectScore);
        this.finalSoln.clear();
        this.finalSoln.valid = true;
        for (int yi = 0; yi < this.model.numY; ++yi) {
            this.finalSoln.add(firstEntries[yi], 0.0f);
        }
        return corrScore;
    }

    double fillArrayBackward(DataSequence dataSeq, double[] lambda, Entry[] firstEntries, DoubleMatrix2D[] Mis, DoubleMatrix1D[] Ris, boolean calcScore) {
        int yi;
        int i;
        double corrScore = 0.0;
        int numY = this.model.numY;
        for (i = dataSeq.length() - 1; i >= 0; --i) {
            for (yi = 0; yi < numY; ++yi) {
                this.winningLabel[yi][i].clear();
                this.winningLabel[yi][i].valid = true;
                if (i != dataSeq.length() - 1) continue;
                this.winningLabel[yi][i].add(0.0f);
            }
        }
        for (i = dataSeq.length() - 1; i >= 0; --i) {
            Trainer.computeLogMi(this.model.featureGenerator, lambda, dataSeq, i, this.Mi, this.Ri, false);
            Mis[i].assign(this.Mi);
            Ris[i].assign(this.Ri);
            if (i == 0) break;
            yi = this.model.edgeGen.firstY(i);
            while (yi < numY) {
                int yp = this.model.edgeGen.first(yi);
                while (yp < numY) {
                    double val = this.Mi.get(yp, yi) + this.Ri.get(yi);
                    this.winningLabel[yp][i - 1].add(this.winningLabel[yi][i], (float)val);
                    yp = this.model.edgeGen.next(yi, yp);
                }
                yi = this.model.edgeGen.nextY(yi, i);
            }
            if (!calcScore) continue;
            corrScore += this.Ri.get(dataSeq.y(i)) + (i > 0 ? this.Mi.get(dataSeq.y(i - 1), dataSeq.y(i)) : 0.0);
        }
        for (int yi2 = 0; yi2 < numY; ++yi2) {
            firstEntries[yi2].clear();
            firstEntries[yi2].valid = true;
            firstEntries[yi2].add(this.winningLabel[yi2][0], (float)this.Ri.get(yi2));
        }
        return corrScore;
    }

    protected void setSegment(DataSequence dataSeq, int prevPos, int pos, int label) {
        dataSeq.set_y(pos, label);
    }

    public void bestLabelSequence(DataSequence dataSeq, double[] lambda) {
        double corrScore = this.viterbiSearch(dataSeq, lambda, false);
        this.assignLabels(dataSeq);
    }

    void assignLabels(DataSequence dataSeq) {
        Soln ybest = this.finalSoln.get(0);
        ybest = ybest.prevSoln;
        int pos = -1;
        while (ybest != null) {
            pos = ybest.pos;
            this.setSegment(dataSeq, ybest.prevPos(), ybest.pos, ybest.label);
            ybest = ybest.prevSoln;
        }
        if (!$assertionsDisabled && pos < 0) {
            throw new AssertionError();
        }
    }

    public double viterbiSearch(DataSequence dataSeq, double[] lambda, boolean calcCorrectScore) {
        if (this.Mi == null) {
            this.allocateScratch(this.model.numY);
        }
        if (this.winningLabel[0] == null || this.winningLabel[0].length < dataSeq.length()) {
            for (int yi = 0; yi < this.winningLabel.length; ++yi) {
                this.winningLabel[yi] = new Entry[dataSeq.length()];
                for (int l = 0; l < dataSeq.length(); ++l) {
                    this.winningLabel[yi][l] = new Entry(l == 0 ? 1 : this.beamsize, yi, l);
                }
            }
        }
        double corrScore = this.fillArray(dataSeq, lambda, calcCorrectScore);
        this.finalSoln.clear();
        this.finalSoln.valid = true;
        for (int yi = 0; yi < this.model.numY; ++yi) {
            this.finalSoln.add(this.winningLabel[yi][dataSeq.length() - 1], 0.0f);
        }
        return corrScore;
    }

    int numSolutions() {
        return this.finalSoln.numSolns();
    }

    Soln getBestSoln(int k) {
        return this.finalSoln.get((int)k).prevSoln;
    }

    static {
        $assertionsDisabled = !Viterbi.class.desiredAssertionStatus();
    }

    protected class Entry {
        public Soln[] solns;
        boolean valid = true;
        static final /* synthetic */ boolean $assertionsDisabled;

        protected Entry() {
        }

        protected Entry(int beamsize, int id, int pos) {
            this.solns = new Soln[beamsize];
            for (int i = 0; i < this.solns.length; ++i) {
                this.solns[i] = this.newSoln(id, pos);
            }
        }

        protected Soln newSoln(int label, int pos) {
            return new Soln(label, pos);
        }

        protected void clear() {
            this.valid = false;
            for (int i = 0; i < this.solns.length; ++i) {
                this.solns[i].clear();
            }
        }

        protected int size() {
            return this.solns.length;
        }

        protected Soln get(int i) {
            return this.solns[i];
        }

        protected void insert(int i, float score, Soln prev) {
            for (int k = this.size() - 1; k > i; --k) {
                this.solns[k].copy(this.solns[k - 1]);
            }
            this.solns[i].setPrevSoln(prev, score);
        }

        protected void add(Entry e, float thisScore) {
            if (!$assertionsDisabled && !this.valid) {
                throw new AssertionError();
            }
            if (e == null) {
                this.add(thisScore);
                return;
            }
            int insertPos = 0;
            for (int i = 0; i < e.size() && insertPos < this.size(); ++i) {
                float score = e.get((int)i).score + thisScore;
                insertPos = this.findInsert(insertPos, score, e.get(i));
            }
        }

        protected int findInsert(int insertPos, float score, Soln prev) {
            while (insertPos < this.size()) {
                if (score >= this.get((int)insertPos).score) {
                    this.insert(insertPos, score, prev);
                    ++insertPos;
                    break;
                }
                ++insertPos;
            }
            return insertPos;
        }

        protected void add(float thisScore) {
            this.findInsert(0, thisScore, null);
        }

        protected int numSolns() {
            for (int i = 0; i < this.solns.length; ++i) {
                if (!this.solns[i].isClear()) continue;
                return i;
            }
            return this.size();
        }

        public void setValid() {
            this.valid = true;
        }

        void print() {
            String str = "";
            for (int i = 0; i < this.size(); ++i) {
                str = str + "[" + i + " " + this.solns[i].score + " i:" + this.solns[i].pos + " y:" + this.solns[i].label + "]";
            }
            System.out.println(str);
        }

        public String toString() {
            if (!($assertionsDisabled || this.solns != null && this.solns[0] != null)) {
                throw new AssertionError();
            }
            String toString = "";
            toString = toString + "[" + this.solns[0].pos + " " + this.solns[0].label + " " + this.solns[0].score;
            if (this.solns[0].prevSoln != null) {
                toString = toString + " : " + this.solns[0].prevSoln.pos + " " + this.solns[0].prevSoln.label + " " + this.solns[0].prevSoln.score;
            }
            toString = toString + "]";
            return toString;
        }

        static {
            $assertionsDisabled = !(class$iitb$CRF$Viterbi == null ? (class$iitb$CRF$Viterbi = Viterbi.class$("iitb.CRF.Viterbi")) : class$iitb$CRF$Viterbi).desiredAssertionStatus();
        }
    }
}

