KMeansPlusPlusClusterer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.ml.clustering;

  18. import java.util.ArrayList;
  19. import java.util.Collection;
  20. import java.util.Collections;
  21. import java.util.List;

  22. import org.apache.commons.math3.exception.ConvergenceException;
  23. import org.apache.commons.math3.exception.MathIllegalArgumentException;
  24. import org.apache.commons.math3.exception.NumberIsTooSmallException;
  25. import org.apache.commons.math3.exception.util.LocalizedFormats;
  26. import org.apache.commons.math3.ml.distance.DistanceMeasure;
  27. import org.apache.commons.math3.ml.distance.EuclideanDistance;
  28. import org.apache.commons.math3.random.JDKRandomGenerator;
  29. import org.apache.commons.math3.random.RandomGenerator;
  30. import org.apache.commons.math3.stat.descriptive.moment.Variance;
  31. import org.apache.commons.math3.util.MathUtils;

  32. /**
  33.  * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
  34.  * @param <T> type of the points to cluster
  35.  * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
  36.  * @since 3.2
  37.  */
  38. public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {

  39.     /** Strategies to use for replacing an empty cluster. */
  40.     public static enum EmptyClusterStrategy {

  41.         /** Split the cluster with largest distance variance. */
  42.         LARGEST_VARIANCE,

  43.         /** Split the cluster with largest number of points. */
  44.         LARGEST_POINTS_NUMBER,

  45.         /** Create a cluster around the point farthest from its centroid. */
  46.         FARTHEST_POINT,

  47.         /** Generate an error. */
  48.         ERROR

  49.     }

  50.     /** The number of clusters. */
  51.     private final int k;

  52.     /** The maximum number of iterations. */
  53.     private final int maxIterations;

  54.     /** Random generator for choosing initial centers. */
  55.     private final RandomGenerator random;

  56.     /** Selected strategy for empty clusters. */
  57.     private final EmptyClusterStrategy emptyStrategy;

  58.     /** Build a clusterer.
  59.      * <p>
  60.      * The default strategy for handling empty clusters that may appear during
  61.      * algorithm iterations is to split the cluster with largest distance variance.
  62.      * <p>
  63.      * The euclidean distance will be used as default distance measure.
  64.      *
  65.      * @param k the number of clusters to split the data into
  66.      */
  67.     public KMeansPlusPlusClusterer(final int k) {
  68.         this(k, -1);
  69.     }

  70.     /** Build a clusterer.
  71.      * <p>
  72.      * The default strategy for handling empty clusters that may appear during
  73.      * algorithm iterations is to split the cluster with largest distance variance.
  74.      * <p>
  75.      * The euclidean distance will be used as default distance measure.
  76.      *
  77.      * @param k the number of clusters to split the data into
  78.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  79.      *   If negative, no maximum will be used.
  80.      */
  81.     public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
  82.         this(k, maxIterations, new EuclideanDistance());
  83.     }

  84.     /** Build a clusterer.
  85.      * <p>
  86.      * The default strategy for handling empty clusters that may appear during
  87.      * algorithm iterations is to split the cluster with largest distance variance.
  88.      *
  89.      * @param k the number of clusters to split the data into
  90.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  91.      *   If negative, no maximum will be used.
  92.      * @param measure the distance measure to use
  93.      */
  94.     public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
  95.         this(k, maxIterations, measure, new JDKRandomGenerator());
  96.     }

  97.     /** Build a clusterer.
  98.      * <p>
  99.      * The default strategy for handling empty clusters that may appear during
  100.      * algorithm iterations is to split the cluster with largest distance variance.
  101.      *
  102.      * @param k the number of clusters to split the data into
  103.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  104.      *   If negative, no maximum will be used.
  105.      * @param measure the distance measure to use
  106.      * @param random random generator to use for choosing initial centers
  107.      */
  108.     public KMeansPlusPlusClusterer(final int k, final int maxIterations,
  109.                                    final DistanceMeasure measure,
  110.                                    final RandomGenerator random) {
  111.         this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
  112.     }

  113.     /** Build a clusterer.
  114.      *
  115.      * @param k the number of clusters to split the data into
  116.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  117.      *   If negative, no maximum will be used.
  118.      * @param measure the distance measure to use
  119.      * @param random random generator to use for choosing initial centers
  120.      * @param emptyStrategy strategy to use for handling empty clusters that
  121.      * may appear during algorithm iterations
  122.      */
  123.     public KMeansPlusPlusClusterer(final int k, final int maxIterations,
  124.                                    final DistanceMeasure measure,
  125.                                    final RandomGenerator random,
  126.                                    final EmptyClusterStrategy emptyStrategy) {
  127.         super(measure);
  128.         this.k             = k;
  129.         this.maxIterations = maxIterations;
  130.         this.random        = random;
  131.         this.emptyStrategy = emptyStrategy;
  132.     }

  133.     /**
  134.      * Return the number of clusters this instance will use.
  135.      * @return the number of clusters
  136.      */
  137.     public int getK() {
  138.         return k;
  139.     }

  140.     /**
  141.      * Returns the maximum number of iterations this instance will use.
  142.      * @return the maximum number of iterations, or -1 if no maximum is set
  143.      */
  144.     public int getMaxIterations() {
  145.         return maxIterations;
  146.     }

  147.     /**
  148.      * Returns the random generator this instance will use.
  149.      * @return the random generator
  150.      */
  151.     public RandomGenerator getRandomGenerator() {
  152.         return random;
  153.     }

  154.     /**
  155.      * Returns the {@link EmptyClusterStrategy} used by this instance.
  156.      * @return the {@link EmptyClusterStrategy}
  157.      */
  158.     public EmptyClusterStrategy getEmptyClusterStrategy() {
  159.         return emptyStrategy;
  160.     }

  161.     /**
  162.      * Runs the K-means++ clustering algorithm.
  163.      *
  164.      * @param points the points to cluster
  165.      * @return a list of clusters containing the points
  166.      * @throws MathIllegalArgumentException if the data points are null or the number
  167.      *     of clusters is larger than the number of data points
  168.      * @throws ConvergenceException if an empty cluster is encountered and the
  169.      * {@link #emptyStrategy} is set to {@code ERROR}
  170.      */
  171.     @Override
  172.     public List<CentroidCluster<T>> cluster(final Collection<T> points)
  173.         throws MathIllegalArgumentException, ConvergenceException {

  174.         // sanity checks
  175.         MathUtils.checkNotNull(points);

  176.         // number of clusters has to be smaller or equal the number of data points
  177.         if (points.size() < k) {
  178.             throw new NumberIsTooSmallException(points.size(), k, false);
  179.         }

  180.         // create the initial clusters
  181.         List<CentroidCluster<T>> clusters = chooseInitialCenters(points);

  182.         // create an array containing the latest assignment of a point to a cluster
  183.         // no need to initialize the array, as it will be filled with the first assignment
  184.         int[] assignments = new int[points.size()];
  185.         assignPointsToClusters(clusters, points, assignments);

  186.         // iterate through updating the centers until we're done
  187.         final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
  188.         for (int count = 0; count < max; count++) {
  189.             boolean emptyCluster = false;
  190.             List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>();
  191.             for (final CentroidCluster<T> cluster : clusters) {
  192.                 final Clusterable newCenter;
  193.                 if (cluster.getPoints().isEmpty()) {
  194.                     switch (emptyStrategy) {
  195.                         case LARGEST_VARIANCE :
  196.                             newCenter = getPointFromLargestVarianceCluster(clusters);
  197.                             break;
  198.                         case LARGEST_POINTS_NUMBER :
  199.                             newCenter = getPointFromLargestNumberCluster(clusters);
  200.                             break;
  201.                         case FARTHEST_POINT :
  202.                             newCenter = getFarthestPoint(clusters);
  203.                             break;
  204.                         default :
  205.                             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  206.                     }
  207.                     emptyCluster = true;
  208.                 } else {
  209.                     newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
  210.                 }
  211.                 newClusters.add(new CentroidCluster<T>(newCenter));
  212.             }
  213.             int changes = assignPointsToClusters(newClusters, points, assignments);
  214.             clusters = newClusters;

  215.             // if there were no more changes in the point-to-cluster assignment
  216.             // and there are no empty clusters left, return the current clusters
  217.             if (changes == 0 && !emptyCluster) {
  218.                 return clusters;
  219.             }
  220.         }
  221.         return clusters;
  222.     }

  223.     /**
  224.      * Adds the given points to the closest {@link Cluster}.
  225.      *
  226.      * @param clusters the {@link Cluster}s to add the points to
  227.      * @param points the points to add to the given {@link Cluster}s
  228.      * @param assignments points assignments to clusters
  229.      * @return the number of points assigned to different clusters as the iteration before
  230.      */
  231.     private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
  232.                                        final Collection<T> points,
  233.                                        final int[] assignments) {
  234.         int assignedDifferently = 0;
  235.         int pointIndex = 0;
  236.         for (final T p : points) {
  237.             int clusterIndex = getNearestCluster(clusters, p);
  238.             if (clusterIndex != assignments[pointIndex]) {
  239.                 assignedDifferently++;
  240.             }

  241.             CentroidCluster<T> cluster = clusters.get(clusterIndex);
  242.             cluster.addPoint(p);
  243.             assignments[pointIndex++] = clusterIndex;
  244.         }

  245.         return assignedDifferently;
  246.     }

  247.     /**
  248.      * Use K-means++ to choose the initial centers.
  249.      *
  250.      * @param points the points to choose the initial centers from
  251.      * @return the initial centers
  252.      */
  253.     private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {

  254.         // Convert to list for indexed access. Make it unmodifiable, since removal of items
  255.         // would screw up the logic of this method.
  256.         final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));

  257.         // The number of points in the list.
  258.         final int numPoints = pointList.size();

  259.         // Set the corresponding element in this array to indicate when
  260.         // elements of pointList are no longer available.
  261.         final boolean[] taken = new boolean[numPoints];

  262.         // The resulting list of initial centers.
  263.         final List<CentroidCluster<T>> resultSet = new ArrayList<CentroidCluster<T>>();

  264.         // Choose one center uniformly at random from among the data points.
  265.         final int firstPointIndex = random.nextInt(numPoints);

  266.         final T firstPoint = pointList.get(firstPointIndex);

  267.         resultSet.add(new CentroidCluster<T>(firstPoint));

  268.         // Must mark it as taken
  269.         taken[firstPointIndex] = true;

  270.         // To keep track of the minimum distance squared of elements of
  271.         // pointList to elements of resultSet.
  272.         final double[] minDistSquared = new double[numPoints];

  273.         // Initialize the elements.  Since the only point in resultSet is firstPoint,
  274.         // this is very easy.
  275.         for (int i = 0; i < numPoints; i++) {
  276.             if (i != firstPointIndex) { // That point isn't considered
  277.                 double d = distance(firstPoint, pointList.get(i));
  278.                 minDistSquared[i] = d*d;
  279.             }
  280.         }

  281.         while (resultSet.size() < k) {

  282.             // Sum up the squared distances for the points in pointList not
  283.             // already taken.
  284.             double distSqSum = 0.0;

  285.             for (int i = 0; i < numPoints; i++) {
  286.                 if (!taken[i]) {
  287.                     distSqSum += minDistSquared[i];
  288.                 }
  289.             }

  290.             // Add one new data point as a center. Each point x is chosen with
  291.             // probability proportional to D(x)2
  292.             final double r = random.nextDouble() * distSqSum;

  293.             // The index of the next point to be added to the resultSet.
  294.             int nextPointIndex = -1;

  295.             // Sum through the squared min distances again, stopping when
  296.             // sum >= r.
  297.             double sum = 0.0;
  298.             for (int i = 0; i < numPoints; i++) {
  299.                 if (!taken[i]) {
  300.                     sum += minDistSquared[i];
  301.                     if (sum >= r) {
  302.                         nextPointIndex = i;
  303.                         break;
  304.                     }
  305.                 }
  306.             }

  307.             // If it's not set to >= 0, the point wasn't found in the previous
  308.             // for loop, probably because distances are extremely small.  Just pick
  309.             // the last available point.
  310.             if (nextPointIndex == -1) {
  311.                 for (int i = numPoints - 1; i >= 0; i--) {
  312.                     if (!taken[i]) {
  313.                         nextPointIndex = i;
  314.                         break;
  315.                     }
  316.                 }
  317.             }

  318.             // We found one.
  319.             if (nextPointIndex >= 0) {

  320.                 final T p = pointList.get(nextPointIndex);

  321.                 resultSet.add(new CentroidCluster<T> (p));

  322.                 // Mark it as taken.
  323.                 taken[nextPointIndex] = true;

  324.                 if (resultSet.size() < k) {
  325.                     // Now update elements of minDistSquared.  We only have to compute
  326.                     // the distance to the new center to do this.
  327.                     for (int j = 0; j < numPoints; j++) {
  328.                         // Only have to worry about the points still not taken.
  329.                         if (!taken[j]) {
  330.                             double d = distance(p, pointList.get(j));
  331.                             double d2 = d * d;
  332.                             if (d2 < minDistSquared[j]) {
  333.                                 minDistSquared[j] = d2;
  334.                             }
  335.                         }
  336.                     }
  337.                 }

  338.             } else {
  339.                 // None found --
  340.                 // Break from the while loop to prevent
  341.                 // an infinite loop.
  342.                 break;
  343.             }
  344.         }

  345.         return resultSet;
  346.     }

  347.     /**
  348.      * Get a random point from the {@link Cluster} with the largest distance variance.
  349.      *
  350.      * @param clusters the {@link Cluster}s to search
  351.      * @return a random point from the selected cluster
  352.      * @throws ConvergenceException if clusters are all empty
  353.      */
  354.     private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters)
  355.             throws ConvergenceException {

  356.         double maxVariance = Double.NEGATIVE_INFINITY;
  357.         Cluster<T> selected = null;
  358.         for (final CentroidCluster<T> cluster : clusters) {
  359.             if (!cluster.getPoints().isEmpty()) {

  360.                 // compute the distance variance of the current cluster
  361.                 final Clusterable center = cluster.getCenter();
  362.                 final Variance stat = new Variance();
  363.                 for (final T point : cluster.getPoints()) {
  364.                     stat.increment(distance(point, center));
  365.                 }
  366.                 final double variance = stat.getResult();

  367.                 // select the cluster with the largest variance
  368.                 if (variance > maxVariance) {
  369.                     maxVariance = variance;
  370.                     selected = cluster;
  371.                 }

  372.             }
  373.         }

  374.         // did we find at least one non-empty cluster ?
  375.         if (selected == null) {
  376.             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  377.         }

  378.         // extract a random point from the cluster
  379.         final List<T> selectedPoints = selected.getPoints();
  380.         return selectedPoints.remove(random.nextInt(selectedPoints.size()));

  381.     }

  382.     /**
  383.      * Get a random point from the {@link Cluster} with the largest number of points
  384.      *
  385.      * @param clusters the {@link Cluster}s to search
  386.      * @return a random point from the selected cluster
  387.      * @throws ConvergenceException if clusters are all empty
  388.      */
  389.     private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters)
  390.             throws ConvergenceException {

  391.         int maxNumber = 0;
  392.         Cluster<T> selected = null;
  393.         for (final Cluster<T> cluster : clusters) {

  394.             // get the number of points of the current cluster
  395.             final int number = cluster.getPoints().size();

  396.             // select the cluster with the largest number of points
  397.             if (number > maxNumber) {
  398.                 maxNumber = number;
  399.                 selected = cluster;
  400.             }

  401.         }

  402.         // did we find at least one non-empty cluster ?
  403.         if (selected == null) {
  404.             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  405.         }

  406.         // extract a random point from the cluster
  407.         final List<T> selectedPoints = selected.getPoints();
  408.         return selectedPoints.remove(random.nextInt(selectedPoints.size()));

  409.     }

  410.     /**
  411.      * Get the point farthest to its cluster center
  412.      *
  413.      * @param clusters the {@link Cluster}s to search
  414.      * @return point farthest to its cluster center
  415.      * @throws ConvergenceException if clusters are all empty
  416.      */
  417.     private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws ConvergenceException {

  418.         double maxDistance = Double.NEGATIVE_INFINITY;
  419.         Cluster<T> selectedCluster = null;
  420.         int selectedPoint = -1;
  421.         for (final CentroidCluster<T> cluster : clusters) {

  422.             // get the farthest point
  423.             final Clusterable center = cluster.getCenter();
  424.             final List<T> points = cluster.getPoints();
  425.             for (int i = 0; i < points.size(); ++i) {
  426.                 final double distance = distance(points.get(i), center);
  427.                 if (distance > maxDistance) {
  428.                     maxDistance     = distance;
  429.                     selectedCluster = cluster;
  430.                     selectedPoint   = i;
  431.                 }
  432.             }

  433.         }

  434.         // did we find at least one non-empty cluster ?
  435.         if (selectedCluster == null) {
  436.             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  437.         }

  438.         return selectedCluster.getPoints().remove(selectedPoint);

  439.     }

  440.     /**
  441.      * Returns the nearest {@link Cluster} to the given point
  442.      *
  443.      * @param clusters the {@link Cluster}s to search
  444.      * @param point the point to find the nearest {@link Cluster} for
  445.      * @return the index of the nearest {@link Cluster} to the given point
  446.      */
  447.     private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
  448.         double minDistance = Double.MAX_VALUE;
  449.         int clusterIndex = 0;
  450.         int minCluster = 0;
  451.         for (final CentroidCluster<T> c : clusters) {
  452.             final double distance = distance(point, c.getCenter());
  453.             if (distance < minDistance) {
  454.                 minDistance = distance;
  455.                 minCluster = clusterIndex;
  456.             }
  457.             clusterIndex++;
  458.         }
  459.         return minCluster;
  460.     }

  461.     /**
  462.      * Computes the centroid for a set of points.
  463.      *
  464.      * @param points the set of points
  465.      * @param dimension the point dimension
  466.      * @return the computed centroid for the set of points
  467.      */
  468.     private Clusterable centroidOf(final Collection<T> points, final int dimension) {
  469.         final double[] centroid = new double[dimension];
  470.         for (final T p : points) {
  471.             final double[] point = p.getPoint();
  472.             for (int i = 0; i < centroid.length; i++) {
  473.                 centroid[i] += point[i];
  474.             }
  475.         }
  476.         for (int i = 0; i < centroid.length; i++) {
  477.             centroid[i] /= points.size();
  478.         }
  479.         return new DoublePoint(centroid);
  480.     }

  481. }