PowellOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.optimization.direct;

  18. import org.apache.commons.math3.util.FastMath;
  19. import org.apache.commons.math3.util.MathArrays;
  20. import org.apache.commons.math3.analysis.UnivariateFunction;
  21. import org.apache.commons.math3.analysis.MultivariateFunction;
  22. import org.apache.commons.math3.exception.NumberIsTooSmallException;
  23. import org.apache.commons.math3.exception.NotStrictlyPositiveException;
  24. import org.apache.commons.math3.optimization.GoalType;
  25. import org.apache.commons.math3.optimization.PointValuePair;
  26. import org.apache.commons.math3.optimization.ConvergenceChecker;
  27. import org.apache.commons.math3.optimization.MultivariateOptimizer;
  28. import org.apache.commons.math3.optimization.univariate.BracketFinder;
  29. import org.apache.commons.math3.optimization.univariate.BrentOptimizer;
  30. import org.apache.commons.math3.optimization.univariate.UnivariatePointValuePair;
  31. import org.apache.commons.math3.optimization.univariate.SimpleUnivariateValueChecker;

  32. /**
  33.  * Powell algorithm.
  34.  * This code is translated and adapted from the Python version of this
  35.  * algorithm (as implemented in module {@code optimize.py} v0.5 of
  36.  * <em>SciPy</em>).
  37.  * <br/>
  38.  * The default stopping criterion is based on the differences of the
  39.  * function value between two successive iterations. It is however possible
  40.  * to define a custom convergence checker that might terminate the algorithm
  41.  * earlier.
  42.  * <br/>
  43.  * The internal line search optimizer is a {@link BrentOptimizer} with a
  44.  * convergence checker set to {@link SimpleUnivariateValueChecker}.
  45.  *
  46.  * @deprecated As of 3.1 (to be removed in 4.0).
  47.  * @since 2.2
  48.  */
  49. @Deprecated
  50. public class PowellOptimizer
  51.     extends BaseAbstractMultivariateOptimizer<MultivariateFunction>
  52.     implements MultivariateOptimizer {
  53.     /**
  54.      * Minimum relative tolerance.
  55.      */
  56.     private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
  57.     /**
  58.      * Relative threshold.
  59.      */
  60.     private final double relativeThreshold;
  61.     /**
  62.      * Absolute threshold.
  63.      */
  64.     private final double absoluteThreshold;
  65.     /**
  66.      * Line search.
  67.      */
  68.     private final LineSearch line;

  69.     /**
  70.      * This constructor allows to specify a user-defined convergence checker,
  71.      * in addition to the parameters that control the default convergence
  72.      * checking procedure.
  73.      * <br/>
  74.      * The internal line search tolerances are set to the square-root of their
  75.      * corresponding value in the multivariate optimizer.
  76.      *
  77.      * @param rel Relative threshold.
  78.      * @param abs Absolute threshold.
  79.      * @param checker Convergence checker.
  80.      * @throws NotStrictlyPositiveException if {@code abs <= 0}.
  81.      * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
  82.      */
  83.     public PowellOptimizer(double rel,
  84.                            double abs,
  85.                            ConvergenceChecker<PointValuePair> checker) {
  86.         this(rel, abs, FastMath.sqrt(rel), FastMath.sqrt(abs), checker);
  87.     }

  88.     /**
  89.      * This constructor allows to specify a user-defined convergence checker,
  90.      * in addition to the parameters that control the default convergence
  91.      * checking procedure and the line search tolerances.
  92.      *
  93.      * @param rel Relative threshold for this optimizer.
  94.      * @param abs Absolute threshold for this optimizer.
  95.      * @param lineRel Relative threshold for the internal line search optimizer.
  96.      * @param lineAbs Absolute threshold for the internal line search optimizer.
  97.      * @param checker Convergence checker.
  98.      * @throws NotStrictlyPositiveException if {@code abs <= 0}.
  99.      * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
  100.      */
  101.     public PowellOptimizer(double rel,
  102.                            double abs,
  103.                            double lineRel,
  104.                            double lineAbs,
  105.                            ConvergenceChecker<PointValuePair> checker) {
  106.         super(checker);

  107.         if (rel < MIN_RELATIVE_TOLERANCE) {
  108.             throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
  109.         }
  110.         if (abs <= 0) {
  111.             throw new NotStrictlyPositiveException(abs);
  112.         }
  113.         relativeThreshold = rel;
  114.         absoluteThreshold = abs;

  115.         // Create the line search optimizer.
  116.         line = new LineSearch(lineRel,
  117.                               lineAbs);
  118.     }

  119.     /**
  120.      * The parameters control the default convergence checking procedure.
  121.      * <br/>
  122.      * The internal line search tolerances are set to the square-root of their
  123.      * corresponding value in the multivariate optimizer.
  124.      *
  125.      * @param rel Relative threshold.
  126.      * @param abs Absolute threshold.
  127.      * @throws NotStrictlyPositiveException if {@code abs <= 0}.
  128.      * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
  129.      */
  130.     public PowellOptimizer(double rel,
  131.                            double abs) {
  132.         this(rel, abs, null);
  133.     }

  134.     /**
  135.      * Builds an instance with the default convergence checking procedure.
  136.      *
  137.      * @param rel Relative threshold.
  138.      * @param abs Absolute threshold.
  139.      * @param lineRel Relative threshold for the internal line search optimizer.
  140.      * @param lineAbs Absolute threshold for the internal line search optimizer.
  141.      * @throws NotStrictlyPositiveException if {@code abs <= 0}.
  142.      * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
  143.      * @since 3.1
  144.      */
  145.     public PowellOptimizer(double rel,
  146.                            double abs,
  147.                            double lineRel,
  148.                            double lineAbs) {
  149.         this(rel, abs, lineRel, lineAbs, null);
  150.     }

  151.     /** {@inheritDoc} */
  152.     @Override
  153.     protected PointValuePair doOptimize() {
  154.         final GoalType goal = getGoalType();
  155.         final double[] guess = getStartPoint();
  156.         final int n = guess.length;

  157.         final double[][] direc = new double[n][n];
  158.         for (int i = 0; i < n; i++) {
  159.             direc[i][i] = 1;
  160.         }

  161.         final ConvergenceChecker<PointValuePair> checker
  162.             = getConvergenceChecker();

  163.         double[] x = guess;
  164.         double fVal = computeObjectiveValue(x);
  165.         double[] x1 = x.clone();
  166.         int iter = 0;
  167.         while (true) {
  168.             ++iter;

  169.             double fX = fVal;
  170.             double fX2 = 0;
  171.             double delta = 0;
  172.             int bigInd = 0;
  173.             double alphaMin = 0;

  174.             for (int i = 0; i < n; i++) {
  175.                 final double[] d = MathArrays.copyOf(direc[i]);

  176.                 fX2 = fVal;

  177.                 final UnivariatePointValuePair optimum = line.search(x, d);
  178.                 fVal = optimum.getValue();
  179.                 alphaMin = optimum.getPoint();
  180.                 final double[][] result = newPointAndDirection(x, d, alphaMin);
  181.                 x = result[0];

  182.                 if ((fX2 - fVal) > delta) {
  183.                     delta = fX2 - fVal;
  184.                     bigInd = i;
  185.                 }
  186.             }

  187.             // Default convergence check.
  188.             boolean stop = 2 * (fX - fVal) <=
  189.                 (relativeThreshold * (FastMath.abs(fX) + FastMath.abs(fVal)) +
  190.                  absoluteThreshold);

  191.             final PointValuePair previous = new PointValuePair(x1, fX);
  192.             final PointValuePair current = new PointValuePair(x, fVal);
  193.             if (!stop && checker != null) {
  194.                 stop = checker.converged(iter, previous, current);
  195.             }
  196.             if (stop) {
  197.                 if (goal == GoalType.MINIMIZE) {
  198.                     return (fVal < fX) ? current : previous;
  199.                 } else {
  200.                     return (fVal > fX) ? current : previous;
  201.                 }
  202.             }

  203.             final double[] d = new double[n];
  204.             final double[] x2 = new double[n];
  205.             for (int i = 0; i < n; i++) {
  206.                 d[i] = x[i] - x1[i];
  207.                 x2[i] = 2 * x[i] - x1[i];
  208.             }

  209.             x1 = x.clone();
  210.             fX2 = computeObjectiveValue(x2);

  211.             if (fX > fX2) {
  212.                 double t = 2 * (fX + fX2 - 2 * fVal);
  213.                 double temp = fX - fVal - delta;
  214.                 t *= temp * temp;
  215.                 temp = fX - fX2;
  216.                 t -= delta * temp * temp;

  217.                 if (t < 0.0) {
  218.                     final UnivariatePointValuePair optimum = line.search(x, d);
  219.                     fVal = optimum.getValue();
  220.                     alphaMin = optimum.getPoint();
  221.                     final double[][] result = newPointAndDirection(x, d, alphaMin);
  222.                     x = result[0];

  223.                     final int lastInd = n - 1;
  224.                     direc[bigInd] = direc[lastInd];
  225.                     direc[lastInd] = result[1];
  226.                 }
  227.             }
  228.         }
  229.     }

  230.     /**
  231.      * Compute a new point (in the original space) and a new direction
  232.      * vector, resulting from the line search.
  233.      *
  234.      * @param p Point used in the line search.
  235.      * @param d Direction used in the line search.
  236.      * @param optimum Optimum found by the line search.
  237.      * @return a 2-element array containing the new point (at index 0) and
  238.      * the new direction (at index 1).
  239.      */
  240.     private double[][] newPointAndDirection(double[] p,
  241.                                             double[] d,
  242.                                             double optimum) {
  243.         final int n = p.length;
  244.         final double[] nP = new double[n];
  245.         final double[] nD = new double[n];
  246.         for (int i = 0; i < n; i++) {
  247.             nD[i] = d[i] * optimum;
  248.             nP[i] = p[i] + nD[i];
  249.         }

  250.         final double[][] result = new double[2][];
  251.         result[0] = nP;
  252.         result[1] = nD;

  253.         return result;
  254.     }

  255.     /**
  256.      * Class for finding the minimum of the objective function along a given
  257.      * direction.
  258.      */
  259.     private class LineSearch extends BrentOptimizer {
  260.         /**
  261.          * Value that will pass the precondition check for {@link BrentOptimizer}
  262.          * but will not pass the convergence check, so that the custom checker
  263.          * will always decide when to stop the line search.
  264.          */
  265.         private static final double REL_TOL_UNUSED = 1e-15;
  266.         /**
  267.          * Value that will pass the precondition check for {@link BrentOptimizer}
  268.          * but will not pass the convergence check, so that the custom checker
  269.          * will always decide when to stop the line search.
  270.          */
  271.         private static final double ABS_TOL_UNUSED = Double.MIN_VALUE;
  272.         /**
  273.          * Automatic bracketing.
  274.          */
  275.         private final BracketFinder bracket = new BracketFinder();

  276.         /**
  277.          * The "BrentOptimizer" default stopping criterion uses the tolerances
  278.          * to check the domain (point) values, not the function values.
  279.          * We thus create a custom checker to use function values.
  280.          *
  281.          * @param rel Relative threshold.
  282.          * @param abs Absolute threshold.
  283.          */
  284.         LineSearch(double rel,
  285.                    double abs) {
  286.             super(REL_TOL_UNUSED,
  287.                   ABS_TOL_UNUSED,
  288.                   new SimpleUnivariateValueChecker(rel, abs));
  289.         }

  290.         /**
  291.          * Find the minimum of the function {@code f(p + alpha * d)}.
  292.          *
  293.          * @param p Starting point.
  294.          * @param d Search direction.
  295.          * @return the optimum.
  296.          * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
  297.          * if the number of evaluations is exceeded.
  298.          */
  299.         public UnivariatePointValuePair search(final double[] p, final double[] d) {
  300.             final int n = p.length;
  301.             final UnivariateFunction f = new UnivariateFunction() {
  302.                     public double value(double alpha) {
  303.                         final double[] x = new double[n];
  304.                         for (int i = 0; i < n; i++) {
  305.                             x[i] = p[i] + alpha * d[i];
  306.                         }
  307.                         final double obj = PowellOptimizer.this.computeObjectiveValue(x);
  308.                         return obj;
  309.                     }
  310.                 };

  311.             final GoalType goal = PowellOptimizer.this.getGoalType();
  312.             bracket.search(f, goal, 0, 1);
  313.             // Passing "MAX_VALUE" as a dummy value because it is the enclosing
  314.             // class that counts the number of evaluations (and will eventually
  315.             // generate the exception).
  316.             return optimize(Integer.MAX_VALUE, f, goal,
  317.                             bracket.getLo(), bracket.getHi(), bracket.getMid());
  318.         }
  319.     }
  320. }