/*
 * Decompiled with CFR 0.152.
 */
package iitb.CRF;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import iitb.CRF.CRF;
import iitb.CRF.CandSegDataSequence;
import iitb.CRF.Constraint;
import iitb.CRF.CrfParams;
import iitb.CRF.DataIter;
import iitb.CRF.Feature;
import iitb.CRF.FeatureGeneratorNested;
import iitb.CRF.RestrictConstraint;
import iitb.CRF.RobustMath;
import iitb.CRF.SparseTrainer;
import iitb.CRF.Util;
import java.util.Iterator;

public class SegmentTrainer
extends SparseTrainer {
    protected DoubleMatrix1D[] alpha_Y_Array;
    protected DoubleMatrix1D[] alpha_Y_ArrayM;
    protected boolean[] initAlphaMDone;
    protected DoubleMatrix1D allZeroVector;

    public SegmentTrainer(CrfParams p) {
        super(p);
        this.logTrainer = true;
    }

    protected void init(CRF model, DataIter data, double[] l) {
        super.init(model, data, l);
        this.allZeroVector = this.newLogDoubleMatrix1D(this.numY);
        this.allZeroVector.assign(0.0);
    }

    protected double computeFunctionGradient(double[] lambda, double[] grad) {
        try {
            FeatureGeneratorNested featureGenNested = (FeatureGeneratorNested)this.featureGenerator;
            double logli = 0.0;
            for (int f = 0; f < lambda.length; ++f) {
                grad[f] = -1.0 * lambda[f] * this.params.invSigmaSquare;
                logli -= lambda[f] * lambda[f] * this.params.invSigmaSquare / 2.0;
            }
            this.diter.startScan();
            this.initMDone = false;
            if (this.featureGenCache != null) {
                this.featureGenCache.startDataScan();
            }
            int numRecord = 0;
            while (this.diter.hasNext()) {
                CandSegDataSequence dataSeq = (CandSegDataSequence)this.diter.next();
                if (this.featureGenCache != null) {
                    this.featureGenCache.nextDataIndex();
                }
                if (this.params.debugLvl > 1) {
                    Util.printDbg("Read next seq: " + numRecord + " logli " + logli);
                }
                for (int f = 0; f < lambda.length; ++f) {
                    this.ExpF[f] = RobustMath.LOG0;
                }
                int base = -1;
                if (this.alpha_Y_Array == null || this.alpha_Y_Array.length < dataSeq.length() - base) {
                    this.allocateAlphaBeta(2 * dataSeq.length() + 1);
                }
                if (this.reuseM) {
                    for (int i = dataSeq.length(); i >= 0; --i) {
                        this.initAlphaMDone[i] = false;
                    }
                }
                int dataSize = dataSeq.length();
                DoubleMatrix1D oldBeta = this.beta_Y[dataSeq.length() - 1];
                this.beta_Y[dataSeq.length() - 1] = this.allZeroVector;
                for (int i = dataSeq.length() - 2; i >= 0; --i) {
                    this.beta_Y[i].assign(RobustMath.LOG0);
                }
                CandSegDataSequence candidateSegs = dataSeq;
                for (int segEnd = dataSeq.length() - 1; segEnd >= 0; --segEnd) {
                    for (int nc = candidateSegs.numCandSegmentsEndingAt(segEnd) - 1; nc >= 0; --nc) {
                        int segStart = candidateSegs.candSegmentStart(segEnd, nc);
                        int ell = segEnd - segStart + 1;
                        int i = segStart - 1;
                        if (i < 0) continue;
                        this.initMDone = SegmentTrainer.computeLogMi(dataSeq, i, i + ell, featureGenNested, lambda, this.Mi_YY, this.Ri_Y, this.reuseM, this.initMDone);
                        this.tmp_Y.assign(this.Ri_Y);
                        if (i + ell < dataSize - 1) {
                            this.tmp_Y.assign(this.beta_Y[i + ell], sumFunc);
                        }
                        if (!this.reuseM) {
                            this.Mi_YY.zMult(this.tmp_Y, this.beta_Y[i], 1.0, 1.0, false);
                            continue;
                        }
                        this.beta_Y[i].assign(this.tmp_Y, RobustMath.logSumExpFunc);
                    }
                    if (!this.reuseM || segEnd - 1 < 0) continue;
                    this.tmp_Y.assign(this.beta_Y[segEnd - 1]);
                    this.Mi_YY.zMult(this.tmp_Y, this.beta_Y[segEnd - 1], 1.0, 0.0, false);
                }
                double thisSeqLogli = 0.0;
                this.alpha_Y_Array[0] = this.allZeroVector;
                int trainingSegmentEnd = -1;
                int trainingSegmentStart = 0;
                boolean trainingSegmentFound = true;
                boolean noneFired = true;
                for (int segEnd = 0; segEnd < dataSize; ++segEnd) {
                    this.alpha_Y_Array[segEnd - base].assign(RobustMath.LOG0);
                    if (trainingSegmentEnd < segEnd) {
                        if (!trainingSegmentFound && noneFired) {
                            System.out.println("Error: Training segment (" + trainingSegmentStart + " " + trainingSegmentEnd + ") not found amongst candidate segments");
                        }
                        trainingSegmentFound = false;
                        trainingSegmentStart = segEnd;
                        trainingSegmentEnd = dataSeq.getSegmentEnd(segEnd);
                    }
                    for (int nc = candidateSegs.numCandSegmentsEndingAt(segEnd) - 1; nc >= 0; --nc) {
                        int ell = segEnd - candidateSegs.candSegmentStart(segEnd, nc) + 1;
                        this.initMDone = SegmentTrainer.computeLogMi(dataSeq, segEnd - ell, segEnd, featureGenNested, lambda, this.Mi_YY, this.Ri_Y, this.reuseM, this.initMDone);
                        boolean mAdded = false;
                        boolean rAdded = false;
                        if (segEnd - ell >= 0) {
                            if (!this.reuseM) {
                                this.Mi_YY.zMult(this.alpha_Y_Array[segEnd - ell - base], this.newAlpha_Y, 1.0, 0.0, true);
                            } else {
                                if (!this.initAlphaMDone[segEnd - ell - base]) {
                                    this.alpha_Y_ArrayM[segEnd - ell - base].assign(RobustMath.LOG0);
                                    this.Mi_YY.zMult(this.alpha_Y_Array[segEnd - ell - base], this.alpha_Y_ArrayM[segEnd - ell - base], 1.0, 0.0, true);
                                    this.initAlphaMDone[segEnd - ell - base] = true;
                                }
                                this.newAlpha_Y.assign(this.alpha_Y_ArrayM[segEnd - ell - base]);
                            }
                            this.newAlpha_Y.assign(this.Ri_Y, sumFunc);
                        } else {
                            this.newAlpha_Y.assign(this.Ri_Y);
                        }
                        this.alpha_Y_Array[segEnd - base].assign(this.newAlpha_Y, RobustMath.logSumExpFunc);
                        featureGenNested.startScanFeaturesAt(dataSeq, segEnd - ell, segEnd);
                        while (featureGenNested.hasNext()) {
                            Feature feature = featureGenNested.next();
                            int f = feature.index();
                            int yp = feature.y();
                            int yprev = feature.yprev();
                            float val = feature.value();
                            if (dataSeq.holdsInTrainingData(feature, segEnd - ell, segEnd)) {
                                int n = f;
                                grad[n] = grad[n] + (double)val;
                                thisSeqLogli += (double)val * lambda[f];
                                noneFired = false;
                                if (this.params.debugLvl > 2) {
                                    System.out.println("Feature fired " + f + " " + feature);
                                }
                            }
                            if (yprev < 0) {
                                this.ExpF[f] = RobustMath.logSumExp(this.ExpF[f], this.newAlpha_Y.get(yp) + RobustMath.log(val) + this.beta_Y[segEnd].get(yp));
                                continue;
                            }
                            this.ExpF[f] = RobustMath.logSumExp(this.ExpF[f], this.alpha_Y_Array[segEnd - ell - base].get(yprev) + this.Ri_Y.get(yp) + this.Mi_YY.get(yprev, yp) + RobustMath.log(val) + this.beta_Y[segEnd].get(yp));
                        }
                        if (segEnd != trainingSegmentEnd || segEnd - ell + 1 != trainingSegmentStart) continue;
                        trainingSegmentFound = true;
                        double val1 = this.Ri_Y.get(dataSeq.y(trainingSegmentEnd));
                        double val2 = 0.0;
                        if (trainingSegmentStart > 0) {
                            val2 = this.Mi_YY.get(dataSeq.y(trainingSegmentStart - 1), dataSeq.y(trainingSegmentEnd));
                        }
                        if (val1 != RobustMath.LOG0 && val2 != RobustMath.LOG0) continue;
                        System.out.println("Error: training labels not covered in generated features " + val1 + " " + val2 + " yprev " + dataSeq.y(trainingSegmentStart - 1) + " y " + dataSeq.y(trainingSegmentEnd));
                        System.out.println(dataSeq);
                        featureGenNested.startScanFeaturesAt(dataSeq, segEnd - ell, segEnd);
                        while (featureGenNested.hasNext()) {
                            Feature feature = featureGenNested.next();
                            System.out.println(feature + " " + feature.yprev() + " " + feature.y());
                        }
                    }
                    if (this.params.debugLvl <= 2) continue;
                    System.out.println("Alpha-i " + this.alpha_Y_Array[segEnd - base].toString());
                    System.out.println("Ri " + this.Ri_Y.toString());
                    System.out.println("Mi " + this.Mi_YY.toString());
                    System.out.println("Beta-i " + this.beta_Y[segEnd].toString());
                }
                double lZx = this.alpha_Y_Array[dataSeq.length() - 1 - base].zSum();
                logli += (thisSeqLogli -= lZx);
                for (int f = 0; f < grad.length; ++f) {
                    int n = f;
                    grad[n] = grad[n] - SegmentTrainer.expLE(this.ExpF[f] - lZx);
                }
                if (noneFired) {
                    System.out.println("WARNING: no features fired in the training set");
                }
                if (thisSeqLogli > 0.0) {
                    System.out.println("ERROR: something is wrong Pr(y|x) > 1! for sequence " + numRecord);
                    System.out.println(dataSeq);
                }
                if (this.params.debugLvl > 1 || thisSeqLogli > 0.0) {
                    System.out.println("Sequence likelihood " + thisSeqLogli + " lZx " + lZx + " Zx " + Math.exp(lZx));
                    System.out.println("Last Alpha-i " + this.alpha_Y_Array[dataSeq.length() - 1 - base].toString());
                }
                this.beta_Y[dataSeq.length() - 1] = oldBeta;
                ++numRecord;
            }
            if (this.params.debugLvl > 2) {
                int f;
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(lambda[f] + " ");
                }
                System.out.println(" :x");
                for (f = 0; f < lambda.length; ++f) {
                    System.out.println(f + " " + featureGenNested.featureName(f) + " " + grad[f] + " ");
                }
                System.out.println(" :g");
            }
            if (this.params.debugLvl > 0) {
                if (this.icall == 0) {
                    Util.printDbg("Number of training records " + numRecord);
                }
                Util.printDbg("Iter " + this.icall + " loglikelihood " + logli + " gnorm " + this.norm(grad) + " xnorm " + this.norm(lambda));
            }
            return logli;
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0;
        }
    }

    protected void allocateAlphaBeta(int newSize) {
        int i;
        this.alpha_Y_Array = new DoubleMatrix1D[newSize];
        for (i = 0; i < this.alpha_Y_Array.length; ++i) {
            this.alpha_Y_Array[i] = this.newLogDoubleMatrix1D(this.numY);
        }
        this.beta_Y = new DoubleMatrix1D[newSize];
        for (i = 0; i < this.beta_Y.length; ++i) {
            this.beta_Y[i] = this.newLogDoubleMatrix1D(this.numY);
        }
        this.alpha_Y_ArrayM = new DoubleMatrix1D[newSize];
        for (i = 0; i < this.alpha_Y_ArrayM.length; ++i) {
            this.alpha_Y_ArrayM[i] = this.newLogDoubleMatrix1D(this.numY);
        }
        this.initAlphaMDone = new boolean[newSize];
    }

    public static double initLogMi(CandSegDataSequence dataSeq, int prevPos, int pos, FeatureGeneratorNested featureGenNested, double[] lambda, DoubleMatrix2D Mi, DoubleMatrix1D Ri) {
        featureGenNested.startScanFeaturesAt(dataSeq, prevPos, pos);
        Iterator constraints = dataSeq.constraints(prevPos, pos);
        double defaultValue = RobustMath.LOG0;
        if (Mi != null) {
            Mi.assign(defaultValue);
        }
        Ri.assign(defaultValue);
        if (constraints != null) {
            while (constraints.hasNext()) {
                Constraint constraint = (Constraint)constraints.next();
                if (constraint.type() != 3) continue;
                RestrictConstraint cons = (RestrictConstraint)constraint;
                cons.startScan();
                while (cons.hasNext()) {
                    cons.advance();
                    int y = cons.y();
                    int yprev = cons.yprev();
                    if (yprev < 0) {
                        Ri.set(y, 0.0);
                        continue;
                    }
                    if (Mi == null) continue;
                    Mi.set(yprev, y, 0.0);
                }
            }
        } else {
            defaultValue = 0.0;
            if (Mi != null) {
                Mi.assign(defaultValue);
            }
            Ri.assign(defaultValue);
        }
        return defaultValue;
    }

    static boolean computeLogMi(CandSegDataSequence dataSeq, int prevPos, int pos, FeatureGeneratorNested featureGenNested, double[] lambda, DoubleMatrix2D Mi, DoubleMatrix1D Ri, boolean reuseM, boolean initMDone) {
        if (reuseM && initMDone) {
            Mi = null;
        }
        SegmentTrainer.computeLogMi(dataSeq, prevPos, pos, featureGenNested, lambda, Mi, Ri);
        if (prevPos >= 0 && reuseM) {
            initMDone = true;
        }
        return initMDone;
    }

    static void computeLogMi(CandSegDataSequence dataSeq, int prevPos, int pos, FeatureGeneratorNested featureGenNested, double[] lambda, DoubleMatrix2D Mi, DoubleMatrix1D Ri) {
        double defaultValue = SegmentTrainer.initLogMi(dataSeq, prevPos, pos, featureGenNested, lambda, Mi, Ri);
        SparseTrainer.computeLogMiInitDone(featureGenNested, lambda, Mi, Ri, defaultValue);
    }
}

