/*
 * Decompiled with CFR 0.152.
 */
package iitb.CRF;

import cern.colt.function.DoubleFunction;
import cern.colt.function.IntDoubleFunction;
import cern.colt.function.IntIntDoubleFunction;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.SparseDoubleMatrix1D;
import cern.colt.matrix.impl.SparseDoubleMatrix2D;
import iitb.CRF.CRF;
import iitb.CRF.CrfParams;
import iitb.CRF.DataIter;
import iitb.CRF.DataSequence;
import iitb.CRF.Evaluator;
import iitb.CRF.Feature;
import iitb.CRF.FeatureGenerator;
import iitb.CRF.LogDenseDoubleMatrix1D;
import iitb.CRF.LogDenseDoubleMatrix2D;
import iitb.CRF.LogSparseDoubleMatrix1D;
import iitb.CRF.LogSparseDoubleMatrix2D;
import iitb.CRF.RobustMath;
import iitb.CRF.Trainer;
import iitb.CRF.Util;

public class SparseTrainer
extends Trainer {
    boolean logTrainer;
    static ExpFunc expFunc = new ExpFunc();
    static IntDoubleFunction expFunc1D = new ExpFunc1D();
    static IntIntDoubleFunction expFunc2D = new ExpFunc2D();

    protected DoubleMatrix1D newLogDoubleMatrix1D(int numY) {
        if (Boolean.valueOf(this.params.miscOptions.getProperty("sparse", "false")).booleanValue()) {
            return new LogSparseDoubleMatrix1D(numY);
        }
        return new LogDenseDoubleMatrix1D(numY);
    }

    protected DoubleMatrix2D newLogDoubleMatrix2D(int numR, int numC) {
        if (Boolean.valueOf(this.params.miscOptions.getProperty("sparse", "false")).booleanValue()) {
            return new LogSparseDoubleMatrix2D(numR, numC);
        }
        return new LogDenseDoubleMatrix2D(numR, numC);
    }

    public SparseTrainer(CrfParams p) {
        super(p);
        this.params = p;
        this.logTrainer = this.params.trainerType.equals("ll");
    }

    public void train(CRF model, DataIter data, double[] l, Evaluator eval) {
        this.init(model, data, l);
        this.evaluator = eval;
        if (this.params.debugLvl > 0) {
            Util.printDbg("Number of features :" + this.lambda.length);
        }
        this.doTrain();
    }

    void initMatrices() {
        if (!this.logTrainer) {
            this.Mi_YY = new SparseDoubleMatrix2D(this.numY, this.numY);
            this.Ri_Y = new SparseDoubleMatrix1D(this.numY);
            this.alpha_Y = new SparseDoubleMatrix1D(this.numY);
            this.newAlpha_Y = new SparseDoubleMatrix1D(this.numY);
            this.tmp_Y = new SparseDoubleMatrix1D(this.numY);
        } else {
            this.Mi_YY = this.newLogDoubleMatrix2D(this.numY, this.numY);
            this.Ri_Y = this.newLogDoubleMatrix1D(this.numY);
            this.alpha_Y = this.newLogDoubleMatrix1D(this.numY);
            this.newAlpha_Y = this.newLogDoubleMatrix1D(this.numY);
            this.tmp_Y = this.newLogDoubleMatrix1D(this.numY);
        }
    }

    protected double computeFunctionGradient(double[] lambda, double[] grad) {
        if (this.params.trainerType.equals("ll")) {
            return this.computeFunctionGradientLL(lambda, grad);
        }
        double logli = 0.0;
        try {
            for (int f = 0; f < lambda.length; ++f) {
                grad[f] = -1.0 * lambda[f] * this.params.invSigmaSquare;
                logli -= lambda[f] * lambda[f] * this.params.invSigmaSquare / 2.0;
            }
            boolean doScaling = this.params.doScaling;
            this.diter.startScan();
            if (this.featureGenCache != null) {
                this.featureGenCache.startDataScan();
            }
            int numRecord = 0;
            while (this.diter.hasNext()) {
                int f;
                int i;
                DataSequence dataSeq = this.diter.next();
                if (this.featureGenCache != null) {
                    this.featureGenCache.nextDataIndex();
                }
                if (this.params.debugLvl > 1) {
                    Util.printDbg("Read next seq: " + numRecord + " logli " + logli);
                }
                this.alpha_Y.assign(1.0);
                for (int f2 = 0; f2 < lambda.length; ++f2) {
                    this.ExpF[f2] = 0.0;
                }
                if (this.beta_Y == null || this.beta_Y.length < dataSeq.length()) {
                    this.beta_Y = new DoubleMatrix1D[2 * dataSeq.length()];
                    for (i = 0; i < this.beta_Y.length; ++i) {
                        this.beta_Y[i] = new SparseDoubleMatrix1D(this.numY);
                    }
                    this.scale = new double[2 * dataSeq.length()];
                }
                this.scale[dataSeq.length() - 1] = doScaling ? (double)this.numY : 1.0;
                this.beta_Y[dataSeq.length() - 1].assign(1.0 / this.scale[dataSeq.length() - 1]);
                for (i = dataSeq.length() - 1; i > 0; --i) {
                    if (this.params.debugLvl > 2) {
                        Util.printDbg("Features fired");
                    }
                    SparseTrainer.computeMi(this.featureGenerator, lambda, dataSeq, i, this.Mi_YY, this.Ri_Y);
                    this.tmp_Y.assign(this.beta_Y[i]);
                    this.tmp_Y.assign(this.Ri_Y, multFunc);
                    this.Mi_YY.zMult(this.tmp_Y, this.beta_Y[i - 1]);
                    double d = this.scale[i - 1] = doScaling ? this.beta_Y[i - 1].zSum() : 1.0;
                    if (this.scale[i - 1] < 1.0 && this.scale[i - 1] > -1.0) {
                        this.scale[i - 1] = 1.0;
                    }
                    this.constMultiplier.multiplicator = 1.0 / this.scale[i - 1];
                    this.beta_Y[i - 1].assign(this.constMultiplier);
                }
                double thisSeqLogli = 0.0;
                for (int i2 = 0; i2 < dataSeq.length(); ++i2) {
                    SparseTrainer.computeMi(this.featureGenerator, lambda, dataSeq, i2, this.Mi_YY, this.Ri_Y);
                    this.featureGenerator.startScanFeaturesAt(dataSeq, i2);
                    if (i2 > 0) {
                        this.Mi_YY.zMult(this.alpha_Y, this.newAlpha_Y, 1.0, 0.0, true);
                        this.newAlpha_Y.assign(this.Ri_Y, multFunc);
                    } else {
                        this.newAlpha_Y.assign(this.Ri_Y);
                    }
                    while (this.featureGenerator.hasNext()) {
                        Feature feature = this.featureGenerator.next();
                        f = feature.index();
                        int yp = feature.y();
                        int yprev = feature.yprev();
                        float val = feature.value();
                        if (dataSeq.y(i2) == yp && (i2 - 1 >= 0 && yprev == dataSeq.y(i2 - 1) || yprev < 0)) {
                            int n = f;
                            grad[n] = grad[n] + (double)val;
                            thisSeqLogli += (double)val * lambda[f];
                        }
                        if (yprev < 0) {
                            int n = f;
                            this.ExpF[n] = this.ExpF[n] + this.newAlpha_Y.get(yp) * (double)val * this.beta_Y[i2].get(yp);
                            continue;
                        }
                        int n = f;
                        this.ExpF[n] = this.ExpF[n] + this.alpha_Y.get(yprev) * this.Ri_Y.get(yp) * this.Mi_YY.get(yprev, yp) * (double)val * this.beta_Y[i2].get(yp);
                    }
                    this.alpha_Y.assign(this.newAlpha_Y);
                    this.constMultiplier.multiplicator = 1.0 / this.scale[i2];
                    this.alpha_Y.assign(this.constMultiplier);
                    if (this.params.debugLvl <= 2) continue;
                    System.out.println("Alpha-i " + this.alpha_Y.toString());
                    System.out.println("Ri " + this.Ri_Y.toString());
                    System.out.println("Mi " + this.Mi_YY.toString());
                    System.out.println("Beta-i " + this.beta_Y[i2].toString());
                }
                double Zx = this.alpha_Y.zSum();
                thisSeqLogli -= SparseTrainer.log(Zx);
                for (int i3 = 0; i3 < dataSeq.length(); ++i3) {
                    thisSeqLogli -= SparseTrainer.log(this.scale[i3]);
                }
                if (thisSeqLogli > 0.0) {
                    System.out.println("This is shady: something is wrong Pr(y|x) > 1!");
                }
                logli += thisSeqLogli;
                for (f = 0; f < grad.length; ++f) {
                    int n = f;
                    grad[n] = grad[n] - this.ExpF[f] / Zx;
                }
                if (this.params.debugLvl > 1) {
                    System.out.println("Sequence " + thisSeqLogli + " " + logli);
                }
                ++numRecord;
            }
            if (this.params.debugLvl > 2) {
                int f;
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(lambda[f] + " ");
                }
                System.out.println(" :x");
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(grad[f] + " ");
                }
                System.out.println(" :g");
            }
            if (this.params.debugLvl > 0) {
                Util.printDbg("Iter " + this.icall + " log likelihood " + logli + " norm(grad logli) " + this.norm(grad) + " norm(x) " + this.norm(lambda));
            }
        }
        catch (Exception e) {
            System.out.println("Alpha-i " + this.alpha_Y.toString());
            System.out.println("Ri " + this.Ri_Y.toString());
            System.out.println("Mi " + this.Mi_YY.toString());
            e.printStackTrace();
            System.exit(0);
        }
        return logli;
    }

    static void computeLogMi(FeatureGenerator featureGen, double[] lambda, DoubleMatrix2D Mi_YY, DoubleMatrix1D Ri_Y) {
        double DEFAULT_VALUE = 0.0;
        Mi_YY.assign(DEFAULT_VALUE);
        Ri_Y.assign(DEFAULT_VALUE);
        SparseTrainer.computeLogMiInitDone(featureGen, lambda, Mi_YY, Ri_Y, DEFAULT_VALUE);
    }

    static void computeLogMiInitDone(FeatureGenerator featureGen, double[] lambda, DoubleMatrix2D Mi_YY, DoubleMatrix1D Ri_Y, double DEFAULT_VALUE) {
        while (featureGen.hasNext()) {
            double oldVal;
            Feature feature = featureGen.next();
            int f = feature.index();
            int yp = feature.y();
            int yprev = feature.yprev();
            float val = feature.value();
            if (yprev == -1) {
                oldVal = Ri_Y.get(yp);
                if (oldVal == DEFAULT_VALUE) {
                    oldVal = 0.0;
                }
                Ri_Y.set(yp, oldVal + lambda[f] * (double)val);
                continue;
            }
            if (Mi_YY == null) continue;
            oldVal = Mi_YY.get(yprev, yp);
            if (oldVal == DEFAULT_VALUE) {
                oldVal = 0.0;
                if (Ri_Y.get(yp) == DEFAULT_VALUE) {
                    Ri_Y.set(yp, 0.0);
                }
            }
            Mi_YY.set(yprev, yp, oldVal + lambda[f] * (double)val);
        }
    }

    static void computeMi(FeatureGenerator featureGen, double[] lambda, DataSequence dataSeq, int i, DoubleMatrix2D Mi_YY, DoubleMatrix1D Ri_Y) {
        featureGen.startScanFeaturesAt(dataSeq, i);
        SparseTrainer.computeLogMi(featureGen, lambda, Mi_YY, Ri_Y);
        Ri_Y.assign(expFunc);
        Mi_YY.assign(expFunc);
    }

    static void computeLogMi(FeatureGenerator featureGen, double[] lambda, DataSequence dataSeq, int i, DoubleMatrix2D Mi_YY, DoubleMatrix1D Ri_Y) {
        featureGen.startScanFeaturesAt(dataSeq, i);
        SparseTrainer.computeLogMi(featureGen, lambda, Mi_YY, Ri_Y);
    }

    protected double computeFunctionGradientLL(double[] lambda, double[] grad) {
        double logli = 0.0;
        try {
            int f;
            for (f = 0; f < lambda.length; ++f) {
                grad[f] = -1.0 * lambda[f] * this.params.invSigmaSquare;
                logli -= lambda[f] * lambda[f] * this.params.invSigmaSquare / 2.0;
            }
            this.diter.startScan();
            if (this.featureGenCache != null) {
                this.featureGenCache.startDataScan();
            }
            int numRecord = 0;
            while (this.diter.hasNext()) {
                int f2;
                int i;
                DataSequence dataSeq = this.diter.next();
                if (this.featureGenCache != null) {
                    this.featureGenCache.nextDataIndex();
                }
                if (this.params.debugLvl > 1) {
                    Util.printDbg("Read next seq: " + numRecord + " logli " + logli);
                }
                this.alpha_Y.assign(0.0);
                for (int f3 = 0; f3 < lambda.length; ++f3) {
                    this.ExpF[f3] = RobustMath.LOG0;
                }
                if (this.beta_Y == null || this.beta_Y.length < dataSeq.length()) {
                    this.beta_Y = new DoubleMatrix1D[2 * dataSeq.length()];
                    for (i = 0; i < this.beta_Y.length; ++i) {
                        this.beta_Y[i] = this.newLogDoubleMatrix1D(this.numY);
                    }
                }
                this.beta_Y[dataSeq.length() - 1].assign(0.0);
                for (i = dataSeq.length() - 1; i > 0; --i) {
                    if (this.params.debugLvl > 3) {
                        Util.printDbg("Features fired");
                        this.featureGenerator.startScanFeaturesAt(dataSeq, i);
                        while (this.featureGenerator.hasNext()) {
                            Feature feature = this.featureGenerator.next();
                            Util.printDbg(feature.toString());
                        }
                    }
                    SparseTrainer.computeLogMi(this.featureGenerator, lambda, dataSeq, i, this.Mi_YY, this.Ri_Y);
                    this.tmp_Y.assign(this.beta_Y[i]);
                    this.tmp_Y.assign(this.Ri_Y, sumFunc);
                    this.Mi_YY.zMult(this.tmp_Y, this.beta_Y[i - 1], 1.0, 0.0, false);
                }
                double thisSeqLogli = 0.0;
                for (int i2 = 0; i2 < dataSeq.length(); ++i2) {
                    SparseTrainer.computeLogMi(this.featureGenerator, lambda, dataSeq, i2, this.Mi_YY, this.Ri_Y);
                    this.featureGenerator.startScanFeaturesAt(dataSeq, i2);
                    if (i2 > 0) {
                        this.Mi_YY.zMult(this.alpha_Y, this.newAlpha_Y, 1.0, 0.0, true);
                        this.newAlpha_Y.assign(this.Ri_Y, sumFunc);
                    } else {
                        this.newAlpha_Y.assign(this.Ri_Y);
                    }
                    while (this.featureGenerator.hasNext()) {
                        Feature feature = this.featureGenerator.next();
                        f2 = feature.index();
                        int yp = feature.y();
                        int yprev = feature.yprev();
                        float val = feature.value();
                        if (dataSeq.y(i2) == yp && (i2 - 1 >= 0 && yprev == dataSeq.y(i2 - 1) || yprev < 0)) {
                            int n = f2;
                            grad[n] = grad[n] + (double)val;
                            thisSeqLogli += (double)val * lambda[f2];
                        }
                        if (yprev < 0) {
                            this.ExpF[f2] = RobustMath.logSumExp(this.ExpF[f2], this.newAlpha_Y.get(yp) + RobustMath.log(val) + this.beta_Y[i2].get(yp));
                            continue;
                        }
                        this.ExpF[f2] = RobustMath.logSumExp(this.ExpF[f2], this.alpha_Y.get(yprev) + this.Ri_Y.get(yp) + this.Mi_YY.get(yprev, yp) + RobustMath.log(val) + this.beta_Y[i2].get(yp));
                    }
                    this.alpha_Y.assign(this.newAlpha_Y);
                    if (this.params.debugLvl <= 2) continue;
                    System.out.println("Alpha-i " + this.alpha_Y.toString());
                    System.out.println("Ri " + this.Ri_Y.toString());
                    System.out.println("Mi " + this.Mi_YY.toString());
                    System.out.println("Beta-i " + this.beta_Y[i2].toString());
                }
                double lZx = this.alpha_Y.zSum();
                logli += (thisSeqLogli -= lZx);
                for (f2 = 0; f2 < grad.length; ++f2) {
                    int n = f2;
                    grad[n] = grad[n] - RobustMath.exp(this.ExpF[f2] - lZx);
                }
                if (this.params.debugLvl > 1) {
                    System.out.println("Sequence " + thisSeqLogli + " " + logli);
                }
                if (thisSeqLogli > 0.0) {
                    System.out.println("This is shady: something is wrong Pr(y|x) > 1!");
                }
                ++numRecord;
            }
            if (this.params.debugLvl > 2) {
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(lambda[f] + " ");
                }
                System.out.println(" :x");
                for (f = 0; f < lambda.length; ++f) {
                    System.out.print(grad[f] + " ");
                }
                System.out.println(" :g");
            }
            if (this.params.debugLvl > 0) {
                Util.printDbg("Iteration " + this.icall + " log-likelihood " + logli + " norm(grad logli) " + this.norm(grad) + " norm(x) " + this.norm(lambda));
            }
        }
        catch (Exception e) {
            System.out.println("Alpha-i " + this.alpha_Y.toString());
            System.out.println("Ri " + this.Ri_Y.toString());
            System.out.println("Mi " + this.Mi_YY.toString());
            e.printStackTrace();
            System.exit(0);
        }
        return logli;
    }

    static class ExpFunc1D
    implements IntDoubleFunction {
        ExpFunc1D() {
        }

        public double apply(int first, double third) {
            return Math.exp(third);
        }
    }

    static class ExpFunc2D
    implements IntIntDoubleFunction {
        ExpFunc2D() {
        }

        public double apply(int first, int second, double third) {
            return Math.exp(third);
        }
    }

    static class ExpFunc
    implements DoubleFunction {
        ExpFunc() {
        }

        public double apply(double a) {
            return Math.exp(a);
        }
    }
}

