/*
 * Decompiled with CFR 0.152.
 */
package func;

import dist.AbstractConditionalDistribution;
import dist.DiscreteDistribution;
import dist.Distribution;
import func.FunctionApproximater;
import func.inst.KDTree;
import shared.DataSet;
import shared.DataSetDescription;
import shared.DistanceMeasure;
import shared.EuclideanDistance;
import shared.Instance;

public class KNNClassifier
extends AbstractConditionalDistribution
implements FunctionApproximater {
    private DistanceMeasure distanceMeasure;
    private double range;
    private int k;
    private boolean weightByDistance;
    private int classRange;
    private KDTree tree;

    public KNNClassifier() {
        this(1, new EuclideanDistance());
    }

    public KNNClassifier(int k, DistanceMeasure measure) {
        this(k, false, measure, -1.0);
    }

    public KNNClassifier(int k, boolean weight, DistanceMeasure measure) {
        this(k, weight, measure, -1.0);
    }

    public KNNClassifier(int k, boolean weight, DistanceMeasure measure, double range) {
        this.k = k;
        this.weightByDistance = weight;
        this.range = range;
        this.distanceMeasure = measure;
    }

    public void estimate(DataSet examples) {
        if (examples.getDescription() == null) {
            examples.setDescription(new DataSetDescription(examples));
        }
        this.classRange = examples.getDescription().getLabelDescription().getDiscreteRange();
        this.tree = new KDTree(examples, this.distanceMeasure);
    }

    public Distribution distributionFor(Instance data) {
        double[] distribution = new double[this.classRange];
        Instance[] results = this.range > 0.0 ? this.tree.knnrange(data, this.k, this.range) : this.tree.knn(data, this.k);
        int i = 0;
        while (i < results.length) {
            Instance neighbor = results[i];
            if (this.weightByDistance) {
                int n = neighbor.getLabel().getDiscrete();
                distribution[n] = distribution[n] + neighbor.getWeight() / this.distanceMeasure.value(data, neighbor);
            } else {
                int n = neighbor.getLabel().getDiscrete();
                distribution[n] = distribution[n] + neighbor.getWeight();
            }
            ++i;
        }
        double sum = 0.0;
        int i2 = 0;
        while (i2 < distribution.length) {
            sum += distribution[i2];
            ++i2;
        }
        if (Double.isInfinite(sum)) {
            sum = 0.0;
            i2 = 0;
            while (i2 < distribution.length) {
                if (Double.isInfinite(distribution[i2])) {
                    distribution[i2] = 1.0;
                    sum += 1.0;
                } else {
                    distribution[i2] = 0.0;
                }
                ++i2;
            }
        }
        i2 = 0;
        while (i2 < distribution.length) {
            int n = i2++;
            distribution[n] = distribution[n] / sum;
        }
        return new DiscreteDistribution(distribution);
    }

    public Instance value(Instance data) {
        return this.distributionFor(data).mode();
    }

    public DistanceMeasure getDistanceMeasure() {
        return this.distanceMeasure;
    }

    public int getK() {
        return this.k;
    }

    public boolean isWeightByDistance() {
        return this.weightByDistance;
    }

    public void setDistanceMeasure(DistanceMeasure measure) {
        this.distanceMeasure = measure;
    }

    public void setK(int i) {
        this.k = i;
    }

    public void setWeightByDistance(boolean b) {
        this.weightByDistance = b;
    }
}

