CurveFitter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.fitting;

  18. import java.util.ArrayList;
  19. import java.util.List;
  20. import org.apache.commons.math3.analysis.MultivariateVectorFunction;
  21. import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
  22. import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
  23. import org.apache.commons.math3.optim.MaxEval;
  24. import org.apache.commons.math3.optim.InitialGuess;
  25. import org.apache.commons.math3.optim.PointVectorValuePair;
  26. import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
  27. import org.apache.commons.math3.optim.nonlinear.vector.ModelFunction;
  28. import org.apache.commons.math3.optim.nonlinear.vector.ModelFunctionJacobian;
  29. import org.apache.commons.math3.optim.nonlinear.vector.Target;
  30. import org.apache.commons.math3.optim.nonlinear.vector.Weight;

  31. /**
  32.  * Fitter for parametric univariate real functions y = f(x).
  33.  * <br/>
  34.  * When a univariate real function y = f(x) does depend on some
  35.  * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
  36.  * this class can be used to find these parameters. It does this
  37.  * by <em>fitting</em> the curve so it remains very close to a set of
  38.  * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
  39.  * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
  40.  * is done by finding the parameters values that minimizes the objective
  41.  * function &sum;(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
  42.  * really a least squares problem.
  43.  *
  44.  * @param <T> Function to use for the fit.
  45.  *
  46.  * @since 2.0
  47.  * @deprecated As of 3.3. Please use {@link AbstractCurveFitter} and
  48.  * {@link WeightedObservedPoints} instead.
  49.  */
  50. @Deprecated
  51. public class CurveFitter<T extends ParametricUnivariateFunction> {
  52.     /** Optimizer to use for the fitting. */
  53.     private final MultivariateVectorOptimizer optimizer;
  54.     /** Observed points. */
  55.     private final List<WeightedObservedPoint> observations;

  56.     /**
  57.      * Simple constructor.
  58.      *
  59.      * @param optimizer Optimizer to use for the fitting.
  60.      * @since 3.1
  61.      */
  62.     public CurveFitter(final MultivariateVectorOptimizer optimizer) {
  63.         this.optimizer = optimizer;
  64.         observations = new ArrayList<WeightedObservedPoint>();
  65.     }

  66.     /** Add an observed (x,y) point to the sample with unit weight.
  67.      * <p>Calling this method is equivalent to call
  68.      * {@code addObservedPoint(1.0, x, y)}.</p>
  69.      * @param x abscissa of the point
  70.      * @param y observed value of the point at x, after fitting we should
  71.      * have f(x) as close as possible to this value
  72.      * @see #addObservedPoint(double, double, double)
  73.      * @see #addObservedPoint(WeightedObservedPoint)
  74.      * @see #getObservations()
  75.      */
  76.     public void addObservedPoint(double x, double y) {
  77.         addObservedPoint(1.0, x, y);
  78.     }

  79.     /** Add an observed weighted (x,y) point to the sample.
  80.      * @param weight weight of the observed point in the fit
  81.      * @param x abscissa of the point
  82.      * @param y observed value of the point at x, after fitting we should
  83.      * have f(x) as close as possible to this value
  84.      * @see #addObservedPoint(double, double)
  85.      * @see #addObservedPoint(WeightedObservedPoint)
  86.      * @see #getObservations()
  87.      */
  88.     public void addObservedPoint(double weight, double x, double y) {
  89.         observations.add(new WeightedObservedPoint(weight, x, y));
  90.     }

  91.     /** Add an observed weighted (x,y) point to the sample.
  92.      * @param observed observed point to add
  93.      * @see #addObservedPoint(double, double)
  94.      * @see #addObservedPoint(double, double, double)
  95.      * @see #getObservations()
  96.      */
  97.     public void addObservedPoint(WeightedObservedPoint observed) {
  98.         observations.add(observed);
  99.     }

  100.     /** Get the observed points.
  101.      * @return observed points
  102.      * @see #addObservedPoint(double, double)
  103.      * @see #addObservedPoint(double, double, double)
  104.      * @see #addObservedPoint(WeightedObservedPoint)
  105.      */
  106.     public WeightedObservedPoint[] getObservations() {
  107.         return observations.toArray(new WeightedObservedPoint[observations.size()]);
  108.     }

  109.     /**
  110.      * Remove all observations.
  111.      */
  112.     public void clearObservations() {
  113.         observations.clear();
  114.     }

  115.     /**
  116.      * Fit a curve.
  117.      * This method compute the coefficients of the curve that best
  118.      * fit the sample of observed points previously given through calls
  119.      * to the {@link #addObservedPoint(WeightedObservedPoint)
  120.      * addObservedPoint} method.
  121.      *
  122.      * @param f parametric function to fit.
  123.      * @param initialGuess first guess of the function parameters.
  124.      * @return the fitted parameters.
  125.      * @throws org.apache.commons.math3.exception.DimensionMismatchException
  126.      * if the start point dimension is wrong.
  127.      */
  128.     public double[] fit(T f, final double[] initialGuess) {
  129.         return fit(Integer.MAX_VALUE, f, initialGuess);
  130.     }

  131.     /**
  132.      * Fit a curve.
  133.      * This method compute the coefficients of the curve that best
  134.      * fit the sample of observed points previously given through calls
  135.      * to the {@link #addObservedPoint(WeightedObservedPoint)
  136.      * addObservedPoint} method.
  137.      *
  138.      * @param f parametric function to fit.
  139.      * @param initialGuess first guess of the function parameters.
  140.      * @param maxEval Maximum number of function evaluations.
  141.      * @return the fitted parameters.
  142.      * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
  143.      * if the number of allowed evaluations is exceeded.
  144.      * @throws org.apache.commons.math3.exception.DimensionMismatchException
  145.      * if the start point dimension is wrong.
  146.      * @since 3.0
  147.      */
  148.     public double[] fit(int maxEval, T f,
  149.                         final double[] initialGuess) {
  150.         // Prepare least squares problem.
  151.         double[] target  = new double[observations.size()];
  152.         double[] weights = new double[observations.size()];
  153.         int i = 0;
  154.         for (WeightedObservedPoint point : observations) {
  155.             target[i]  = point.getY();
  156.             weights[i] = point.getWeight();
  157.             ++i;
  158.         }

  159.         // Input to the optimizer: the model and its Jacobian.
  160.         final TheoreticalValuesFunction model = new TheoreticalValuesFunction(f);

  161.         // Perform the fit.
  162.         final PointVectorValuePair optimum
  163.             = optimizer.optimize(new MaxEval(maxEval),
  164.                                  model.getModelFunction(),
  165.                                  model.getModelFunctionJacobian(),
  166.                                  new Target(target),
  167.                                  new Weight(weights),
  168.                                  new InitialGuess(initialGuess));
  169.         // Extract the coefficients.
  170.         return optimum.getPointRef();
  171.     }

  172.     /** Vectorial function computing function theoretical values. */
  173.     private class TheoreticalValuesFunction {
  174.         /** Function to fit. */
  175.         private final ParametricUnivariateFunction f;

  176.         /**
  177.          * @param f function to fit.
  178.          */
  179.         public TheoreticalValuesFunction(final ParametricUnivariateFunction f) {
  180.             this.f = f;
  181.         }

  182.         /**
  183.          * @return the model function values.
  184.          */
  185.         public ModelFunction getModelFunction() {
  186.             return new ModelFunction(new MultivariateVectorFunction() {
  187.                     /** {@inheritDoc} */
  188.                     public double[] value(double[] point) {
  189.                         // compute the residuals
  190.                         final double[] values = new double[observations.size()];
  191.                         int i = 0;
  192.                         for (WeightedObservedPoint observed : observations) {
  193.                             values[i++] = f.value(observed.getX(), point);
  194.                         }

  195.                         return values;
  196.                     }
  197.                 });
  198.         }

  199.         /**
  200.          * @return the model function Jacobian.
  201.          */
  202.         public ModelFunctionJacobian getModelFunctionJacobian() {
  203.             return new ModelFunctionJacobian(new MultivariateMatrixFunction() {
  204.                     public double[][] value(double[] point) {
  205.                         final double[][] jacobian = new double[observations.size()][];
  206.                         int i = 0;
  207.                         for (WeightedObservedPoint observed : observations) {
  208.                             jacobian[i++] = f.gradient(observed.getX(), point);
  209.                         }
  210.                         return jacobian;
  211.                     }
  212.                 });
  213.         }
  214.     }
  215. }