EnumeratedDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.distribution;

  18. import java.io.Serializable;
  19. import java.lang.reflect.Array;
  20. import java.util.ArrayList;
  21. import java.util.Arrays;
  22. import java.util.List;

  23. import org.apache.commons.math3.exception.MathArithmeticException;
  24. import org.apache.commons.math3.exception.NotANumberException;
  25. import org.apache.commons.math3.exception.NotFiniteNumberException;
  26. import org.apache.commons.math3.exception.NotPositiveException;
  27. import org.apache.commons.math3.exception.NotStrictlyPositiveException;
  28. import org.apache.commons.math3.exception.NullArgumentException;
  29. import org.apache.commons.math3.exception.util.LocalizedFormats;
  30. import org.apache.commons.math3.random.RandomGenerator;
  31. import org.apache.commons.math3.random.Well19937c;
  32. import org.apache.commons.math3.util.MathArrays;
  33. import org.apache.commons.math3.util.Pair;

  34. /**
  35.  * <p>A generic implementation of a
  36.  * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
  37.  * discrete probability distribution (Wikipedia)</a> over a finite sample space,
  38.  * based on an enumerated list of &lt;value, probability&gt; pairs.  Input probabilities must all be non-negative,
  39.  * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input
  40.  * probabilities to make them sum to one.</p>
  41.  *
  42.  * <p>The list of <value, probability> pairs does not, strictly speaking, have to be a function and it can
  43.  * contain null values.  The pmf created by the constructor will combine probabilities of equal values and
  44.  * will treat null values as equal.  For example, if the list of pairs &lt;"dog", 0.2&gt;, &lt;null, 0.1&gt;,
  45.  * &lt;"pig", 0.2&gt;, &lt;"dog", 0.1&gt;, &lt;null, 0.4&gt; is provided to the constructor, the resulting
  46.  * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to null.</p>
  47.  *
  48.  * @param <T> type of the elements in the sample space.
  49.  * @since 3.2
  50.  */
  51. public class EnumeratedDistribution<T> implements Serializable {

  52.     /** Serializable UID. */
  53.     private static final long serialVersionUID = 20123308L;

  54.     /**
  55.      * RNG instance used to generate samples from the distribution.
  56.      */
  57.     protected final RandomGenerator random;

  58.     /**
  59.      * List of random variable values.
  60.      */
  61.     private final List<T> singletons;

  62.     /**
  63.      * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
  64.      * probability[i] is the probability that a random variable following this distribution takes
  65.      * the value singletons[i].
  66.      */
  67.     private final double[] probabilities;

  68.     /**
  69.      * Cumulative probabilities, cached to speed up sampling.
  70.      */
  71.     private final double[] cumulativeProbabilities;

  72.     /**
  73.      * Create an enumerated distribution using the given probability mass function
  74.      * enumeration.
  75.      * <p>
  76.      * <b>Note:</b> this constructor will implicitly create an instance of
  77.      * {@link Well19937c} as random generator to be used for sampling only (see
  78.      * {@link #sample()} and {@link #sample(int)}). In case no sampling is
  79.      * needed for the created distribution, it is advised to pass {@code null}
  80.      * as random generator via the appropriate constructors to avoid the
  81.      * additional initialisation overhead.
  82.      *
  83.      * @param pmf probability mass function enumerated as a list of <T, probability>
  84.      * pairs.
  85.      * @throws NotPositiveException if any of the probabilities are negative.
  86.      * @throws NotFiniteNumberException if any of the probabilities are infinite.
  87.      * @throws NotANumberException if any of the probabilities are NaN.
  88.      * @throws MathArithmeticException all of the probabilities are 0.
  89.      */
  90.     public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
  91.         throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException {
  92.         this(new Well19937c(), pmf);
  93.     }

  94.     /**
  95.      * Create an enumerated distribution using the given random number generator
  96.      * and probability mass function enumeration.
  97.      *
  98.      * @param rng random number generator.
  99.      * @param pmf probability mass function enumerated as a list of <T, probability>
  100.      * pairs.
  101.      * @throws NotPositiveException if any of the probabilities are negative.
  102.      * @throws NotFiniteNumberException if any of the probabilities are infinite.
  103.      * @throws NotANumberException if any of the probabilities are NaN.
  104.      * @throws MathArithmeticException all of the probabilities are 0.
  105.      */
  106.     public EnumeratedDistribution(final RandomGenerator rng, final List<Pair<T, Double>> pmf)
  107.         throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException {
  108.         random = rng;

  109.         singletons = new ArrayList<T>(pmf.size());
  110.         final double[] probs = new double[pmf.size()];

  111.         for (int i = 0; i < pmf.size(); i++) {
  112.             final Pair<T, Double> sample = pmf.get(i);
  113.             singletons.add(sample.getKey());
  114.             final double p = sample.getValue();
  115.             if (p < 0) {
  116.                 throw new NotPositiveException(sample.getValue());
  117.             }
  118.             if (Double.isInfinite(p)) {
  119.                 throw new NotFiniteNumberException(p);
  120.             }
  121.             if (Double.isNaN(p)) {
  122.                 throw new NotANumberException();
  123.             }
  124.             probs[i] = p;
  125.         }

  126.         probabilities = MathArrays.normalizeArray(probs, 1.0);

  127.         cumulativeProbabilities = new double[probabilities.length];
  128.         double sum = 0;
  129.         for (int i = 0; i < probabilities.length; i++) {
  130.             sum += probabilities[i];
  131.             cumulativeProbabilities[i] = sum;
  132.         }
  133.     }

  134.     /**
  135.      * Reseed the random generator used to generate samples.
  136.      *
  137.      * @param seed the new seed
  138.      */
  139.     public void reseedRandomGenerator(long seed) {
  140.         random.setSeed(seed);
  141.     }

  142.     /**
  143.      * <p>For a random variable {@code X} whose values are distributed according to
  144.      * this distribution, this method returns {@code P(X = x)}. In other words,
  145.      * this method represents the probability mass function (PMF) for the
  146.      * distribution.</p>
  147.      *
  148.      * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
  149.      * or both are null, then {@code probability(x1) = probability(x2)}.</p>
  150.      *
  151.      * @param x the point at which the PMF is evaluated
  152.      * @return the value of the probability mass function at {@code x}
  153.      */
  154.     double probability(final T x) {
  155.         double probability = 0;

  156.         for (int i = 0; i < probabilities.length; i++) {
  157.             if ((x == null && singletons.get(i) == null) ||
  158.                 (x != null && x.equals(singletons.get(i)))) {
  159.                 probability += probabilities[i];
  160.             }
  161.         }

  162.         return probability;
  163.     }

  164.     /**
  165.      * <p>Return the probability mass function as a list of <value, probability> pairs.</p>
  166.      *
  167.      * <p>Note that if duplicate and / or null values were provided to the constructor
  168.      * when creating this EnumeratedDistribution, the returned list will contain these
  169.      * values.  If duplicates values exist, what is returned will not represent
  170.      * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p>
  171.      *
  172.      * @return the probability mass function.
  173.      */
  174.     public List<Pair<T, Double>> getPmf() {
  175.         final List<Pair<T, Double>> samples = new ArrayList<Pair<T, Double>>(probabilities.length);

  176.         for (int i = 0; i < probabilities.length; i++) {
  177.             samples.add(new Pair<T, Double>(singletons.get(i), probabilities[i]));
  178.         }

  179.         return samples;
  180.     }

  181.     /**
  182.      * Generate a random value sampled from this distribution.
  183.      *
  184.      * @return a random value.
  185.      */
  186.     public T sample() {
  187.         final double randomValue = random.nextDouble();

  188.         int index = Arrays.binarySearch(cumulativeProbabilities, randomValue);
  189.         if (index < 0) {
  190.             index = -index-1;
  191.         }

  192.         if (index >= 0 && index < probabilities.length) {
  193.             if (randomValue < cumulativeProbabilities[index]) {
  194.                 return singletons.get(index);
  195.             }
  196.         }

  197.         /* This should never happen, but it ensures we will return a correct
  198.          * object in case there is some floating point inequality problem
  199.          * wrt the cumulative probabilities. */
  200.         return singletons.get(singletons.size() - 1);
  201.     }

  202.     /**
  203.      * Generate a random sample from the distribution.
  204.      *
  205.      * @param sampleSize the number of random values to generate.
  206.      * @return an array representing the random sample.
  207.      * @throws NotStrictlyPositiveException if {@code sampleSize} is not
  208.      * positive.
  209.      */
  210.     public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
  211.         if (sampleSize <= 0) {
  212.             throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
  213.                     sampleSize);
  214.         }

  215.         final Object[] out = new Object[sampleSize];

  216.         for (int i = 0; i < sampleSize; i++) {
  217.             out[i] = sample();
  218.         }

  219.         return out;

  220.     }

  221.     /**
  222.      * Generate a random sample from the distribution.
  223.      * <p>
  224.      * If the requested samples fit in the specified array, it is returned
  225.      * therein. Otherwise, a new array is allocated with the runtime type of
  226.      * the specified array and the size of this collection.
  227.      *
  228.      * @param sampleSize the number of random values to generate.
  229.      * @param array the array to populate.
  230.      * @return an array representing the random sample.
  231.      * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
  232.      * @throws NullArgumentException if {@code array} is null
  233.      */
  234.     public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
  235.         if (sampleSize <= 0) {
  236.             throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
  237.         }

  238.         if (array == null) {
  239.             throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
  240.         }

  241.         T[] out;
  242.         if (array.length < sampleSize) {
  243.             @SuppressWarnings("unchecked") // safe as both are of type T
  244.             final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize);
  245.             out = unchecked;
  246.         } else {
  247.             out = array;
  248.         }

  249.         for (int i = 0; i < sampleSize; i++) {
  250.             out[i] = sample();
  251.         }

  252.         return out;

  253.     }

  254. }