/*
 * Decompiled with CFR 0.152.
 */
package shared.filt;

import dist.MultivariateGaussian;
import shared.DataSet;
import shared.DataSetDescription;
import shared.Instance;
import shared.filt.ReversibleFilter;
import util.linalg.CholeskyFactorization;
import util.linalg.LowerTriangularMatrix;
import util.linalg.Matrix;
import util.linalg.RectangularMatrix;
import util.linalg.SymmetricEigenvalueDecomposition;
import util.linalg.UpperTriangularMatrix;
import util.linalg.Vector;

public class LinearDiscriminantAnalysis
implements ReversibleFilter {
    private Matrix projection;
    private Vector mean;

    public LinearDiscriminantAnalysis(DataSet dataSet) {
        MultivariateGaussian mg = new MultivariateGaussian();
        mg.estimate(dataSet);
        this.mean = mg.getMean();
        if (dataSet.getDescription() == null) {
            dataSet.setDescription(new DataSetDescription(dataSet));
        }
        int classCount = dataSet.getDescription().getLabelDescription().getDiscreteRange();
        int toKeep = classCount - 1;
        int[] classCounts = new int[classCount];
        double[] weightSums = new double[classCount];
        double weightSum = 0.0;
        int i = 0;
        while (i < dataSet.size()) {
            int classification;
            int n = classification = dataSet.get(i).getLabel().getDiscrete();
            classCounts[n] = classCounts[n] + 1;
            int n2 = classification;
            weightSums[n2] = weightSums[n2] + dataSet.get(i).getWeight();
            weightSum += dataSet.get(i).getWeight();
            ++i;
        }
        i = 0;
        while (i < weightSums.length) {
            int n = i++;
            weightSums[n] = weightSums[n] / weightSum;
        }
        Instance[][] instances = new Instance[classCount][];
        int i2 = 0;
        while (i2 < instances.length) {
            instances[i2] = new Instance[classCounts[i2]];
            classCounts[i2] = 0;
            ++i2;
        }
        i2 = 0;
        while (i2 < dataSet.size()) {
            int classification = dataSet.get(i2).getLabel().getDiscrete();
            instances[classification][classCounts[classification]] = dataSet.get(i2);
            int n = classification;
            classCounts[n] = classCounts[n] + 1;
            ++i2;
        }
        RectangularMatrix sb = new RectangularMatrix(this.mean.size(), this.mean.size());
        RectangularMatrix sw = new RectangularMatrix(this.mean.size(), this.mean.size());
        int i3 = 0;
        while (i3 < classCount) {
            mg = new MultivariateGaussian();
            mg.estimate(new DataSet(instances[i3]));
            sw.plusEquals(mg.getCovarianceMatrix().times(weightSums[i3]));
            Vector classMean = mg.getMean();
            Vector classMeanMinusMean = classMean.minus(this.mean);
            sb.plusEquals(classMeanMinusMean.outerProduct(classMeanMinusMean).times(weightSums[i3]));
            ++i3;
        }
        CholeskyFactorization cf = new CholeskyFactorization(sw);
        LowerTriangularMatrix g = cf.getL();
        LowerTriangularMatrix gInverse = g.inverse();
        UpperTriangularMatrix gInverseTranspose = (UpperTriangularMatrix)gInverse.transpose();
        Matrix c = gInverse.times(sb).times(gInverseTranspose);
        SymmetricEigenvalueDecomposition sed = new SymmetricEigenvalueDecomposition(c);
        Matrix eigenVectors = gInverseTranspose.times(sed.getU());
        this.projection = new RectangularMatrix(toKeep, eigenVectors.m());
        int i4 = 0;
        while (i4 < toKeep) {
            Vector v = eigenVectors.getColumn(i4);
            this.projection.setRow(i4, v.times(1.0 / v.norm()));
            ++i4;
        }
    }

    public void filter(DataSet dataSet) {
        int i = 0;
        while (i < dataSet.size()) {
            Instance instance = dataSet.get(i);
            instance.setData(instance.getData().minus(this.mean));
            instance.setData(this.projection.times(instance.getData()));
            ++i;
        }
        dataSet.setDescription(null);
    }

    public void reverse(DataSet dataSet) {
        int i = 0;
        while (i < dataSet.size()) {
            Instance instance = dataSet.get(i);
            instance.setData(this.projection.transpose().times(instance.getData()));
            instance.setData(instance.getData().plus(this.mean));
            ++i;
        }
        dataSet.setDescription(null);
    }

    public Matrix getProjection() {
        return this.projection;
    }

    public Vector getMean() {
        return this.mean;
    }
}

