NonLinearConjugateGradientOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.optim.nonlinear.scalar.gradient;

  18. import org.apache.commons.math3.analysis.solvers.UnivariateSolver;
  19. import org.apache.commons.math3.exception.MathInternalError;
  20. import org.apache.commons.math3.exception.TooManyEvaluationsException;
  21. import org.apache.commons.math3.exception.MathUnsupportedOperationException;
  22. import org.apache.commons.math3.exception.util.LocalizedFormats;
  23. import org.apache.commons.math3.optim.OptimizationData;
  24. import org.apache.commons.math3.optim.PointValuePair;
  25. import org.apache.commons.math3.optim.ConvergenceChecker;
  26. import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
  27. import org.apache.commons.math3.optim.nonlinear.scalar.GradientMultivariateOptimizer;
  28. import org.apache.commons.math3.optim.nonlinear.scalar.LineSearch;


  29. /**
  30.  * Non-linear conjugate gradient optimizer.
  31.  * <br/>
  32.  * This class supports both the Fletcher-Reeves and the Polak-Ribière
  33.  * update formulas for the conjugate search directions.
  34.  * It also supports optional preconditioning.
  35.  * <br/>
  36.  * Constraints are not supported: the call to
  37.  * {@link #optimize(OptimizationData[]) optimize} will throw
  38.  * {@link MathUnsupportedOperationException} if bounds are passed to it.
  39.  *
  40.  * @since 2.0
  41.  */
  42. public class NonLinearConjugateGradientOptimizer
  43.     extends GradientMultivariateOptimizer {
  44.     /** Update formula for the beta parameter. */
  45.     private final Formula updateFormula;
  46.     /** Preconditioner (may be null). */
  47.     private final Preconditioner preconditioner;
  48.     /** Line search algorithm. */
  49.     private final LineSearch line;

  50.     /**
  51.      * Available choices of update formulas for the updating the parameter
  52.      * that is used to compute the successive conjugate search directions.
  53.      * For non-linear conjugate gradients, there are
  54.      * two formulas:
  55.      * <ul>
  56.      *   <li>Fletcher-Reeves formula</li>
  57.      *   <li>Polak-Ribière formula</li>
  58.      * </ul>
  59.      *
  60.      * On the one hand, the Fletcher-Reeves formula is guaranteed to converge
  61.      * if the start point is close enough of the optimum whether the
  62.      * Polak-Ribière formula may not converge in rare cases. On the
  63.      * other hand, the Polak-Ribière formula is often faster when it
  64.      * does converge. Polak-Ribière is often used.
  65.      *
  66.      * @since 2.0
  67.      */
  68.     public static enum Formula {
  69.         /** Fletcher-Reeves formula. */
  70.         FLETCHER_REEVES,
  71.         /** Polak-Ribière formula. */
  72.         POLAK_RIBIERE
  73.     }

  74.     /**
  75.      * The initial step is a factor with respect to the search direction
  76.      * (which itself is roughly related to the gradient of the function).
  77.      * <br/>
  78.      * It is used to find an interval that brackets the optimum in line
  79.      * search.
  80.      *
  81.      * @since 3.1
  82.      * @deprecated As of v3.3, this class is not used anymore.
  83.      * This setting is replaced by the {@code initialBracketingRange}
  84.      * argument to the new constructors.
  85.      */
  86.     @Deprecated
  87.     public static class BracketingStep implements OptimizationData {
  88.         /** Initial step. */
  89.         private final double initialStep;

  90.         /**
  91.          * @param step Initial step for the bracket search.
  92.          */
  93.         public BracketingStep(double step) {
  94.             initialStep = step;
  95.         }

  96.         /**
  97.          * Gets the initial step.
  98.          *
  99.          * @return the initial step.
  100.          */
  101.         public double getBracketingStep() {
  102.             return initialStep;
  103.         }
  104.     }

  105.     /**
  106.      * Constructor with default tolerances for the line search (1e-8) and
  107.      * {@link IdentityPreconditioner preconditioner}.
  108.      *
  109.      * @param updateFormula formula to use for updating the &beta; parameter,
  110.      * must be one of {@link Formula#FLETCHER_REEVES} or
  111.      * {@link Formula#POLAK_RIBIERE}.
  112.      * @param checker Convergence checker.
  113.      */
  114.     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
  115.                                                ConvergenceChecker<PointValuePair> checker) {
  116.         this(updateFormula,
  117.              checker,
  118.              1e-8,
  119.              1e-8,
  120.              1e-8,
  121.              new IdentityPreconditioner());
  122.     }

  123.     /**
  124.      * Constructor with default {@link IdentityPreconditioner preconditioner}.
  125.      *
  126.      * @param updateFormula formula to use for updating the &beta; parameter,
  127.      * must be one of {@link Formula#FLETCHER_REEVES} or
  128.      * {@link Formula#POLAK_RIBIERE}.
  129.      * @param checker Convergence checker.
  130.      * @param lineSearchSolver Solver to use during line search.
  131.      * @deprecated as of 3.3. Please use
  132.      * {@link #NonLinearConjugateGradientOptimizer(Formula,ConvergenceChecker,double,double,double)} instead.
  133.      */
  134.     @Deprecated
  135.     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
  136.                                                ConvergenceChecker<PointValuePair> checker,
  137.                                                final UnivariateSolver lineSearchSolver) {
  138.         this(updateFormula,
  139.              checker,
  140.              lineSearchSolver,
  141.              new IdentityPreconditioner());
  142.     }

  143.     /**
  144.      * Constructor with default {@link IdentityPreconditioner preconditioner}.
  145.      *
  146.      * @param updateFormula formula to use for updating the &beta; parameter,
  147.      * must be one of {@link Formula#FLETCHER_REEVES} or
  148.      * {@link Formula#POLAK_RIBIERE}.
  149.      * @param checker Convergence checker.
  150.      * @param relativeTolerance Relative threshold for line search.
  151.      * @param absoluteTolerance Absolute threshold for line search.
  152.      * @param initialBracketingRange Extent of the initial interval used to
  153.      * find an interval that brackets the optimum in order to perform the
  154.      * line search.
  155.      *
  156.      * @see LineSearch#LineSearch(MultivariateOptimizer,double,double,double)
  157.      * @since 3.3
  158.      */
  159.     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
  160.                                                ConvergenceChecker<PointValuePair> checker,
  161.                                                double relativeTolerance,
  162.                                                double absoluteTolerance,
  163.                                                double initialBracketingRange) {
  164.         this(updateFormula,
  165.              checker,
  166.              relativeTolerance,
  167.              absoluteTolerance,
  168.              initialBracketingRange,
  169.              new IdentityPreconditioner());
  170.     }

  171.     /**
  172.      * @param updateFormula formula to use for updating the &beta; parameter,
  173.      * must be one of {@link Formula#FLETCHER_REEVES} or
  174.      * {@link Formula#POLAK_RIBIERE}.
  175.      * @param checker Convergence checker.
  176.      * @param lineSearchSolver Solver to use during line search.
  177.      * @param preconditioner Preconditioner.
  178.      * @deprecated as of 3.3. Please use
  179.      * {@link #NonLinearConjugateGradientOptimizer(Formula,ConvergenceChecker,double,double,double,Preconditioner)} instead.
  180.      */
  181.     @Deprecated
  182.     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
  183.                                                ConvergenceChecker<PointValuePair> checker,
  184.                                                final UnivariateSolver lineSearchSolver,
  185.                                                final Preconditioner preconditioner) {
  186.         this(updateFormula,
  187.              checker,
  188.              lineSearchSolver.getRelativeAccuracy(),
  189.              lineSearchSolver.getAbsoluteAccuracy(),
  190.              lineSearchSolver.getAbsoluteAccuracy(),
  191.              preconditioner);
  192.     }

  193.     /**
  194.      * @param updateFormula formula to use for updating the &beta; parameter,
  195.      * must be one of {@link Formula#FLETCHER_REEVES} or
  196.      * {@link Formula#POLAK_RIBIERE}.
  197.      * @param checker Convergence checker.
  198.      * @param preconditioner Preconditioner.
  199.      * @param relativeTolerance Relative threshold for line search.
  200.      * @param absoluteTolerance Absolute threshold for line search.
  201.      * @param initialBracketingRange Extent of the initial interval used to
  202.      * find an interval that brackets the optimum in order to perform the
  203.      * line search.
  204.      *
  205.      * @see LineSearch#LineSearch(MultivariateOptimizer,double,double,double)
  206.      * @since 3.3
  207.      */
  208.     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
  209.                                                ConvergenceChecker<PointValuePair> checker,
  210.                                                double relativeTolerance,
  211.                                                double absoluteTolerance,
  212.                                                double initialBracketingRange,
  213.                                                final Preconditioner preconditioner) {
  214.         super(checker);

  215.         this.updateFormula = updateFormula;
  216.         this.preconditioner = preconditioner;
  217.         line = new LineSearch(this,
  218.                               relativeTolerance,
  219.                               absoluteTolerance,
  220.                               initialBracketingRange);
  221.     }

  222.     /**
  223.      * {@inheritDoc}
  224.      */
  225.     @Override
  226.     public PointValuePair optimize(OptimizationData... optData)
  227.         throws TooManyEvaluationsException {
  228.         // Set up base class and perform computation.
  229.         return super.optimize(optData);
  230.     }

  231.     /** {@inheritDoc} */
  232.     @Override
  233.     protected PointValuePair doOptimize() {
  234.         final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
  235.         final double[] point = getStartPoint();
  236.         final GoalType goal = getGoalType();
  237.         final int n = point.length;
  238.         double[] r = computeObjectiveGradient(point);
  239.         if (goal == GoalType.MINIMIZE) {
  240.             for (int i = 0; i < n; i++) {
  241.                 r[i] = -r[i];
  242.             }
  243.         }

  244.         // Initial search direction.
  245.         double[] steepestDescent = preconditioner.precondition(point, r);
  246.         double[] searchDirection = steepestDescent.clone();

  247.         double delta = 0;
  248.         for (int i = 0; i < n; ++i) {
  249.             delta += r[i] * searchDirection[i];
  250.         }

  251.         PointValuePair current = null;
  252.         while (true) {
  253.             incrementIterationCount();

  254.             final double objective = computeObjectiveValue(point);
  255.             PointValuePair previous = current;
  256.             current = new PointValuePair(point, objective);
  257.             if (previous != null && checker.converged(getIterations(), previous, current)) {
  258.                 // We have found an optimum.
  259.                 return current;
  260.             }

  261.             final double step = line.search(point, searchDirection).getPoint();

  262.             // Validate new point.
  263.             for (int i = 0; i < point.length; ++i) {
  264.                 point[i] += step * searchDirection[i];
  265.             }

  266.             r = computeObjectiveGradient(point);
  267.             if (goal == GoalType.MINIMIZE) {
  268.                 for (int i = 0; i < n; ++i) {
  269.                     r[i] = -r[i];
  270.                 }
  271.             }

  272.             // Compute beta.
  273.             final double deltaOld = delta;
  274.             final double[] newSteepestDescent = preconditioner.precondition(point, r);
  275.             delta = 0;
  276.             for (int i = 0; i < n; ++i) {
  277.                 delta += r[i] * newSteepestDescent[i];
  278.             }

  279.             final double beta;
  280.             switch (updateFormula) {
  281.             case FLETCHER_REEVES:
  282.                 beta = delta / deltaOld;
  283.                 break;
  284.             case POLAK_RIBIERE:
  285.                 double deltaMid = 0;
  286.                 for (int i = 0; i < r.length; ++i) {
  287.                     deltaMid += r[i] * steepestDescent[i];
  288.                 }
  289.                 beta = (delta - deltaMid) / deltaOld;
  290.                 break;
  291.             default:
  292.                 // Should never happen.
  293.                 throw new MathInternalError();
  294.             }
  295.             steepestDescent = newSteepestDescent;

  296.             // Compute conjugate search direction.
  297.             if (getIterations() % n == 0 ||
  298.                 beta < 0) {
  299.                 // Break conjugation: reset search direction.
  300.                 searchDirection = steepestDescent.clone();
  301.             } else {
  302.                 // Compute new conjugate search direction.
  303.                 for (int i = 0; i < n; ++i) {
  304.                     searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
  305.                 }
  306.             }
  307.         }
  308.     }

  309.     /**
  310.      * {@inheritDoc}
  311.      */
  312.     @Override
  313.     protected void parseOptimizationData(OptimizationData... optData) {
  314.         // Allow base class to register its own data.
  315.         super.parseOptimizationData(optData);

  316.         checkParameters();
  317.     }

  318.     /** Default identity preconditioner. */
  319.     public static class IdentityPreconditioner implements Preconditioner {
  320.         /** {@inheritDoc} */
  321.         public double[] precondition(double[] variables, double[] r) {
  322.             return r.clone();
  323.         }
  324.     }

  325.     // Class is not used anymore (cf. MATH-1092). However, it might
  326.     // be interesting to create a class similar to "LineSearch", but
  327.     // that will take advantage that the model's gradient is available.
  328. //     /**
  329. //      * Internal class for line search.
  330. //      * <p>
  331. //      * The function represented by this class is the dot product of
  332. //      * the objective function gradient and the search direction. Its
  333. //      * value is zero when the gradient is orthogonal to the search
  334. //      * direction, i.e. when the objective function value is a local
  335. //      * extremum along the search direction.
  336. //      * </p>
  337. //      */
  338. //     private class LineSearchFunction implements UnivariateFunction {
  339. //         /** Current point. */
  340. //         private final double[] currentPoint;
  341. //         /** Search direction. */
  342. //         private final double[] searchDirection;

  343. //         /**
  344. //          * @param point Current point.
  345. //          * @param direction Search direction.
  346. //          */
  347. //         public LineSearchFunction(double[] point,
  348. //                                   double[] direction) {
  349. //             currentPoint = point.clone();
  350. //             searchDirection = direction.clone();
  351. //         }

  352. //         /** {@inheritDoc} */
  353. //         public double value(double x) {
  354. //             // current point in the search direction
  355. //             final double[] shiftedPoint = currentPoint.clone();
  356. //             for (int i = 0; i < shiftedPoint.length; ++i) {
  357. //                 shiftedPoint[i] += x * searchDirection[i];
  358. //             }

  359. //             // gradient of the objective function
  360. //             final double[] gradient = computeObjectiveGradient(shiftedPoint);

  361. //             // dot product with the search direction
  362. //             double dotProduct = 0;
  363. //             for (int i = 0; i < gradient.length; ++i) {
  364. //                 dotProduct += gradient[i] * searchDirection[i];
  365. //             }

  366. //             return dotProduct;
  367. //         }
  368. //     }

  369.     /**
  370.      * @throws MathUnsupportedOperationException if bounds were passed to the
  371.      * {@link #optimize(OptimizationData[]) optimize} method.
  372.      */
  373.     private void checkParameters() {
  374.         if (getLowerBound() != null ||
  375.             getUpperBound() != null) {
  376.             throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT);
  377.         }
  378.     }
  379. }