/*
 * Decompiled with CFR 0.152.
 */
package rl.test;

import rl.EpsilonGreedyStrategy;
import rl.Policy;
import rl.PolicyIteration;
import rl.QLambda;
import rl.SarsaLambda;
import rl.SimpleMarkovDecisionProcess;
import rl.ValueIteration;
import shared.FixedIterationTrainer;
import shared.ThresholdTrainer;

public class MDPTest {
    public static void main(String[] args) {
        SimpleMarkovDecisionProcess mdp = new SimpleMarkovDecisionProcess();
        mdp.setRewards(new double[]{0.0, 0.0, 10.0, 10.0});
        mdp.setTransitionMatrices(new double[][][]{new double[][]{{1.0, 0.0, 0.0, 0.0}, {0.5, 0.5, 0.0, 0.0}}, new double[][]{{0.5, 0.0, 0.0, 0.5}, {0.0, 1.0, 0.0, 0.0}}, new double[][]{{0.5, 0.0, 0.5, 0.0}, {0.5, 0.5, 0.0, 0.0}}, new double[][]{{0.0, 0.0, 0.5, 0.5}, {0.0, 1.0, 0.0, 0.0}}});
        mdp.setInitialState(0);
        ValueIteration vi = new ValueIteration(0.9, mdp);
        ThresholdTrainer tt = new ThresholdTrainer(vi);
        long startTime = System.currentTimeMillis();
        tt.train();
        Policy p = vi.getPolicy();
        long finishTime = System.currentTimeMillis();
        System.out.println("Value iteration learned : " + p);
        System.out.println("in " + tt.getIterations() + " iterations");
        System.out.println("and " + (finishTime - startTime) + " ms");
        PolicyIteration pi = new PolicyIteration(0.9, mdp);
        tt = new ThresholdTrainer(pi);
        startTime = System.currentTimeMillis();
        tt.train();
        p = pi.getPolicy();
        finishTime = System.currentTimeMillis();
        System.out.println("Policy iteration learned : " + p);
        System.out.println("in " + tt.getIterations() + " iterations");
        System.out.println("and " + (finishTime - startTime) + " ms");
        QLambda ql = new QLambda(0.5, 0.9, 0.2, 0.995, new EpsilonGreedyStrategy(0.3), mdp);
        FixedIterationTrainer fit = new FixedIterationTrainer(ql, 100);
        startTime = System.currentTimeMillis();
        fit.train();
        p = ql.getPolicy();
        finishTime = System.currentTimeMillis();
        System.out.println("Q lambda learned : " + p);
        System.out.println("in 100 iterations");
        System.out.println("and " + (finishTime - startTime) + " ms");
        System.out.println("Acquiring " + ql.getTotalReward() + " reward");
        SarsaLambda sl = new SarsaLambda(0.5, 0.9, 0.2, 0.995, new EpsilonGreedyStrategy(0.3), mdp);
        fit = new FixedIterationTrainer(sl, 100);
        startTime = System.currentTimeMillis();
        fit.train();
        p = sl.getPolicy();
        finishTime = System.currentTimeMillis();
        System.out.println("Sarsa lambda learned : " + p);
        System.out.println("in 100 iterations");
        System.out.println("and " + (finishTime - startTime) + " ms");
        System.out.println("Acquiring " + sl.getTotalReward() + " reward");
    }
}

