/*
 * Decompiled with CFR 0.152.
 */
package rl;

import dist.Distribution;
import rl.ExplorationStrategy;
import rl.GreedyStrategy;
import rl.MarkovDecisionProcess;
import rl.Policy;
import rl.PolicyLearner;

public class QLambda
implements PolicyLearner {
    private static final double ZERO = 1.0E-6;
    private double lambda;
    private double gamma;
    private double alpha;
    private double decay;
    private ExplorationStrategy strategy;
    private MarkovDecisionProcess process;
    private double[][] values;
    private double[][] eligibility;
    private int state;
    private int action;
    private int iteration;
    private int episode;
    private double totalReward;

    public QLambda(double lambda, double gamma, double alpha, double decay, ExplorationStrategy strategy, MarkovDecisionProcess process) {
        this.lambda = lambda;
        this.gamma = gamma;
        this.alpha = alpha;
        this.decay = decay;
        this.strategy = strategy;
        this.process = process;
        this.values = new double[process.getStateCount()][process.getActionCount()];
        this.eligibility = new double[process.getStateCount()][process.getActionCount()];
        this.state = process.sampleInitialState();
        this.action = Distribution.random.nextInt(process.getActionCount());
    }

    public double train() {
        int a;
        double reward = this.process.reward(this.state, this.action);
        this.totalReward += reward;
        int nextState = this.process.sampleState(this.state, this.action);
        int nextAction = this.strategy.action(this.values[nextState]);
        GreedyStrategy greedy = new GreedyStrategy();
        int nextBestAction = greedy.action(this.values[nextState]);
        double delta = reward + this.gamma * this.values[nextState][nextBestAction] - this.values[this.state][this.action];
        double[] dArray = this.eligibility[this.state];
        int n = this.action;
        dArray[n] = dArray[n] + 1.0;
        double difference = 0.0;
        int i = 0;
        while (i < this.process.getStateCount()) {
            a = 0;
            while (a < this.process.getActionCount()) {
                if (!(this.eligibility[i][a] < 1.0E-6)) {
                    double newValue = this.values[i][a] + this.alpha * delta * this.eligibility[i][a];
                    difference = Math.max(difference, Math.abs(this.values[i][a] - newValue));
                    this.values[i][a] = newValue;
                    if (nextAction == nextBestAction) {
                        double[] dArray2 = this.eligibility[i];
                        int n2 = a;
                        dArray2[n2] = dArray2[n2] * (this.gamma * this.lambda);
                    } else {
                        this.eligibility[i][a] = 0.0;
                    }
                }
                ++a;
            }
            ++i;
        }
        this.state = nextState;
        this.action = nextAction;
        if (this.process.isTerminalState(this.state)) {
            ++this.episode;
            this.state = this.process.sampleInitialState();
            this.action = this.strategy.action(this.values[this.state]);
            i = 0;
            while (i < this.process.getStateCount()) {
                a = 0;
                while (a < this.process.getActionCount()) {
                    this.eligibility[i][a] = 0.0;
                    ++a;
                }
                ++i;
            }
        }
        ++this.iteration;
        this.alpha *= this.decay;
        return difference;
    }

    public Policy getPolicy() {
        int stateCount = this.process.getStateCount();
        int actionCount = this.process.getActionCount();
        int[] policy = new int[stateCount];
        int i = 0;
        while (i < stateCount) {
            double maxActionVal = 0.0;
            int maxAction = 0;
            int a = 0;
            while (a < actionCount) {
                double actionVal = this.values[i][a];
                if (actionVal > maxActionVal) {
                    maxActionVal = actionVal;
                    maxAction = a;
                }
                ++a;
            }
            policy[i] = maxAction;
            ++i;
        }
        return new Policy(policy);
    }

    public String toString() {
        return String.valueOf(this.iteration) + ", " + this.episode;
    }

    public double getTotalReward() {
        return this.totalReward;
    }
}

