/*
 * Decompiled with CFR 0.152.
 */
package rl;

import rl.MarkovDecisionProcess;
import rl.Policy;
import rl.PolicyLearner;

public class PolicyIteration
implements PolicyLearner {
    private static final double TOLERANCE = 1.0E-6;
    private Policy policy;
    private MarkovDecisionProcess process;
    private double gamma;

    public PolicyIteration(double gamma, MarkovDecisionProcess process) {
        this.gamma = gamma;
        this.process = process;
        this.policy = new Policy(process.getStateCount(), process.getActionCount());
    }

    public double train() {
        int stateCount = this.process.getStateCount();
        int actionCount = this.process.getActionCount();
        double[] values = new double[stateCount];
        boolean valuesChanged = false;
        do {
            valuesChanged = false;
            int i = 0;
            while (i < stateCount) {
                int action = this.policy.getAction(i);
                double actionVal = 0.0;
                int j = 0;
                while (j < stateCount) {
                    actionVal += this.process.transitionProbability(i, j, action) * values[j];
                    ++j;
                }
                double val = this.process.reward(i, action) + this.gamma * actionVal;
                if (Math.abs(values[i] - val) > 1.0E-6) {
                    valuesChanged = true;
                }
                values[i] = val;
                ++i;
            }
        } while (valuesChanged);
        int changed = 0;
        int i = 0;
        while (i < stateCount) {
            double maxActionVal = 0.0;
            int maxAction = 0;
            int a = 0;
            while (a < actionCount) {
                double actionVal = 0.0;
                int j = 0;
                while (j < stateCount) {
                    actionVal += this.process.transitionProbability(i, j, a) * values[j];
                    ++j;
                }
                actionVal = this.process.reward(i, a) + this.gamma * actionVal;
                if (actionVal > maxActionVal) {
                    maxActionVal = actionVal;
                    maxAction = a;
                }
                ++a;
            }
            if (this.policy.getAction(i) != maxAction) {
                ++changed;
            }
            this.policy.setAction(i, maxAction);
            ++i;
        }
        return changed;
    }

    public Policy getPolicy() {
        return this.policy;
    }
}

