NonLinearConjugateGradientOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.optimization.general;

  18. import org.apache.commons.math3.exception.MathIllegalStateException;
  19. import org.apache.commons.math3.analysis.UnivariateFunction;
  20. import org.apache.commons.math3.analysis.solvers.BrentSolver;
  21. import org.apache.commons.math3.analysis.solvers.UnivariateSolver;
  22. import org.apache.commons.math3.exception.util.LocalizedFormats;
  23. import org.apache.commons.math3.optimization.GoalType;
  24. import org.apache.commons.math3.optimization.PointValuePair;
  25. import org.apache.commons.math3.optimization.SimpleValueChecker;
  26. import org.apache.commons.math3.optimization.ConvergenceChecker;
  27. import org.apache.commons.math3.util.FastMath;

  28. /**
  29.  * Non-linear conjugate gradient optimizer.
  30.  * <p>
  31.  * This class supports both the Fletcher-Reeves and the Polak-Ribi&egrave;re
  32.  * update formulas for the conjugate search directions. It also supports
  33.  * optional preconditioning.
  34.  * </p>
  35.  *
  36.  * @deprecated As of 3.1 (to be removed in 4.0).
  37.  * @since 2.0
  38.  *
  39.  */
  40. @Deprecated
  41. public class NonLinearConjugateGradientOptimizer
  42.     extends AbstractScalarDifferentiableOptimizer {
  43.     /** Update formula for the beta parameter. */
  44.     private final ConjugateGradientFormula updateFormula;
  45.     /** Preconditioner (may be null). */
  46.     private final Preconditioner preconditioner;
  47.     /** solver to use in the line search (may be null). */
  48.     private final UnivariateSolver solver;
  49.     /** Initial step used to bracket the optimum in line search. */
  50.     private double initialStep;
  51.     /** Current point. */
  52.     private double[] point;

  53.     /**
  54.      * Constructor with default {@link SimpleValueChecker checker},
  55.      * {@link BrentSolver line search solver} and
  56.      * {@link IdentityPreconditioner preconditioner}.
  57.      *
  58.      * @param updateFormula formula to use for updating the &beta; parameter,
  59.      * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
  60.      * ConjugateGradientFormula#POLAK_RIBIERE}.
  61.      * @deprecated See {@link SimpleValueChecker#SimpleValueChecker()}
  62.      */
  63.     @Deprecated
  64.     public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula) {
  65.         this(updateFormula,
  66.              new SimpleValueChecker());
  67.     }

  68.     /**
  69.      * Constructor with default {@link BrentSolver line search solver} and
  70.      * {@link IdentityPreconditioner preconditioner}.
  71.      *
  72.      * @param updateFormula formula to use for updating the &beta; parameter,
  73.      * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
  74.      * ConjugateGradientFormula#POLAK_RIBIERE}.
  75.      * @param checker Convergence checker.
  76.      */
  77.     public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula,
  78.                                                ConvergenceChecker<PointValuePair> checker) {
  79.         this(updateFormula,
  80.              checker,
  81.              new BrentSolver(),
  82.              new IdentityPreconditioner());
  83.     }


  84.     /**
  85.      * Constructor with default {@link IdentityPreconditioner preconditioner}.
  86.      *
  87.      * @param updateFormula formula to use for updating the &beta; parameter,
  88.      * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
  89.      * ConjugateGradientFormula#POLAK_RIBIERE}.
  90.      * @param checker Convergence checker.
  91.      * @param lineSearchSolver Solver to use during line search.
  92.      */
  93.     public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula,
  94.                                                ConvergenceChecker<PointValuePair> checker,
  95.                                                final UnivariateSolver lineSearchSolver) {
  96.         this(updateFormula,
  97.              checker,
  98.              lineSearchSolver,
  99.              new IdentityPreconditioner());
  100.     }

  101.     /**
  102.      * @param updateFormula formula to use for updating the &beta; parameter,
  103.      * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
  104.      * ConjugateGradientFormula#POLAK_RIBIERE}.
  105.      * @param checker Convergence checker.
  106.      * @param lineSearchSolver Solver to use during line search.
  107.      * @param preconditioner Preconditioner.
  108.      */
  109.     public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula,
  110.                                                ConvergenceChecker<PointValuePair> checker,
  111.                                                final UnivariateSolver lineSearchSolver,
  112.                                                final Preconditioner preconditioner) {
  113.         super(checker);

  114.         this.updateFormula = updateFormula;
  115.         solver = lineSearchSolver;
  116.         this.preconditioner = preconditioner;
  117.         initialStep = 1.0;
  118.     }

  119.     /**
  120.      * Set the initial step used to bracket the optimum in line search.
  121.      * <p>
  122.      * The initial step is a factor with respect to the search direction,
  123.      * which itself is roughly related to the gradient of the function
  124.      * </p>
  125.      * @param initialStep initial step used to bracket the optimum in line search,
  126.      * if a non-positive value is used, the initial step is reset to its
  127.      * default value of 1.0
  128.      */
  129.     public void setInitialStep(final double initialStep) {
  130.         if (initialStep <= 0) {
  131.             this.initialStep = 1.0;
  132.         } else {
  133.             this.initialStep = initialStep;
  134.         }
  135.     }

  136.     /** {@inheritDoc} */
  137.     @Override
  138.     protected PointValuePair doOptimize() {
  139.         final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
  140.         point = getStartPoint();
  141.         final GoalType goal = getGoalType();
  142.         final int n = point.length;
  143.         double[] r = computeObjectiveGradient(point);
  144.         if (goal == GoalType.MINIMIZE) {
  145.             for (int i = 0; i < n; ++i) {
  146.                 r[i] = -r[i];
  147.             }
  148.         }

  149.         // Initial search direction.
  150.         double[] steepestDescent = preconditioner.precondition(point, r);
  151.         double[] searchDirection = steepestDescent.clone();

  152.         double delta = 0;
  153.         for (int i = 0; i < n; ++i) {
  154.             delta += r[i] * searchDirection[i];
  155.         }

  156.         PointValuePair current = null;
  157.         int iter = 0;
  158.         int maxEval = getMaxEvaluations();
  159.         while (true) {
  160.             ++iter;

  161.             final double objective = computeObjectiveValue(point);
  162.             PointValuePair previous = current;
  163.             current = new PointValuePair(point, objective);
  164.             if (previous != null && checker.converged(iter, previous, current)) {
  165.                 // We have found an optimum.
  166.                 return current;
  167.             }

  168.             // Find the optimal step in the search direction.
  169.             final UnivariateFunction lsf = new LineSearchFunction(searchDirection);
  170.             final double uB = findUpperBound(lsf, 0, initialStep);
  171.             // XXX Last parameters is set to a value close to zero in order to
  172.             // work around the divergence problem in the "testCircleFitting"
  173.             // unit test (see MATH-439).
  174.             final double step = solver.solve(maxEval, lsf, 0, uB, 1e-15);
  175.             maxEval -= solver.getEvaluations(); // Subtract used up evaluations.

  176.             // Validate new point.
  177.             for (int i = 0; i < point.length; ++i) {
  178.                 point[i] += step * searchDirection[i];
  179.             }

  180.             r = computeObjectiveGradient(point);
  181.             if (goal == GoalType.MINIMIZE) {
  182.                 for (int i = 0; i < n; ++i) {
  183.                     r[i] = -r[i];
  184.                 }
  185.             }

  186.             // Compute beta.
  187.             final double deltaOld = delta;
  188.             final double[] newSteepestDescent = preconditioner.precondition(point, r);
  189.             delta = 0;
  190.             for (int i = 0; i < n; ++i) {
  191.                 delta += r[i] * newSteepestDescent[i];
  192.             }

  193.             final double beta;
  194.             if (updateFormula == ConjugateGradientFormula.FLETCHER_REEVES) {
  195.                 beta = delta / deltaOld;
  196.             } else {
  197.                 double deltaMid = 0;
  198.                 for (int i = 0; i < r.length; ++i) {
  199.                     deltaMid += r[i] * steepestDescent[i];
  200.                 }
  201.                 beta = (delta - deltaMid) / deltaOld;
  202.             }
  203.             steepestDescent = newSteepestDescent;

  204.             // Compute conjugate search direction.
  205.             if (iter % n == 0 ||
  206.                 beta < 0) {
  207.                 // Break conjugation: reset search direction.
  208.                 searchDirection = steepestDescent.clone();
  209.             } else {
  210.                 // Compute new conjugate search direction.
  211.                 for (int i = 0; i < n; ++i) {
  212.                     searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
  213.                 }
  214.             }
  215.         }
  216.     }

  217.     /**
  218.      * Find the upper bound b ensuring bracketing of a root between a and b.
  219.      *
  220.      * @param f function whose root must be bracketed.
  221.      * @param a lower bound of the interval.
  222.      * @param h initial step to try.
  223.      * @return b such that f(a) and f(b) have opposite signs.
  224.      * @throws MathIllegalStateException if no bracket can be found.
  225.      */
  226.     private double findUpperBound(final UnivariateFunction f,
  227.                                   final double a, final double h) {
  228.         final double yA = f.value(a);
  229.         double yB = yA;
  230.         for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) {
  231.             final double b = a + step;
  232.             yB = f.value(b);
  233.             if (yA * yB <= 0) {
  234.                 return b;
  235.             }
  236.         }
  237.         throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
  238.     }

  239.     /** Default identity preconditioner. */
  240.     public static class IdentityPreconditioner implements Preconditioner {

  241.         /** {@inheritDoc} */
  242.         public double[] precondition(double[] variables, double[] r) {
  243.             return r.clone();
  244.         }
  245.     }

  246.     /** Internal class for line search.
  247.      * <p>
  248.      * The function represented by this class is the dot product of
  249.      * the objective function gradient and the search direction. Its
  250.      * value is zero when the gradient is orthogonal to the search
  251.      * direction, i.e. when the objective function value is a local
  252.      * extremum along the search direction.
  253.      * </p>
  254.      */
  255.     private class LineSearchFunction implements UnivariateFunction {
  256.         /** Search direction. */
  257.         private final double[] searchDirection;

  258.         /** Simple constructor.
  259.          * @param searchDirection search direction
  260.          */
  261.         public LineSearchFunction(final double[] searchDirection) {
  262.             this.searchDirection = searchDirection;
  263.         }

  264.         /** {@inheritDoc} */
  265.         public double value(double x) {
  266.             // current point in the search direction
  267.             final double[] shiftedPoint = point.clone();
  268.             for (int i = 0; i < shiftedPoint.length; ++i) {
  269.                 shiftedPoint[i] += x * searchDirection[i];
  270.             }

  271.             // gradient of the objective function
  272.             final double[] gradient = computeObjectiveGradient(shiftedPoint);

  273.             // dot product with the search direction
  274.             double dotProduct = 0;
  275.             for (int i = 0; i < gradient.length; ++i) {
  276.                 dotProduct += gradient[i] * searchDirection[i];
  277.             }

  278.             return dotProduct;
  279.         }
  280.     }
  281. }