SimpleCurveFitter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.fitting;

  18. import java.util.Collection;

  19. import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
  20. import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
  21. import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
  22. import org.apache.commons.math3.linear.DiagonalMatrix;

  23. /**
  24.  * Fits points to a user-defined {@link ParametricUnivariateFunction function}.
  25.  *
  26.  * @since 3.4
  27.  */
  28. public class SimpleCurveFitter extends AbstractCurveFitter {
  29.     /** Function to fit. */
  30.     private final ParametricUnivariateFunction function;
  31.     /** Initial guess for the parameters. */
  32.     private final double[] initialGuess;
  33.     /** Maximum number of iterations of the optimization algorithm. */
  34.     private final int maxIter;

  35.     /**
  36.      * Contructor used by the factory methods.
  37.      *
  38.      * @param function Function to fit.
  39.      * @param initialGuess Initial guess. Cannot be {@code null}. Its length must
  40.      * be consistent with the number of parameters of the {@code function} to fit.
  41.      * @param maxIter Maximum number of iterations of the optimization algorithm.
  42.      */
  43.     private SimpleCurveFitter(ParametricUnivariateFunction function,
  44.                               double[] initialGuess,
  45.                               int maxIter) {
  46.         this.function = function;
  47.         this.initialGuess = initialGuess;
  48.         this.maxIter = maxIter;
  49.     }

  50.     /**
  51.      * Creates a curve fitter.
  52.      * The maximum number of iterations of the optimization algorithm is set
  53.      * to {@link Integer#MAX_VALUE}.
  54.      *
  55.      * @param f Function to fit.
  56.      * @param start Initial guess for the parameters.  Cannot be {@code null}.
  57.      * Its length must be consistent with the number of parameters of the
  58.      * function to fit.
  59.      * @return a curve fitter.
  60.      *
  61.      * @see #withStartPoint(double[])
  62.      * @see #withMaxIterations(int)
  63.      */
  64.     public static SimpleCurveFitter create(ParametricUnivariateFunction f,
  65.                                            double[] start) {
  66.         return new SimpleCurveFitter(f, start, Integer.MAX_VALUE);
  67.     }

  68.     /**
  69.      * Configure the start point (initial guess).
  70.      * @param newStart new start point (initial guess)
  71.      * @return a new instance.
  72.      */
  73.     public SimpleCurveFitter withStartPoint(double[] newStart) {
  74.         return new SimpleCurveFitter(function,
  75.                                      newStart.clone(),
  76.                                      maxIter);
  77.     }

  78.     /**
  79.      * Configure the maximum number of iterations.
  80.      * @param newMaxIter maximum number of iterations
  81.      * @return a new instance.
  82.      */
  83.     public SimpleCurveFitter withMaxIterations(int newMaxIter) {
  84.         return new SimpleCurveFitter(function,
  85.                                      initialGuess,
  86.                                      newMaxIter);
  87.     }

  88.     /** {@inheritDoc} */
  89.     @Override
  90.     protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
  91.         // Prepare least-squares problem.
  92.         final int len = observations.size();
  93.         final double[] target  = new double[len];
  94.         final double[] weights = new double[len];

  95.         int count = 0;
  96.         for (WeightedObservedPoint obs : observations) {
  97.             target[count]  = obs.getY();
  98.             weights[count] = obs.getWeight();
  99.             ++count;
  100.         }

  101.         final AbstractCurveFitter.TheoreticalValuesFunction model
  102.             = new AbstractCurveFitter.TheoreticalValuesFunction(function,
  103.                                                                 observations);

  104.         // Create an optimizer for fitting the curve to the observed points.
  105.         return new LeastSquaresBuilder().
  106.                 maxEvaluations(Integer.MAX_VALUE).
  107.                 maxIterations(maxIter).
  108.                 start(initialGuess).
  109.                 target(target).
  110.                 weight(new DiagonalMatrix(weights)).
  111.                 model(model.getModelFunction(), model.getModelFunctionJacobian()).
  112.                 build();
  113.     }
  114. }