GaussianCurveFitter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.fitting;

  18. import java.util.ArrayList;
  19. import java.util.Collection;
  20. import java.util.Collections;
  21. import java.util.Comparator;
  22. import java.util.List;

  23. import org.apache.commons.math3.analysis.function.Gaussian;
  24. import org.apache.commons.math3.exception.NotStrictlyPositiveException;
  25. import org.apache.commons.math3.exception.NullArgumentException;
  26. import org.apache.commons.math3.exception.NumberIsTooSmallException;
  27. import org.apache.commons.math3.exception.OutOfRangeException;
  28. import org.apache.commons.math3.exception.ZeroException;
  29. import org.apache.commons.math3.exception.util.LocalizedFormats;
  30. import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
  31. import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
  32. import org.apache.commons.math3.linear.DiagonalMatrix;
  33. import org.apache.commons.math3.util.FastMath;

  34. /**
  35.  * Fits points to a {@link
  36.  * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian}
  37.  * function.
  38.  * <br/>
  39.  * The {@link #withStartPoint(double[]) initial guess values} must be passed
  40.  * in the following order:
  41.  * <ul>
  42.  *  <li>Normalization</li>
  43.  *  <li>Mean</li>
  44.  *  <li>Sigma</li>
  45.  * </ul>
  46.  * The optimal values will be returned in the same order.
  47.  *
  48.  * <p>
  49.  * Usage example:
  50.  * <pre>
  51.  *   WeightedObservedPoints obs = new WeightedObservedPoints();
  52.  *   obs.add(4.0254623,  531026.0);
  53.  *   obs.add(4.03128248, 984167.0);
  54.  *   obs.add(4.03839603, 1887233.0);
  55.  *   obs.add(4.04421621, 2687152.0);
  56.  *   obs.add(4.05132976, 3461228.0);
  57.  *   obs.add(4.05326982, 3580526.0);
  58.  *   obs.add(4.05779662, 3439750.0);
  59.  *   obs.add(4.0636168,  2877648.0);
  60.  *   obs.add(4.06943698, 2175960.0);
  61.  *   obs.add(4.07525716, 1447024.0);
  62.  *   obs.add(4.08237071, 717104.0);
  63.  *   obs.add(4.08366408, 620014.0);
  64.  *   double[] parameters = GaussianCurveFitter.create().fit(obs.toList());
  65.  * </pre>
  66.  *
  67.  * @since 3.3
  68.  */
  69. public class GaussianCurveFitter extends AbstractCurveFitter {
  70.     /** Parametric function to be fitted. */
  71.     private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
  72.             @Override
  73.             public double value(double x, double ... p) {
  74.                 double v = Double.POSITIVE_INFINITY;
  75.                 try {
  76.                     v = super.value(x, p);
  77.                 } catch (NotStrictlyPositiveException e) { // NOPMD
  78.                     // Do nothing.
  79.                 }
  80.                 return v;
  81.             }

  82.             @Override
  83.             public double[] gradient(double x, double ... p) {
  84.                 double[] v = { Double.POSITIVE_INFINITY,
  85.                                Double.POSITIVE_INFINITY,
  86.                                Double.POSITIVE_INFINITY };
  87.                 try {
  88.                     v = super.gradient(x, p);
  89.                 } catch (NotStrictlyPositiveException e) { // NOPMD
  90.                     // Do nothing.
  91.                 }
  92.                 return v;
  93.             }
  94.         };
  95.     /** Initial guess. */
  96.     private final double[] initialGuess;
  97.     /** Maximum number of iterations of the optimization algorithm. */
  98.     private final int maxIter;

  99.     /**
  100.      * Contructor used by the factory methods.
  101.      *
  102.      * @param initialGuess Initial guess. If set to {@code null}, the initial guess
  103.      * will be estimated using the {@link ParameterGuesser}.
  104.      * @param maxIter Maximum number of iterations of the optimization algorithm.
  105.      */
  106.     private GaussianCurveFitter(double[] initialGuess,
  107.                                 int maxIter) {
  108.         this.initialGuess = initialGuess;
  109.         this.maxIter = maxIter;
  110.     }

  111.     /**
  112.      * Creates a default curve fitter.
  113.      * The initial guess for the parameters will be {@link ParameterGuesser}
  114.      * computed automatically, and the maximum number of iterations of the
  115.      * optimization algorithm is set to {@link Integer#MAX_VALUE}.
  116.      *
  117.      * @return a curve fitter.
  118.      *
  119.      * @see #withStartPoint(double[])
  120.      * @see #withMaxIterations(int)
  121.      */
  122.     public static GaussianCurveFitter create() {
  123.         return new GaussianCurveFitter(null, Integer.MAX_VALUE);
  124.     }

  125.     /**
  126.      * Configure the start point (initial guess).
  127.      * @param newStart new start point (initial guess)
  128.      * @return a new instance.
  129.      */
  130.     public GaussianCurveFitter withStartPoint(double[] newStart) {
  131.         return new GaussianCurveFitter(newStart.clone(),
  132.                                        maxIter);
  133.     }

  134.     /**
  135.      * Configure the maximum number of iterations.
  136.      * @param newMaxIter maximum number of iterations
  137.      * @return a new instance.
  138.      */
  139.     public GaussianCurveFitter withMaxIterations(int newMaxIter) {
  140.         return new GaussianCurveFitter(initialGuess,
  141.                                        newMaxIter);
  142.     }

  143.     /** {@inheritDoc} */
  144.     @Override
  145.     protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {

  146.         // Prepare least-squares problem.
  147.         final int len = observations.size();
  148.         final double[] target  = new double[len];
  149.         final double[] weights = new double[len];

  150.         int i = 0;
  151.         for (WeightedObservedPoint obs : observations) {
  152.             target[i]  = obs.getY();
  153.             weights[i] = obs.getWeight();
  154.             ++i;
  155.         }

  156.         final AbstractCurveFitter.TheoreticalValuesFunction model =
  157.                 new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);

  158.         final double[] startPoint = initialGuess != null ?
  159.             initialGuess :
  160.             // Compute estimation.
  161.             new ParameterGuesser(observations).guess();

  162.         // Return a new least squares problem set up to fit a Gaussian curve to the
  163.         // observed points.
  164.         return new LeastSquaresBuilder().
  165.                 maxEvaluations(Integer.MAX_VALUE).
  166.                 maxIterations(maxIter).
  167.                 start(startPoint).
  168.                 target(target).
  169.                 weight(new DiagonalMatrix(weights)).
  170.                 model(model.getModelFunction(), model.getModelFunctionJacobian()).
  171.                 build();

  172.     }

  173.     /**
  174.      * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
  175.      * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric}
  176.      * based on the specified observed points.
  177.      */
  178.     public static class ParameterGuesser {
  179.         /** Normalization factor. */
  180.         private final double norm;
  181.         /** Mean. */
  182.         private final double mean;
  183.         /** Standard deviation. */
  184.         private final double sigma;

  185.         /**
  186.          * Constructs instance with the specified observed points.
  187.          *
  188.          * @param observations Observed points from which to guess the
  189.          * parameters of the Gaussian.
  190.          * @throws NullArgumentException if {@code observations} is
  191.          * {@code null}.
  192.          * @throws NumberIsTooSmallException if there are less than 3
  193.          * observations.
  194.          */
  195.         public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
  196.             if (observations == null) {
  197.                 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
  198.             }
  199.             if (observations.size() < 3) {
  200.                 throw new NumberIsTooSmallException(observations.size(), 3, true);
  201.             }

  202.             final List<WeightedObservedPoint> sorted = sortObservations(observations);
  203.             final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));

  204.             norm = params[0];
  205.             mean = params[1];
  206.             sigma = params[2];
  207.         }

  208.         /**
  209.          * Gets an estimation of the parameters.
  210.          *
  211.          * @return the guessed parameters, in the following order:
  212.          * <ul>
  213.          *  <li>Normalization factor</li>
  214.          *  <li>Mean</li>
  215.          *  <li>Standard deviation</li>
  216.          * </ul>
  217.          */
  218.         public double[] guess() {
  219.             return new double[] { norm, mean, sigma };
  220.         }

  221.         /**
  222.          * Sort the observations.
  223.          *
  224.          * @param unsorted Input observations.
  225.          * @return the input observations, sorted.
  226.          */
  227.         private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
  228.             final List<WeightedObservedPoint> observations = new ArrayList<WeightedObservedPoint>(unsorted);

  229.             final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
  230.                 public int compare(WeightedObservedPoint p1,
  231.                                    WeightedObservedPoint p2) {
  232.                     if (p1 == null && p2 == null) {
  233.                         return 0;
  234.                     }
  235.                     if (p1 == null) {
  236.                         return -1;
  237.                     }
  238.                     if (p2 == null) {
  239.                         return 1;
  240.                     }
  241.                     if (p1.getX() < p2.getX()) {
  242.                         return -1;
  243.                     }
  244.                     if (p1.getX() > p2.getX()) {
  245.                         return 1;
  246.                     }
  247.                     if (p1.getY() < p2.getY()) {
  248.                         return -1;
  249.                     }
  250.                     if (p1.getY() > p2.getY()) {
  251.                         return 1;
  252.                     }
  253.                     if (p1.getWeight() < p2.getWeight()) {
  254.                         return -1;
  255.                     }
  256.                     if (p1.getWeight() > p2.getWeight()) {
  257.                         return 1;
  258.                     }
  259.                     return 0;
  260.                 }
  261.             };

  262.             Collections.sort(observations, cmp);
  263.             return observations;
  264.         }

  265.         /**
  266.          * Guesses the parameters based on the specified observed points.
  267.          *
  268.          * @param points Observed points, sorted.
  269.          * @return the guessed parameters (normalization factor, mean and
  270.          * sigma).
  271.          */
  272.         private double[] basicGuess(WeightedObservedPoint[] points) {
  273.             final int maxYIdx = findMaxY(points);
  274.             final double n = points[maxYIdx].getY();
  275.             final double m = points[maxYIdx].getX();

  276.             double fwhmApprox;
  277.             try {
  278.                 final double halfY = n + ((m - n) / 2);
  279.                 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
  280.                 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
  281.                 fwhmApprox = fwhmX2 - fwhmX1;
  282.             } catch (OutOfRangeException e) {
  283.                 // TODO: Exceptions should not be used for flow control.
  284.                 fwhmApprox = points[points.length - 1].getX() - points[0].getX();
  285.             }
  286.             final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));

  287.             return new double[] { n, m, s };
  288.         }

  289.         /**
  290.          * Finds index of point in specified points with the largest Y.
  291.          *
  292.          * @param points Points to search.
  293.          * @return the index in specified points array.
  294.          */
  295.         private int findMaxY(WeightedObservedPoint[] points) {
  296.             int maxYIdx = 0;
  297.             for (int i = 1; i < points.length; i++) {
  298.                 if (points[i].getY() > points[maxYIdx].getY()) {
  299.                     maxYIdx = i;
  300.                 }
  301.             }
  302.             return maxYIdx;
  303.         }

  304.         /**
  305.          * Interpolates using the specified points to determine X at the
  306.          * specified Y.
  307.          *
  308.          * @param points Points to use for interpolation.
  309.          * @param startIdx Index within points from which to start the search for
  310.          * interpolation bounds points.
  311.          * @param idxStep Index step for searching interpolation bounds points.
  312.          * @param y Y value for which X should be determined.
  313.          * @return the value of X for the specified Y.
  314.          * @throws ZeroException if {@code idxStep} is 0.
  315.          * @throws OutOfRangeException if specified {@code y} is not within the
  316.          * range of the specified {@code points}.
  317.          */
  318.         private double interpolateXAtY(WeightedObservedPoint[] points,
  319.                                        int startIdx,
  320.                                        int idxStep,
  321.                                        double y)
  322.             throws OutOfRangeException {
  323.             if (idxStep == 0) {
  324.                 throw new ZeroException();
  325.             }
  326.             final WeightedObservedPoint[] twoPoints
  327.                 = getInterpolationPointsForY(points, startIdx, idxStep, y);
  328.             final WeightedObservedPoint p1 = twoPoints[0];
  329.             final WeightedObservedPoint p2 = twoPoints[1];
  330.             if (p1.getY() == y) {
  331.                 return p1.getX();
  332.             }
  333.             if (p2.getY() == y) {
  334.                 return p2.getX();
  335.             }
  336.             return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
  337.                                 (p2.getY() - p1.getY()));
  338.         }

  339.         /**
  340.          * Gets the two bounding interpolation points from the specified points
  341.          * suitable for determining X at the specified Y.
  342.          *
  343.          * @param points Points to use for interpolation.
  344.          * @param startIdx Index within points from which to start search for
  345.          * interpolation bounds points.
  346.          * @param idxStep Index step for search for interpolation bounds points.
  347.          * @param y Y value for which X should be determined.
  348.          * @return the array containing two points suitable for determining X at
  349.          * the specified Y.
  350.          * @throws ZeroException if {@code idxStep} is 0.
  351.          * @throws OutOfRangeException if specified {@code y} is not within the
  352.          * range of the specified {@code points}.
  353.          */
  354.         private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
  355.                                                                    int startIdx,
  356.                                                                    int idxStep,
  357.                                                                    double y)
  358.             throws OutOfRangeException {
  359.             if (idxStep == 0) {
  360.                 throw new ZeroException();
  361.             }
  362.             for (int i = startIdx;
  363.                  idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
  364.                  i += idxStep) {
  365.                 final WeightedObservedPoint p1 = points[i];
  366.                 final WeightedObservedPoint p2 = points[i + idxStep];
  367.                 if (isBetween(y, p1.getY(), p2.getY())) {
  368.                     if (idxStep < 0) {
  369.                         return new WeightedObservedPoint[] { p2, p1 };
  370.                     } else {
  371.                         return new WeightedObservedPoint[] { p1, p2 };
  372.                     }
  373.                 }
  374.             }

  375.             // Boundaries are replaced by dummy values because the raised
  376.             // exception is caught and the message never displayed.
  377.             // TODO: Exceptions should not be used for flow control.
  378.             throw new OutOfRangeException(y,
  379.                                           Double.NEGATIVE_INFINITY,
  380.                                           Double.POSITIVE_INFINITY);
  381.         }

  382.         /**
  383.          * Determines whether a value is between two other values.
  384.          *
  385.          * @param value Value to test whether it is between {@code boundary1}
  386.          * and {@code boundary2}.
  387.          * @param boundary1 One end of the range.
  388.          * @param boundary2 Other end of the range.
  389.          * @return {@code true} if {@code value} is between {@code boundary1} and
  390.          * {@code boundary2} (inclusive), {@code false} otherwise.
  391.          */
  392.         private boolean isBetween(double value,
  393.                                   double boundary1,
  394.                                   double boundary2) {
  395.             return (value >= boundary1 && value <= boundary2) ||
  396.                 (value >= boundary2 && value <= boundary1);
  397.         }
  398.     }
  399. }