BrentOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.optim.univariate;

  18. import org.apache.commons.math3.util.Precision;
  19. import org.apache.commons.math3.util.FastMath;
  20. import org.apache.commons.math3.exception.NumberIsTooSmallException;
  21. import org.apache.commons.math3.exception.NotStrictlyPositiveException;
  22. import org.apache.commons.math3.optim.ConvergenceChecker;
  23. import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;

  24. /**
  25.  * For a function defined on some interval {@code (lo, hi)}, this class
  26.  * finds an approximation {@code x} to the point at which the function
  27.  * attains its minimum.
  28.  * It implements Richard Brent's algorithm (from his book "Algorithms for
  29.  * Minimization without Derivatives", p. 79) for finding minima of real
  30.  * univariate functions.
  31.  * <br/>
  32.  * This code is an adaptation, partly based on the Python code from SciPy
  33.  * (module "optimize.py" v0.5); the original algorithm is also modified
  34.  * <ul>
  35.  *  <li>to use an initial guess provided by the user,</li>
  36.  *  <li>to ensure that the best point encountered is the one returned.</li>
  37.  * </ul>
  38.  *
  39.  * @since 2.0
  40.  */
  41. public class BrentOptimizer extends UnivariateOptimizer {
  42.     /**
  43.      * Golden section.
  44.      */
  45.     private static final double GOLDEN_SECTION = 0.5 * (3 - FastMath.sqrt(5));
  46.     /**
  47.      * Minimum relative tolerance.
  48.      */
  49.     private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
  50.     /**
  51.      * Relative threshold.
  52.      */
  53.     private final double relativeThreshold;
  54.     /**
  55.      * Absolute threshold.
  56.      */
  57.     private final double absoluteThreshold;

  58.     /**
  59.      * The arguments are used implement the original stopping criterion
  60.      * of Brent's algorithm.
  61.      * {@code abs} and {@code rel} define a tolerance
  62.      * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
  63.      * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
  64.      * where <em>macheps</em> is the relative machine precision. {@code abs} must
  65.      * be positive.
  66.      *
  67.      * @param rel Relative threshold.
  68.      * @param abs Absolute threshold.
  69.      * @param checker Additional, user-defined, convergence checking
  70.      * procedure.
  71.      * @throws NotStrictlyPositiveException if {@code abs <= 0}.
  72.      * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
  73.      */
  74.     public BrentOptimizer(double rel,
  75.                           double abs,
  76.                           ConvergenceChecker<UnivariatePointValuePair> checker) {
  77.         super(checker);

  78.         if (rel < MIN_RELATIVE_TOLERANCE) {
  79.             throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
  80.         }
  81.         if (abs <= 0) {
  82.             throw new NotStrictlyPositiveException(abs);
  83.         }

  84.         relativeThreshold = rel;
  85.         absoluteThreshold = abs;
  86.     }

  87.     /**
  88.      * The arguments are used for implementing the original stopping criterion
  89.      * of Brent's algorithm.
  90.      * {@code abs} and {@code rel} define a tolerance
  91.      * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
  92.      * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
  93.      * where <em>macheps</em> is the relative machine precision. {@code abs} must
  94.      * be positive.
  95.      *
  96.      * @param rel Relative threshold.
  97.      * @param abs Absolute threshold.
  98.      * @throws NotStrictlyPositiveException if {@code abs <= 0}.
  99.      * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
  100.      */
  101.     public BrentOptimizer(double rel,
  102.                           double abs) {
  103.         this(rel, abs, null);
  104.     }

  105.     /** {@inheritDoc} */
  106.     @Override
  107.     protected UnivariatePointValuePair doOptimize() {
  108.         final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
  109.         final double lo = getMin();
  110.         final double mid = getStartValue();
  111.         final double hi = getMax();

  112.         // Optional additional convergence criteria.
  113.         final ConvergenceChecker<UnivariatePointValuePair> checker
  114.             = getConvergenceChecker();

  115.         double a;
  116.         double b;
  117.         if (lo < hi) {
  118.             a = lo;
  119.             b = hi;
  120.         } else {
  121.             a = hi;
  122.             b = lo;
  123.         }

  124.         double x = mid;
  125.         double v = x;
  126.         double w = x;
  127.         double d = 0;
  128.         double e = 0;
  129.         double fx = computeObjectiveValue(x);
  130.         if (!isMinim) {
  131.             fx = -fx;
  132.         }
  133.         double fv = fx;
  134.         double fw = fx;

  135.         UnivariatePointValuePair previous = null;
  136.         UnivariatePointValuePair current
  137.             = new UnivariatePointValuePair(x, isMinim ? fx : -fx);
  138.         // Best point encountered so far (which is the initial guess).
  139.         UnivariatePointValuePair best = current;

  140.         while (true) {
  141.             final double m = 0.5 * (a + b);
  142.             final double tol1 = relativeThreshold * FastMath.abs(x) + absoluteThreshold;
  143.             final double tol2 = 2 * tol1;

  144.             // Default stopping criterion.
  145.             final boolean stop = FastMath.abs(x - m) <= tol2 - 0.5 * (b - a);
  146.             if (!stop) {
  147.                 double p = 0;
  148.                 double q = 0;
  149.                 double r = 0;
  150.                 double u = 0;

  151.                 if (FastMath.abs(e) > tol1) { // Fit parabola.
  152.                     r = (x - w) * (fx - fv);
  153.                     q = (x - v) * (fx - fw);
  154.                     p = (x - v) * q - (x - w) * r;
  155.                     q = 2 * (q - r);

  156.                     if (q > 0) {
  157.                         p = -p;
  158.                     } else {
  159.                         q = -q;
  160.                     }

  161.                     r = e;
  162.                     e = d;

  163.                     if (p > q * (a - x) &&
  164.                         p < q * (b - x) &&
  165.                         FastMath.abs(p) < FastMath.abs(0.5 * q * r)) {
  166.                         // Parabolic interpolation step.
  167.                         d = p / q;
  168.                         u = x + d;

  169.                         // f must not be evaluated too close to a or b.
  170.                         if (u - a < tol2 || b - u < tol2) {
  171.                             if (x <= m) {
  172.                                 d = tol1;
  173.                             } else {
  174.                                 d = -tol1;
  175.                             }
  176.                         }
  177.                     } else {
  178.                         // Golden section step.
  179.                         if (x < m) {
  180.                             e = b - x;
  181.                         } else {
  182.                             e = a - x;
  183.                         }
  184.                         d = GOLDEN_SECTION * e;
  185.                     }
  186.                 } else {
  187.                     // Golden section step.
  188.                     if (x < m) {
  189.                         e = b - x;
  190.                     } else {
  191.                         e = a - x;
  192.                     }
  193.                     d = GOLDEN_SECTION * e;
  194.                 }

  195.                 // Update by at least "tol1".
  196.                 if (FastMath.abs(d) < tol1) {
  197.                     if (d >= 0) {
  198.                         u = x + tol1;
  199.                     } else {
  200.                         u = x - tol1;
  201.                     }
  202.                 } else {
  203.                     u = x + d;
  204.                 }

  205.                 double fu = computeObjectiveValue(u);
  206.                 if (!isMinim) {
  207.                     fu = -fu;
  208.                 }

  209.                 // User-defined convergence checker.
  210.                 previous = current;
  211.                 current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
  212.                 best = best(best,
  213.                             best(previous,
  214.                                  current,
  215.                                  isMinim),
  216.                             isMinim);

  217.                 if (checker != null && checker.converged(getIterations(), previous, current)) {
  218.                     return best;
  219.                 }

  220.                 // Update a, b, v, w and x.
  221.                 if (fu <= fx) {
  222.                     if (u < x) {
  223.                         b = x;
  224.                     } else {
  225.                         a = x;
  226.                     }
  227.                     v = w;
  228.                     fv = fw;
  229.                     w = x;
  230.                     fw = fx;
  231.                     x = u;
  232.                     fx = fu;
  233.                 } else {
  234.                     if (u < x) {
  235.                         a = u;
  236.                     } else {
  237.                         b = u;
  238.                     }
  239.                     if (fu <= fw ||
  240.                         Precision.equals(w, x)) {
  241.                         v = w;
  242.                         fv = fw;
  243.                         w = u;
  244.                         fw = fu;
  245.                     } else if (fu <= fv ||
  246.                                Precision.equals(v, x) ||
  247.                                Precision.equals(v, w)) {
  248.                         v = u;
  249.                         fv = fu;
  250.                     }
  251.                 }
  252.             } else { // Default termination (Brent's criterion).
  253.                 return best(best,
  254.                             best(previous,
  255.                                  current,
  256.                                  isMinim),
  257.                             isMinim);
  258.             }

  259.             incrementIterationCount();
  260.         }
  261.     }

  262.     /**
  263.      * Selects the best of two points.
  264.      *
  265.      * @param a Point and value.
  266.      * @param b Point and value.
  267.      * @param isMinim {@code true} if the selected point must be the one with
  268.      * the lowest value.
  269.      * @return the best point, or {@code null} if {@code a} and {@code b} are
  270.      * both {@code null}. When {@code a} and {@code b} have the same function
  271.      * value, {@code a} is returned.
  272.      */
  273.     private UnivariatePointValuePair best(UnivariatePointValuePair a,
  274.                                           UnivariatePointValuePair b,
  275.                                           boolean isMinim) {
  276.         if (a == null) {
  277.             return b;
  278.         }
  279.         if (b == null) {
  280.             return a;
  281.         }

  282.         if (isMinim) {
  283.             return a.getValue() <= b.getValue() ? a : b;
  284.         } else {
  285.             return a.getValue() >= b.getValue() ? a : b;
  286.         }
  287.     }
  288. }