MixtureMultivariateRealDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.distribution;

  18. import java.util.ArrayList;
  19. import java.util.List;

  20. import org.apache.commons.math3.exception.DimensionMismatchException;
  21. import org.apache.commons.math3.exception.MathArithmeticException;
  22. import org.apache.commons.math3.exception.NotPositiveException;
  23. import org.apache.commons.math3.exception.util.LocalizedFormats;
  24. import org.apache.commons.math3.random.RandomGenerator;
  25. import org.apache.commons.math3.random.Well19937c;
  26. import org.apache.commons.math3.util.Pair;

  27. /**
  28.  * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
  29.  * mixture model</a> distributions.
  30.  *
  31.  * @param <T> Type of the mixture components.
  32.  *
  33.  * @since 3.1
  34.  */
  35. public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
  36.     extends AbstractMultivariateRealDistribution {
  37.     /** Normalized weight of each mixture component. */
  38.     private final double[] weight;
  39.     /** Mixture components. */
  40.     private final List<T> distribution;

  41.     /**
  42.      * Creates a mixture model from a list of distributions and their
  43.      * associated weights.
  44.      * <p>
  45.      * <b>Note:</b> this constructor will implicitly create an instance of
  46.      * {@link Well19937c} as random generator to be used for sampling only (see
  47.      * {@link #sample()} and {@link #sample(int)}). In case no sampling is
  48.      * needed for the created distribution, it is advised to pass {@code null}
  49.      * as random generator via the appropriate constructors to avoid the
  50.      * additional initialisation overhead.
  51.      *
  52.      * @param components List of (weight, distribution) pairs from which to sample.
  53.      */
  54.     public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
  55.         this(new Well19937c(), components);
  56.     }

  57.     /**
  58.      * Creates a mixture model from a list of distributions and their
  59.      * associated weights.
  60.      *
  61.      * @param rng Random number generator.
  62.      * @param components Distributions from which to sample.
  63.      * @throws NotPositiveException if any of the weights is negative.
  64.      * @throws DimensionMismatchException if not all components have the same
  65.      * number of variables.
  66.      */
  67.     public MixtureMultivariateRealDistribution(RandomGenerator rng,
  68.                                                List<Pair<Double, T>> components) {
  69.         super(rng, components.get(0).getSecond().getDimension());

  70.         final int numComp = components.size();
  71.         final int dim = getDimension();
  72.         double weightSum = 0;
  73.         for (int i = 0; i < numComp; i++) {
  74.             final Pair<Double, T> comp = components.get(i);
  75.             if (comp.getSecond().getDimension() != dim) {
  76.                 throw new DimensionMismatchException(comp.getSecond().getDimension(), dim);
  77.             }
  78.             if (comp.getFirst() < 0) {
  79.                 throw new NotPositiveException(comp.getFirst());
  80.             }
  81.             weightSum += comp.getFirst();
  82.         }

  83.         // Check for overflow.
  84.         if (Double.isInfinite(weightSum)) {
  85.             throw new MathArithmeticException(LocalizedFormats.OVERFLOW);
  86.         }

  87.         // Store each distribution and its normalized weight.
  88.         distribution = new ArrayList<T>();
  89.         weight = new double[numComp];
  90.         for (int i = 0; i < numComp; i++) {
  91.             final Pair<Double, T> comp = components.get(i);
  92.             weight[i] = comp.getFirst() / weightSum;
  93.             distribution.add(comp.getSecond());
  94.         }
  95.     }

  96.     /** {@inheritDoc} */
  97.     public double density(final double[] values) {
  98.         double p = 0;
  99.         for (int i = 0; i < weight.length; i++) {
  100.             p += weight[i] * distribution.get(i).density(values);
  101.         }
  102.         return p;
  103.     }

  104.     /** {@inheritDoc} */
  105.     @Override
  106.     public double[] sample() {
  107.         // Sampled values.
  108.         double[] vals = null;

  109.         // Determine which component to sample from.
  110.         final double randomValue = random.nextDouble();
  111.         double sum = 0;

  112.         for (int i = 0; i < weight.length; i++) {
  113.             sum += weight[i];
  114.             if (randomValue <= sum) {
  115.                 // pick model i
  116.                 vals = distribution.get(i).sample();
  117.                 break;
  118.             }
  119.         }

  120.         if (vals == null) {
  121.             // This should never happen, but it ensures we won't return a null in
  122.             // case the loop above has some floating point inequality problem on
  123.             // the final iteration.
  124.             vals = distribution.get(weight.length - 1).sample();
  125.         }

  126.         return vals;
  127.     }

  128.     /** {@inheritDoc} */
  129.     @Override
  130.     public void reseedRandomGenerator(long seed) {
  131.         // Seed needs to be propagated to underlying components
  132.         // in order to maintain consistency between runs.
  133.         super.reseedRandomGenerator(seed);

  134.         for (int i = 0; i < distribution.size(); i++) {
  135.             // Make each component's seed different in order to avoid
  136.             // using the same sequence of random numbers.
  137.             distribution.get(i).reseedRandomGenerator(i + 1 + seed);
  138.         }
  139.     }

  140.     /**
  141.      * Gets the distributions that make up the mixture model.
  142.      *
  143.      * @return the component distributions and associated weights.
  144.      */
  145.     public List<Pair<Double, T>> getComponents() {
  146.         final List<Pair<Double, T>> list = new ArrayList<Pair<Double, T>>(weight.length);

  147.         for (int i = 0; i < weight.length; i++) {
  148.             list.add(new Pair<Double, T>(weight[i], distribution.get(i)));
  149.         }

  150.         return list;
  151.     }
  152. }