EnumeratedIntegerDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.distribution;

  18. import java.util.ArrayList;
  19. import java.util.List;

  20. import org.apache.commons.math3.exception.DimensionMismatchException;
  21. import org.apache.commons.math3.exception.MathArithmeticException;
  22. import org.apache.commons.math3.exception.NotANumberException;
  23. import org.apache.commons.math3.exception.NotFiniteNumberException;
  24. import org.apache.commons.math3.exception.NotPositiveException;
  25. import org.apache.commons.math3.random.RandomGenerator;
  26. import org.apache.commons.math3.random.Well19937c;
  27. import org.apache.commons.math3.util.Pair;

  28. /**
  29.  * <p>Implementation of an integer-valued {@link EnumeratedDistribution}.</p>
  30.  *
  31.  * <p>Values with zero-probability are allowed but they do not extend the
  32.  * support.<br/>
  33.  * Duplicate values are allowed. Probabilities of duplicate values are combined
  34.  * when computing cumulative probabilities and statistics.</p>
  35.  *
  36.  * @since 3.2
  37.  */
  38. public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution {

  39.     /** Serializable UID. */
  40.     private static final long serialVersionUID = 20130308L;

  41.     /**
  42.      * {@link EnumeratedDistribution} instance (using the {@link Integer} wrapper)
  43.      * used to generate the pmf.
  44.      */
  45.     protected final EnumeratedDistribution<Integer> innerDistribution;

  46.     /**
  47.      * Create a discrete distribution using the given probability mass function
  48.      * definition.
  49.      * <p>
  50.      * <b>Note:</b> this constructor will implicitly create an instance of
  51.      * {@link Well19937c} as random generator to be used for sampling only (see
  52.      * {@link #sample()} and {@link #sample(int)}). In case no sampling is
  53.      * needed for the created distribution, it is advised to pass {@code null}
  54.      * as random generator via the appropriate constructors to avoid the
  55.      * additional initialisation overhead.
  56.      *
  57.      * @param singletons array of random variable values.
  58.      * @param probabilities array of probabilities.
  59.      * @throws DimensionMismatchException if
  60.      * {@code singletons.length != probabilities.length}
  61.      * @throws NotPositiveException if any of the probabilities are negative.
  62.      * @throws NotFiniteNumberException if any of the probabilities are infinite.
  63.      * @throws NotANumberException if any of the probabilities are NaN.
  64.      * @throws MathArithmeticException all of the probabilities are 0.
  65.      */
  66.     public EnumeratedIntegerDistribution(final int[] singletons, final double[] probabilities)
  67.     throws DimensionMismatchException, NotPositiveException, MathArithmeticException,
  68.            NotFiniteNumberException, NotANumberException{
  69.         this(new Well19937c(), singletons, probabilities);
  70.     }

  71.     /**
  72.      * Create a discrete distribution using the given random number generator
  73.      * and probability mass function definition.
  74.      *
  75.      * @param rng random number generator.
  76.      * @param singletons array of random variable values.
  77.      * @param probabilities array of probabilities.
  78.      * @throws DimensionMismatchException if
  79.      * {@code singletons.length != probabilities.length}
  80.      * @throws NotPositiveException if any of the probabilities are negative.
  81.      * @throws NotFiniteNumberException if any of the probabilities are infinite.
  82.      * @throws NotANumberException if any of the probabilities are NaN.
  83.      * @throws MathArithmeticException all of the probabilities are 0.
  84.      */
  85.     public EnumeratedIntegerDistribution(final RandomGenerator rng,
  86.                                        final int[] singletons, final double[] probabilities)
  87.         throws DimensionMismatchException, NotPositiveException, MathArithmeticException,
  88.                 NotFiniteNumberException, NotANumberException {
  89.         super(rng);
  90.         if (singletons.length != probabilities.length) {
  91.             throw new DimensionMismatchException(probabilities.length, singletons.length);
  92.         }

  93.         final List<Pair<Integer, Double>> samples = new ArrayList<Pair<Integer, Double>>(singletons.length);

  94.         for (int i = 0; i < singletons.length; i++) {
  95.             samples.add(new Pair<Integer, Double>(singletons[i], probabilities[i]));
  96.         }

  97.         innerDistribution = new EnumeratedDistribution<Integer>(rng, samples);
  98.     }

  99.     /**
  100.      * {@inheritDoc}
  101.      */
  102.     public double probability(final int x) {
  103.         return innerDistribution.probability(x);
  104.     }

  105.     /**
  106.      * {@inheritDoc}
  107.      */
  108.     public double cumulativeProbability(final int x) {
  109.         double probability = 0;

  110.         for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
  111.             if (sample.getKey() <= x) {
  112.                 probability += sample.getValue();
  113.             }
  114.         }

  115.         return probability;
  116.     }

  117.     /**
  118.      * {@inheritDoc}
  119.      *
  120.      * @return {@code sum(singletons[i] * probabilities[i])}
  121.      */
  122.     public double getNumericalMean() {
  123.         double mean = 0;

  124.         for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
  125.             mean += sample.getValue() * sample.getKey();
  126.         }

  127.         return mean;
  128.     }

  129.     /**
  130.      * {@inheritDoc}
  131.      *
  132.      * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
  133.      */
  134.     public double getNumericalVariance() {
  135.         double mean = 0;
  136.         double meanOfSquares = 0;

  137.         for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
  138.             mean += sample.getValue() * sample.getKey();
  139.             meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
  140.         }

  141.         return meanOfSquares - mean * mean;
  142.     }

  143.     /**
  144.      * {@inheritDoc}
  145.      *
  146.      * Returns the lowest value with non-zero probability.
  147.      *
  148.      * @return the lowest value with non-zero probability.
  149.      */
  150.     public int getSupportLowerBound() {
  151.         int min = Integer.MAX_VALUE;
  152.         for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
  153.             if (sample.getKey() < min && sample.getValue() > 0) {
  154.                 min = sample.getKey();
  155.             }
  156.         }

  157.         return min;
  158.     }

  159.     /**
  160.      * {@inheritDoc}
  161.      *
  162.      * Returns the highest value with non-zero probability.
  163.      *
  164.      * @return the highest value with non-zero probability.
  165.      */
  166.     public int getSupportUpperBound() {
  167.         int max = Integer.MIN_VALUE;
  168.         for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
  169.             if (sample.getKey() > max && sample.getValue() > 0) {
  170.                 max = sample.getKey();
  171.             }
  172.         }

  173.         return max;
  174.     }

  175.     /**
  176.      * {@inheritDoc}
  177.      *
  178.      * The support of this distribution is connected.
  179.      *
  180.      * @return {@code true}
  181.      */
  182.     public boolean isSupportConnected() {
  183.         return true;
  184.     }

  185.     /**
  186.      * {@inheritDoc}
  187.      */
  188.     @Override
  189.     public int sample() {
  190.         return innerDistribution.sample();
  191.     }
  192. }