Gaussian.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.analysis.function;

  18. import java.util.Arrays;

  19. import org.apache.commons.math3.analysis.FunctionUtils;
  20. import org.apache.commons.math3.analysis.UnivariateFunction;
  21. import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
  22. import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
  23. import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
  24. import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
  25. import org.apache.commons.math3.exception.NotStrictlyPositiveException;
  26. import org.apache.commons.math3.exception.NullArgumentException;
  27. import org.apache.commons.math3.exception.DimensionMismatchException;
  28. import org.apache.commons.math3.util.FastMath;
  29. import org.apache.commons.math3.util.Precision;

  30. /**
  31.  * <a href="http://en.wikipedia.org/wiki/Gaussian_function">
  32.  *  Gaussian</a> function.
  33.  *
  34.  * @since 3.0
  35.  */
  36. public class Gaussian implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
  37.     /** Mean. */
  38.     private final double mean;
  39.     /** Inverse of the standard deviation. */
  40.     private final double is;
  41.     /** Inverse of twice the square of the standard deviation. */
  42.     private final double i2s2;
  43.     /** Normalization factor. */
  44.     private final double norm;

  45.     /**
  46.      * Gaussian with given normalization factor, mean and standard deviation.
  47.      *
  48.      * @param norm Normalization factor.
  49.      * @param mean Mean.
  50.      * @param sigma Standard deviation.
  51.      * @throws NotStrictlyPositiveException if {@code sigma <= 0}.
  52.      */
  53.     public Gaussian(double norm,
  54.                     double mean,
  55.                     double sigma)
  56.         throws NotStrictlyPositiveException {
  57.         if (sigma <= 0) {
  58.             throw new NotStrictlyPositiveException(sigma);
  59.         }

  60.         this.norm = norm;
  61.         this.mean = mean;
  62.         this.is   = 1 / sigma;
  63.         this.i2s2 = 0.5 * is * is;
  64.     }

  65.     /**
  66.      * Normalized gaussian with given mean and standard deviation.
  67.      *
  68.      * @param mean Mean.
  69.      * @param sigma Standard deviation.
  70.      * @throws NotStrictlyPositiveException if {@code sigma <= 0}.
  71.      */
  72.     public Gaussian(double mean,
  73.                     double sigma)
  74.         throws NotStrictlyPositiveException {
  75.         this(1 / (sigma * FastMath.sqrt(2 * Math.PI)), mean, sigma);
  76.     }

  77.     /**
  78.      * Normalized gaussian with zero mean and unit standard deviation.
  79.      */
  80.     public Gaussian() {
  81.         this(0, 1);
  82.     }

  83.     /** {@inheritDoc} */
  84.     public double value(double x) {
  85.         return value(x - mean, norm, i2s2);
  86.     }

  87.     /** {@inheritDoc}
  88.      * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
  89.      */
  90.     @Deprecated
  91.     public UnivariateFunction derivative() {
  92.         return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
  93.     }

  94.     /**
  95.      * Parametric function where the input array contains the parameters of
  96.      * the Gaussian, ordered as follows:
  97.      * <ul>
  98.      *  <li>Norm</li>
  99.      *  <li>Mean</li>
  100.      *  <li>Standard deviation</li>
  101.      * </ul>
  102.      */
  103.     public static class Parametric implements ParametricUnivariateFunction {
  104.         /**
  105.          * Computes the value of the Gaussian at {@code x}.
  106.          *
  107.          * @param x Value for which the function must be computed.
  108.          * @param param Values of norm, mean and standard deviation.
  109.          * @return the value of the function.
  110.          * @throws NullArgumentException if {@code param} is {@code null}.
  111.          * @throws DimensionMismatchException if the size of {@code param} is
  112.          * not 3.
  113.          * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
  114.          */
  115.         public double value(double x, double ... param)
  116.             throws NullArgumentException,
  117.                    DimensionMismatchException,
  118.                    NotStrictlyPositiveException {
  119.             validateParameters(param);

  120.             final double diff = x - param[1];
  121.             final double i2s2 = 1 / (2 * param[2] * param[2]);
  122.             return Gaussian.value(diff, param[0], i2s2);
  123.         }

  124.         /**
  125.          * Computes the value of the gradient at {@code x}.
  126.          * The components of the gradient vector are the partial
  127.          * derivatives of the function with respect to each of the
  128.          * <em>parameters</em> (norm, mean and standard deviation).
  129.          *
  130.          * @param x Value at which the gradient must be computed.
  131.          * @param param Values of norm, mean and standard deviation.
  132.          * @return the gradient vector at {@code x}.
  133.          * @throws NullArgumentException if {@code param} is {@code null}.
  134.          * @throws DimensionMismatchException if the size of {@code param} is
  135.          * not 3.
  136.          * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
  137.          */
  138.         public double[] gradient(double x, double ... param)
  139.             throws NullArgumentException,
  140.                    DimensionMismatchException,
  141.                    NotStrictlyPositiveException {
  142.             validateParameters(param);

  143.             final double norm = param[0];
  144.             final double diff = x - param[1];
  145.             final double sigma = param[2];
  146.             final double i2s2 = 1 / (2 * sigma * sigma);

  147.             final double n = Gaussian.value(diff, 1, i2s2);
  148.             final double m = norm * n * 2 * i2s2 * diff;
  149.             final double s = m * diff / sigma;

  150.             return new double[] { n, m, s };
  151.         }

  152.         /**
  153.          * Validates parameters to ensure they are appropriate for the evaluation of
  154.          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
  155.          * methods.
  156.          *
  157.          * @param param Values of norm, mean and standard deviation.
  158.          * @throws NullArgumentException if {@code param} is {@code null}.
  159.          * @throws DimensionMismatchException if the size of {@code param} is
  160.          * not 3.
  161.          * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
  162.          */
  163.         private void validateParameters(double[] param)
  164.             throws NullArgumentException,
  165.                    DimensionMismatchException,
  166.                    NotStrictlyPositiveException {
  167.             if (param == null) {
  168.                 throw new NullArgumentException();
  169.             }
  170.             if (param.length != 3) {
  171.                 throw new DimensionMismatchException(param.length, 3);
  172.             }
  173.             if (param[2] <= 0) {
  174.                 throw new NotStrictlyPositiveException(param[2]);
  175.             }
  176.         }
  177.     }

  178.     /**
  179.      * @param xMinusMean {@code x - mean}.
  180.      * @param norm Normalization factor.
  181.      * @param i2s2 Inverse of twice the square of the standard deviation.
  182.      * @return the value of the Gaussian at {@code x}.
  183.      */
  184.     private static double value(double xMinusMean,
  185.                                 double norm,
  186.                                 double i2s2) {
  187.         return norm * FastMath.exp(-xMinusMean * xMinusMean * i2s2);
  188.     }

  189.     /** {@inheritDoc}
  190.      * @since 3.1
  191.      */
  192.     public DerivativeStructure value(final DerivativeStructure t)
  193.         throws DimensionMismatchException {

  194.         final double u = is * (t.getValue() - mean);
  195.         double[] f = new double[t.getOrder() + 1];

  196.         // the nth order derivative of the Gaussian has the form:
  197.         // dn(g(x)/dxn = (norm / s^n) P_n(u) exp(-u^2/2) with u=(x-m)/s
  198.         // where P_n(u) is a degree n polynomial with same parity as n
  199.         // P_0(u) = 1, P_1(u) = -u, P_2(u) = u^2 - 1, P_3(u) = -u^3 + 3 u...
  200.         // the general recurrence relation for P_n is:
  201.         // P_n(u) = P_(n-1)'(u) - u P_(n-1)(u)
  202.         // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array
  203.         final double[] p = new double[f.length];
  204.         p[0] = 1;
  205.         final double u2 = u * u;
  206.         double coeff = norm * FastMath.exp(-0.5 * u2);
  207.         if (coeff <= Precision.SAFE_MIN) {
  208.             Arrays.fill(f, 0.0);
  209.         } else {
  210.             f[0] = coeff;
  211.             for (int n = 1; n < f.length; ++n) {

  212.                 // update and evaluate polynomial P_n(x)
  213.                 double v = 0;
  214.                 p[n] = -p[n - 1];
  215.                 for (int k = n; k >= 0; k -= 2) {
  216.                     v = v * u2 + p[k];
  217.                     if (k > 2) {
  218.                         p[k - 2] = (k - 1) * p[k - 1] - p[k - 3];
  219.                     } else if (k == 2) {
  220.                         p[0] = p[1];
  221.                     }
  222.                 }
  223.                 if ((n & 0x1) == 1) {
  224.                     v *= u;
  225.                 }

  226.                 coeff *= is;
  227.                 f[n] = coeff * v;

  228.             }
  229.         }

  230.         return t.compose(f);

  231.     }

  232. }