package edu.mit.csail.cgs.clustering.kmeans;

import edu.mit.csail.cgs.clustering.Cluster;
import edu.mit.csail.cgs.clustering.ClusterRepresentative;
import edu.mit.csail.cgs.clustering.ClusteringMethod;
import edu.mit.csail.cgs.clustering.DefaultCluster;
import edu.mit.csail.cgs.clustering.PairwiseElementMetric;
import java.util.Collection;
import java.util.Vector;

/* loaded from: input_file:edu/mit/csail/cgs/clustering/kmeans/KMeansClustering.class */
public class KMeansClustering<X> implements ClusteringMethod<X> {
    private PairwiseElementMetric<X> metric;
    private ClusterRepresentative<X> repr;
    private Vector<X> startMeans;
    private int numClusters;
    private int iterations;
    private Vector<X> elmts;
    private Vector<DefaultCluster<X>> clusters = new Vector<>();
    private Vector<X> clusterMeans;

    public KMeansClustering(PairwiseElementMetric<X> pairwiseElementMetric, ClusterRepresentative<X> clusterRepresentative, Collection<X> collection) {
        this.metric = pairwiseElementMetric;
        this.repr = clusterRepresentative;
        this.numClusters = collection.size();
        for (int i = 0; i < this.numClusters; i++) {
            this.clusters.add(new DefaultCluster<>());
        }
        this.clusterMeans = new Vector<>(collection);
        this.startMeans = new Vector<>(collection);
        this.iterations = 10;
        this.elmts = new Vector<>();
    }

    public void setIterations(int i) {
        this.iterations = i;
    }

    @Override // edu.mit.csail.cgs.clustering.ClusteringMethod
    public Collection<Cluster<X>> clusterElements(Collection<X> collection) {
        return clusterElements(collection, 0.0d);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Collection<Cluster<X>> clusterElements(Collection<X> collection, double d) {
        init(collection);
        boolean z = false;
        for (int i = 0; i < this.iterations && !z; i++) {
            Vector vector = (Vector) this.clusterMeans.clone();
            assignToClusters();
            getClusterMeans();
            double d2 = 0.0d;
            for (int i2 = 0; i2 < this.numClusters; i2++) {
                d2 += Math.abs(this.metric.evaluate(vector.get(i2), this.clusterMeans.get(i2)));
            }
            if (d2 <= d) {
                z = true;
            }
        }
        return new Vector(this.clusters);
    }

    private void assignToClusters() {
        for (int i = 0; i < this.numClusters; i++) {
            this.clusters.get(i).clear();
        }
        for (int i2 = 0; i2 < this.elmts.size(); i2++) {
            X x = this.elmts.get(i2);
            int i3 = -1;
            double d = 0.0d;
            for (int i4 = 0; i4 < this.numClusters; i4++) {
                double evaluate = this.metric.evaluate(x, this.clusterMeans.get(i4));
                if (i3 == -1 || evaluate < d) {
                    d = evaluate;
                    i3 = i4;
                }
            }
            this.clusters.get(i3).addElement(x);
        }
    }

    public Vector<X> getClusterMeans() {
        for (int i = 0; i < this.numClusters; i++) {
            this.clusterMeans.set(i, this.repr.getRepresentative(this.clusters.get(i)));
        }
        return this.clusterMeans;
    }

    private void init(Collection<X> collection) {
        this.elmts = new Vector<>(collection);
        for (int i = 0; i < this.numClusters; i++) {
            this.clusters.set(i, new DefaultCluster<>());
            this.clusterMeans.set(i, this.startMeans.get(i));
        }
    }

    public double sumOfSquaredDistance() {
        double d = 0.0d;
        for (int i = 0; i < this.elmts.size(); i++) {
            X x = this.elmts.get(i);
            int i2 = -1;
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.numClusters; i3++) {
                double evaluate = this.metric.evaluate(x, this.clusterMeans.get(i3));
                if (i2 == -1 || evaluate < d2) {
                    d2 = evaluate;
                    i2 = i3;
                }
            }
            d += d2 * d2;
        }
        return d;
    }
}
