AbstractIntegerDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.distribution;

  18. import java.io.Serializable;

  19. import org.apache.commons.math3.exception.MathInternalError;
  20. import org.apache.commons.math3.exception.NotStrictlyPositiveException;
  21. import org.apache.commons.math3.exception.NumberIsTooLargeException;
  22. import org.apache.commons.math3.exception.OutOfRangeException;
  23. import org.apache.commons.math3.exception.util.LocalizedFormats;
  24. import org.apache.commons.math3.random.RandomGenerator;
  25. import org.apache.commons.math3.random.RandomDataImpl;
  26. import org.apache.commons.math3.util.FastMath;

  27. /**
  28.  * Base class for integer-valued discrete distributions.  Default
  29.  * implementations are provided for some of the methods that do not vary
  30.  * from distribution to distribution.
  31.  *
  32.  */
  33. public abstract class AbstractIntegerDistribution implements IntegerDistribution, Serializable {

  34.     /** Serializable version identifier */
  35.     private static final long serialVersionUID = -1146319659338487221L;

  36.     /**
  37.      * RandomData instance used to generate samples from the distribution.
  38.      * @deprecated As of 3.1, to be removed in 4.0. Please use the
  39.      * {@link #random} instance variable instead.
  40.      */
  41.     @Deprecated
  42.     protected final RandomDataImpl randomData = new RandomDataImpl();

  43.     /**
  44.      * RNG instance used to generate samples from the distribution.
  45.      * @since 3.1
  46.      */
  47.     protected final RandomGenerator random;

  48.     /**
  49.      * @deprecated As of 3.1, to be removed in 4.0. Please use
  50.      * {@link #AbstractIntegerDistribution(RandomGenerator)} instead.
  51.      */
  52.     @Deprecated
  53.     protected AbstractIntegerDistribution() {
  54.         // Legacy users are only allowed to access the deprecated "randomData".
  55.         // New users are forbidden to use this constructor.
  56.         random = null;
  57.     }

  58.     /**
  59.      * @param rng Random number generator.
  60.      * @since 3.1
  61.      */
  62.     protected AbstractIntegerDistribution(RandomGenerator rng) {
  63.         random = rng;
  64.     }

  65.     /**
  66.      * {@inheritDoc}
  67.      *
  68.      * The default implementation uses the identity
  69.      * <p>{@code P(x0 < X <= x1) = P(X <= x1) - P(X <= x0)}</p>
  70.      */
  71.     public double cumulativeProbability(int x0, int x1) throws NumberIsTooLargeException {
  72.         if (x1 < x0) {
  73.             throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT,
  74.                     x0, x1, true);
  75.         }
  76.         return cumulativeProbability(x1) - cumulativeProbability(x0);
  77.     }

  78.     /**
  79.      * {@inheritDoc}
  80.      *
  81.      * The default implementation returns
  82.      * <ul>
  83.      * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
  84.      * <li>{@link #getSupportUpperBound()} for {@code p = 1}, and</li>
  85.      * <li>{@link #solveInverseCumulativeProbability(double, int, int)} for
  86.      *     {@code 0 < p < 1}.</li>
  87.      * </ul>
  88.      */
  89.     public int inverseCumulativeProbability(final double p) throws OutOfRangeException {
  90.         if (p < 0.0 || p > 1.0) {
  91.             throw new OutOfRangeException(p, 0, 1);
  92.         }

  93.         int lower = getSupportLowerBound();
  94.         if (p == 0.0) {
  95.             return lower;
  96.         }
  97.         if (lower == Integer.MIN_VALUE) {
  98.             if (checkedCumulativeProbability(lower) >= p) {
  99.                 return lower;
  100.             }
  101.         } else {
  102.             lower -= 1; // this ensures cumulativeProbability(lower) < p, which
  103.                         // is important for the solving step
  104.         }

  105.         int upper = getSupportUpperBound();
  106.         if (p == 1.0) {
  107.             return upper;
  108.         }

  109.         // use the one-sided Chebyshev inequality to narrow the bracket
  110.         // cf. AbstractRealDistribution.inverseCumulativeProbability(double)
  111.         final double mu = getNumericalMean();
  112.         final double sigma = FastMath.sqrt(getNumericalVariance());
  113.         final boolean chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) ||
  114.                 Double.isInfinite(sigma) || Double.isNaN(sigma) || sigma == 0.0);
  115.         if (chebyshevApplies) {
  116.             double k = FastMath.sqrt((1.0 - p) / p);
  117.             double tmp = mu - k * sigma;
  118.             if (tmp > lower) {
  119.                 lower = ((int) FastMath.ceil(tmp)) - 1;
  120.             }
  121.             k = 1.0 / k;
  122.             tmp = mu + k * sigma;
  123.             if (tmp < upper) {
  124.                 upper = ((int) FastMath.ceil(tmp)) - 1;
  125.             }
  126.         }

  127.         return solveInverseCumulativeProbability(p, lower, upper);
  128.     }

  129.     /**
  130.      * This is a utility function used by {@link
  131.      * #inverseCumulativeProbability(double)}. It assumes {@code 0 < p < 1} and
  132.      * that the inverse cumulative probability lies in the bracket {@code
  133.      * (lower, upper]}. The implementation does simple bisection to find the
  134.      * smallest {@code p}-quantile <code>inf{x in Z | P(X<=x) >= p}</code>.
  135.      *
  136.      * @param p the cumulative probability
  137.      * @param lower a value satisfying {@code cumulativeProbability(lower) < p}
  138.      * @param upper a value satisfying {@code p <= cumulativeProbability(upper)}
  139.      * @return the smallest {@code p}-quantile of this distribution
  140.      */
  141.     protected int solveInverseCumulativeProbability(final double p, int lower, int upper) {
  142.         while (lower + 1 < upper) {
  143.             int xm = (lower + upper) / 2;
  144.             if (xm < lower || xm > upper) {
  145.                 /*
  146.                  * Overflow.
  147.                  * There will never be an overflow in both calculation methods
  148.                  * for xm at the same time
  149.                  */
  150.                 xm = lower + (upper - lower) / 2;
  151.             }

  152.             double pm = checkedCumulativeProbability(xm);
  153.             if (pm >= p) {
  154.                 upper = xm;
  155.             } else {
  156.                 lower = xm;
  157.             }
  158.         }
  159.         return upper;
  160.     }

  161.     /** {@inheritDoc} */
  162.     public void reseedRandomGenerator(long seed) {
  163.         random.setSeed(seed);
  164.         randomData.reSeed(seed);
  165.     }

  166.     /**
  167.      * {@inheritDoc}
  168.      *
  169.      * The default implementation uses the
  170.      * <a href="http://en.wikipedia.org/wiki/Inverse_transform_sampling">
  171.      * inversion method</a>.
  172.      */
  173.     public int sample() {
  174.         return inverseCumulativeProbability(random.nextDouble());
  175.     }

  176.     /**
  177.      * {@inheritDoc}
  178.      *
  179.      * The default implementation generates the sample by calling
  180.      * {@link #sample()} in a loop.
  181.      */
  182.     public int[] sample(int sampleSize) {
  183.         if (sampleSize <= 0) {
  184.             throw new NotStrictlyPositiveException(
  185.                     LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
  186.         }
  187.         int[] out = new int[sampleSize];
  188.         for (int i = 0; i < sampleSize; i++) {
  189.             out[i] = sample();
  190.         }
  191.         return out;
  192.     }

  193.     /**
  194.      * Computes the cumulative probability function and checks for {@code NaN}
  195.      * values returned. Throws {@code MathInternalError} if the value is
  196.      * {@code NaN}. Rethrows any exception encountered evaluating the cumulative
  197.      * probability function. Throws {@code MathInternalError} if the cumulative
  198.      * probability function returns {@code NaN}.
  199.      *
  200.      * @param argument input value
  201.      * @return the cumulative probability
  202.      * @throws MathInternalError if the cumulative probability is {@code NaN}
  203.      */
  204.     private double checkedCumulativeProbability(int argument)
  205.         throws MathInternalError {
  206.         double result = Double.NaN;
  207.         result = cumulativeProbability(argument);
  208.         if (Double.isNaN(result)) {
  209.             throw new MathInternalError(LocalizedFormats
  210.                     .DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN, argument);
  211.         }
  212.         return result;
  213.     }

  214.     /**
  215.      * For a random variable {@code X} whose values are distributed according to
  216.      * this distribution, this method returns {@code log(P(X = x))}, where
  217.      * {@code log} is the natural logarithm. In other words, this method
  218.      * represents the logarithm of the probability mass function (PMF) for the
  219.      * distribution. Note that due to the floating point precision and
  220.      * under/overflow issues, this method will for some distributions be more
  221.      * precise and faster than computing the logarithm of
  222.      * {@link #probability(int)}.
  223.      * <p>
  224.      * The default implementation simply computes the logarithm of {@code probability(x)}.</p>
  225.      *
  226.      * @param x the point at which the PMF is evaluated
  227.      * @return the logarithm of the value of the probability mass function at {@code x}
  228.      */
  229.     public double logProbability(int x) {
  230.         return FastMath.log(probability(x));
  231.     }
  232. }