AbstractCurveFitter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.fitting;

  18. import java.util.Collection;

  19. import org.apache.commons.math3.analysis.MultivariateVectorFunction;
  20. import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
  21. import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
  22. import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer;
  23. import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
  24. import org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer;

  25. /**
  26.  * Base class that contains common code for fitting parametric univariate
  27.  * real functions <code>y = f(p<sub>i</sub>;x)</code>, where {@code x} is
  28.  * the independent variable and the <code>p<sub>i</sub></code> are the
  29.  * <em>parameters</em>.
  30.  * <br/>
  31.  * A fitter will find the optimal values of the parameters by
  32.  * <em>fitting</em> the curve so it remains very close to a set of
  33.  * {@code N} observed points <code>(x<sub>k</sub>, y<sub>k</sub>)</code>,
  34.  * {@code 0 <= k < N}.
  35.  * <br/>
  36.  * An algorithm usually performs the fit by finding the parameter
  37.  * values that minimizes the objective function
  38.  * <pre><code>
  39.  *  &sum;y<sub>k</sub> - f(x<sub>k</sub>)<sup>2</sup>,
  40.  * </code></pre>
  41.  * which is actually a least-squares problem.
  42.  * This class contains boilerplate code for calling the
  43.  * {@link #fit(Collection)} method for obtaining the parameters.
  44.  * The problem setup, such as the choice of optimization algorithm
  45.  * for fitting a specific function is delegated to subclasses.
  46.  *
  47.  * @since 3.3
  48.  */
  49. public abstract class AbstractCurveFitter {
  50.     /**
  51.      * Fits a curve.
  52.      * This method computes the coefficients of the curve that best
  53.      * fit the sample of observed points.
  54.      *
  55.      * @param points Observations.
  56.      * @return the fitted parameters.
  57.      */
  58.     public double[] fit(Collection<WeightedObservedPoint> points) {
  59.         // Perform the fit.
  60.         return getOptimizer().optimize(getProblem(points)).getPoint().toArray();
  61.     }

  62.     /**
  63.      * Creates an optimizer set up to fit the appropriate curve.
  64.      * <p>
  65.      * The default implementation uses a {@link LevenbergMarquardtOptimizer
  66.      * Levenberg-Marquardt} optimizer.
  67.      * </p>
  68.      * @return the optimizer to use for fitting the curve to the
  69.      * given {@code points}.
  70.      */
  71.     protected LeastSquaresOptimizer getOptimizer() {
  72.         return new LevenbergMarquardtOptimizer();
  73.     }

  74.     /**
  75.      * Creates a least squares problem corresponding to the appropriate curve.
  76.      *
  77.      * @param points Sample points.
  78.      * @return the least squares problem to use for fitting the curve to the
  79.      * given {@code points}.
  80.      */
  81.     protected abstract LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points);

  82.     /**
  83.      * Vector function for computing function theoretical values.
  84.      */
  85.     protected static class TheoreticalValuesFunction {
  86.         /** Function to fit. */
  87.         private final ParametricUnivariateFunction f;
  88.         /** Observations. */
  89.         private final double[] points;

  90.         /**
  91.          * @param f function to fit.
  92.          * @param observations Observations.
  93.          */
  94.         public TheoreticalValuesFunction(final ParametricUnivariateFunction f,
  95.                                          final Collection<WeightedObservedPoint> observations) {
  96.             this.f = f;

  97.             final int len = observations.size();
  98.             this.points = new double[len];
  99.             int i = 0;
  100.             for (WeightedObservedPoint obs : observations) {
  101.                 this.points[i++] = obs.getX();
  102.             }
  103.         }

  104.         /**
  105.          * @return the model function values.
  106.          */
  107.         public MultivariateVectorFunction getModelFunction() {
  108.             return new MultivariateVectorFunction() {
  109.                 /** {@inheritDoc} */
  110.                 public double[] value(double[] p) {
  111.                     final int len = points.length;
  112.                     final double[] values = new double[len];
  113.                     for (int i = 0; i < len; i++) {
  114.                         values[i] = f.value(points[i], p);
  115.                     }

  116.                     return values;
  117.                 }
  118.             };
  119.         }

  120.         /**
  121.          * @return the model function Jacobian.
  122.          */
  123.         public MultivariateMatrixFunction getModelFunctionJacobian() {
  124.             return new MultivariateMatrixFunction() {
  125.                 public double[][] value(double[] p) {
  126.                     final int len = points.length;
  127.                     final double[][] jacobian = new double[len][];
  128.                     for (int i = 0; i < len; i++) {
  129.                         jacobian[i] = f.gradient(points[i], p);
  130.                     }
  131.                     return jacobian;
  132.                 }
  133.             };
  134.         }
  135.     }
  136. }