/*
 * Decompiled with CFR 0.152.
 */
package iitb.CRF;

import cern.colt.function.DoubleDoubleFunction;
import cern.colt.function.DoubleFunction;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import iitb.CRF.CRF;
import iitb.CRF.CrfParams;
import iitb.CRF.DataIter;
import iitb.CRF.DataSequence;
import iitb.CRF.EdgeGenerator;
import iitb.CRF.Evaluator;
import iitb.CRF.Feature;
import iitb.CRF.FeatureGenCache;
import iitb.CRF.FeatureGenerator;
import iitb.CRF.RobustMath;
import iitb.CRF.Util;
import riso.numerical.LBFGS;

public class Trainer {
    protected int numF;
    protected int numY;
    double[] gradLogli;
    double[] diag;
    double[] lambda;
    protected boolean reuseM;
    protected boolean initMDone = false;
    protected double[] ExpF;
    double[] scale;
    double[] rLogScale;
    protected DoubleMatrix2D Mi_YY;
    protected DoubleMatrix1D Ri_Y;
    protected DoubleMatrix1D alpha_Y;
    protected DoubleMatrix1D newAlpha_Y;
    protected DoubleMatrix1D[] beta_Y;
    protected DoubleMatrix1D tmp_Y;
    static MultFunc multFunc = new MultFunc();
    protected static SumFunc sumFunc = new SumFunc();
    MultSingle constMultiplier = new MultSingle();
    protected DataIter diter;
    FeatureGenerator featureGenerator;
    protected CrfParams params;
    EdgeGenerator edgeGen;
    protected int icall;
    Evaluator evaluator = null;
    FeatureGenCache featureGenCache;

    protected double norm(double[] ar) {
        double v = 0.0;
        for (int f = 0; f < ar.length; ++f) {
            v += ar[f] * ar[f];
        }
        return Math.sqrt(v);
    }

    public Trainer(CrfParams p) {
        this.params = p;
    }

    public void train(CRF model, DataIter data, double[] l, Evaluator eval) {
        this.init(model, data, l);
        this.evaluator = eval;
        if (this.params.debugLvl > 0) {
            Util.printDbg("Number of features :" + this.lambda.length);
        }
        this.doTrain();
    }

    double getInitValue() {
        return this.params.initValue;
    }

    protected void init(CRF model, DataIter data, double[] l) {
        this.edgeGen = model.edgeGen;
        this.lambda = l;
        this.numY = model.numY;
        this.diter = data;
        this.featureGenerator = model.featureGenerator;
        this.numF = this.featureGenerator.numFeatures();
        this.gradLogli = new double[this.numF];
        this.diag = new double[this.numF];
        this.ExpF = new double[this.lambda.length];
        this.initMatrices();
        this.reuseM = this.params.reuseM;
        if (this.params.miscOptions.getProperty("cache", "false").equals("true")) {
            this.featureGenCache = new FeatureGenCache(this.featureGenerator);
            this.featureGenerator = this.featureGenCache;
        } else {
            this.featureGenCache = null;
        }
    }

    void initMatrices() {
        this.Mi_YY = new DenseDoubleMatrix2D(this.numY, this.numY);
        this.Ri_Y = new DenseDoubleMatrix1D(this.numY);
        this.alpha_Y = new DenseDoubleMatrix1D(this.numY);
        this.newAlpha_Y = new DenseDoubleMatrix1D(this.numY);
        this.tmp_Y = new DenseDoubleMatrix1D(this.numY);
    }

    void doTrain() {
        int j;
        double xtol = 1.0E-16;
        int[] iprint = new int[2];
        int[] iflag = new int[1];
        this.icall = 0;
        iprint[0] = this.params.debugLvl - 2;
        iprint[1] = this.params.debugLvl - 1;
        iflag[0] = 0;
        for (j = 0; j < this.lambda.length; ++j) {
            this.lambda[j] = this.getInitValue();
        }
        do {
            double f = this.computeFunctionGradient(this.lambda, this.gradLogli);
            f = -1.0 * f;
            j = 0;
            while (j < this.lambda.length) {
                int n = j++;
                this.gradLogli[n] = this.gradLogli[n] * -1.0;
            }
            if (this.evaluator != null && !this.evaluator.evaluate()) break;
            try {
                LBFGS.lbfgs(this.numF, this.params.mForHessian, this.lambda, f, this.gradLogli, false, this.diag, iprint, this.params.epsForConvergence, xtol, iflag);
            }
            catch (LBFGS.ExceptionWithIflag e) {
                System.err.println("CRF: lbfgs failed.\n" + e);
                if (e.iflag == -1) {
                    System.err.println("Possible reasons could be: \n \t 1. Bug in the feature generation or data handling code\n\t 2. Not enough features to make observed feature value==expected value\n");
                }
                return;
            }
            ++this.icall;
        } while (iflag[0] != 0 && this.icall <= this.params.maxIters);
    }

    protected double computeFunctionGradient(double[] lambda, double[] grad) {
        this.initMDone = false;
        if (this.params.trainerType.equals("ll")) {
            return this.computeFunctionGradientLL(lambda, grad);
        }
        double logli = 0.0;
        try {
            for (int f = 0; f < lambda.length; ++f) {
                grad[f] = -1.0 * lambda[f] * this.params.invSigmaSquare;
                logli -= lambda[f] * lambda[f] * this.params.invSigmaSquare / 2.0;
            }
            boolean doScaling = this.params.doScaling;
            this.diter.startScan();
            if (this.featureGenCache != null) {
                this.featureGenCache.startDataScan();
            }
            int numRecord = 0;
            numRecord = 0;
            while (this.diter.hasNext()) {
                int f;
                int i;
                DataSequence dataSeq = this.diter.next();
                if (this.featureGenCache != null) {
                    this.featureGenCache.nextDataIndex();
                }
                if (this.params.debugLvl > 1) {
                    Util.printDbg("Read next seq: " + numRecord + " logli " + logli);
                }
                this.alpha_Y.assign(1.0);
                for (int f2 = 0; f2 < lambda.length; ++f2) {
                    this.ExpF[f2] = 0.0;
                }
                if (this.beta_Y == null || this.beta_Y.length < dataSeq.length()) {
                    this.beta_Y = new DenseDoubleMatrix1D[2 * dataSeq.length()];
                    for (i = 0; i < this.beta_Y.length; ++i) {
                        this.beta_Y[i] = new DenseDoubleMatrix1D(this.numY);
                    }
                    this.scale = new double[2 * dataSeq.length()];
                }
                this.scale[dataSeq.length() - 1] = doScaling ? (double)this.numY : 1.0;
                this.beta_Y[dataSeq.length() - 1].assign(1.0 / this.scale[dataSeq.length() - 1]);
                for (i = dataSeq.length() - 1; i > 0; --i) {
                    if (this.params.debugLvl > 2) {
                        Util.printDbg("Features fired");
                    }
                    this.initMDone = Trainer.computeLogMi(this.featureGenerator, lambda, dataSeq, i, this.Mi_YY, this.Ri_Y, true, this.reuseM, this.initMDone);
                    this.tmp_Y.assign(this.beta_Y[i]);
                    this.tmp_Y.assign(this.Ri_Y, multFunc);
                    RobustMath.Mult(this.Mi_YY, this.tmp_Y, this.beta_Y[i - 1], 1.0, 0.0, false, this.edgeGen);
                    double d = this.scale[i - 1] = doScaling ? this.beta_Y[i - 1].zSum() : 1.0;
                    if (this.scale[i - 1] < 1.0 && this.scale[i - 1] > -1.0) {
                        this.scale[i - 1] = 1.0;
                    }
                    this.constMultiplier.multiplicator = 1.0 / this.scale[i - 1];
                    this.beta_Y[i - 1].assign(this.constMultiplier);
                }
                double thisSeqLogli = 0.0;
                for (int i2 = 0; i2 < dataSeq.length(); ++i2) {
                    this.initMDone = Trainer.computeLogMi(this.featureGenerator, lambda, dataSeq, i2, this.Mi_YY, this.Ri_Y, true, this.reuseM, this.initMDone);
                    this.featureGenerator.startScanFeaturesAt(dataSeq, i2);
                    if (i2 > 0) {
                        this.tmp_Y.assign(this.alpha_Y);
                        RobustMath.Mult(this.Mi_YY, this.tmp_Y, this.newAlpha_Y, 1.0, 0.0, true, this.edgeGen);
                        this.newAlpha_Y.assign(this.Ri_Y, multFunc);
                    } else {
                        this.newAlpha_Y.assign(this.Ri_Y);
                    }
                    while (this.featureGenerator.hasNext()) {
                        Feature feature = this.featureGenerator.next();
                        f = feature.index();
                        int yp = feature.y();
                        int yprev = feature.yprev();
                        float val = feature.value();
                        if (dataSeq.y(i2) == yp && (i2 - 1 >= 0 && yprev == dataSeq.y(i2 - 1) || yprev < 0)) {
                            int n = f;
                            grad[n] = grad[n] + (double)val;
                            thisSeqLogli += (double)val * lambda[f];
                        }
                        if (yprev < 0) {
                            int n = f;
                            this.ExpF[n] = this.ExpF[n] + this.newAlpha_Y.get(yp) * (double)val * this.beta_Y[i2].get(yp);
                            continue;
                        }
                        int n = f;
                        this.ExpF[n] = this.ExpF[n] + this.alpha_Y.get(yprev) * this.Ri_Y.get(yp) * this.Mi_YY.get(yprev, yp) * (double)val * this.beta_Y[i2].get(yp);
                    }
                    this.alpha_Y.assign(this.newAlpha_Y);
                    this.constMultiplier.multiplicator = 1.0 / this.scale[i2];
                    this.alpha_Y.assign(this.constMultiplier);
                    if (this.params.debugLvl <= 2) continue;
                    System.out.println("Alpha-i " + this.alpha_Y.toString());
                    System.out.println("Ri " + this.Ri_Y.toString());
                    System.out.println("Mi " + this.Mi_YY.toString());
                    System.out.println("Beta-i " + this.beta_Y[i2].toString());
                }
                double Zx = this.alpha_Y.zSum();
                thisSeqLogli -= Trainer.log(Zx);
                for (int i3 = 0; i3 < dataSeq.length(); ++i3) {
                    thisSeqLogli -= Trainer.log(this.scale[i3]);
                }
                logli += thisSeqLogli;
                for (f = 0; f < grad.length; ++f) {
                    int n = f;
                    grad[n] = grad[n] - this.ExpF[f] / Zx;
                }
                if (this.params.debugLvl > 1) {
                    System.out.println("Sequence " + thisSeqLogli + " logli " + logli + " log(Zx) " + Math.log(Zx) + " Zx " + Zx);
                }
                ++numRecord;
            }
            if (this.params.debugLvl > 2) {
                int f;
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(lambda[f] + " ");
                }
                System.out.println(" :x");
                for (f = 0; f < lambda.length; ++f) {
                    System.out.println(this.featureGenerator.featureName(f) + " " + grad[f] + " ");
                }
                System.out.println(" :g");
            }
            if (this.params.debugLvl > 0) {
                Util.printDbg("Iter " + this.icall + " log likelihood " + logli + " norm(grad logli) " + this.norm(grad) + " norm(x) " + this.norm(lambda));
            }
            if (this.icall == 0) {
                System.out.println("Number of training records" + numRecord);
            }
        }
        catch (Exception e) {
            System.out.println("Alpha-i " + this.alpha_Y.toString());
            System.out.println("Ri " + this.Ri_Y.toString());
            System.out.println("Mi " + this.Mi_YY.toString());
            e.printStackTrace();
            System.exit(0);
        }
        return logli;
    }

    static void computeLogMi(FeatureGenerator featureGen, double[] lambda, DoubleMatrix2D Mi_YY, DoubleMatrix1D Ri_Y, boolean takeExp) {
        Trainer.computeLogMi(featureGen, lambda, Mi_YY, Ri_Y, takeExp, false, false);
    }

    static boolean computeLogMi(FeatureGenerator featureGen, double[] lambda, DoubleMatrix2D Mi_YY, DoubleMatrix1D Ri_Y, boolean takeExp, boolean reuseM, boolean initMDone) {
        if (reuseM && initMDone) {
            Mi_YY = null;
        } else {
            initMDone = false;
        }
        if (Mi_YY != null) {
            Mi_YY.assign(0.0);
        }
        Ri_Y.assign(0.0);
        while (featureGen.hasNext()) {
            Feature feature = featureGen.next();
            int f = feature.index();
            int yp = feature.y();
            int yprev = feature.yprev();
            float val = feature.value();
            if (yprev < 0) {
                double oldVal = Ri_Y.getQuick(yp);
                Ri_Y.setQuick(yp, oldVal + lambda[f] * (double)val);
                continue;
            }
            if (Mi_YY == null) continue;
            Mi_YY.setQuick(yprev, yp, Mi_YY.getQuick(yprev, yp) + lambda[f] * (double)val);
            initMDone = true;
        }
        if (takeExp) {
            for (int r = Ri_Y.size() - 1; r >= 0; --r) {
                Ri_Y.setQuick(r, Trainer.expE(Ri_Y.getQuick(r)));
                if (Mi_YY == null) continue;
                for (int c = Mi_YY.columns() - 1; c >= 0; --c) {
                    Mi_YY.setQuick(r, c, Trainer.expE(Mi_YY.getQuick(r, c)));
                }
            }
        }
        return initMDone;
    }

    static void computeLogMi(FeatureGenerator featureGen, double[] lambda, DataSequence dataSeq, int i, DoubleMatrix2D Mi_YY, DoubleMatrix1D Ri_Y, boolean takeExp) {
        Trainer.computeLogMi(featureGen, lambda, dataSeq, i, Mi_YY, Ri_Y, takeExp, false, false);
    }

    static boolean computeLogMi(FeatureGenerator featureGen, double[] lambda, DataSequence dataSeq, int i, DoubleMatrix2D Mi_YY, DoubleMatrix1D Ri_Y, boolean takeExp, boolean reuseM, boolean initMDone) {
        featureGen.startScanFeaturesAt(dataSeq, i);
        return Trainer.computeLogMi(featureGen, lambda, Mi_YY, Ri_Y, takeExp, reuseM, initMDone);
    }

    protected double computeFunctionGradientLL(double[] lambda, double[] grad) {
        double logli = 0.0;
        try {
            int f;
            for (f = 0; f < lambda.length; ++f) {
                grad[f] = -1.0 * lambda[f] * this.params.invSigmaSquare;
                logli -= lambda[f] * lambda[f] * this.params.invSigmaSquare / 2.0;
            }
            this.diter.startScan();
            if (this.featureGenCache != null) {
                this.featureGenCache.startDataScan();
            }
            int numRecord = 0;
            while (this.diter.hasNext()) {
                int f2;
                int i;
                DataSequence dataSeq = this.diter.next();
                if (this.featureGenCache != null) {
                    this.featureGenCache.nextDataIndex();
                }
                if (this.params.debugLvl > 1) {
                    Util.printDbg("Read next seq: " + numRecord + " logli " + logli);
                }
                this.alpha_Y.assign(0.0);
                for (int f3 = 0; f3 < lambda.length; ++f3) {
                    this.ExpF[f3] = RobustMath.LOG0;
                }
                if (this.beta_Y == null || this.beta_Y.length < dataSeq.length()) {
                    this.beta_Y = new DenseDoubleMatrix1D[2 * dataSeq.length()];
                    for (i = 0; i < this.beta_Y.length; ++i) {
                        this.beta_Y[i] = new DenseDoubleMatrix1D(this.numY);
                    }
                }
                this.beta_Y[dataSeq.length() - 1].assign(0.0);
                for (i = dataSeq.length() - 1; i > 0; --i) {
                    if (this.params.debugLvl > 2) {
                        // empty if block
                    }
                    this.initMDone = Trainer.computeLogMi(this.featureGenerator, lambda, dataSeq, i, this.Mi_YY, this.Ri_Y, false, this.reuseM, this.initMDone);
                    this.tmp_Y.assign(this.beta_Y[i]);
                    this.tmp_Y.assign(this.Ri_Y, sumFunc);
                    RobustMath.logMult(this.Mi_YY, this.tmp_Y, this.beta_Y[i - 1], 1.0, 0.0, false, this.edgeGen);
                }
                double thisSeqLogli = 0.0;
                for (int i2 = 0; i2 < dataSeq.length(); ++i2) {
                    this.initMDone = Trainer.computeLogMi(this.featureGenerator, lambda, dataSeq, i2, this.Mi_YY, this.Ri_Y, false, this.reuseM, this.initMDone);
                    this.featureGenerator.startScanFeaturesAt(dataSeq, i2);
                    if (i2 > 0) {
                        this.tmp_Y.assign(this.alpha_Y);
                        RobustMath.logMult(this.Mi_YY, this.tmp_Y, this.newAlpha_Y, 1.0, 0.0, true, this.edgeGen);
                        this.newAlpha_Y.assign(this.Ri_Y, sumFunc);
                    } else {
                        this.newAlpha_Y.assign(this.Ri_Y);
                    }
                    while (this.featureGenerator.hasNext()) {
                        Feature feature = this.featureGenerator.next();
                        f2 = feature.index();
                        int yp = feature.y();
                        int yprev = feature.yprev();
                        float val = feature.value();
                        if (dataSeq.y(i2) == yp && (i2 - 1 >= 0 && yprev == dataSeq.y(i2 - 1) || yprev < 0)) {
                            int n = f2;
                            grad[n] = grad[n] + (double)val;
                            thisSeqLogli += (double)val * lambda[f2];
                            if (this.params.debugLvl > 2) {
                                System.out.println("Feature fired " + f2 + " " + feature);
                            }
                        }
                        if (yprev < 0) {
                            this.ExpF[f2] = RobustMath.logSumExp(this.ExpF[f2], this.newAlpha_Y.get(yp) + RobustMath.log(val) + this.beta_Y[i2].get(yp));
                            continue;
                        }
                        this.ExpF[f2] = RobustMath.logSumExp(this.ExpF[f2], this.alpha_Y.get(yprev) + this.Ri_Y.get(yp) + this.Mi_YY.get(yprev, yp) + RobustMath.log(val) + this.beta_Y[i2].get(yp));
                    }
                    this.alpha_Y.assign(this.newAlpha_Y);
                    if (this.params.debugLvl <= 2) continue;
                    System.out.println("Alpha-i " + this.alpha_Y.toString());
                    System.out.println("Ri " + this.Ri_Y.toString());
                    System.out.println("Mi " + this.Mi_YY.toString());
                    System.out.println("Beta-i " + this.beta_Y[i2].toString());
                }
                double lZx = RobustMath.logSumExp(this.alpha_Y);
                logli += (thisSeqLogli -= lZx);
                for (f2 = 0; f2 < grad.length; ++f2) {
                    int n = f2;
                    grad[n] = grad[n] - RobustMath.exp(this.ExpF[f2] - lZx);
                }
                if (this.params.debugLvl > 1) {
                    System.out.println("Sequence " + thisSeqLogli + " logli " + logli + " log(Zx) " + lZx + " Zx " + Math.exp(lZx));
                }
                ++numRecord;
            }
            if (this.params.debugLvl > 2) {
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(lambda[f] + " ");
                }
                System.out.println(" :x");
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(grad[f] + " ");
                }
                System.out.println(" :g");
            }
            if (this.params.debugLvl > 0) {
                Util.printDbg("Iteration " + this.icall + " log-likelihood " + logli + " norm(grad logli) " + this.norm(grad) + " norm(x) " + this.norm(lambda));
            }
        }
        catch (Exception e) {
            System.out.println("Alpha-i " + this.alpha_Y.toString());
            System.out.println("Ri " + this.Ri_Y.toString());
            System.out.println("Mi " + this.Mi_YY.toString());
            e.printStackTrace();
            System.exit(0);
        }
        return logli;
    }

    static double log(double val) {
        try {
            return Trainer.logE(val);
        }
        catch (Exception e) {
            System.out.println(e.getMessage());
            e.printStackTrace();
            return -1.7976931348623157E308;
        }
    }

    static double logE(double val) throws Exception {
        double pr = Math.log(val);
        if (Double.isNaN(pr) || Double.isInfinite(pr)) {
            throw new Exception("Overflow error when taking log of " + val);
        }
        return pr;
    }

    static double expE(double val) {
        double pr = RobustMath.exp(val);
        if (Double.isNaN(pr) || Double.isInfinite(pr)) {
            try {
                throw new Exception("Overflow error when taking exp of " + val + "\n Try running the CRF with the following option \"trainer ll\" to perform computations in the log-space.");
            }
            catch (Exception e) {
                System.out.println(e.getMessage());
                e.printStackTrace();
                return Double.MAX_VALUE;
            }
        }
        return pr;
    }

    static double expLE(double val) {
        double pr = RobustMath.exp(val);
        if (Double.isNaN(pr) || Double.isInfinite(pr)) {
            try {
                throw new Exception("Overflow error when taking exp of " + val + " you might need to redesign feature values so as to not reach such high values");
            }
            catch (Exception e) {
                System.out.println(e.getMessage());
                e.printStackTrace();
                return Double.MAX_VALUE;
            }
        }
        return pr;
    }

    class MultSingle
    implements DoubleFunction {
        public double multiplicator = 1.0;

        MultSingle() {
        }

        public double apply(double a) {
            return a * this.multiplicator;
        }
    }

    static class SumFunc
    implements DoubleDoubleFunction {
        SumFunc() {
        }

        public double apply(double a, double b) {
            return a + b;
        }
    }

    static class MultFunc
    implements DoubleDoubleFunction {
        MultFunc() {
        }

        public double apply(double a, double b) {
            return a * b;
        }
    }
}

