MultiKMeansPlusPlusClusterer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.ml.clustering;

  18. import java.util.Collection;
  19. import java.util.List;

  20. import org.apache.commons.math3.exception.ConvergenceException;
  21. import org.apache.commons.math3.exception.MathIllegalArgumentException;
  22. import org.apache.commons.math3.ml.clustering.evaluation.ClusterEvaluator;
  23. import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances;

  24. /**
  25.  * A wrapper around a k-means++ clustering algorithm which performs multiple trials
  26.  * and returns the best solution.
  27.  * @param <T> type of the points to cluster
  28.  * @since 3.2
  29.  */
  30. public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {

  31.     /** The underlying k-means clusterer. */
  32.     private final KMeansPlusPlusClusterer<T> clusterer;

  33.     /** The number of trial runs. */
  34.     private final int numTrials;

  35.     /** The cluster evaluator to use. */
  36.     private final ClusterEvaluator<T> evaluator;

  37.     /** Build a clusterer.
  38.      * @param clusterer the k-means clusterer to use
  39.      * @param numTrials number of trial runs
  40.      */
  41.     public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
  42.                                         final int numTrials) {
  43.         this(clusterer, numTrials, new SumOfClusterVariances<T>(clusterer.getDistanceMeasure()));
  44.     }

  45.     /** Build a clusterer.
  46.      * @param clusterer the k-means clusterer to use
  47.      * @param numTrials number of trial runs
  48.      * @param evaluator the cluster evaluator to use
  49.      * @since 3.3
  50.      */
  51.     public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
  52.                                         final int numTrials,
  53.                                         final ClusterEvaluator<T> evaluator) {
  54.         super(clusterer.getDistanceMeasure());
  55.         this.clusterer = clusterer;
  56.         this.numTrials = numTrials;
  57.         this.evaluator = evaluator;
  58.     }

  59.     /**
  60.      * Returns the embedded k-means clusterer used by this instance.
  61.      * @return the embedded clusterer
  62.      */
  63.     public KMeansPlusPlusClusterer<T> getClusterer() {
  64.         return clusterer;
  65.     }

  66.     /**
  67.      * Returns the number of trials this instance will do.
  68.      * @return the number of trials
  69.      */
  70.     public int getNumTrials() {
  71.         return numTrials;
  72.     }

  73.     /**
  74.      * Returns the {@link ClusterEvaluator} used to determine the "best" clustering.
  75.      * @return the used {@link ClusterEvaluator}
  76.      * @since 3.3
  77.      */
  78.     public ClusterEvaluator<T> getClusterEvaluator() {
  79.        return evaluator;
  80.     }

  81.     /**
  82.      * Runs the K-means++ clustering algorithm.
  83.      *
  84.      * @param points the points to cluster
  85.      * @return a list of clusters containing the points
  86.      * @throws MathIllegalArgumentException if the data points are null or the number
  87.      *   of clusters is larger than the number of data points
  88.      * @throws ConvergenceException if an empty cluster is encountered and the
  89.      *   underlying {@link KMeansPlusPlusClusterer} has its
  90.      *   {@link KMeansPlusPlusClusterer.EmptyClusterStrategy} is set to {@code ERROR}.
  91.      */
  92.     @Override
  93.     public List<CentroidCluster<T>> cluster(final Collection<T> points)
  94.         throws MathIllegalArgumentException, ConvergenceException {

  95.         // at first, we have not found any clusters list yet
  96.         List<CentroidCluster<T>> best = null;
  97.         double bestVarianceSum = Double.POSITIVE_INFINITY;

  98.         // do several clustering trials
  99.         for (int i = 0; i < numTrials; ++i) {

  100.             // compute a clusters list
  101.             List<CentroidCluster<T>> clusters = clusterer.cluster(points);

  102.             // compute the variance of the current list
  103.             final double varianceSum = evaluator.score(clusters);

  104.             if (evaluator.isBetterScore(varianceSum, bestVarianceSum)) {
  105.                 // this one is the best we have found so far, remember it
  106.                 best            = clusters;
  107.                 bestVarianceSum = varianceSum;
  108.             }

  109.         }

  110.         // return the best clusters list found
  111.         return best;

  112.     }

  113. }