Logit.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.analysis.function;

  18. import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
  19. import org.apache.commons.math3.analysis.FunctionUtils;
  20. import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
  21. import org.apache.commons.math3.analysis.UnivariateFunction;
  22. import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
  23. import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
  24. import org.apache.commons.math3.exception.DimensionMismatchException;
  25. import org.apache.commons.math3.exception.NullArgumentException;
  26. import org.apache.commons.math3.exception.OutOfRangeException;
  27. import org.apache.commons.math3.util.FastMath;

  28. /**
  29.  * <a href="http://en.wikipedia.org/wiki/Logit">
  30.  *  Logit</a> function.
  31.  * It is the inverse of the {@link Sigmoid sigmoid} function.
  32.  *
  33.  * @since 3.0
  34.  */
  35. public class Logit implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
  36.     /** Lower bound. */
  37.     private final double lo;
  38.     /** Higher bound. */
  39.     private final double hi;

  40.     /**
  41.      * Usual logit function, where the lower bound is 0 and the higher
  42.      * bound is 1.
  43.      */
  44.     public Logit() {
  45.         this(0, 1);
  46.     }

  47.     /**
  48.      * Logit function.
  49.      *
  50.      * @param lo Lower bound of the function domain.
  51.      * @param hi Higher bound of the function domain.
  52.      */
  53.     public Logit(double lo,
  54.                  double hi) {
  55.         this.lo = lo;
  56.         this.hi = hi;
  57.     }

  58.     /** {@inheritDoc} */
  59.     public double value(double x)
  60.         throws OutOfRangeException {
  61.         return value(x, lo, hi);
  62.     }

  63.     /** {@inheritDoc}
  64.      * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
  65.      */
  66.     @Deprecated
  67.     public UnivariateFunction derivative() {
  68.         return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
  69.     }

  70.     /**
  71.      * Parametric function where the input array contains the parameters of
  72.      * the logit function, ordered as follows:
  73.      * <ul>
  74.      *  <li>Lower bound</li>
  75.      *  <li>Higher bound</li>
  76.      * </ul>
  77.      */
  78.     public static class Parametric implements ParametricUnivariateFunction {
  79.         /**
  80.          * Computes the value of the logit at {@code x}.
  81.          *
  82.          * @param x Value for which the function must be computed.
  83.          * @param param Values of lower bound and higher bounds.
  84.          * @return the value of the function.
  85.          * @throws NullArgumentException if {@code param} is {@code null}.
  86.          * @throws DimensionMismatchException if the size of {@code param} is
  87.          * not 2.
  88.          */
  89.         public double value(double x, double ... param)
  90.             throws NullArgumentException,
  91.                    DimensionMismatchException {
  92.             validateParameters(param);
  93.             return Logit.value(x, param[0], param[1]);
  94.         }

  95.         /**
  96.          * Computes the value of the gradient at {@code x}.
  97.          * The components of the gradient vector are the partial
  98.          * derivatives of the function with respect to each of the
  99.          * <em>parameters</em> (lower bound and higher bound).
  100.          *
  101.          * @param x Value at which the gradient must be computed.
  102.          * @param param Values for lower and higher bounds.
  103.          * @return the gradient vector at {@code x}.
  104.          * @throws NullArgumentException if {@code param} is {@code null}.
  105.          * @throws DimensionMismatchException if the size of {@code param} is
  106.          * not 2.
  107.          */
  108.         public double[] gradient(double x, double ... param)
  109.             throws NullArgumentException,
  110.                    DimensionMismatchException {
  111.             validateParameters(param);

  112.             final double lo = param[0];
  113.             final double hi = param[1];

  114.             return new double[] { 1 / (lo - x), 1 / (hi - x) };
  115.         }

  116.         /**
  117.          * Validates parameters to ensure they are appropriate for the evaluation of
  118.          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
  119.          * methods.
  120.          *
  121.          * @param param Values for lower and higher bounds.
  122.          * @throws NullArgumentException if {@code param} is {@code null}.
  123.          * @throws DimensionMismatchException if the size of {@code param} is
  124.          * not 2.
  125.          */
  126.         private void validateParameters(double[] param)
  127.             throws NullArgumentException,
  128.                    DimensionMismatchException {
  129.             if (param == null) {
  130.                 throw new NullArgumentException();
  131.             }
  132.             if (param.length != 2) {
  133.                 throw new DimensionMismatchException(param.length, 2);
  134.             }
  135.         }
  136.     }

  137.     /**
  138.      * @param x Value at which to compute the logit.
  139.      * @param lo Lower bound.
  140.      * @param hi Higher bound.
  141.      * @return the value of the logit function at {@code x}.
  142.      * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}.
  143.      */
  144.     private static double value(double x,
  145.                                 double lo,
  146.                                 double hi)
  147.         throws OutOfRangeException {
  148.         if (x < lo || x > hi) {
  149.             throw new OutOfRangeException(x, lo, hi);
  150.         }
  151.         return FastMath.log((x - lo) / (hi - x));
  152.     }

  153.     /** {@inheritDoc}
  154.      * @since 3.1
  155.      * @exception OutOfRangeException if parameter is outside of function domain
  156.      */
  157.     public DerivativeStructure value(final DerivativeStructure t)
  158.         throws OutOfRangeException {
  159.         final double x = t.getValue();
  160.         if (x < lo || x > hi) {
  161.             throw new OutOfRangeException(x, lo, hi);
  162.         }
  163.         double[] f = new double[t.getOrder() + 1];

  164.         // function value
  165.         f[0] = FastMath.log((x - lo) / (hi - x));

  166.         if (Double.isInfinite(f[0])) {

  167.             if (f.length > 1) {
  168.                 f[1] = Double.POSITIVE_INFINITY;
  169.             }
  170.             // fill the array with infinities
  171.             // (for x close to lo the signs will flip between -inf and +inf,
  172.             //  for x close to hi the signs will always be +inf)
  173.             // this is probably overkill, since the call to compose at the end
  174.             // of the method will transform most infinities into NaN ...
  175.             for (int i = 2; i < f.length; ++i) {
  176.                 f[i] = f[i - 2];
  177.             }

  178.         } else {

  179.             // function derivatives
  180.             final double invL = 1.0 / (x - lo);
  181.             double xL = invL;
  182.             final double invH = 1.0 / (hi - x);
  183.             double xH = invH;
  184.             for (int i = 1; i < f.length; ++i) {
  185.                 f[i] = xL + xH;
  186.                 xL  *= -i * invL;
  187.                 xH  *=  i * invH;
  188.             }
  189.         }

  190.         return t.compose(f);
  191.     }
  192. }