/*
 * Decompiled with CFR 0.152.
 */
package func;

import dist.AbstractConditionalDistribution;
import dist.DiscreteDistribution;
import dist.Distribution;
import func.FunctionApproximater;
import func.nn.activation.DifferentiableActivationFunction;
import func.nn.activation.HyperbolicTangentSigmoid;
import func.nn.backprop.BackPropagationNetwork;
import func.nn.backprop.BackPropagationNetworkFactory;
import func.nn.backprop.BatchBackPropagationTrainer;
import func.nn.backprop.RPROPUpdateRule;
import func.nn.backprop.WeightUpdateRule;
import shared.ConvergenceTrainer;
import shared.DataSet;
import shared.DataSetDescription;
import shared.Instance;
import shared.SumOfSquaresError;

public class NeuralNetworkClassifier
extends AbstractConditionalDistribution
implements FunctionApproximater {
    private DifferentiableActivationFunction activationFunction;
    private int hiddenNodeCount;
    private WeightUpdateRule updateRule;
    private BackPropagationNetwork network;

    public NeuralNetworkClassifier(int hiddenNodeCount, DifferentiableActivationFunction activationFunction, WeightUpdateRule updateRule) {
        this.hiddenNodeCount = hiddenNodeCount;
        this.activationFunction = activationFunction;
        this.updateRule = updateRule;
    }

    public NeuralNetworkClassifier() {
        this(3, new HyperbolicTangentSigmoid(), new RPROPUpdateRule());
    }

    public void estimate(DataSet set) {
        int[] topology;
        if (set.getDescription() == null) {
            set.setDescription(new DataSetDescription(set));
        }
        if (this.hiddenNodeCount != 0) {
            topology = new int[3];
            topology[1] = this.hiddenNodeCount;
        } else {
            topology = new int[2];
        }
        topology[0] = set.getDescription().getAttributeTypes().length;
        topology[topology.length - 1] = set.getDescription().getLabelDescription().getDiscreteRange() == 2 ? 1 : set.getDescription().getLabelDescription().getDiscreteRange();
        this.network = new BackPropagationNetworkFactory().createClassificationNetwork(topology, this.activationFunction);
        SumOfSquaresError errorMeasure = new SumOfSquaresError();
        ConvergenceTrainer trainer = new ConvergenceTrainer(new BatchBackPropagationTrainer(set, this.network, errorMeasure, this.updateRule));
        trainer.train();
    }

    public Distribution distributionFor(Instance input) {
        this.network.setInputValues(input.getData());
        this.network.run();
        if (this.network.getOutputLayer().getNodeCount() > 1) {
            return new DiscreteDistribution(this.network.getOutputValues());
        }
        double[] p = new double[2];
        p[1] = this.network.getOutputValues().get(0);
        p[0] = 1.0 - p[1];
        return new DiscreteDistribution(p);
    }

    public Instance value(Instance i) {
        return this.distributionFor(i).mode();
    }
}

