HypergeometricDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.distribution;

  18. import org.apache.commons.math3.exception.NotPositiveException;
  19. import org.apache.commons.math3.exception.NotStrictlyPositiveException;
  20. import org.apache.commons.math3.exception.NumberIsTooLargeException;
  21. import org.apache.commons.math3.exception.util.LocalizedFormats;
  22. import org.apache.commons.math3.random.RandomGenerator;
  23. import org.apache.commons.math3.random.Well19937c;
  24. import org.apache.commons.math3.util.FastMath;

  25. /**
  26.  * Implementation of the hypergeometric distribution.
  27.  *
  28.  * @see <a href="http://en.wikipedia.org/wiki/Hypergeometric_distribution">Hypergeometric distribution (Wikipedia)</a>
  29.  * @see <a href="http://mathworld.wolfram.com/HypergeometricDistribution.html">Hypergeometric distribution (MathWorld)</a>
  30.  */
  31. public class HypergeometricDistribution extends AbstractIntegerDistribution {
  32.     /** Serializable version identifier. */
  33.     private static final long serialVersionUID = -436928820673516179L;
  34.     /** The number of successes in the population. */
  35.     private final int numberOfSuccesses;
  36.     /** The population size. */
  37.     private final int populationSize;
  38.     /** The sample size. */
  39.     private final int sampleSize;
  40.     /** Cached numerical variance */
  41.     private double numericalVariance = Double.NaN;
  42.     /** Whether or not the numerical variance has been calculated */
  43.     private boolean numericalVarianceIsCalculated = false;

  44.     /**
  45.      * Construct a new hypergeometric distribution with the specified population
  46.      * size, number of successes in the population, and sample size.
  47.      * <p>
  48.      * <b>Note:</b> this constructor will implicitly create an instance of
  49.      * {@link Well19937c} as random generator to be used for sampling only (see
  50.      * {@link #sample()} and {@link #sample(int)}). In case no sampling is
  51.      * needed for the created distribution, it is advised to pass {@code null}
  52.      * as random generator via the appropriate constructors to avoid the
  53.      * additional initialisation overhead.
  54.      *
  55.      * @param populationSize Population size.
  56.      * @param numberOfSuccesses Number of successes in the population.
  57.      * @param sampleSize Sample size.
  58.      * @throws NotPositiveException if {@code numberOfSuccesses < 0}.
  59.      * @throws NotStrictlyPositiveException if {@code populationSize <= 0}.
  60.      * @throws NumberIsTooLargeException if {@code numberOfSuccesses > populationSize},
  61.      * or {@code sampleSize > populationSize}.
  62.      */
  63.     public HypergeometricDistribution(int populationSize, int numberOfSuccesses, int sampleSize)
  64.     throws NotPositiveException, NotStrictlyPositiveException, NumberIsTooLargeException {
  65.         this(new Well19937c(), populationSize, numberOfSuccesses, sampleSize);
  66.     }

  67.     /**
  68.      * Creates a new hypergeometric distribution.
  69.      *
  70.      * @param rng Random number generator.
  71.      * @param populationSize Population size.
  72.      * @param numberOfSuccesses Number of successes in the population.
  73.      * @param sampleSize Sample size.
  74.      * @throws NotPositiveException if {@code numberOfSuccesses < 0}.
  75.      * @throws NotStrictlyPositiveException if {@code populationSize <= 0}.
  76.      * @throws NumberIsTooLargeException if {@code numberOfSuccesses > populationSize},
  77.      * or {@code sampleSize > populationSize}.
  78.      * @since 3.1
  79.      */
  80.     public HypergeometricDistribution(RandomGenerator rng,
  81.                                       int populationSize,
  82.                                       int numberOfSuccesses,
  83.                                       int sampleSize)
  84.     throws NotPositiveException, NotStrictlyPositiveException, NumberIsTooLargeException {
  85.         super(rng);

  86.         if (populationSize <= 0) {
  87.             throw new NotStrictlyPositiveException(LocalizedFormats.POPULATION_SIZE,
  88.                                                    populationSize);
  89.         }
  90.         if (numberOfSuccesses < 0) {
  91.             throw new NotPositiveException(LocalizedFormats.NUMBER_OF_SUCCESSES,
  92.                                            numberOfSuccesses);
  93.         }
  94.         if (sampleSize < 0) {
  95.             throw new NotPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
  96.                                            sampleSize);
  97.         }

  98.         if (numberOfSuccesses > populationSize) {
  99.             throw new NumberIsTooLargeException(LocalizedFormats.NUMBER_OF_SUCCESS_LARGER_THAN_POPULATION_SIZE,
  100.                                                 numberOfSuccesses, populationSize, true);
  101.         }
  102.         if (sampleSize > populationSize) {
  103.             throw new NumberIsTooLargeException(LocalizedFormats.SAMPLE_SIZE_LARGER_THAN_POPULATION_SIZE,
  104.                                                 sampleSize, populationSize, true);
  105.         }

  106.         this.numberOfSuccesses = numberOfSuccesses;
  107.         this.populationSize = populationSize;
  108.         this.sampleSize = sampleSize;
  109.     }

  110.     /** {@inheritDoc} */
  111.     public double cumulativeProbability(int x) {
  112.         double ret;

  113.         int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
  114.         if (x < domain[0]) {
  115.             ret = 0.0;
  116.         } else if (x >= domain[1]) {
  117.             ret = 1.0;
  118.         } else {
  119.             ret = innerCumulativeProbability(domain[0], x, 1);
  120.         }

  121.         return ret;
  122.     }

  123.     /**
  124.      * Return the domain for the given hypergeometric distribution parameters.
  125.      *
  126.      * @param n Population size.
  127.      * @param m Number of successes in the population.
  128.      * @param k Sample size.
  129.      * @return a two element array containing the lower and upper bounds of the
  130.      * hypergeometric distribution.
  131.      */
  132.     private int[] getDomain(int n, int m, int k) {
  133.         return new int[] { getLowerDomain(n, m, k), getUpperDomain(m, k) };
  134.     }

  135.     /**
  136.      * Return the lowest domain value for the given hypergeometric distribution
  137.      * parameters.
  138.      *
  139.      * @param n Population size.
  140.      * @param m Number of successes in the population.
  141.      * @param k Sample size.
  142.      * @return the lowest domain value of the hypergeometric distribution.
  143.      */
  144.     private int getLowerDomain(int n, int m, int k) {
  145.         return FastMath.max(0, m - (n - k));
  146.     }

  147.     /**
  148.      * Access the number of successes.
  149.      *
  150.      * @return the number of successes.
  151.      */
  152.     public int getNumberOfSuccesses() {
  153.         return numberOfSuccesses;
  154.     }

  155.     /**
  156.      * Access the population size.
  157.      *
  158.      * @return the population size.
  159.      */
  160.     public int getPopulationSize() {
  161.         return populationSize;
  162.     }

  163.     /**
  164.      * Access the sample size.
  165.      *
  166.      * @return the sample size.
  167.      */
  168.     public int getSampleSize() {
  169.         return sampleSize;
  170.     }

  171.     /**
  172.      * Return the highest domain value for the given hypergeometric distribution
  173.      * parameters.
  174.      *
  175.      * @param m Number of successes in the population.
  176.      * @param k Sample size.
  177.      * @return the highest domain value of the hypergeometric distribution.
  178.      */
  179.     private int getUpperDomain(int m, int k) {
  180.         return FastMath.min(k, m);
  181.     }

  182.     /** {@inheritDoc} */
  183.     public double probability(int x) {
  184.         final double logProbability = logProbability(x);
  185.         return logProbability == Double.NEGATIVE_INFINITY ? 0 : FastMath.exp(logProbability);
  186.     }

  187.     /** {@inheritDoc} */
  188.     @Override
  189.     public double logProbability(int x) {
  190.         double ret;

  191.         int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
  192.         if (x < domain[0] || x > domain[1]) {
  193.             ret = Double.NEGATIVE_INFINITY;
  194.         } else {
  195.             double p = (double) sampleSize / (double) populationSize;
  196.             double q = (double) (populationSize - sampleSize) / (double) populationSize;
  197.             double p1 = SaddlePointExpansion.logBinomialProbability(x,
  198.                     numberOfSuccesses, p, q);
  199.             double p2 =
  200.                     SaddlePointExpansion.logBinomialProbability(sampleSize - x,
  201.                             populationSize - numberOfSuccesses, p, q);
  202.             double p3 =
  203.                     SaddlePointExpansion.logBinomialProbability(sampleSize, populationSize, p, q);
  204.             ret = p1 + p2 - p3;
  205.         }

  206.         return ret;
  207.     }

  208.     /**
  209.      * For this distribution, {@code X}, this method returns {@code P(X >= x)}.
  210.      *
  211.      * @param x Value at which the CDF is evaluated.
  212.      * @return the upper tail CDF for this distribution.
  213.      * @since 1.1
  214.      */
  215.     public double upperCumulativeProbability(int x) {
  216.         double ret;

  217.         final int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
  218.         if (x <= domain[0]) {
  219.             ret = 1.0;
  220.         } else if (x > domain[1]) {
  221.             ret = 0.0;
  222.         } else {
  223.             ret = innerCumulativeProbability(domain[1], x, -1);
  224.         }

  225.         return ret;
  226.     }

  227.     /**
  228.      * For this distribution, {@code X}, this method returns
  229.      * {@code P(x0 <= X <= x1)}.
  230.      * This probability is computed by summing the point probabilities for the
  231.      * values {@code x0, x0 + 1, x0 + 2, ..., x1}, in the order directed by
  232.      * {@code dx}.
  233.      *
  234.      * @param x0 Inclusive lower bound.
  235.      * @param x1 Inclusive upper bound.
  236.      * @param dx Direction of summation (1 indicates summing from x0 to x1, and
  237.      * 0 indicates summing from x1 to x0).
  238.      * @return {@code P(x0 <= X <= x1)}.
  239.      */
  240.     private double innerCumulativeProbability(int x0, int x1, int dx) {
  241.         double ret = probability(x0);
  242.         while (x0 != x1) {
  243.             x0 += dx;
  244.             ret += probability(x0);
  245.         }
  246.         return ret;
  247.     }

  248.     /**
  249.      * {@inheritDoc}
  250.      *
  251.      * For population size {@code N}, number of successes {@code m}, and sample
  252.      * size {@code n}, the mean is {@code n * m / N}.
  253.      */
  254.     public double getNumericalMean() {
  255.         return getSampleSize() * (getNumberOfSuccesses() / (double) getPopulationSize());
  256.     }

  257.     /**
  258.      * {@inheritDoc}
  259.      *
  260.      * For population size {@code N}, number of successes {@code m}, and sample
  261.      * size {@code n}, the variance is
  262.      * {@code [n * m * (N - n) * (N - m)] / [N^2 * (N - 1)]}.
  263.      */
  264.     public double getNumericalVariance() {
  265.         if (!numericalVarianceIsCalculated) {
  266.             numericalVariance = calculateNumericalVariance();
  267.             numericalVarianceIsCalculated = true;
  268.         }
  269.         return numericalVariance;
  270.     }

  271.     /**
  272.      * Used by {@link #getNumericalVariance()}.
  273.      *
  274.      * @return the variance of this distribution
  275.      */
  276.     protected double calculateNumericalVariance() {
  277.         final double N = getPopulationSize();
  278.         final double m = getNumberOfSuccesses();
  279.         final double n = getSampleSize();
  280.         return (n * m * (N - n) * (N - m)) / (N * N * (N - 1));
  281.     }

  282.     /**
  283.      * {@inheritDoc}
  284.      *
  285.      * For population size {@code N}, number of successes {@code m}, and sample
  286.      * size {@code n}, the lower bound of the support is
  287.      * {@code max(0, n + m - N)}.
  288.      *
  289.      * @return lower bound of the support
  290.      */
  291.     public int getSupportLowerBound() {
  292.         return FastMath.max(0,
  293.                             getSampleSize() + getNumberOfSuccesses() - getPopulationSize());
  294.     }

  295.     /**
  296.      * {@inheritDoc}
  297.      *
  298.      * For number of successes {@code m} and sample size {@code n}, the upper
  299.      * bound of the support is {@code min(m, n)}.
  300.      *
  301.      * @return upper bound of the support
  302.      */
  303.     public int getSupportUpperBound() {
  304.         return FastMath.min(getNumberOfSuccesses(), getSampleSize());
  305.     }

  306.     /**
  307.      * {@inheritDoc}
  308.      *
  309.      * The support of this distribution is connected.
  310.      *
  311.      * @return {@code true}
  312.      */
  313.     public boolean isSupportConnected() {
  314.         return true;
  315.     }
  316. }