/*
 * Decompiled with CFR 0.152.
 */
package dist;

import dist.AbstractDistribution;
import dist.Distribution;
import shared.Copyable;
import shared.DataSet;
import shared.Instance;
import util.linalg.CholeskyFactorization;
import util.linalg.DenseVector;
import util.linalg.Matrix;
import util.linalg.RectangularMatrix;
import util.linalg.Vector;

public class MultivariateGaussian
extends AbstractDistribution
implements Copyable {
    private static final double FLOOR = 0.01;
    private static final double FLOOR_CHANGE = 10.0;
    private Vector mean;
    private Matrix covarianceMatrix;
    private CholeskyFactorization decomposition;
    private double determinant;
    private double floor;
    private boolean debug;

    public MultivariateGaussian(Vector mean, Matrix covariance, double floor) {
        this.mean = mean;
        this.covarianceMatrix = covariance;
        this.floor = floor;
        this.decomposition = new CholeskyFactorization(covariance);
        this.determinant = this.decomposition.determinant();
    }

    public MultivariateGaussian(Vector mean, Matrix covariance) {
        this(mean, covariance, 0.0);
    }

    public MultivariateGaussian(double floor) {
        this.floor = floor;
    }

    public MultivariateGaussian() {
    }

    public double p(Instance i) {
        Vector d = i.getData();
        Vector dMinusMean = d.minus(this.mean);
        double p = 1.0 / Math.sqrt(Math.pow(Math.PI * 2, this.mean.size()) * this.determinant) * Math.exp(-0.5 * dMinusMean.dotProduct(this.decomposition.solve(dMinusMean)));
        return p;
    }

    public double logp(Instance i) {
        Vector d = i.getData();
        Vector dMinusMean = d.minus(this.mean);
        double p = Math.log(1.0 / Math.sqrt(Math.pow(Math.PI * 2, this.mean.size()) * this.determinant)) - 0.5 * dMinusMean.dotProduct(this.decomposition.solve(dMinusMean));
        return p;
    }

    public Instance sample(Instance ignored) {
        DenseVector r = new DenseVector(this.mean.size());
        int i = 0;
        while (i < ((Vector)r).size()) {
            ((Vector)r).set(i, Distribution.random.nextGaussian());
            ++i;
        }
        return new Instance(this.decomposition.getL().times(r).plus(this.mean));
    }

    public Instance mode(Instance ignored) {
        return new Instance((Vector)this.mean.copy());
    }

    public void estimate(DataSet observations) {
        double weightSum = 0.0;
        this.mean = new DenseVector(observations.get(0).size());
        int t = 0;
        while (t < observations.size()) {
            double weight = observations.get(t).getWeight();
            Vector d = observations.get(t).getData();
            int i = 0;
            while (i < this.mean.size()) {
                this.mean.set(i, this.mean.get(i) + d.get(i) * weight);
                ++i;
            }
            weightSum += weight;
            ++t;
        }
        this.mean.timesEquals(1.0 / weightSum);
        this.covarianceMatrix = new RectangularMatrix(this.mean.size(), this.mean.size());
        t = 0;
        while (t < observations.size()) {
            Vector d = observations.get(t).getData();
            double weight = observations.get(t).getWeight();
            Vector dMinusMean = d.minus(this.mean);
            int i = 0;
            while (i < this.covarianceMatrix.m()) {
                int j = 0;
                while (j < this.covarianceMatrix.n()) {
                    this.covarianceMatrix.set(i, j, this.covarianceMatrix.get(i, j) + dMinusMean.get(i) * dMinusMean.get(j) * weight);
                    ++j;
                }
                ++i;
            }
            ++t;
        }
        this.covarianceMatrix.timesEquals(1.0 / weightSum);
        boolean scale = false;
        int i = 0;
        while (i < this.covarianceMatrix.m()) {
            if (this.covarianceMatrix.get(i, i) < this.floor) {
                scale = true;
            }
            ++i;
        }
        if (scale) {
            i = 0;
            while (i < this.covarianceMatrix.m()) {
                this.covarianceMatrix.set(i, i, this.covarianceMatrix.get(i, i) + this.floor);
                ++i;
            }
        }
        this.decomposition = new CholeskyFactorization(this.covarianceMatrix);
        this.determinant = this.decomposition.determinant();
        if (this.determinant == 0.0 || Double.isNaN(this.determinant)) {
            if (this.debug) {
                System.out.println("Covariance matrix not positive, applying ridge adjustment");
                System.out.println(this.covarianceMatrix);
            }
            this.floor = this.floor == 0.0 ? 0.01 : (this.floor *= 10.0);
            this.estimate(observations);
        }
    }

    public String toString() {
        return "mean =\n" + this.mean.toString() + "\ncovariance matrix =\n" + this.covarianceMatrix.toString();
    }

    public Matrix getCovarianceMatrix() {
        return this.covarianceMatrix;
    }

    public Vector getMean() {
        return this.mean;
    }

    public void setCovarianceMatrix(Matrix matrix) {
        this.covarianceMatrix = matrix;
    }

    public void setMean(Vector vector) {
        this.mean = vector;
    }

    public boolean isDebug() {
        return this.debug;
    }

    public void setDebug(boolean b) {
        this.debug = b;
    }

    public Copyable copy() {
        return new MultivariateGaussian((Vector)this.mean.copy(), (Matrix)this.covarianceMatrix.copy(), this.floor);
    }
}

