BaseAbstractMultivariateVectorOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.optimization.direct;

  18. import org.apache.commons.math3.util.Incrementor;
  19. import org.apache.commons.math3.exception.MaxCountExceededException;
  20. import org.apache.commons.math3.exception.TooManyEvaluationsException;
  21. import org.apache.commons.math3.exception.DimensionMismatchException;
  22. import org.apache.commons.math3.exception.NullArgumentException;
  23. import org.apache.commons.math3.analysis.MultivariateVectorFunction;
  24. import org.apache.commons.math3.optimization.OptimizationData;
  25. import org.apache.commons.math3.optimization.InitialGuess;
  26. import org.apache.commons.math3.optimization.Target;
  27. import org.apache.commons.math3.optimization.Weight;
  28. import org.apache.commons.math3.optimization.BaseMultivariateVectorOptimizer;
  29. import org.apache.commons.math3.optimization.ConvergenceChecker;
  30. import org.apache.commons.math3.optimization.PointVectorValuePair;
  31. import org.apache.commons.math3.optimization.SimpleVectorValueChecker;
  32. import org.apache.commons.math3.linear.RealMatrix;

  33. /**
  34.  * Base class for implementing optimizers for multivariate scalar functions.
  35.  * This base class handles the boiler-plate methods associated to thresholds
  36.  * settings, iterations and evaluations counting.
  37.  *
  38.  * @param <FUNC> the type of the objective function to be optimized
  39.  *
  40.  * @deprecated As of 3.1 (to be removed in 4.0).
  41.  * @since 3.0
  42.  */
  43. @Deprecated
  44. public abstract class BaseAbstractMultivariateVectorOptimizer<FUNC extends MultivariateVectorFunction>
  45.     implements BaseMultivariateVectorOptimizer<FUNC> {
  46.     /** Evaluations counter. */
  47.     protected final Incrementor evaluations = new Incrementor();
  48.     /** Convergence checker. */
  49.     private ConvergenceChecker<PointVectorValuePair> checker;
  50.     /** Target value for the objective functions at optimum. */
  51.     private double[] target;
  52.     /** Weight matrix. */
  53.     private RealMatrix weightMatrix;
  54.     /** Weight for the least squares cost computation.
  55.      * @deprecated
  56.      */
  57.     @Deprecated
  58.     private double[] weight;
  59.     /** Initial guess. */
  60.     private double[] start;
  61.     /** Objective function. */
  62.     private FUNC function;

  63.     /**
  64.      * Simple constructor with default settings.
  65.      * The convergence check is set to a {@link SimpleVectorValueChecker}.
  66.      * @deprecated See {@link SimpleVectorValueChecker#SimpleVectorValueChecker()}
  67.      */
  68.     @Deprecated
  69.     protected BaseAbstractMultivariateVectorOptimizer() {
  70.         this(new SimpleVectorValueChecker());
  71.     }
  72.     /**
  73.      * @param checker Convergence checker.
  74.      */
  75.     protected BaseAbstractMultivariateVectorOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
  76.         this.checker = checker;
  77.     }

  78.     /** {@inheritDoc} */
  79.     public int getMaxEvaluations() {
  80.         return evaluations.getMaximalCount();
  81.     }

  82.     /** {@inheritDoc} */
  83.     public int getEvaluations() {
  84.         return evaluations.getCount();
  85.     }

  86.     /** {@inheritDoc} */
  87.     public ConvergenceChecker<PointVectorValuePair> getConvergenceChecker() {
  88.         return checker;
  89.     }

  90.     /**
  91.      * Compute the objective function value.
  92.      *
  93.      * @param point Point at which the objective function must be evaluated.
  94.      * @return the objective function value at the specified point.
  95.      * @throws TooManyEvaluationsException if the maximal number of evaluations is
  96.      * exceeded.
  97.      */
  98.     protected double[] computeObjectiveValue(double[] point) {
  99.         try {
  100.             evaluations.incrementCount();
  101.         } catch (MaxCountExceededException e) {
  102.             throw new TooManyEvaluationsException(e.getMax());
  103.         }
  104.         return function.value(point);
  105.     }

  106.     /** {@inheritDoc}
  107.      *
  108.      * @deprecated As of 3.1. Please use
  109.      * {@link #optimize(int,MultivariateVectorFunction,OptimizationData[])}
  110.      * instead.
  111.      */
  112.     @Deprecated
  113.     public PointVectorValuePair optimize(int maxEval, FUNC f, double[] t, double[] w,
  114.                                          double[] startPoint) {
  115.         return optimizeInternal(maxEval, f, t, w, startPoint);
  116.     }

  117.     /**
  118.      * Optimize an objective function.
  119.      *
  120.      * @param maxEval Allowed number of evaluations of the objective function.
  121.      * @param f Objective function.
  122.      * @param optData Optimization data. The following data will be looked for:
  123.      * <ul>
  124.      *  <li>{@link Target}</li>
  125.      *  <li>{@link Weight}</li>
  126.      *  <li>{@link InitialGuess}</li>
  127.      * </ul>
  128.      * @return the point/value pair giving the optimal value of the objective
  129.      * function.
  130.      * @throws TooManyEvaluationsException if the maximal number of
  131.      * evaluations is exceeded.
  132.      * @throws DimensionMismatchException if the initial guess, target, and weight
  133.      * arguments have inconsistent dimensions.
  134.      *
  135.      * @since 3.1
  136.      */
  137.     protected PointVectorValuePair optimize(int maxEval,
  138.                                             FUNC f,
  139.                                             OptimizationData... optData)
  140.         throws TooManyEvaluationsException,
  141.                DimensionMismatchException {
  142.         return optimizeInternal(maxEval, f, optData);
  143.     }

  144.     /**
  145.      * Optimize an objective function.
  146.      * Optimization is considered to be a weighted least-squares minimization.
  147.      * The cost function to be minimized is
  148.      * <code>&sum;weight<sub>i</sub>(objective<sub>i</sub> - target<sub>i</sub>)<sup>2</sup></code>
  149.      *
  150.      * @param f Objective function.
  151.      * @param t Target value for the objective functions at optimum.
  152.      * @param w Weights for the least squares cost computation.
  153.      * @param startPoint Start point for optimization.
  154.      * @return the point/value pair giving the optimal value for objective
  155.      * function.
  156.      * @param maxEval Maximum number of function evaluations.
  157.      * @throws org.apache.commons.math3.exception.DimensionMismatchException
  158.      * if the start point dimension is wrong.
  159.      * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
  160.      * if the maximal number of evaluations is exceeded.
  161.      * @throws org.apache.commons.math3.exception.NullArgumentException if
  162.      * any argument is {@code null}.
  163.      * @deprecated As of 3.1. Please use
  164.      * {@link #optimizeInternal(int,MultivariateVectorFunction,OptimizationData[])}
  165.      * instead.
  166.      */
  167.     @Deprecated
  168.     protected PointVectorValuePair optimizeInternal(final int maxEval, final FUNC f,
  169.                                                     final double[] t, final double[] w,
  170.                                                     final double[] startPoint) {
  171.         // Checks.
  172.         if (f == null) {
  173.             throw new NullArgumentException();
  174.         }
  175.         if (t == null) {
  176.             throw new NullArgumentException();
  177.         }
  178.         if (w == null) {
  179.             throw new NullArgumentException();
  180.         }
  181.         if (startPoint == null) {
  182.             throw new NullArgumentException();
  183.         }
  184.         if (t.length != w.length) {
  185.             throw new DimensionMismatchException(t.length, w.length);
  186.         }

  187.         return optimizeInternal(maxEval, f,
  188.                                 new Target(t),
  189.                                 new Weight(w),
  190.                                 new InitialGuess(startPoint));
  191.     }

  192.     /**
  193.      * Optimize an objective function.
  194.      *
  195.      * @param maxEval Allowed number of evaluations of the objective function.
  196.      * @param f Objective function.
  197.      * @param optData Optimization data. The following data will be looked for:
  198.      * <ul>
  199.      *  <li>{@link Target}</li>
  200.      *  <li>{@link Weight}</li>
  201.      *  <li>{@link InitialGuess}</li>
  202.      * </ul>
  203.      * @return the point/value pair giving the optimal value of the objective
  204.      * function.
  205.      * @throws TooManyEvaluationsException if the maximal number of
  206.      * evaluations is exceeded.
  207.      * @throws DimensionMismatchException if the initial guess, target, and weight
  208.      * arguments have inconsistent dimensions.
  209.      *
  210.      * @since 3.1
  211.      */
  212.     protected PointVectorValuePair optimizeInternal(int maxEval,
  213.                                                     FUNC f,
  214.                                                     OptimizationData... optData)
  215.         throws TooManyEvaluationsException,
  216.                DimensionMismatchException {
  217.         // Set internal state.
  218.         evaluations.setMaximalCount(maxEval);
  219.         evaluations.resetCount();
  220.         function = f;
  221.         // Retrieve other settings.
  222.         parseOptimizationData(optData);
  223.         // Check input consistency.
  224.         checkParameters();
  225.         // Allow subclasses to reset their own internal state.
  226.         setUp();
  227.         // Perform computation.
  228.         return doOptimize();
  229.     }

  230.     /**
  231.      * Gets the initial values of the optimized parameters.
  232.      *
  233.      * @return the initial guess.
  234.      */
  235.     public double[] getStartPoint() {
  236.         return start.clone();
  237.     }

  238.     /**
  239.      * Gets the weight matrix of the observations.
  240.      *
  241.      * @return the weight matrix.
  242.      * @since 3.1
  243.      */
  244.     public RealMatrix getWeight() {
  245.         return weightMatrix.copy();
  246.     }
  247.     /**
  248.      * Gets the observed values to be matched by the objective vector
  249.      * function.
  250.      *
  251.      * @return the target values.
  252.      * @since 3.1
  253.      */
  254.     public double[] getTarget() {
  255.         return target.clone();
  256.     }

  257.     /**
  258.      * Gets the objective vector function.
  259.      * Note that this access bypasses the evaluation counter.
  260.      *
  261.      * @return the objective vector function.
  262.      * @since 3.1
  263.      */
  264.     protected FUNC getObjectiveFunction() {
  265.         return function;
  266.     }

  267.     /**
  268.      * Perform the bulk of the optimization algorithm.
  269.      *
  270.      * @return the point/value pair giving the optimal value for the
  271.      * objective function.
  272.      */
  273.     protected abstract PointVectorValuePair doOptimize();

  274.     /**
  275.      * @return a reference to the {@link #target array}.
  276.      * @deprecated As of 3.1.
  277.      */
  278.     @Deprecated
  279.     protected double[] getTargetRef() {
  280.         return target;
  281.     }
  282.     /**
  283.      * @return a reference to the {@link #weight array}.
  284.      * @deprecated As of 3.1.
  285.      */
  286.     @Deprecated
  287.     protected double[] getWeightRef() {
  288.         return weight;
  289.     }

  290.     /**
  291.      * Method which a subclass <em>must</em> override whenever its internal
  292.      * state depend on the {@link OptimizationData input} parsed by this base
  293.      * class.
  294.      * It will be called after the parsing step performed in the
  295.      * {@link #optimize(int,MultivariateVectorFunction,OptimizationData[])
  296.      * optimize} method and just before {@link #doOptimize()}.
  297.      *
  298.      * @since 3.1
  299.      */
  300.     protected void setUp() {
  301.         // XXX Temporary code until the new internal data is used everywhere.
  302.         final int dim = target.length;
  303.         weight = new double[dim];
  304.         for (int i = 0; i < dim; i++) {
  305.             weight[i] = weightMatrix.getEntry(i, i);
  306.         }
  307.     }

  308.     /**
  309.      * Scans the list of (required and optional) optimization data that
  310.      * characterize the problem.
  311.      *
  312.      * @param optData Optimization data. The following data will be looked for:
  313.      * <ul>
  314.      *  <li>{@link Target}</li>
  315.      *  <li>{@link Weight}</li>
  316.      *  <li>{@link InitialGuess}</li>
  317.      * </ul>
  318.      */
  319.     private void parseOptimizationData(OptimizationData... optData) {
  320.         // The existing values (as set by the previous call) are reused if
  321.         // not provided in the argument list.
  322.         for (OptimizationData data : optData) {
  323.             if (data instanceof Target) {
  324.                 target = ((Target) data).getTarget();
  325.                 continue;
  326.             }
  327.             if (data instanceof Weight) {
  328.                 weightMatrix = ((Weight) data).getWeight();
  329.                 continue;
  330.             }
  331.             if (data instanceof InitialGuess) {
  332.                 start = ((InitialGuess) data).getInitialGuess();
  333.                 continue;
  334.             }
  335.         }
  336.     }

  337.     /**
  338.      * Check parameters consistency.
  339.      *
  340.      * @throws DimensionMismatchException if {@link #target} and
  341.      * {@link #weightMatrix} have inconsistent dimensions.
  342.      */
  343.     private void checkParameters() {
  344.         if (target.length != weightMatrix.getColumnDimension()) {
  345.             throw new DimensionMismatchException(target.length,
  346.                                                  weightMatrix.getColumnDimension());
  347.         }
  348.     }
  349. }