/*
 * Decompiled with CFR 0.152.
 */
package func.nn.backprop;

import func.nn.NetworkTrainer;
import func.nn.backprop.BackPropagationNetwork;
import func.nn.backprop.WeightUpdateRule;
import shared.DataSet;
import shared.GradientErrorMeasure;
import shared.Instance;
import shared.filt.RandomOrderFilter;

public class StochasticBackPropagationTrainer
extends NetworkTrainer {
    private WeightUpdateRule rule;

    public StochasticBackPropagationTrainer(DataSet patterns, BackPropagationNetwork network, GradientErrorMeasure errorMeasure, WeightUpdateRule rule) {
        super(patterns, network, errorMeasure);
        this.rule = rule;
    }

    public double train() {
        BackPropagationNetwork network = (BackPropagationNetwork)this.getNetwork();
        GradientErrorMeasure measure = (GradientErrorMeasure)this.getErrorMeasure();
        DataSet patterns = this.getDataSet();
        RandomOrderFilter randomizer = new RandomOrderFilter();
        randomizer.filter(patterns);
        double error = 0.0;
        int i = 0;
        while (i < patterns.size()) {
            Instance pattern = patterns.get(i);
            network.setInputValues(pattern.getData());
            network.run();
            Instance output = new Instance(network.getOutputValues());
            double[] errors = measure.gradient(output, pattern);
            error += measure.value(output, pattern);
            network.setOutputErrors(errors);
            network.backpropagate();
            network.updateWeights(this.rule);
            network.clearError();
            ++i;
        }
        return error / (double)patterns.size();
    }
}

