/*
 * Decompiled with CFR 0.152.
 */
package dist;

import dist.AbstractDistribution;
import dist.DiscreteDependencyTreeRootNode;
import shared.DataSet;
import shared.DataSetDescription;
import shared.Instance;
import util.graph.DFSTree;
import util.graph.Graph;
import util.graph.KruskalsMST;
import util.graph.Node;
import util.graph.Tree;
import util.graph.WeightedEdge;
import util.linalg.DenseVector;

public class DiscreteDependencyTree
extends AbstractDistribution {
    private DiscreteDependencyTreeRootNode root;
    private Tree dt;
    private double m;
    private DataSetDescription description;

    public DiscreteDependencyTree(double m) {
        this.m = m;
    }

    public DiscreteDependencyTree(double m, int[] ranges) {
        this.m = m;
        this.description = new DataSetDescription();
        this.description.setMinVector(new DenseVector(ranges.length));
        DenseVector max = new DenseVector(ranges.length);
        int i = 0;
        while (i < max.size()) {
            max.set(i, ranges[i] - 1);
            ++i;
        }
        this.description.setMaxVector(max);
    }

    public double p(Instance i) {
        return this.root.probabilityOf(i);
    }

    public Instance sample(Instance ignored) {
        Instance i = new Instance(new DenseVector(this.dt.getNodeCount()));
        this.root.generateRandom(i);
        return i;
    }

    public Instance mode(Instance ignored) {
        Instance i = new Instance(new DenseVector(this.dt.getNodeCount()));
        this.root.generateMostLikely(i);
        return i;
    }

    public void estimate(DataSet observations) {
        if (this.description != null) {
            observations.setDescription(this.description);
        } else if (observations.getDescription() == null) {
            observations.setDescription(new DataSetDescription(observations));
        }
        double[][] mutualI = this.calculateMutualInformation(observations);
        Tree rg = this.buildDirectedMST(observations, mutualI);
        this.dt = new Tree();
        this.root = new DiscreteDependencyTreeRootNode(observations, rg.getRoot(), this.m, this.dt);
        this.dt.setRoot(this.root);
    }

    private Tree buildDirectedMST(DataSet observations, double[][] mutualI) {
        Graph g = new Graph();
        int i = 0;
        while (i < observations.get(0).size()) {
            Node n = new Node(i);
            g.addNode(n);
            ++i;
        }
        i = 0;
        while (i < observations.get(0).size()) {
            int j = 0;
            while (j < i) {
                Node a = g.getNode(i);
                Node b = g.getNode(j);
                a.connect(b, new WeightedEdge(-mutualI[i][j]));
                ++j;
            }
            ++i;
        }
        g = new KruskalsMST().transform(g);
        Tree rg = (Tree)new DFSTree().transform(g);
        return rg;
    }

    private double[][] calculateMutualInformation(DataSet observations) {
        int j;
        DataSetDescription dsd = observations.getDescription();
        double[][] probs = new double[observations.get(0).size()][];
        int i = 0;
        while (i < probs.length) {
            probs[i] = new double[dsd.getDiscreteRange(i)];
            ++i;
        }
        double weightSum = 0.0;
        int i2 = 0;
        while (i2 < observations.size()) {
            j = 0;
            while (j < observations.get(i2).size()) {
                double[] dArray = probs[j];
                int n = observations.get(i2).getDiscrete(j);
                dArray[n] = dArray[n] + observations.get(i2).getWeight();
                ++j;
            }
            weightSum += observations.get(i2).getWeight();
            ++i2;
        }
        i2 = 0;
        while (i2 < probs.length) {
            j = 0;
            while (j < probs[i2].length) {
                double[] dArray = probs[i2];
                int n = j++;
                dArray[n] = dArray[n] / weightSum;
            }
            ++i2;
        }
        double[] entropies = new double[observations.get(0).size()];
        int i3 = 0;
        while (i3 < observations.get(0).size()) {
            int j2 = 0;
            while (j2 < dsd.getDiscreteRange(i3)) {
                if (probs[i3][j2] != 0.0) {
                    int n = i3;
                    entropies[n] = entropies[n] - probs[i3][j2] * Math.log(probs[i3][j2]);
                }
                ++j2;
            }
            ++i3;
        }
        double[][] mutualI = new double[observations.get(0).size()][];
        int i4 = 0;
        while (i4 < mutualI.length) {
            mutualI[i4] = new double[i4];
            int j3 = 0;
            while (j3 < i4) {
                double[][] joints = new double[dsd.getDiscreteRange(i4)][dsd.getDiscreteRange(j3)];
                int k = 0;
                while (k < observations.size()) {
                    Instance instance = observations.get(k);
                    double[] dArray = joints[instance.getDiscrete(i4)];
                    int n = instance.getDiscrete(j3);
                    dArray[n] = dArray[n] + 1.0;
                    ++k;
                }
                k = 0;
                while (k < joints.length) {
                    int l = 0;
                    while (l < joints[k].length) {
                        double[] dArray = joints[k];
                        int n = l++;
                        dArray[n] = dArray[n] / weightSum;
                    }
                    ++k;
                }
                double[] dArray = mutualI[i4];
                int n = j3;
                dArray[n] = dArray[n] + entropies[i4];
                double[] dArray2 = mutualI[i4];
                int n2 = j3;
                dArray2[n2] = dArray2[n2] + entropies[j3];
                k = 0;
                while (k < joints.length) {
                    int l = 0;
                    while (l < joints[k].length) {
                        if (joints[k][l] != 0.0) {
                            double[] dArray3 = mutualI[i4];
                            int n3 = j3;
                            dArray3[n3] = dArray3[n3] + joints[k][l] * Math.log(joints[k][l]);
                        }
                        ++l;
                    }
                    ++k;
                }
                ++j3;
            }
            ++i4;
        }
        return mutualI;
    }

    public String toString() {
        return this.dt.toString();
    }
}

