CurveFitter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.optimization.fitting;

  18. import java.util.ArrayList;
  19. import java.util.List;

  20. import org.apache.commons.math3.analysis.DifferentiableMultivariateVectorFunction;
  21. import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
  22. import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
  23. import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
  24. import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction;
  25. import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
  26. import org.apache.commons.math3.optimization.MultivariateDifferentiableVectorOptimizer;
  27. import org.apache.commons.math3.optimization.PointVectorValuePair;

  28. /** Fitter for parametric univariate real functions y = f(x).
  29.  * <br/>
  30.  * When a univariate real function y = f(x) does depend on some
  31.  * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
  32.  * this class can be used to find these parameters. It does this
  33.  * by <em>fitting</em> the curve so it remains very close to a set of
  34.  * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
  35.  * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
  36.  * is done by finding the parameters values that minimizes the objective
  37.  * function &sum;(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
  38.  * really a least squares problem.
  39.  *
  40.  * @param <T> Function to use for the fit.
  41.  *
  42.  * @deprecated As of 3.1 (to be removed in 4.0).
  43.  * @since 2.0
  44.  */
  45. @Deprecated
  46. public class CurveFitter<T extends ParametricUnivariateFunction> {

  47.     /** Optimizer to use for the fitting.
  48.      * @deprecated as of 3.1 replaced by {@link #optimizer}
  49.      */
  50.     @Deprecated
  51.     private final DifferentiableMultivariateVectorOptimizer oldOptimizer;

  52.     /** Optimizer to use for the fitting. */
  53.     private final MultivariateDifferentiableVectorOptimizer optimizer;

  54.     /** Observed points. */
  55.     private final List<WeightedObservedPoint> observations;

  56.     /** Simple constructor.
  57.      * @param optimizer optimizer to use for the fitting
  58.      * @deprecated as of 3.1 replaced by {@link #CurveFitter(MultivariateDifferentiableVectorOptimizer)}
  59.      */
  60.     @Deprecated
  61.     public CurveFitter(final DifferentiableMultivariateVectorOptimizer optimizer) {
  62.         this.oldOptimizer = optimizer;
  63.         this.optimizer    = null;
  64.         observations      = new ArrayList<WeightedObservedPoint>();
  65.     }

  66.     /** Simple constructor.
  67.      * @param optimizer optimizer to use for the fitting
  68.      * @since 3.1
  69.      */
  70.     public CurveFitter(final MultivariateDifferentiableVectorOptimizer optimizer) {
  71.         this.oldOptimizer = null;
  72.         this.optimizer    = optimizer;
  73.         observations      = new ArrayList<WeightedObservedPoint>();
  74.     }

  75.     /** Add an observed (x,y) point to the sample with unit weight.
  76.      * <p>Calling this method is equivalent to call
  77.      * {@code addObservedPoint(1.0, x, y)}.</p>
  78.      * @param x abscissa of the point
  79.      * @param y observed value of the point at x, after fitting we should
  80.      * have f(x) as close as possible to this value
  81.      * @see #addObservedPoint(double, double, double)
  82.      * @see #addObservedPoint(WeightedObservedPoint)
  83.      * @see #getObservations()
  84.      */
  85.     public void addObservedPoint(double x, double y) {
  86.         addObservedPoint(1.0, x, y);
  87.     }

  88.     /** Add an observed weighted (x,y) point to the sample.
  89.      * @param weight weight of the observed point in the fit
  90.      * @param x abscissa of the point
  91.      * @param y observed value of the point at x, after fitting we should
  92.      * have f(x) as close as possible to this value
  93.      * @see #addObservedPoint(double, double)
  94.      * @see #addObservedPoint(WeightedObservedPoint)
  95.      * @see #getObservations()
  96.      */
  97.     public void addObservedPoint(double weight, double x, double y) {
  98.         observations.add(new WeightedObservedPoint(weight, x, y));
  99.     }

  100.     /** Add an observed weighted (x,y) point to the sample.
  101.      * @param observed observed point to add
  102.      * @see #addObservedPoint(double, double)
  103.      * @see #addObservedPoint(double, double, double)
  104.      * @see #getObservations()
  105.      */
  106.     public void addObservedPoint(WeightedObservedPoint observed) {
  107.         observations.add(observed);
  108.     }

  109.     /** Get the observed points.
  110.      * @return observed points
  111.      * @see #addObservedPoint(double, double)
  112.      * @see #addObservedPoint(double, double, double)
  113.      * @see #addObservedPoint(WeightedObservedPoint)
  114.      */
  115.     public WeightedObservedPoint[] getObservations() {
  116.         return observations.toArray(new WeightedObservedPoint[observations.size()]);
  117.     }

  118.     /**
  119.      * Remove all observations.
  120.      */
  121.     public void clearObservations() {
  122.         observations.clear();
  123.     }

  124.     /**
  125.      * Fit a curve.
  126.      * This method compute the coefficients of the curve that best
  127.      * fit the sample of observed points previously given through calls
  128.      * to the {@link #addObservedPoint(WeightedObservedPoint)
  129.      * addObservedPoint} method.
  130.      *
  131.      * @param f parametric function to fit.
  132.      * @param initialGuess first guess of the function parameters.
  133.      * @return the fitted parameters.
  134.      * @throws org.apache.commons.math3.exception.DimensionMismatchException
  135.      * if the start point dimension is wrong.
  136.      */
  137.     public double[] fit(T f, final double[] initialGuess) {
  138.         return fit(Integer.MAX_VALUE, f, initialGuess);
  139.     }

  140.     /**
  141.      * Fit a curve.
  142.      * This method compute the coefficients of the curve that best
  143.      * fit the sample of observed points previously given through calls
  144.      * to the {@link #addObservedPoint(WeightedObservedPoint)
  145.      * addObservedPoint} method.
  146.      *
  147.      * @param f parametric function to fit.
  148.      * @param initialGuess first guess of the function parameters.
  149.      * @param maxEval Maximum number of function evaluations.
  150.      * @return the fitted parameters.
  151.      * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
  152.      * if the number of allowed evaluations is exceeded.
  153.      * @throws org.apache.commons.math3.exception.DimensionMismatchException
  154.      * if the start point dimension is wrong.
  155.      * @since 3.0
  156.      */
  157.     public double[] fit(int maxEval, T f,
  158.                         final double[] initialGuess) {
  159.         // prepare least squares problem
  160.         double[] target  = new double[observations.size()];
  161.         double[] weights = new double[observations.size()];
  162.         int i = 0;
  163.         for (WeightedObservedPoint point : observations) {
  164.             target[i]  = point.getY();
  165.             weights[i] = point.getWeight();
  166.             ++i;
  167.         }

  168.         // perform the fit
  169.         final PointVectorValuePair optimum;
  170.         if (optimizer == null) {
  171.             // to be removed in 4.0
  172.             optimum = oldOptimizer.optimize(maxEval, new OldTheoreticalValuesFunction(f),
  173.                                             target, weights, initialGuess);
  174.         } else {
  175.             optimum = optimizer.optimize(maxEval, new TheoreticalValuesFunction(f),
  176.                                          target, weights, initialGuess);
  177.         }

  178.         // extract the coefficients
  179.         return optimum.getPointRef();
  180.     }

  181.     /** Vectorial function computing function theoretical values. */
  182.     @Deprecated
  183.     private class OldTheoreticalValuesFunction
  184.         implements DifferentiableMultivariateVectorFunction {
  185.         /** Function to fit. */
  186.         private final ParametricUnivariateFunction f;

  187.         /** Simple constructor.
  188.          * @param f function to fit.
  189.          */
  190.         public OldTheoreticalValuesFunction(final ParametricUnivariateFunction f) {
  191.             this.f = f;
  192.         }

  193.         /** {@inheritDoc} */
  194.         public MultivariateMatrixFunction jacobian() {
  195.             return new MultivariateMatrixFunction() {
  196.                 public double[][] value(double[] point) {
  197.                     final double[][] jacobian = new double[observations.size()][];

  198.                     int i = 0;
  199.                     for (WeightedObservedPoint observed : observations) {
  200.                         jacobian[i++] = f.gradient(observed.getX(), point);
  201.                     }

  202.                     return jacobian;
  203.                 }
  204.             };
  205.         }

  206.         /** {@inheritDoc} */
  207.         public double[] value(double[] point) {
  208.             // compute the residuals
  209.             final double[] values = new double[observations.size()];
  210.             int i = 0;
  211.             for (WeightedObservedPoint observed : observations) {
  212.                 values[i++] = f.value(observed.getX(), point);
  213.             }

  214.             return values;
  215.         }
  216.     }

  217.     /** Vectorial function computing function theoretical values. */
  218.     private class TheoreticalValuesFunction implements MultivariateDifferentiableVectorFunction {

  219.         /** Function to fit. */
  220.         private final ParametricUnivariateFunction f;

  221.         /** Simple constructor.
  222.          * @param f function to fit.
  223.          */
  224.         public TheoreticalValuesFunction(final ParametricUnivariateFunction f) {
  225.             this.f = f;
  226.         }

  227.         /** {@inheritDoc} */
  228.         public double[] value(double[] point) {
  229.             // compute the residuals
  230.             final double[] values = new double[observations.size()];
  231.             int i = 0;
  232.             for (WeightedObservedPoint observed : observations) {
  233.                 values[i++] = f.value(observed.getX(), point);
  234.             }

  235.             return values;
  236.         }

  237.         /** {@inheritDoc} */
  238.         public DerivativeStructure[] value(DerivativeStructure[] point) {

  239.             // extract parameters
  240.             final double[] parameters = new double[point.length];
  241.             for (int k = 0; k < point.length; ++k) {
  242.                 parameters[k] = point[k].getValue();
  243.             }

  244.             // compute the residuals
  245.             final DerivativeStructure[] values = new DerivativeStructure[observations.size()];
  246.             int i = 0;
  247.             for (WeightedObservedPoint observed : observations) {

  248.                 // build the DerivativeStructure by adding first the value as a constant
  249.                 // and then adding derivatives
  250.                 DerivativeStructure vi = new DerivativeStructure(point.length, 1, f.value(observed.getX(), parameters));
  251.                 for (int k = 0; k < point.length; ++k) {
  252.                     vi = vi.add(new DerivativeStructure(point.length, 1, k, 0.0));
  253.                 }

  254.                 values[i++] = vi;

  255.             }

  256.             return values;
  257.         }

  258.     }

  259. }