GaussNewtonOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.optimization.general;

  18. import org.apache.commons.math3.exception.ConvergenceException;
  19. import org.apache.commons.math3.exception.NullArgumentException;
  20. import org.apache.commons.math3.exception.MathInternalError;
  21. import org.apache.commons.math3.exception.util.LocalizedFormats;
  22. import org.apache.commons.math3.linear.ArrayRealVector;
  23. import org.apache.commons.math3.linear.BlockRealMatrix;
  24. import org.apache.commons.math3.linear.DecompositionSolver;
  25. import org.apache.commons.math3.linear.LUDecomposition;
  26. import org.apache.commons.math3.linear.QRDecomposition;
  27. import org.apache.commons.math3.linear.RealMatrix;
  28. import org.apache.commons.math3.linear.SingularMatrixException;
  29. import org.apache.commons.math3.optimization.ConvergenceChecker;
  30. import org.apache.commons.math3.optimization.SimpleVectorValueChecker;
  31. import org.apache.commons.math3.optimization.PointVectorValuePair;

  32. /**
  33.  * Gauss-Newton least-squares solver.
  34.  * <p>
  35.  * This class solve a least-square problem by solving the normal equations
  36.  * of the linearized problem at each iteration. Either LU decomposition or
  37.  * QR decomposition can be used to solve the normal equations. LU decomposition
  38.  * is faster but QR decomposition is more robust for difficult problems.
  39.  * </p>
  40.  *
  41.  * @deprecated As of 3.1 (to be removed in 4.0).
  42.  * @since 2.0
  43.  *
  44.  */
  45. @Deprecated
  46. public class GaussNewtonOptimizer extends AbstractLeastSquaresOptimizer {
  47.     /** Indicator for using LU decomposition. */
  48.     private final boolean useLU;

  49.     /**
  50.      * Simple constructor with default settings.
  51.      * The normal equations will be solved using LU decomposition and the
  52.      * convergence check is set to a {@link SimpleVectorValueChecker}
  53.      * with default tolerances.
  54.      * @deprecated See {@link SimpleVectorValueChecker#SimpleVectorValueChecker()}
  55.      */
  56.     @Deprecated
  57.     public GaussNewtonOptimizer() {
  58.         this(true);
  59.     }

  60.     /**
  61.      * Simple constructor with default settings.
  62.      * The normal equations will be solved using LU decomposition.
  63.      *
  64.      * @param checker Convergence checker.
  65.      */
  66.     public GaussNewtonOptimizer(ConvergenceChecker<PointVectorValuePair> checker) {
  67.         this(true, checker);
  68.     }

  69.     /**
  70.      * Simple constructor with default settings.
  71.      * The convergence check is set to a {@link SimpleVectorValueChecker}
  72.      * with default tolerances.
  73.      *
  74.      * @param useLU If {@code true}, the normal equations will be solved
  75.      * using LU decomposition, otherwise they will be solved using QR
  76.      * decomposition.
  77.      * @deprecated See {@link SimpleVectorValueChecker#SimpleVectorValueChecker()}
  78.      */
  79.     @Deprecated
  80.     public GaussNewtonOptimizer(final boolean useLU) {
  81.         this(useLU, new SimpleVectorValueChecker());
  82.     }

  83.     /**
  84.      * @param useLU If {@code true}, the normal equations will be solved
  85.      * using LU decomposition, otherwise they will be solved using QR
  86.      * decomposition.
  87.      * @param checker Convergence checker.
  88.      */
  89.     public GaussNewtonOptimizer(final boolean useLU,
  90.                                 ConvergenceChecker<PointVectorValuePair> checker) {
  91.         super(checker);
  92.         this.useLU = useLU;
  93.     }

  94.     /** {@inheritDoc} */
  95.     @Override
  96.     public PointVectorValuePair doOptimize() {
  97.         final ConvergenceChecker<PointVectorValuePair> checker
  98.             = getConvergenceChecker();

  99.         // Computation will be useless without a checker (see "for-loop").
  100.         if (checker == null) {
  101.             throw new NullArgumentException();
  102.         }

  103.         final double[] targetValues = getTarget();
  104.         final int nR = targetValues.length; // Number of observed data.

  105.         final RealMatrix weightMatrix = getWeight();
  106.         // Diagonal of the weight matrix.
  107.         final double[] residualsWeights = new double[nR];
  108.         for (int i = 0; i < nR; i++) {
  109.             residualsWeights[i] = weightMatrix.getEntry(i, i);
  110.         }

  111.         final double[] currentPoint = getStartPoint();
  112.         final int nC = currentPoint.length;

  113.         // iterate until convergence is reached
  114.         PointVectorValuePair current = null;
  115.         int iter = 0;
  116.         for (boolean converged = false; !converged;) {
  117.             ++iter;

  118.             // evaluate the objective function and its jacobian
  119.             PointVectorValuePair previous = current;
  120.             // Value of the objective function at "currentPoint".
  121.             final double[] currentObjective = computeObjectiveValue(currentPoint);
  122.             final double[] currentResiduals = computeResiduals(currentObjective);
  123.             final RealMatrix weightedJacobian = computeWeightedJacobian(currentPoint);
  124.             current = new PointVectorValuePair(currentPoint, currentObjective);

  125.             // build the linear problem
  126.             final double[]   b = new double[nC];
  127.             final double[][] a = new double[nC][nC];
  128.             for (int i = 0; i < nR; ++i) {

  129.                 final double[] grad   = weightedJacobian.getRow(i);
  130.                 final double weight   = residualsWeights[i];
  131.                 final double residual = currentResiduals[i];

  132.                 // compute the normal equation
  133.                 final double wr = weight * residual;
  134.                 for (int j = 0; j < nC; ++j) {
  135.                     b[j] += wr * grad[j];
  136.                 }

  137.                 // build the contribution matrix for measurement i
  138.                 for (int k = 0; k < nC; ++k) {
  139.                     double[] ak = a[k];
  140.                     double wgk = weight * grad[k];
  141.                     for (int l = 0; l < nC; ++l) {
  142.                         ak[l] += wgk * grad[l];
  143.                     }
  144.                 }
  145.             }

  146.             try {
  147.                 // solve the linearized least squares problem
  148.                 RealMatrix mA = new BlockRealMatrix(a);
  149.                 DecompositionSolver solver = useLU ?
  150.                         new LUDecomposition(mA).getSolver() :
  151.                         new QRDecomposition(mA).getSolver();
  152.                 final double[] dX = solver.solve(new ArrayRealVector(b, false)).toArray();
  153.                 // update the estimated parameters
  154.                 for (int i = 0; i < nC; ++i) {
  155.                     currentPoint[i] += dX[i];
  156.                 }
  157.             } catch (SingularMatrixException e) {
  158.                 throw new ConvergenceException(LocalizedFormats.UNABLE_TO_SOLVE_SINGULAR_PROBLEM);
  159.             }

  160.             // Check convergence.
  161.             if (previous != null) {
  162.                 converged = checker.converged(iter, previous, current);
  163.                 if (converged) {
  164.                     cost = computeCost(currentResiduals);
  165.                     // Update (deprecated) "point" field.
  166.                     point = current.getPoint();
  167.                     return current;
  168.                 }
  169.             }
  170.         }
  171.         // Must never happen.
  172.         throw new MathInternalError();
  173.     }
  174. }