/*
 * Decompiled with CFR 0.152.
 */
package dist.hmm;

import dist.hmm.HiddenMarkovModel;
import shared.DataSet;

public class StateSequenceCalculator {
    private HiddenMarkovModel model;
    private DataSet observations;
    private double[][] probabilities;
    private int[][] chain;

    public StateSequenceCalculator(HiddenMarkovModel model, DataSet observations) {
        this.model = model;
        this.observations = observations;
    }

    public int[] calculateStateSequence() {
        this.probabilities = new double[this.observations.size()][this.model.getStateCount()];
        this.chain = new int[this.observations.size()][this.model.getStateCount()];
        this.calcuateForward();
        return this.calcuateBackward();
    }

    private int[] calcuateBackward() {
        double max = Double.NEGATIVE_INFINITY;
        int argMax = Integer.MIN_VALUE;
        int i = 0;
        while (i < this.model.getStateCount()) {
            if (this.probabilities[this.observations.size() - 1][i] > max) {
                max = this.probabilities[this.observations.size() - 1][i];
                argMax = i;
            }
            ++i;
        }
        int[] states = new int[this.observations.size()];
        states[this.observations.size() - 1] = argMax;
        int t = this.observations.size() - 2;
        while (t >= 0) {
            states[t] = this.chain[t + 1][states[t + 1]];
            --t;
        }
        return states;
    }

    private void calcuateForward() {
        int i = 0;
        while (i < this.model.getStateCount()) {
            this.probabilities[0][i] = Math.log(this.model.initialStateProbability(i, this.observations.get(0))) + Math.log(this.model.observationProbability(i, this.observations.get(0)));
            this.chain[0][i] = 0;
            ++i;
        }
        int t = 1;
        while (t < this.observations.size()) {
            int i2 = 0;
            while (i2 < this.model.getStateCount()) {
                double max = Double.NEGATIVE_INFINITY;
                int argMax = Integer.MIN_VALUE;
                int j = 0;
                while (j < this.model.getStateCount()) {
                    double value = this.probabilities[t - 1][j] + Math.log(this.model.transitionProbability(j, i2, this.observations.get(t)));
                    if (value > max) {
                        max = value;
                        argMax = j;
                    }
                    ++j;
                }
                this.probabilities[t][i2] = max + Math.log(this.model.observationProbability(i2, this.observations.get(t)));
                this.chain[t][i2] = argMax;
                ++i2;
            }
            ++t;
        }
    }
}

