/*
 * Decompiled with CFR 0.152.
 */
package iitb.CRF;

import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import iitb.CRF.CrfParams;
import iitb.CRF.Feature;
import iitb.CRF.FeatureGenerator;
import iitb.CRF.FeatureGeneratorNested;
import iitb.CRF.RobustMath;
import iitb.CRF.SegmentDataSequence;
import iitb.CRF.Trainer;
import iitb.CRF.Util;

class NestedTrainer
extends Trainer {
    DenseDoubleMatrix1D[] alpha_Y_Array;

    public NestedTrainer(CrfParams p) {
        super(p);
    }

    protected double computeFunctionGradient(double[] lambda, double[] grad) {
        if (this.params.doScaling) {
            return this.computeFunctionGradientLL(lambda, grad);
        }
        try {
            FeatureGeneratorNested featureGenNested = (FeatureGeneratorNested)this.featureGenerator;
            double logli = 0.0;
            for (int f = 0; f < lambda.length; ++f) {
                grad[f] = -1.0 * lambda[f] * this.params.invSigmaSquare;
                logli -= lambda[f] * lambda[f] * this.params.invSigmaSquare / 2.0;
            }
            boolean doScaling = false;
            this.diter.startScan();
            if (this.featureGenCache != null) {
                this.featureGenCache.startDataScan();
            }
            int numRecord = 0;
            while (this.diter.hasNext()) {
                int i;
                SegmentDataSequence dataSeq = (SegmentDataSequence)this.diter.next();
                if (this.featureGenCache != null) {
                    this.featureGenCache.nextDataIndex();
                }
                if (this.params.debugLvl > 1) {
                    Util.printDbg("Read next seq: " + numRecord + " logli " + logli);
                }
                for (int f = 0; f < lambda.length; ++f) {
                    this.ExpF[f] = 0.0;
                }
                int base = -1;
                if (this.alpha_Y_Array == null || this.alpha_Y_Array.length < dataSeq.length() - base) {
                    this.alpha_Y_Array = new DenseDoubleMatrix1D[2 * dataSeq.length()];
                    for (i = 0; i < this.alpha_Y_Array.length; ++i) {
                        this.alpha_Y_Array[i] = new DenseDoubleMatrix1D(this.numY);
                    }
                }
                if (this.beta_Y == null || this.beta_Y.length < dataSeq.length()) {
                    this.beta_Y = new DenseDoubleMatrix1D[2 * dataSeq.length()];
                    for (i = 0; i < this.beta_Y.length; ++i) {
                        this.beta_Y[i] = new DenseDoubleMatrix1D(this.numY);
                    }
                    this.scale = new double[2 * dataSeq.length()];
                }
                this.beta_Y[dataSeq.length() - 1].assign(1.0);
                this.scale[dataSeq.length() - 1] = 1.0;
                for (i = dataSeq.length() - 2; i >= 0; --i) {
                    if (doScaling && i + featureGenNested.maxMemory() < dataSeq.length()) {
                        int iL = i + featureGenNested.maxMemory();
                        this.scale[iL] = this.beta_Y[iL].zSum();
                        this.constMultiplier.multiplicator = 1.0 / this.scale[iL];
                        for (int j = i + 1; j <= iL; ++j) {
                            this.beta_Y[j].assign(this.constMultiplier);
                        }
                    }
                    this.beta_Y[i].assign(0.0);
                    this.scale[i] = 1.0;
                    for (int ell = 1; ell <= featureGenNested.maxMemory() && i + ell < dataSeq.length(); ++ell) {
                        featureGenNested.startScanFeaturesAt(dataSeq, i, i + ell);
                        this.initMDone = NestedTrainer.computeLogMi((FeatureGenerator)featureGenNested, lambda, this.Mi_YY, this.Ri_Y, true, this.reuseM, this.initMDone);
                        this.tmp_Y.assign(this.beta_Y[i + ell]);
                        this.tmp_Y.assign(this.Ri_Y, multFunc);
                        this.Mi_YY.zMult(this.tmp_Y, this.beta_Y[i], 1.0, 1.0, false);
                    }
                }
                double thisSeqLogli = 0.0;
                this.alpha_Y_Array[0].assign(1.0);
                int segmentStart = 0;
                int segmentEnd = -1;
                boolean invalid = false;
                for (int i2 = 0; i2 < dataSeq.length(); ++i2) {
                    if (segmentEnd < i2) {
                        segmentStart = i2;
                        segmentEnd = dataSeq.getSegmentEnd(i2);
                    }
                    if (segmentEnd - segmentStart + 1 > featureGenNested.maxMemory()) {
                        if (this.icall <= 1) {
                            System.out.println("Ignoring record with segment length greater than maxMemory " + numRecord);
                        }
                        invalid = true;
                        break;
                    }
                    this.alpha_Y_Array[i2 - base].assign(0.0);
                    float scaleProduct = 1.0f;
                    for (int j = i2 - featureGenNested.maxMemory() - base; j <= i2 - 1; ++j) {
                        if (j < 0) continue;
                        scaleProduct = (float)((double)scaleProduct * this.scale[j]);
                    }
                    for (int ell = 1; ell <= featureGenNested.maxMemory() && i2 - ell >= base; ++ell) {
                        boolean isSegment;
                        featureGenNested.startScanFeaturesAt(dataSeq, i2 - ell, i2);
                        this.initMDone = NestedTrainer.computeLogMi((FeatureGenerator)featureGenNested, lambda, this.Mi_YY, this.Ri_Y, true, this.reuseM, this.initMDone);
                        featureGenNested.startScanFeaturesAt(dataSeq, i2 - ell, i2);
                        boolean bl = isSegment = i2 - ell + 1 == segmentStart && i2 == segmentEnd;
                        while (featureGenNested.hasNext()) {
                            boolean allEllMatch;
                            Feature feature = featureGenNested.next();
                            int f = feature.index();
                            int yp = feature.y();
                            int yprev = feature.yprev();
                            float val = feature.value();
                            boolean bl2 = allEllMatch = isSegment && dataSeq.y(i2) == yp;
                            if (allEllMatch && (i2 - ell >= 0 && yprev == dataSeq.y(i2 - ell) || yprev < 0)) {
                                int n = f;
                                grad[n] = grad[n] + (double)val;
                                thisSeqLogli += (double)val * lambda[f];
                            }
                            if (yprev < 0) {
                                for (yprev = 0; yprev < this.Mi_YY.rows(); ++yprev) {
                                    int n = f;
                                    this.ExpF[n] = this.ExpF[n] + this.alpha_Y_Array[i2 - ell - base].get(yprev) * this.Ri_Y.get(yp) * this.Mi_YY.get(yprev, yp) * (double)val * this.beta_Y[i2].get(yp) / (double)scaleProduct;
                                }
                                continue;
                            }
                            int n = f;
                            this.ExpF[n] = this.ExpF[n] + this.alpha_Y_Array[i2 - ell - base].get(yprev) * this.Ri_Y.get(yp) * this.Mi_YY.get(yprev, yp) * (double)val * this.beta_Y[i2].get(yp) / (double)scaleProduct;
                        }
                        this.Mi_YY.zMult(this.alpha_Y_Array[i2 - ell - base], this.tmp_Y, 1.0, 0.0, true);
                        this.tmp_Y.assign(this.Ri_Y, multFunc);
                        this.alpha_Y_Array[i2 - base].assign(this.tmp_Y, sumFunc);
                    }
                    if (i2 - base - featureGenNested.maxMemory() >= 0) {
                        int iL = i2 - base - featureGenNested.maxMemory();
                        this.constMultiplier.multiplicator = 1.0 / this.scale[iL];
                        for (int j = iL; j <= i2 - base; ++j) {
                            this.alpha_Y_Array[j].assign(this.constMultiplier);
                        }
                    }
                    if (this.params.debugLvl > 2) {
                        System.out.println("Alpha-i " + this.alpha_Y_Array[i2 - base].toString());
                        System.out.println("Ri " + this.Ri_Y.toString());
                        System.out.println("Mi " + this.Mi_YY.toString());
                        System.out.println("Beta-i " + this.beta_Y[i2].toString());
                    }
                    if (this.params.debugLvl <= 1) continue;
                    System.out.println(" pos " + i2 + " " + thisSeqLogli);
                }
                if (!invalid) {
                    double Zx = this.alpha_Y_Array[dataSeq.length() - 1 - base].zSum();
                    thisSeqLogli -= NestedTrainer.log(Zx);
                    for (int i3 = 0; i3 < dataSeq.length() - base - featureGenNested.maxMemory(); ++i3) {
                        thisSeqLogli -= NestedTrainer.log(this.scale[i3]);
                    }
                    logli += thisSeqLogli;
                    for (int f = 0; f < grad.length; ++f) {
                        int n = f;
                        grad[n] = grad[n] - this.ExpF[f] / Zx;
                    }
                    if (this.params.debugLvl > 1) {
                        System.out.println("Sequence " + thisSeqLogli + " " + logli + " " + Zx);
                        System.out.println("Last Alpha-i " + this.alpha_Y_Array[dataSeq.length() - 1 - base].toString());
                    }
                }
                ++numRecord;
            }
            if (this.params.debugLvl > 2) {
                int f;
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(lambda[f] + " ");
                }
                System.out.println(" :x");
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(grad[f] + " ");
                }
                System.out.println(" :g");
            }
            if (this.params.debugLvl > 0) {
                Util.printDbg("Iter " + this.icall + " loglikelihood " + logli + " gnorm " + this.norm(grad) + " xnorm " + this.norm(lambda));
            }
            return logli;
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0;
        }
    }

    protected double computeFunctionGradientLL(double[] lambda, double[] grad) {
        try {
            int f;
            FeatureGeneratorNested featureGenNested = (FeatureGeneratorNested)this.featureGenerator;
            double logli = 0.0;
            for (f = 0; f < lambda.length; ++f) {
                grad[f] = -1.0 * lambda[f] * this.params.invSigmaSquare;
                logli -= lambda[f] * lambda[f] * this.params.invSigmaSquare / 2.0;
            }
            this.diter.startScan();
            if (this.featureGenCache != null) {
                this.featureGenCache.startDataScan();
            }
            int numRecord = 0;
            while (this.diter.hasNext()) {
                int i;
                SegmentDataSequence dataSeq = (SegmentDataSequence)this.diter.next();
                if (this.featureGenCache != null) {
                    this.featureGenCache.nextDataIndex();
                }
                if (this.params.debugLvl > 1) {
                    Util.printDbg("Read next seq: " + numRecord + " logli " + logli);
                }
                for (int f2 = 0; f2 < lambda.length; ++f2) {
                    this.ExpF[f2] = RobustMath.LOG0;
                }
                int base = -1;
                if (this.alpha_Y_Array == null || this.alpha_Y_Array.length < dataSeq.length() - base) {
                    this.alpha_Y_Array = new DenseDoubleMatrix1D[2 * dataSeq.length()];
                    for (i = 0; i < this.alpha_Y_Array.length; ++i) {
                        this.alpha_Y_Array[i] = new DenseDoubleMatrix1D(this.numY);
                    }
                }
                if (this.beta_Y == null || this.beta_Y.length < dataSeq.length()) {
                    this.beta_Y = new DenseDoubleMatrix1D[2 * dataSeq.length()];
                    for (i = 0; i < this.beta_Y.length; ++i) {
                        this.beta_Y[i] = new DenseDoubleMatrix1D(this.numY);
                    }
                }
                this.beta_Y[dataSeq.length() - 1].assign(0.0);
                for (i = dataSeq.length() - 2; i >= 0; --i) {
                    this.beta_Y[i].assign(RobustMath.LOG0);
                    for (int ell = 1; ell <= featureGenNested.maxMemory() && i + ell < dataSeq.length(); ++ell) {
                        featureGenNested.startScanFeaturesAt(dataSeq, i, i + ell);
                        this.initMDone = NestedTrainer.computeLogMi((FeatureGenerator)featureGenNested, lambda, this.Mi_YY, this.Ri_Y, false, this.reuseM, this.initMDone);
                        this.tmp_Y.assign(this.beta_Y[i + ell]);
                        this.tmp_Y.assign(this.Ri_Y, sumFunc);
                        RobustMath.logMult(this.Mi_YY, this.tmp_Y, this.beta_Y[i], 1.0, 1.0, false, this.edgeGen);
                    }
                }
                double thisSeqLogli = 0.0;
                this.alpha_Y_Array[0].assign(0.0);
                int segmentStart = 0;
                int segmentEnd = -1;
                boolean invalid = false;
                for (int i2 = 0; i2 < dataSeq.length(); ++i2) {
                    if (segmentEnd < i2) {
                        segmentStart = i2;
                        segmentEnd = dataSeq.getSegmentEnd(i2);
                    }
                    if (segmentEnd - segmentStart + 1 > featureGenNested.maxMemory()) {
                        if (this.icall == 0) {
                            System.out.println("Ignoring record with segment length greater than maxMemory " + numRecord);
                        }
                        invalid = true;
                        break;
                    }
                    this.alpha_Y_Array[i2 - base].assign(RobustMath.LOG0);
                    for (int ell = 1; ell <= featureGenNested.maxMemory() && i2 - ell >= base; ++ell) {
                        boolean isSegment;
                        featureGenNested.startScanFeaturesAt(dataSeq, i2 - ell, i2);
                        this.initMDone = NestedTrainer.computeLogMi((FeatureGenerator)featureGenNested, lambda, this.Mi_YY, this.Ri_Y, false, this.reuseM, this.initMDone);
                        featureGenNested.startScanFeaturesAt(dataSeq, i2 - ell, i2);
                        boolean bl = isSegment = i2 - ell + 1 == segmentStart && i2 == segmentEnd;
                        while (featureGenNested.hasNext()) {
                            boolean allEllMatch;
                            Feature feature = featureGenNested.next();
                            int f3 = feature.index();
                            int yp = feature.y();
                            int yprev = feature.yprev();
                            float val = feature.value();
                            boolean bl2 = allEllMatch = isSegment && dataSeq.y(i2) == yp;
                            if (allEllMatch && (i2 - ell >= 0 && yprev == dataSeq.y(i2 - ell) || yprev < 0)) {
                                int n = f3;
                                grad[n] = grad[n] + (double)val;
                                thisSeqLogli += (double)val * lambda[f3];
                            }
                            if (yprev < 0 && i2 - ell >= 0) {
                                for (yprev = 0; yprev < this.Mi_YY.rows(); ++yprev) {
                                    this.ExpF[f3] = RobustMath.logSumExp(this.ExpF[f3], this.alpha_Y_Array[i2 - ell - base].get(yprev) + this.Ri_Y.get(yp) + this.Mi_YY.get(yprev, yp) + RobustMath.log(val) + this.beta_Y[i2].get(yp));
                                }
                                continue;
                            }
                            if (i2 - ell < 0) {
                                this.ExpF[f3] = RobustMath.logSumExp(this.ExpF[f3], this.Ri_Y.get(yp) + RobustMath.log(val) + this.beta_Y[i2].get(yp));
                                continue;
                            }
                            this.ExpF[f3] = RobustMath.logSumExp(this.ExpF[f3], this.alpha_Y_Array[i2 - ell - base].get(yprev) + this.Ri_Y.get(yp) + this.Mi_YY.get(yprev, yp) + RobustMath.log(val) + this.beta_Y[i2].get(yp));
                        }
                        if (i2 - ell >= 0) {
                            RobustMath.logMult(this.Mi_YY, this.alpha_Y_Array[i2 - ell - base], this.tmp_Y, 1.0, 0.0, true, this.edgeGen);
                            this.tmp_Y.assign(this.Ri_Y, sumFunc);
                            RobustMath.logSumExp(this.alpha_Y_Array[i2 - base], this.tmp_Y);
                            continue;
                        }
                        RobustMath.logSumExp(this.alpha_Y_Array[i2 - base], this.Ri_Y);
                    }
                    if (this.params.debugLvl > 2) {
                        System.out.println("Alpha-i " + this.alpha_Y_Array[i2 - base].toString());
                        System.out.println("Ri " + this.Ri_Y.toString());
                        System.out.println("Mi " + this.Mi_YY.toString());
                        System.out.println("Beta-i " + this.beta_Y[i2].toString());
                    }
                    if (this.params.debugLvl <= 1) continue;
                    System.out.println(" pos " + i2 + " " + thisSeqLogli);
                }
                if (!invalid) {
                    double lZx = RobustMath.logSumExp(this.alpha_Y_Array[dataSeq.length() - 1 - base]);
                    logli += (thisSeqLogli -= lZx);
                    for (int f4 = 0; f4 < grad.length; ++f4) {
                        int n = f4;
                        grad[n] = grad[n] - RobustMath.exp(this.ExpF[f4] - lZx);
                    }
                    if (this.params.debugLvl > 1) {
                        System.out.println("Sequence " + thisSeqLogli + " " + logli + " " + Math.exp(lZx));
                        System.out.println("Last Alpha-i " + this.alpha_Y_Array[dataSeq.length() - 1 - base].toString());
                    }
                }
                ++numRecord;
            }
            if (this.params.debugLvl > 2) {
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(lambda[f] + " ");
                }
                System.out.println(" :x");
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(grad[f] + " ");
                }
                System.out.println(" :g");
            }
            if (this.params.debugLvl > 0) {
                Util.printDbg("Iter " + this.icall + " loglikelihood " + logli + " gnorm " + this.norm(grad) + " xnorm " + this.norm(lambda));
            }
            return logli;
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0;
        }
    }
}

