Logistic.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.analysis.function;

  18. import org.apache.commons.math3.analysis.FunctionUtils;
  19. import org.apache.commons.math3.analysis.UnivariateFunction;
  20. import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
  21. import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
  22. import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
  23. import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
  24. import org.apache.commons.math3.exception.NotStrictlyPositiveException;
  25. import org.apache.commons.math3.exception.NullArgumentException;
  26. import org.apache.commons.math3.exception.DimensionMismatchException;
  27. import org.apache.commons.math3.util.FastMath;

  28. /**
  29.  * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function">
  30.  *  Generalised logistic</a> function.
  31.  *
  32.  * @since 3.0
  33.  */
  34. public class Logistic implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
  35.     /** Lower asymptote. */
  36.     private final double a;
  37.     /** Upper asymptote. */
  38.     private final double k;
  39.     /** Growth rate. */
  40.     private final double b;
  41.     /** Parameter that affects near which asymptote maximum growth occurs. */
  42.     private final double oneOverN;
  43.     /** Parameter that affects the position of the curve along the ordinate axis. */
  44.     private final double q;
  45.     /** Abscissa of maximum growth. */
  46.     private final double m;

  47.     /**
  48.      * @param k If {@code b > 0}, value of the function for x going towards +&infin;.
  49.      * If {@code b < 0}, value of the function for x going towards -&infin;.
  50.      * @param m Abscissa of maximum growth.
  51.      * @param b Growth rate.
  52.      * @param q Parameter that affects the position of the curve along the
  53.      * ordinate axis.
  54.      * @param a If {@code b > 0}, value of the function for x going towards -&infin;.
  55.      * If {@code b < 0}, value of the function for x going towards +&infin;.
  56.      * @param n Parameter that affects near which asymptote the maximum
  57.      * growth occurs.
  58.      * @throws NotStrictlyPositiveException if {@code n <= 0}.
  59.      */
  60.     public Logistic(double k,
  61.                     double m,
  62.                     double b,
  63.                     double q,
  64.                     double a,
  65.                     double n)
  66.         throws NotStrictlyPositiveException {
  67.         if (n <= 0) {
  68.             throw new NotStrictlyPositiveException(n);
  69.         }

  70.         this.k = k;
  71.         this.m = m;
  72.         this.b = b;
  73.         this.q = q;
  74.         this.a = a;
  75.         oneOverN = 1 / n;
  76.     }

  77.     /** {@inheritDoc} */
  78.     public double value(double x) {
  79.         return value(m - x, k, b, q, a, oneOverN);
  80.     }

  81.     /** {@inheritDoc}
  82.      * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
  83.      */
  84.     @Deprecated
  85.     public UnivariateFunction derivative() {
  86.         return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
  87.     }

  88.     /**
  89.      * Parametric function where the input array contains the parameters of
  90.      * the {@link Logistic#Logistic(double,double,double,double,double,double)
  91.      * logistic function}, ordered as follows:
  92.      * <ul>
  93.      *  <li>k</li>
  94.      *  <li>m</li>
  95.      *  <li>b</li>
  96.      *  <li>q</li>
  97.      *  <li>a</li>
  98.      *  <li>n</li>
  99.      * </ul>
  100.      */
  101.     public static class Parametric implements ParametricUnivariateFunction {
  102.         /**
  103.          * Computes the value of the sigmoid at {@code x}.
  104.          *
  105.          * @param x Value for which the function must be computed.
  106.          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
  107.          * {@code a} and  {@code n}.
  108.          * @return the value of the function.
  109.          * @throws NullArgumentException if {@code param} is {@code null}.
  110.          * @throws DimensionMismatchException if the size of {@code param} is
  111.          * not 6.
  112.          * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
  113.          */
  114.         public double value(double x, double ... param)
  115.             throws NullArgumentException,
  116.                    DimensionMismatchException,
  117.                    NotStrictlyPositiveException {
  118.             validateParameters(param);
  119.             return Logistic.value(param[1] - x, param[0],
  120.                                   param[2], param[3],
  121.                                   param[4], 1 / param[5]);
  122.         }

  123.         /**
  124.          * Computes the value of the gradient at {@code x}.
  125.          * The components of the gradient vector are the partial
  126.          * derivatives of the function with respect to each of the
  127.          * <em>parameters</em>.
  128.          *
  129.          * @param x Value at which the gradient must be computed.
  130.          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
  131.          * {@code a} and  {@code n}.
  132.          * @return the gradient vector at {@code x}.
  133.          * @throws NullArgumentException if {@code param} is {@code null}.
  134.          * @throws DimensionMismatchException if the size of {@code param} is
  135.          * not 6.
  136.          * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
  137.          */
  138.         public double[] gradient(double x, double ... param)
  139.             throws NullArgumentException,
  140.                    DimensionMismatchException,
  141.                    NotStrictlyPositiveException {
  142.             validateParameters(param);

  143.             final double b = param[2];
  144.             final double q = param[3];

  145.             final double mMinusX = param[1] - x;
  146.             final double oneOverN = 1 / param[5];
  147.             final double exp = FastMath.exp(b * mMinusX);
  148.             final double qExp = q * exp;
  149.             final double qExp1 = qExp + 1;
  150.             final double factor1 = (param[0] - param[4]) * oneOverN / FastMath.pow(qExp1, oneOverN);
  151.             final double factor2 = -factor1 / qExp1;

  152.             // Components of the gradient.
  153.             final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
  154.             final double gm = factor2 * b * qExp;
  155.             final double gb = factor2 * mMinusX * qExp;
  156.             final double gq = factor2 * exp;
  157.             final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
  158.             final double gn = factor1 * FastMath.log(qExp1) * oneOverN;

  159.             return new double[] { gk, gm, gb, gq, ga, gn };
  160.         }

  161.         /**
  162.          * Validates parameters to ensure they are appropriate for the evaluation of
  163.          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
  164.          * methods.
  165.          *
  166.          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
  167.          * {@code a} and {@code n}.
  168.          * @throws NullArgumentException if {@code param} is {@code null}.
  169.          * @throws DimensionMismatchException if the size of {@code param} is
  170.          * not 6.
  171.          * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
  172.          */
  173.         private void validateParameters(double[] param)
  174.             throws NullArgumentException,
  175.                    DimensionMismatchException,
  176.                    NotStrictlyPositiveException {
  177.             if (param == null) {
  178.                 throw new NullArgumentException();
  179.             }
  180.             if (param.length != 6) {
  181.                 throw new DimensionMismatchException(param.length, 6);
  182.             }
  183.             if (param[5] <= 0) {
  184.                 throw new NotStrictlyPositiveException(param[5]);
  185.             }
  186.         }
  187.     }

  188.     /**
  189.      * @param mMinusX {@code m - x}.
  190.      * @param k {@code k}.
  191.      * @param b {@code b}.
  192.      * @param q {@code q}.
  193.      * @param a {@code a}.
  194.      * @param oneOverN {@code 1 / n}.
  195.      * @return the value of the function.
  196.      */
  197.     private static double value(double mMinusX,
  198.                                 double k,
  199.                                 double b,
  200.                                 double q,
  201.                                 double a,
  202.                                 double oneOverN) {
  203.         return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * mMinusX), oneOverN);
  204.     }

  205.     /** {@inheritDoc}
  206.      * @since 3.1
  207.      */
  208.     public DerivativeStructure value(final DerivativeStructure t) {
  209.         return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a);
  210.     }

  211. }