MapUtils.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.ml.neuralnet;

  18. import java.util.HashMap;
  19. import java.util.Collection;
  20. import org.apache.commons.math3.ml.distance.DistanceMeasure;
  21. import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D;
  22. import org.apache.commons.math3.exception.NoDataException;
  23. import org.apache.commons.math3.util.Pair;

  24. /**
  25.  * Utilities for network maps.
  26.  *
  27.  * @since 3.3
  28.  */
  29. public class MapUtils {
  30.     /**
  31.      * Class contains only static methods.
  32.      */
  33.     private MapUtils() {}

  34.     /**
  35.      * Finds the neuron that best matches the given features.
  36.      *
  37.      * @param features Data.
  38.      * @param neurons List of neurons to scan. If the list is empty
  39.      * {@code null} will be returned.
  40.      * @param distance Distance function. The neuron's features are
  41.      * passed as the first argument to {@link DistanceMeasure#compute(double[],double[])}.
  42.      * @return the neuron whose features are closest to the given data.
  43.      * @throws org.apache.commons.math3.exception.DimensionMismatchException
  44.      * if the size of the input is not compatible with the neurons features
  45.      * size.
  46.      */
  47.     public static Neuron findBest(double[] features,
  48.                                   Iterable<Neuron> neurons,
  49.                                   DistanceMeasure distance) {
  50.         Neuron best = null;
  51.         double min = Double.POSITIVE_INFINITY;
  52.         for (final Neuron n : neurons) {
  53.             final double d = distance.compute(n.getFeatures(), features);
  54.             if (d < min) {
  55.                 min = d;
  56.                 best = n;
  57.             }
  58.         }

  59.         return best;
  60.     }

  61.     /**
  62.      * Finds the two neurons that best match the given features.
  63.      *
  64.      * @param features Data.
  65.      * @param neurons List of neurons to scan. If the list is empty
  66.      * {@code null} will be returned.
  67.      * @param distance Distance function. The neuron's features are
  68.      * passed as the first argument to {@link DistanceMeasure#compute(double[],double[])}.
  69.      * @return the two neurons whose features are closest to the given data.
  70.      * @throws org.apache.commons.math3.exception.DimensionMismatchException
  71.      * if the size of the input is not compatible with the neurons features
  72.      * size.
  73.      */
  74.     public static Pair<Neuron, Neuron> findBestAndSecondBest(double[] features,
  75.                                                              Iterable<Neuron> neurons,
  76.                                                              DistanceMeasure distance) {
  77.         Neuron[] best = { null, null };
  78.         double[] min = { Double.POSITIVE_INFINITY,
  79.                          Double.POSITIVE_INFINITY };
  80.         for (final Neuron n : neurons) {
  81.             final double d = distance.compute(n.getFeatures(), features);
  82.             if (d < min[0]) {
  83.                 // Replace second best with old best.
  84.                 min[1] = min[0];
  85.                 best[1] = best[0];

  86.                 // Store current as new best.
  87.                 min[0] = d;
  88.                 best[0] = n;
  89.             } else if (d < min[1]) {
  90.                 // Replace old second best with current.
  91.                 min[1] = d;
  92.                 best[1] = n;
  93.             }
  94.         }

  95.         return new Pair<Neuron, Neuron>(best[0], best[1]);
  96.     }

  97.     /**
  98.      * Computes the <a href="http://en.wikipedia.org/wiki/U-Matrix">
  99.      *  U-matrix</a> of a two-dimensional map.
  100.      *
  101.      * @param map Network.
  102.      * @param distance Function to use for computing the average
  103.      * distance from a neuron to its neighbours.
  104.      * @return the matrix of average distances.
  105.      */
  106.     public static double[][] computeU(NeuronSquareMesh2D map,
  107.                                       DistanceMeasure distance) {
  108.         final int numRows = map.getNumberOfRows();
  109.         final int numCols = map.getNumberOfColumns();
  110.         final double[][] uMatrix = new double[numRows][numCols];

  111.         final Network net = map.getNetwork();

  112.         for (int i = 0; i < numRows; i++) {
  113.             for (int j = 0; j < numCols; j++) {
  114.                 final Neuron neuron = map.getNeuron(i, j);
  115.                 final Collection<Neuron> neighbours = net.getNeighbours(neuron);
  116.                 final double[] features = neuron.getFeatures();

  117.                 double d = 0;
  118.                 int count = 0;
  119.                 for (Neuron n : neighbours) {
  120.                     ++count;
  121.                     d += distance.compute(features, n.getFeatures());
  122.                 }

  123.                 uMatrix[i][j] = d / count;
  124.             }
  125.         }

  126.         return uMatrix;
  127.     }

  128.     /**
  129.      * Computes the "hit" histogram of a two-dimensional map.
  130.      *
  131.      * @param data Feature vectors.
  132.      * @param map Network.
  133.      * @param distance Function to use for determining the best matching unit.
  134.      * @return the number of hits for each neuron in the map.
  135.      */
  136.     public static int[][] computeHitHistogram(Iterable<double[]> data,
  137.                                               NeuronSquareMesh2D map,
  138.                                               DistanceMeasure distance) {
  139.         final HashMap<Neuron, Integer> hit = new HashMap<Neuron, Integer>();
  140.         final Network net = map.getNetwork();

  141.         for (double[] f : data) {
  142.             final Neuron best = findBest(f, net, distance);
  143.             final Integer count = hit.get(best);
  144.             if (count == null) {
  145.                 hit.put(best, 1);
  146.             } else {
  147.                 hit.put(best, count + 1);
  148.             }
  149.         }

  150.         // Copy the histogram data into a 2D map.
  151.         final int numRows = map.getNumberOfRows();
  152.         final int numCols = map.getNumberOfColumns();
  153.         final int[][] histo = new int[numRows][numCols];

  154.         for (int i = 0; i < numRows; i++) {
  155.             for (int j = 0; j < numCols; j++) {
  156.                 final Neuron neuron = map.getNeuron(i, j);
  157.                 final Integer count = hit.get(neuron);
  158.                 if (count == null) {
  159.                     histo[i][j] = 0;
  160.                 } else {
  161.                     histo[i][j] = count;
  162.                 }
  163.             }
  164.         }

  165.         return histo;
  166.     }

  167.     /**
  168.      * Computes the quantization error.
  169.      * The quantization error is the average distance between a feature vector
  170.      * and its "best matching unit" (closest neuron).
  171.      *
  172.      * @param data Feature vectors.
  173.      * @param neurons List of neurons to scan.
  174.      * @param distance Distance function.
  175.      * @return the error.
  176.      * @throws NoDataException if {@code data} is empty.
  177.      */
  178.     public static double computeQuantizationError(Iterable<double[]> data,
  179.                                                   Iterable<Neuron> neurons,
  180.                                                   DistanceMeasure distance) {
  181.         double d = 0;
  182.         int count = 0;
  183.         for (double[] f : data) {
  184.             ++count;
  185.             d += distance.compute(f, findBest(f, neurons, distance).getFeatures());
  186.         }

  187.         if (count == 0) {
  188.             throw new NoDataException();
  189.         }

  190.         return d / count;
  191.     }

  192.     /**
  193.      * Computes the topographic error.
  194.      * The topographic error is the proportion of data for which first and
  195.      * second best matching units are not adjacent in the map.
  196.      *
  197.      * @param data Feature vectors.
  198.      * @param net Network.
  199.      * @param distance Distance function.
  200.      * @return the error.
  201.      * @throws NoDataException if {@code data} is empty.
  202.      */
  203.     public static double computeTopographicError(Iterable<double[]> data,
  204.                                                  Network net,
  205.                                                  DistanceMeasure distance) {
  206.         int notAdjacentCount = 0;
  207.         int count = 0;
  208.         for (double[] f : data) {
  209.             ++count;
  210.             final Pair<Neuron, Neuron> p = findBestAndSecondBest(f, net, distance);
  211.             if (!net.getNeighbours(p.getFirst()).contains(p.getSecond())) {
  212.                 // Increment count if first and second best matching units
  213.                 // are not neighbours.
  214.                 ++notAdjacentCount;
  215.             }
  216.         }

  217.         if (count == 0) {
  218.             throw new NoDataException();
  219.         }

  220.         return ((double) notAdjacentCount) / count;
  221.     }
  222. }