KMeansPlusPlusClusterer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.stat.clustering;

  18. import java.util.ArrayList;
  19. import java.util.Collection;
  20. import java.util.Collections;
  21. import java.util.List;
  22. import java.util.Random;

  23. import org.apache.commons.math3.exception.ConvergenceException;
  24. import org.apache.commons.math3.exception.MathIllegalArgumentException;
  25. import org.apache.commons.math3.exception.NumberIsTooSmallException;
  26. import org.apache.commons.math3.exception.util.LocalizedFormats;
  27. import org.apache.commons.math3.stat.descriptive.moment.Variance;
  28. import org.apache.commons.math3.util.MathUtils;

  29. /**
  30.  * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
  31.  * @param <T> type of the points to cluster
  32.  * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
  33.  * @since 2.0
  34.  * @deprecated As of 3.2 (to be removed in 4.0),
  35.  * use {@link org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer} instead
  36.  */
  37. @Deprecated
  38. public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {

  39.     /** Strategies to use for replacing an empty cluster. */
  40.     public static enum EmptyClusterStrategy {

  41.         /** Split the cluster with largest distance variance. */
  42.         LARGEST_VARIANCE,

  43.         /** Split the cluster with largest number of points. */
  44.         LARGEST_POINTS_NUMBER,

  45.         /** Create a cluster around the point farthest from its centroid. */
  46.         FARTHEST_POINT,

  47.         /** Generate an error. */
  48.         ERROR

  49.     }

  50.     /** Random generator for choosing initial centers. */
  51.     private final Random random;

  52.     /** Selected strategy for empty clusters. */
  53.     private final EmptyClusterStrategy emptyStrategy;

  54.     /** Build a clusterer.
  55.      * <p>
  56.      * The default strategy for handling empty clusters that may appear during
  57.      * algorithm iterations is to split the cluster with largest distance variance.
  58.      * </p>
  59.      * @param random random generator to use for choosing initial centers
  60.      */
  61.     public KMeansPlusPlusClusterer(final Random random) {
  62.         this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
  63.     }

  64.     /** Build a clusterer.
  65.      * @param random random generator to use for choosing initial centers
  66.      * @param emptyStrategy strategy to use for handling empty clusters that
  67.      * may appear during algorithm iterations
  68.      * @since 2.2
  69.      */
  70.     public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
  71.         this.random        = random;
  72.         this.emptyStrategy = emptyStrategy;
  73.     }

  74.     /**
  75.      * Runs the K-means++ clustering algorithm.
  76.      *
  77.      * @param points the points to cluster
  78.      * @param k the number of clusters to split the data into
  79.      * @param numTrials number of trial runs
  80.      * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm
  81.      *     for at each trial run.  If negative, no maximum will be used
  82.      * @return a list of clusters containing the points
  83.      * @throws MathIllegalArgumentException if the data points are null or the number
  84.      *     of clusters is larger than the number of data points
  85.      * @throws ConvergenceException if an empty cluster is encountered and the
  86.      * {@link #emptyStrategy} is set to {@code ERROR}
  87.      */
  88.     public List<Cluster<T>> cluster(final Collection<T> points, final int k,
  89.                                     int numTrials, int maxIterationsPerTrial)
  90.         throws MathIllegalArgumentException, ConvergenceException {

  91.         // at first, we have not found any clusters list yet
  92.         List<Cluster<T>> best = null;
  93.         double bestVarianceSum = Double.POSITIVE_INFINITY;

  94.         // do several clustering trials
  95.         for (int i = 0; i < numTrials; ++i) {

  96.             // compute a clusters list
  97.             List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);

  98.             // compute the variance of the current list
  99.             double varianceSum = 0.0;
  100.             for (final Cluster<T> cluster : clusters) {
  101.                 if (!cluster.getPoints().isEmpty()) {

  102.                     // compute the distance variance of the current cluster
  103.                     final T center = cluster.getCenter();
  104.                     final Variance stat = new Variance();
  105.                     for (final T point : cluster.getPoints()) {
  106.                         stat.increment(point.distanceFrom(center));
  107.                     }
  108.                     varianceSum += stat.getResult();

  109.                 }
  110.             }

  111.             if (varianceSum <= bestVarianceSum) {
  112.                 // this one is the best we have found so far, remember it
  113.                 best            = clusters;
  114.                 bestVarianceSum = varianceSum;
  115.             }

  116.         }

  117.         // return the best clusters list found
  118.         return best;

  119.     }

  120.     /**
  121.      * Runs the K-means++ clustering algorithm.
  122.      *
  123.      * @param points the points to cluster
  124.      * @param k the number of clusters to split the data into
  125.      * @param maxIterations the maximum number of iterations to run the algorithm
  126.      *     for.  If negative, no maximum will be used
  127.      * @return a list of clusters containing the points
  128.      * @throws MathIllegalArgumentException if the data points are null or the number
  129.      *     of clusters is larger than the number of data points
  130.      * @throws ConvergenceException if an empty cluster is encountered and the
  131.      * {@link #emptyStrategy} is set to {@code ERROR}
  132.      */
  133.     public List<Cluster<T>> cluster(final Collection<T> points, final int k,
  134.                                     final int maxIterations)
  135.         throws MathIllegalArgumentException, ConvergenceException {

  136.         // sanity checks
  137.         MathUtils.checkNotNull(points);

  138.         // number of clusters has to be smaller or equal the number of data points
  139.         if (points.size() < k) {
  140.             throw new NumberIsTooSmallException(points.size(), k, false);
  141.         }

  142.         // create the initial clusters
  143.         List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);

  144.         // create an array containing the latest assignment of a point to a cluster
  145.         // no need to initialize the array, as it will be filled with the first assignment
  146.         int[] assignments = new int[points.size()];
  147.         assignPointsToClusters(clusters, points, assignments);

  148.         // iterate through updating the centers until we're done
  149.         final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
  150.         for (int count = 0; count < max; count++) {
  151.             boolean emptyCluster = false;
  152.             List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
  153.             for (final Cluster<T> cluster : clusters) {
  154.                 final T newCenter;
  155.                 if (cluster.getPoints().isEmpty()) {
  156.                     switch (emptyStrategy) {
  157.                         case LARGEST_VARIANCE :
  158.                             newCenter = getPointFromLargestVarianceCluster(clusters);
  159.                             break;
  160.                         case LARGEST_POINTS_NUMBER :
  161.                             newCenter = getPointFromLargestNumberCluster(clusters);
  162.                             break;
  163.                         case FARTHEST_POINT :
  164.                             newCenter = getFarthestPoint(clusters);
  165.                             break;
  166.                         default :
  167.                             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  168.                     }
  169.                     emptyCluster = true;
  170.                 } else {
  171.                     newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
  172.                 }
  173.                 newClusters.add(new Cluster<T>(newCenter));
  174.             }
  175.             int changes = assignPointsToClusters(newClusters, points, assignments);
  176.             clusters = newClusters;

  177.             // if there were no more changes in the point-to-cluster assignment
  178.             // and there are no empty clusters left, return the current clusters
  179.             if (changes == 0 && !emptyCluster) {
  180.                 return clusters;
  181.             }
  182.         }
  183.         return clusters;
  184.     }

  185.     /**
  186.      * Adds the given points to the closest {@link Cluster}.
  187.      *
  188.      * @param <T> type of the points to cluster
  189.      * @param clusters the {@link Cluster}s to add the points to
  190.      * @param points the points to add to the given {@link Cluster}s
  191.      * @param assignments points assignments to clusters
  192.      * @return the number of points assigned to different clusters as the iteration before
  193.      */
  194.     private static <T extends Clusterable<T>> int
  195.         assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
  196.                                final int[] assignments) {
  197.         int assignedDifferently = 0;
  198.         int pointIndex = 0;
  199.         for (final T p : points) {
  200.             int clusterIndex = getNearestCluster(clusters, p);
  201.             if (clusterIndex != assignments[pointIndex]) {
  202.                 assignedDifferently++;
  203.             }

  204.             Cluster<T> cluster = clusters.get(clusterIndex);
  205.             cluster.addPoint(p);
  206.             assignments[pointIndex++] = clusterIndex;
  207.         }

  208.         return assignedDifferently;
  209.     }

  210.     /**
  211.      * Use K-means++ to choose the initial centers.
  212.      *
  213.      * @param <T> type of the points to cluster
  214.      * @param points the points to choose the initial centers from
  215.      * @param k the number of centers to choose
  216.      * @param random random generator to use
  217.      * @return the initial centers
  218.      */
  219.     private static <T extends Clusterable<T>> List<Cluster<T>>
  220.         chooseInitialCenters(final Collection<T> points, final int k, final Random random) {

  221.         // Convert to list for indexed access. Make it unmodifiable, since removal of items
  222.         // would screw up the logic of this method.
  223.         final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));

  224.         // The number of points in the list.
  225.         final int numPoints = pointList.size();

  226.         // Set the corresponding element in this array to indicate when
  227.         // elements of pointList are no longer available.
  228.         final boolean[] taken = new boolean[numPoints];

  229.         // The resulting list of initial centers.
  230.         final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();

  231.         // Choose one center uniformly at random from among the data points.
  232.         final int firstPointIndex = random.nextInt(numPoints);

  233.         final T firstPoint = pointList.get(firstPointIndex);

  234.         resultSet.add(new Cluster<T>(firstPoint));

  235.         // Must mark it as taken
  236.         taken[firstPointIndex] = true;

  237.         // To keep track of the minimum distance squared of elements of
  238.         // pointList to elements of resultSet.
  239.         final double[] minDistSquared = new double[numPoints];

  240.         // Initialize the elements.  Since the only point in resultSet is firstPoint,
  241.         // this is very easy.
  242.         for (int i = 0; i < numPoints; i++) {
  243.             if (i != firstPointIndex) { // That point isn't considered
  244.                 double d = firstPoint.distanceFrom(pointList.get(i));
  245.                 minDistSquared[i] = d*d;
  246.             }
  247.         }

  248.         while (resultSet.size() < k) {

  249.             // Sum up the squared distances for the points in pointList not
  250.             // already taken.
  251.             double distSqSum = 0.0;

  252.             for (int i = 0; i < numPoints; i++) {
  253.                 if (!taken[i]) {
  254.                     distSqSum += minDistSquared[i];
  255.                 }
  256.             }

  257.             // Add one new data point as a center. Each point x is chosen with
  258.             // probability proportional to D(x)2
  259.             final double r = random.nextDouble() * distSqSum;

  260.             // The index of the next point to be added to the resultSet.
  261.             int nextPointIndex = -1;

  262.             // Sum through the squared min distances again, stopping when
  263.             // sum >= r.
  264.             double sum = 0.0;
  265.             for (int i = 0; i < numPoints; i++) {
  266.                 if (!taken[i]) {
  267.                     sum += minDistSquared[i];
  268.                     if (sum >= r) {
  269.                         nextPointIndex = i;
  270.                         break;
  271.                     }
  272.                 }
  273.             }

  274.             // If it's not set to >= 0, the point wasn't found in the previous
  275.             // for loop, probably because distances are extremely small.  Just pick
  276.             // the last available point.
  277.             if (nextPointIndex == -1) {
  278.                 for (int i = numPoints - 1; i >= 0; i--) {
  279.                     if (!taken[i]) {
  280.                         nextPointIndex = i;
  281.                         break;
  282.                     }
  283.                 }
  284.             }

  285.             // We found one.
  286.             if (nextPointIndex >= 0) {

  287.                 final T p = pointList.get(nextPointIndex);

  288.                 resultSet.add(new Cluster<T> (p));

  289.                 // Mark it as taken.
  290.                 taken[nextPointIndex] = true;

  291.                 if (resultSet.size() < k) {
  292.                     // Now update elements of minDistSquared.  We only have to compute
  293.                     // the distance to the new center to do this.
  294.                     for (int j = 0; j < numPoints; j++) {
  295.                         // Only have to worry about the points still not taken.
  296.                         if (!taken[j]) {
  297.                             double d = p.distanceFrom(pointList.get(j));
  298.                             double d2 = d * d;
  299.                             if (d2 < minDistSquared[j]) {
  300.                                 minDistSquared[j] = d2;
  301.                             }
  302.                         }
  303.                     }
  304.                 }

  305.             } else {
  306.                 // None found --
  307.                 // Break from the while loop to prevent
  308.                 // an infinite loop.
  309.                 break;
  310.             }
  311.         }

  312.         return resultSet;
  313.     }

  314.     /**
  315.      * Get a random point from the {@link Cluster} with the largest distance variance.
  316.      *
  317.      * @param clusters the {@link Cluster}s to search
  318.      * @return a random point from the selected cluster
  319.      * @throws ConvergenceException if clusters are all empty
  320.      */
  321.     private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters)
  322.     throws ConvergenceException {

  323.         double maxVariance = Double.NEGATIVE_INFINITY;
  324.         Cluster<T> selected = null;
  325.         for (final Cluster<T> cluster : clusters) {
  326.             if (!cluster.getPoints().isEmpty()) {

  327.                 // compute the distance variance of the current cluster
  328.                 final T center = cluster.getCenter();
  329.                 final Variance stat = new Variance();
  330.                 for (final T point : cluster.getPoints()) {
  331.                     stat.increment(point.distanceFrom(center));
  332.                 }
  333.                 final double variance = stat.getResult();

  334.                 // select the cluster with the largest variance
  335.                 if (variance > maxVariance) {
  336.                     maxVariance = variance;
  337.                     selected = cluster;
  338.                 }

  339.             }
  340.         }

  341.         // did we find at least one non-empty cluster ?
  342.         if (selected == null) {
  343.             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  344.         }

  345.         // extract a random point from the cluster
  346.         final List<T> selectedPoints = selected.getPoints();
  347.         return selectedPoints.remove(random.nextInt(selectedPoints.size()));

  348.     }

  349.     /**
  350.      * Get a random point from the {@link Cluster} with the largest number of points
  351.      *
  352.      * @param clusters the {@link Cluster}s to search
  353.      * @return a random point from the selected cluster
  354.      * @throws ConvergenceException if clusters are all empty
  355.      */
  356.     private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException {

  357.         int maxNumber = 0;
  358.         Cluster<T> selected = null;
  359.         for (final Cluster<T> cluster : clusters) {

  360.             // get the number of points of the current cluster
  361.             final int number = cluster.getPoints().size();

  362.             // select the cluster with the largest number of points
  363.             if (number > maxNumber) {
  364.                 maxNumber = number;
  365.                 selected = cluster;
  366.             }

  367.         }

  368.         // did we find at least one non-empty cluster ?
  369.         if (selected == null) {
  370.             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  371.         }

  372.         // extract a random point from the cluster
  373.         final List<T> selectedPoints = selected.getPoints();
  374.         return selectedPoints.remove(random.nextInt(selectedPoints.size()));

  375.     }

  376.     /**
  377.      * Get the point farthest to its cluster center
  378.      *
  379.      * @param clusters the {@link Cluster}s to search
  380.      * @return point farthest to its cluster center
  381.      * @throws ConvergenceException if clusters are all empty
  382.      */
  383.     private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException {

  384.         double maxDistance = Double.NEGATIVE_INFINITY;
  385.         Cluster<T> selectedCluster = null;
  386.         int selectedPoint = -1;
  387.         for (final Cluster<T> cluster : clusters) {

  388.             // get the farthest point
  389.             final T center = cluster.getCenter();
  390.             final List<T> points = cluster.getPoints();
  391.             for (int i = 0; i < points.size(); ++i) {
  392.                 final double distance = points.get(i).distanceFrom(center);
  393.                 if (distance > maxDistance) {
  394.                     maxDistance     = distance;
  395.                     selectedCluster = cluster;
  396.                     selectedPoint   = i;
  397.                 }
  398.             }

  399.         }

  400.         // did we find at least one non-empty cluster ?
  401.         if (selectedCluster == null) {
  402.             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
  403.         }

  404.         return selectedCluster.getPoints().remove(selectedPoint);

  405.     }

  406.     /**
  407.      * Returns the nearest {@link Cluster} to the given point
  408.      *
  409.      * @param <T> type of the points to cluster
  410.      * @param clusters the {@link Cluster}s to search
  411.      * @param point the point to find the nearest {@link Cluster} for
  412.      * @return the index of the nearest {@link Cluster} to the given point
  413.      */
  414.     private static <T extends Clusterable<T>> int
  415.         getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
  416.         double minDistance = Double.MAX_VALUE;
  417.         int clusterIndex = 0;
  418.         int minCluster = 0;
  419.         for (final Cluster<T> c : clusters) {
  420.             final double distance = point.distanceFrom(c.getCenter());
  421.             if (distance < minDistance) {
  422.                 minDistance = distance;
  423.                 minCluster = clusterIndex;
  424.             }
  425.             clusterIndex++;
  426.         }
  427.         return minCluster;
  428.     }

  429. }