/*
 * Decompiled with CFR 0.152.
 */
package dist.hmm;

import dist.hmm.ForwardBackwardProbabilityCalculator;
import dist.hmm.HiddenMarkovModel;
import shared.DataSet;
import shared.Instance;
import shared.Trainer;

public class HiddenMarkovModelReestimator
implements Trainer {
    private DataSet[] observationSequences;
    private HiddenMarkovModel model;
    private double[][][][] transitionExpectations;
    private double[][][] stateExpectations;
    private DataSet outputObservations;
    private DataSet transitionObservations;
    private DataSet initialObservations;

    public HiddenMarkovModelReestimator(HiddenMarkovModel model, DataSet[] observationSequences) {
        this.model = model;
        this.observationSequences = observationSequences;
        this.stateExpectations = new double[observationSequences.length][][];
        this.transitionExpectations = new double[observationSequences.length][][][];
        this.initializeObservations();
    }

    public void initializeObservations() {
        this.initializeOutputObservations();
        this.initializeTransitionObservations();
        this.initializeInitialObservations();
    }

    public void initializeOutputObservations() {
        int totalTime = 0;
        int k = 0;
        while (k < this.observationSequences.length) {
            totalTime += this.observationSequences[k].size();
            ++k;
        }
        Instance[] outputObservationsInstances = new Instance[totalTime];
        int j = 0;
        int k2 = 0;
        while (k2 < this.observationSequences.length) {
            Instance[] cur = this.observationSequences[k2].getInstances();
            System.arraycopy(cur, 0, outputObservationsInstances, j, cur.length);
            j += cur.length;
            ++k2;
        }
        this.outputObservations = new DataSet(outputObservationsInstances, this.observationSequences[0].getDescription());
    }

    public void initializeInitialObservations() {
        Instance[] initialObservationsInstances = new Instance[this.observationSequences.length];
        int k = 0;
        while (k < this.observationSequences.length) {
            initialObservationsInstances[k] = this.observationSequences[k].get(0);
            ++k;
        }
        this.initialObservations = new DataSet(initialObservationsInstances, this.observationSequences[0].getDescription());
    }

    public void initializeTransitionObservations() {
        int totalTime = 0;
        int k = 0;
        while (k < this.observationSequences.length) {
            totalTime += this.observationSequences[k].size() - 1;
            ++k;
        }
        Instance[] transitionObservationsInstances = new Instance[totalTime];
        int j = 0;
        int k2 = 0;
        while (k2 < this.observationSequences.length) {
            Instance[] cur = this.observationSequences[k2].getInstances();
            System.arraycopy(cur, 1, transitionObservationsInstances, j, cur.length - 1);
            j += cur.length - 1;
            ++k2;
        }
        this.transitionObservations = new DataSet(transitionObservationsInstances, this.observationSequences[0].getDescription());
    }

    public double train() {
        double probability = 0.0;
        int k = 0;
        while (k < this.observationSequences.length) {
            DataSet observationSequence = this.observationSequences[k];
            ForwardBackwardProbabilityCalculator fbc = new ForwardBackwardProbabilityCalculator(this.model, observationSequence);
            double[][] forwardProbabilities = fbc.calculateForwardProbabilities();
            double[][] backwardProbabilities = fbc.calculateBackwardProbabilities();
            this.stateExpectations[k] = this.calculateStateExpectations(observationSequence, forwardProbabilities, backwardProbabilities);
            this.transitionExpectations[k] = this.calculateTransitionExpectations(observationSequence, forwardProbabilities, backwardProbabilities);
            probability += fbc.calculateLogProbability();
            ++k;
        }
        this.reestimateInitialStateDistribution();
        this.reestimateTransitionDistributions();
        this.reestimateOutputDistributions();
        return probability / (double)this.observationSequences.length;
    }

    public double[][][] calculateTransitionExpectations(DataSet observationSequence, double[][] forwardProbabilities, double[][] backwardProbabilities) {
        double[][][] transitions = new double[observationSequence.size() - 1][this.model.getStateCount()][this.model.getStateCount()];
        int t = 0;
        while (t < observationSequence.size() - 1) {
            int j;
            double sum = 0.0;
            int i = 0;
            while (i < this.model.getStateCount()) {
                j = 0;
                while (j < this.model.getStateCount()) {
                    transitions[t][i][j] = forwardProbabilities[t][i] * this.model.transitionProbability(i, j, observationSequence.get(t + 1)) * this.model.observationProbability(j, observationSequence.get(t + 1)) * backwardProbabilities[t + 1][j];
                    sum += transitions[t][i][j];
                    ++j;
                }
                ++i;
            }
            i = 0;
            while (i < this.model.getStateCount()) {
                j = 0;
                while (j < this.model.getStateCount()) {
                    double[] dArray = transitions[t][i];
                    int n = j++;
                    dArray[n] = dArray[n] / sum;
                }
                ++i;
            }
            ++t;
        }
        return transitions;
    }

    public double[][] calculateStateExpectations(DataSet observationSequence, double[][] forwardProbabilities, double[][] backwardProbabilities) {
        double[][] states = new double[observationSequence.size()][this.model.getStateCount()];
        int t = 0;
        while (t < observationSequence.size()) {
            double sum = 0.0;
            int i = 0;
            while (i < this.model.getStateCount()) {
                states[t][i] = forwardProbabilities[t][i] * backwardProbabilities[t][i];
                sum += states[t][i];
                ++i;
            }
            i = 0;
            while (i < this.model.getStateCount()) {
                double[] dArray = states[t];
                int n = i++;
                dArray[n] = dArray[n] / sum;
            }
            ++t;
        }
        return states;
    }

    public void reestimateInitialStateDistribution() {
        double[][] initialStateProbabilities = new double[this.observationSequences.length][this.model.getStateCount()];
        int k = 0;
        while (k < this.observationSequences.length) {
            int i = 0;
            while (i < this.model.getStateCount()) {
                initialStateProbabilities[k][i] = this.stateExpectations[k][0][i];
                ++i;
            }
            ++k;
        }
        this.model.estimateIntialStateDistribution(initialStateProbabilities, this.initialObservations);
    }

    public void reestimateTransitionDistributions() {
        double[][] probabilities = new double[this.transitionObservations.size()][this.model.getStateCount()];
        int i = 0;
        while (i < this.model.getStateCount()) {
            int j = 0;
            while (j < this.model.getStateCount()) {
                int counter = 0;
                int k = 0;
                while (k < this.observationSequences.length) {
                    int t = 0;
                    while (t < this.observationSequences[k].size() - 1) {
                        probabilities[counter][j] = this.transitionExpectations[k][t][i][j];
                        ++counter;
                        ++t;
                    }
                    ++k;
                }
                ++j;
            }
            this.model.estimateTransitionDistribution(i, probabilities, this.transitionObservations);
            ++i;
        }
    }

    public void reestimateOutputDistributions() {
        int i = 0;
        while (i < this.model.getStateCount()) {
            int counter = 0;
            int k = 0;
            while (k < this.observationSequences.length) {
                int t = 0;
                while (t < this.observationSequences[k].size()) {
                    this.observationSequences[k].get(t).setWeight(this.stateExpectations[k][t][i]);
                    ++counter;
                    ++t;
                }
                ++k;
            }
            this.model.estimateOutputDistribution(i, this.outputObservations);
            ++i;
        }
    }

    public HiddenMarkovModel getModel() {
        return this.model;
    }

    public void setModel(HiddenMarkovModel model) {
        this.model = model;
    }
}

