/*
 * Decompiled with CFR 0.152.
 */
package iitb.CRF;

import iitb.CRF.CRF;
import iitb.CRF.CrfParams;
import iitb.CRF.DataIter;
import iitb.CRF.DataSequence;
import iitb.CRF.Evaluator;
import iitb.CRF.Feature;
import iitb.CRF.FeatureGenerator;
import iitb.CRF.Soln;
import iitb.CRF.Trainer;
import iitb.CRF.Util;
import iitb.CRF.Viterbi;
import java.util.Vector;

class CollinsTrainer
extends Trainer {
    int beamsize = 3;
    double beta = 0.05;
    boolean useUpdated = false;
    boolean voted = true;
    Soln[] solnPool;

    public CollinsTrainer(CrfParams p) {
        super(p);
        if (this.params.miscOptions.getProperty("beamSize") != null) {
            this.beamsize = Integer.parseInt(this.params.miscOptions.getProperty("beamSize"));
        }
        if (this.params.miscOptions.getProperty("beta") != null) {
            this.beta = Double.parseDouble(this.params.miscOptions.getProperty("beta"));
        }
        if (this.params.miscOptions.getProperty("UpdatedViterbi") != null) {
            this.useUpdated = this.params.miscOptions.getProperty("UpdatedViterbi").equalsIgnoreCase("true");
        }
        if (this.params.miscOptions.getProperty("voted") != null) {
            this.voted = this.params.miscOptions.getProperty("voted").equalsIgnoreCase("true");
        }
    }

    public void train(CRF model, DataIter data, double[] l, Evaluator eval) {
        this.init(model, data, l);
        double[] grad = this.gradLogli;
        Viterbi viterbiSearcher = model.getViterbi(this.beamsize);
        for (int i = 0; i < this.lambda.length; ++i) {
            grad[i] = 0.0;
            this.lambda[i] = 0.0;
        }
        Vector<Soln> viterbiS = new Vector<Soln>();
        for (int t = 0; t < this.params.maxIters; ++t) {
            int numErrs = 0;
            this.diter.startScan();
            int numRecord = 0;
            while (this.diter.hasNext()) {
                DataSequence dataSeq = this.diter.next();
                viterbiSearcher.viterbiSearch(dataSeq, this.useUpdated ? this.lambda : grad, false);
                Soln corrSoln = this.getCorrectSoln(dataSeq, this.useUpdated ? this.lambda : grad);
                double corrScore = corrSoln.score;
                int maxNum = viterbiSearcher.numSolutions();
                viterbiS.clear();
                for (int k = 0; k < maxNum; ++k) {
                    Soln viterbi = viterbiSearcher.getBestSoln(k);
                    if ((double)viterbi.score < corrScore * (1.0 - this.beta)) break;
                    if (this.isCorrect(viterbi, corrSoln)) continue;
                    viterbiS.add(viterbi);
                }
                if (viterbiS.size() > 0) {
                    while (corrSoln != null) {
                        Soln viterbi;
                        int s;
                        boolean differenceAtI = false;
                        for (s = 0; s < viterbiS.size(); ++s) {
                            viterbi = (Soln)viterbiS.elementAt(s);
                            if (viterbi != null && corrSoln.equals(viterbi)) continue;
                            differenceAtI = true;
                            break;
                        }
                        if (differenceAtI) {
                            ++numErrs;
                            this.updateWeights(corrSoln, 1.0, grad, dataSeq);
                            for (s = 0; s < viterbiS.size(); ++s) {
                                viterbi = (Soln)viterbiS.elementAt(s);
                                while (viterbi != null && viterbi.pos > corrSoln.prevPos()) {
                                    this.updateWeights(viterbi, -1.0 / (double)viterbiS.size(), grad, dataSeq);
                                    viterbi = viterbi.prevSoln;
                                }
                            }
                        }
                        for (s = 0; s < viterbiS.size(); ++s) {
                            viterbi = (Soln)viterbiS.elementAt(s);
                            while (viterbi != null && viterbi.pos > corrSoln.prevPos()) {
                                viterbi = viterbi.prevSoln;
                            }
                            viterbiS.set(s, viterbi);
                        }
                        corrSoln = corrSoln.prevSoln;
                    }
                }
                for (int f = 0; f < this.lambda.length; ++f) {
                    int n = f;
                    this.lambda[n] = this.lambda[n] + grad[f];
                }
                ++numRecord;
            }
            if (this.params.debugLvl > 0) {
                Util.printDbg("Iteration " + t + " numErrs " + numErrs);
            }
            if (numErrs == 0) break;
        }
    }

    boolean isCorrect(Soln viterbi, Soln corr) {
        while (viterbi != null && corr != null) {
            if (!viterbi.equals(corr)) {
                return false;
            }
            corr = corr.prevSoln;
            viterbi = viterbi.prevSoln;
        }
        return viterbi == null && corr == null;
    }

    int getSegmentEnd(DataSequence dataSeq, int ss) {
        return ss;
    }

    void startFeatureGenerator(FeatureGenerator _featureGenerator, DataSequence dataSeq, Soln soln) {
        _featureGenerator.startScanFeaturesAt(dataSeq, soln.pos);
    }

    void updateWeights(Soln soln, double wt, double[] grad, DataSequence dataSeq) {
        this.startFeatureGenerator(this.featureGenerator, dataSeq, soln);
        while (this.featureGenerator.hasNext()) {
            Feature feature = this.featureGenerator.next();
            int f = feature.index();
            int yp = feature.y();
            int yprev = feature.yprev();
            float val = feature.value();
            if (soln.label != yp || (soln.prevPos() < 0 || yprev != soln.prevSoln.label) && yprev >= 0) continue;
            int n = f;
            grad[n] = grad[n] + wt * (double)val;
        }
    }

    Soln getCorrectSoln(DataSequence dataSeq, double[] grad) {
        int se = 0;
        Soln prevSoln = null;
        if (this.solnPool == null || this.solnPool.length < dataSeq.length()) {
            this.solnPool = new Soln[dataSeq.length()];
            int i = 0;
            while (i < dataSeq.length()) {
                this.solnPool[i++] = new Soln(0, 0);
            }
        }
        int ss = 0;
        while (ss < dataSeq.length()) {
            se = this.getSegmentEnd(dataSeq, ss);
            Soln soln = this.solnPool[ss];
            soln.pos = se;
            soln.label = dataSeq.y(ss);
            soln.prevSoln = prevSoln;
            soln.score = prevSoln == null ? 0.0f : prevSoln.score;
            this.startFeatureGenerator(this.featureGenerator, dataSeq, soln);
            while (this.featureGenerator.hasNext()) {
                Feature feature = this.featureGenerator.next();
                int f = feature.index();
                int yp = feature.y();
                int yprev = feature.yprev();
                float val = feature.value();
                if (soln.label != yp || (soln.prevPos() < 0 || yprev != soln.prevSoln.label) && yprev >= 0) continue;
                soln.score = (float)((double)soln.score + grad[f] * (double)val);
            }
            prevSoln = soln;
            ss = se + 1;
        }
        return prevSoln;
    }
}

