/*
 * Decompiled with CFR 0.152.
 */
package opt.test;

import func.nn.backprop.BackPropagationNetwork;
import func.nn.backprop.BackPropagationNetworkFactory;
import opt.OptimizationAlgorithm;
import opt.RandomizedHillClimbing;
import opt.example.NeuralNetworkOptimizationProblem;
import shared.DataSet;
import shared.FixedIterationTrainer;
import shared.Instance;
import shared.SumOfSquaresError;

public class XORTest {
    public static void main(String[] args) {
        BackPropagationNetworkFactory factory = new BackPropagationNetworkFactory();
        double[][][] data = new double[][][]{new double[][]{{1.0, 1.0}, {0.0}}, new double[][]{{1.0, 0.0}, {1.0}}, new double[][]{{0.0, 1.0}, {1.0}}, new double[][]{{0.0, 0.0}, {0.0}}};
        Instance[] patterns = new Instance[data.length];
        int i = 0;
        while (i < patterns.length) {
            patterns[i] = new Instance(data[i][0]);
            patterns[i].setLabel(new Instance(data[i][1]));
            ++i;
        }
        BackPropagationNetwork network = factory.createClassificationNetwork(new int[]{2, 3, 1});
        SumOfSquaresError measure = new SumOfSquaresError();
        DataSet set = new DataSet(patterns);
        NeuralNetworkOptimizationProblem nno = new NeuralNetworkOptimizationProblem(set, network, measure);
        RandomizedHillClimbing o = new RandomizedHillClimbing(nno);
        FixedIterationTrainer fit = new FixedIterationTrainer(o, 5000);
        fit.train();
        Instance opt = ((OptimizationAlgorithm)o).getOptimal();
        network.setWeights(opt.getData());
        int i2 = 0;
        while (i2 < patterns.length) {
            network.setInputValues(patterns[i2].getData());
            network.run();
            System.out.println("~~");
            System.out.println(patterns[i2].getLabel());
            System.out.println(network.getOutputValues());
            ++i2;
        }
    }
}

