HarmonicCurveFitter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.fitting;

  18. import java.util.ArrayList;
  19. import java.util.Collection;
  20. import java.util.List;

  21. import org.apache.commons.math3.analysis.function.HarmonicOscillator;
  22. import org.apache.commons.math3.exception.MathIllegalStateException;
  23. import org.apache.commons.math3.exception.NumberIsTooSmallException;
  24. import org.apache.commons.math3.exception.ZeroException;
  25. import org.apache.commons.math3.exception.util.LocalizedFormats;
  26. import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
  27. import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
  28. import org.apache.commons.math3.linear.DiagonalMatrix;
  29. import org.apache.commons.math3.util.FastMath;

  30. /**
  31.  * Fits points to a {@link
  32.  * org.apache.commons.math3.analysis.function.HarmonicOscillator.Parametric harmonic oscillator}
  33.  * function.
  34.  * <br/>
  35.  * The {@link #withStartPoint(double[]) initial guess values} must be passed
  36.  * in the following order:
  37.  * <ul>
  38.  *  <li>Amplitude</li>
  39.  *  <li>Angular frequency</li>
  40.  *  <li>phase</li>
  41.  * </ul>
  42.  * The optimal values will be returned in the same order.
  43.  *
  44.  * @since 3.3
  45.  */
  46. public class HarmonicCurveFitter extends AbstractCurveFitter {
  47.     /** Parametric function to be fitted. */
  48.     private static final HarmonicOscillator.Parametric FUNCTION = new HarmonicOscillator.Parametric();
  49.     /** Initial guess. */
  50.     private final double[] initialGuess;
  51.     /** Maximum number of iterations of the optimization algorithm. */
  52.     private final int maxIter;

  53.     /**
  54.      * Contructor used by the factory methods.
  55.      *
  56.      * @param initialGuess Initial guess. If set to {@code null}, the initial guess
  57.      * will be estimated using the {@link ParameterGuesser}.
  58.      * @param maxIter Maximum number of iterations of the optimization algorithm.
  59.      */
  60.     private HarmonicCurveFitter(double[] initialGuess,
  61.                                 int maxIter) {
  62.         this.initialGuess = initialGuess;
  63.         this.maxIter = maxIter;
  64.     }

  65.     /**
  66.      * Creates a default curve fitter.
  67.      * The initial guess for the parameters will be {@link ParameterGuesser}
  68.      * computed automatically, and the maximum number of iterations of the
  69.      * optimization algorithm is set to {@link Integer#MAX_VALUE}.
  70.      *
  71.      * @return a curve fitter.
  72.      *
  73.      * @see #withStartPoint(double[])
  74.      * @see #withMaxIterations(int)
  75.      */
  76.     public static HarmonicCurveFitter create() {
  77.         return new HarmonicCurveFitter(null, Integer.MAX_VALUE);
  78.     }

  79.     /**
  80.      * Configure the start point (initial guess).
  81.      * @param newStart new start point (initial guess)
  82.      * @return a new instance.
  83.      */
  84.     public HarmonicCurveFitter withStartPoint(double[] newStart) {
  85.         return new HarmonicCurveFitter(newStart.clone(),
  86.                                        maxIter);
  87.     }

  88.     /**
  89.      * Configure the maximum number of iterations.
  90.      * @param newMaxIter maximum number of iterations
  91.      * @return a new instance.
  92.      */
  93.     public HarmonicCurveFitter withMaxIterations(int newMaxIter) {
  94.         return new HarmonicCurveFitter(initialGuess,
  95.                                        newMaxIter);
  96.     }

  97.     /** {@inheritDoc} */
  98.     @Override
  99.     protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
  100.         // Prepare least-squares problem.
  101.         final int len = observations.size();
  102.         final double[] target  = new double[len];
  103.         final double[] weights = new double[len];

  104.         int i = 0;
  105.         for (WeightedObservedPoint obs : observations) {
  106.             target[i]  = obs.getY();
  107.             weights[i] = obs.getWeight();
  108.             ++i;
  109.         }

  110.         final AbstractCurveFitter.TheoreticalValuesFunction model
  111.             = new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION,
  112.                                                                 observations);

  113.         final double[] startPoint = initialGuess != null ?
  114.             initialGuess :
  115.             // Compute estimation.
  116.             new ParameterGuesser(observations).guess();

  117.         // Return a new optimizer set up to fit a Gaussian curve to the
  118.         // observed points.
  119.         return new LeastSquaresBuilder().
  120.                 maxEvaluations(Integer.MAX_VALUE).
  121.                 maxIterations(maxIter).
  122.                 start(startPoint).
  123.                 target(target).
  124.                 weight(new DiagonalMatrix(weights)).
  125.                 model(model.getModelFunction(), model.getModelFunctionJacobian()).
  126.                 build();

  127.     }

  128.     /**
  129.      * This class guesses harmonic coefficients from a sample.
  130.      * <p>The algorithm used to guess the coefficients is as follows:</p>
  131.      *
  132.      * <p>We know \( f(t) \) at some sampling points \( t_i \) and want
  133.      * to find \( a \), \( \omega \) and \( \phi \) such that
  134.      * \( f(t) = a \cos (\omega t + \phi) \).
  135.      * </p>
  136.      *
  137.      * <p>From the analytical expression, we can compute two primitives :
  138.      * \[
  139.      *     If2(t) = \int f^2 dt  = a^2 (t + S(t)) / 2
  140.      * \]
  141.      * \[
  142.      *     If'2(t) = \int f'^2 dt = a^2 \omega^2 (t - S(t)) / 2
  143.      * \]
  144.      * where \(S(t) = \frac{\sin(2 (\omega t + \phi))}{2\omega}\)
  145.      * </p>
  146.      *
  147.      * <p>We can remove \(S\) between these expressions :
  148.      * \[
  149.      *     If'2(t) = a^2 \omega^2 t - \omega^2 If2(t)
  150.      * \]
  151.      * </p>
  152.      *
  153.      * <p>The preceding expression shows that \(If'2 (t)\) is a linear
  154.      * combination of both \(t\) and \(If2(t)\):
  155.      * \[
  156.      *   If'2(t) = A t + B If2(t)
  157.      * \]
  158.      * </p>
  159.      *
  160.      * <p>From the primitive, we can deduce the same form for definite
  161.      * integrals between \(t_1\) and \(t_i\) for each \(t_i\) :
  162.      * \[
  163.      *   If2(t_i) - If2(t_1) = A (t_i - t_1) + B (If2 (t_i) - If2(t_1))
  164.      * \]
  165.      * </p>
  166.      *
  167.      * <p>We can find the coefficients \(A\) and \(B\) that best fit the sample
  168.      * to this linear expression by computing the definite integrals for
  169.      * each sample points.
  170.      * </p>
  171.      *
  172.      * <p>For a bilinear expression \(z(x_i, y_i) = A x_i + B y_i\), the
  173.      * coefficients \(A\) and \(B\) that minimize a least-squares criterion
  174.      * \(\sum (z_i - z(x_i, y_i))^2\) are given by these expressions:</p>
  175.      * \[
  176.      *   A = \frac{\sum y_i y_i \sum x_i z_i - \sum x_i y_i \sum y_i z_i}
  177.      *            {\sum x_i x_i \sum y_i y_i - \sum x_i y_i \sum x_i y_i}
  178.      * \]
  179.      * \[
  180.      *   B = \frac{\sum x_i x_i \sum y_i z_i - \sum x_i y_i \sum x_i z_i}
  181.      *            {\sum x_i x_i \sum y_i y_i - \sum x_i y_i \sum x_i y_i}
  182.      *
  183.      * \]
  184.      *
  185.      * <p>In fact, we can assume that both \(a\) and \(\omega\) are positive and
  186.      * compute them directly, knowing that \(A = a^2 \omega^2\) and that
  187.      * \(B = -\omega^2\). The complete algorithm is therefore:</p>
  188.      *
  189.      * For each \(t_i\) from \(t_1\) to \(t_{n-1}\), compute:
  190.      * \[ f(t_i) \]
  191.      * \[ f'(t_i) = \frac{f (t_{i+1}) - f(t_{i-1})}{t_{i+1} - t_{i-1}} \]
  192.      * \[ x_i = t_i  - t_1 \]
  193.      * \[ y_i = \int_{t_1}^{t_i} f^2(t) dt \]
  194.      * \[ z_i = \int_{t_1}^{t_i} f'^2(t) dt \]
  195.      * and update the sums:
  196.      * \[ \sum x_i x_i, \sum y_i y_i, \sum x_i y_i, \sum x_i z_i, \sum y_i z_i \]
  197.      *
  198.      * Then:
  199.      * \[
  200.      *  a = \sqrt{\frac{\sum y_i y_i  \sum x_i z_i - \sum x_i y_i \sum y_i z_i }
  201.      *                 {\sum x_i y_i  \sum x_i z_i - \sum x_i x_i \sum y_i z_i }}
  202.      * \]
  203.      * \[
  204.      *  \omega = \sqrt{\frac{\sum x_i y_i \sum x_i z_i - \sum x_i x_i \sum y_i z_i}
  205.      *                      {\sum x_i x_i \sum y_i y_i - \sum x_i y_i \sum x_i y_i}}
  206.      * \]
  207.      *
  208.      * <p>Once we know \(\omega\) we can compute:
  209.      * \[
  210.      *    fc = \omega f(t) \cos(\omega t) - f'(t) \sin(\omega t)
  211.      * \]
  212.      * \[
  213.      *    fs = \omega f(t) \sin(\omega t) + f'(t) \cos(\omega t)
  214.      * \]
  215.      * </p>
  216.      *
  217.      * <p>It appears that \(fc = a \omega \cos(\phi)\) and
  218.      * \(fs = -a \omega \sin(\phi)\), so we can use these
  219.      * expressions to compute \(\phi\). The best estimate over the sample is
  220.      * given by averaging these expressions.
  221.      * </p>
  222.      *
  223.      * <p>Since integrals and means are involved in the preceding
  224.      * estimations, these operations run in \(O(n)\) time, where \(n\) is the
  225.      * number of measurements.</p>
  226.      */
  227.     public static class ParameterGuesser {
  228.         /** Amplitude. */
  229.         private final double a;
  230.         /** Angular frequency. */
  231.         private final double omega;
  232.         /** Phase. */
  233.         private final double phi;

  234.         /**
  235.          * Simple constructor.
  236.          *
  237.          * @param observations Sampled observations.
  238.          * @throws NumberIsTooSmallException if the sample is too short.
  239.          * @throws ZeroException if the abscissa range is zero.
  240.          * @throws MathIllegalStateException when the guessing procedure cannot
  241.          * produce sensible results.
  242.          */
  243.         public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
  244.             if (observations.size() < 4) {
  245.                 throw new NumberIsTooSmallException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE,
  246.                                                     observations.size(), 4, true);
  247.             }

  248.             final WeightedObservedPoint[] sorted
  249.                 = sortObservations(observations).toArray(new WeightedObservedPoint[0]);

  250.             final double aOmega[] = guessAOmega(sorted);
  251.             a = aOmega[0];
  252.             omega = aOmega[1];

  253.             phi = guessPhi(sorted);
  254.         }

  255.         /**
  256.          * Gets an estimation of the parameters.
  257.          *
  258.          * @return the guessed parameters, in the following order:
  259.          * <ul>
  260.          *  <li>Amplitude</li>
  261.          *  <li>Angular frequency</li>
  262.          *  <li>Phase</li>
  263.          * </ul>
  264.          */
  265.         public double[] guess() {
  266.             return new double[] { a, omega, phi };
  267.         }

  268.         /**
  269.          * Sort the observations with respect to the abscissa.
  270.          *
  271.          * @param unsorted Input observations.
  272.          * @return the input observations, sorted.
  273.          */
  274.         private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
  275.             final List<WeightedObservedPoint> observations = new ArrayList<WeightedObservedPoint>(unsorted);

  276.             // Since the samples are almost always already sorted, this
  277.             // method is implemented as an insertion sort that reorders the
  278.             // elements in place. Insertion sort is very efficient in this case.
  279.             WeightedObservedPoint curr = observations.get(0);
  280.             final int len = observations.size();
  281.             for (int j = 1; j < len; j++) {
  282.                 WeightedObservedPoint prec = curr;
  283.                 curr = observations.get(j);
  284.                 if (curr.getX() < prec.getX()) {
  285.                     // the current element should be inserted closer to the beginning
  286.                     int i = j - 1;
  287.                     WeightedObservedPoint mI = observations.get(i);
  288.                     while ((i >= 0) && (curr.getX() < mI.getX())) {
  289.                         observations.set(i + 1, mI);
  290.                         if (i-- != 0) {
  291.                             mI = observations.get(i);
  292.                         }
  293.                     }
  294.                     observations.set(i + 1, curr);
  295.                     curr = observations.get(j);
  296.                 }
  297.             }

  298.             return observations;
  299.         }

  300.         /**
  301.          * Estimate a first guess of the amplitude and angular frequency.
  302.          *
  303.          * @param observations Observations, sorted w.r.t. abscissa.
  304.          * @throws ZeroException if the abscissa range is zero.
  305.          * @throws MathIllegalStateException when the guessing procedure cannot
  306.          * produce sensible results.
  307.          * @return the guessed amplitude (at index 0) and circular frequency
  308.          * (at index 1).
  309.          */
  310.         private double[] guessAOmega(WeightedObservedPoint[] observations) {
  311.             final double[] aOmega = new double[2];

  312.             // initialize the sums for the linear model between the two integrals
  313.             double sx2 = 0;
  314.             double sy2 = 0;
  315.             double sxy = 0;
  316.             double sxz = 0;
  317.             double syz = 0;

  318.             double currentX = observations[0].getX();
  319.             double currentY = observations[0].getY();
  320.             double f2Integral = 0;
  321.             double fPrime2Integral = 0;
  322.             final double startX = currentX;
  323.             for (int i = 1; i < observations.length; ++i) {
  324.                 // one step forward
  325.                 final double previousX = currentX;
  326.                 final double previousY = currentY;
  327.                 currentX = observations[i].getX();
  328.                 currentY = observations[i].getY();

  329.                 // update the integrals of f<sup>2</sup> and f'<sup>2</sup>
  330.                 // considering a linear model for f (and therefore constant f')
  331.                 final double dx = currentX - previousX;
  332.                 final double dy = currentY - previousY;
  333.                 final double f2StepIntegral =
  334.                     dx * (previousY * previousY + previousY * currentY + currentY * currentY) / 3;
  335.                 final double fPrime2StepIntegral = dy * dy / dx;

  336.                 final double x = currentX - startX;
  337.                 f2Integral += f2StepIntegral;
  338.                 fPrime2Integral += fPrime2StepIntegral;

  339.                 sx2 += x * x;
  340.                 sy2 += f2Integral * f2Integral;
  341.                 sxy += x * f2Integral;
  342.                 sxz += x * fPrime2Integral;
  343.                 syz += f2Integral * fPrime2Integral;
  344.             }

  345.             // compute the amplitude and pulsation coefficients
  346.             double c1 = sy2 * sxz - sxy * syz;
  347.             double c2 = sxy * sxz - sx2 * syz;
  348.             double c3 = sx2 * sy2 - sxy * sxy;
  349.             if ((c1 / c2 < 0) || (c2 / c3 < 0)) {
  350.                 final int last = observations.length - 1;
  351.                 // Range of the observations, assuming that the
  352.                 // observations are sorted.
  353.                 final double xRange = observations[last].getX() - observations[0].getX();
  354.                 if (xRange == 0) {
  355.                     throw new ZeroException();
  356.                 }
  357.                 aOmega[1] = 2 * Math.PI / xRange;

  358.                 double yMin = Double.POSITIVE_INFINITY;
  359.                 double yMax = Double.NEGATIVE_INFINITY;
  360.                 for (int i = 1; i < observations.length; ++i) {
  361.                     final double y = observations[i].getY();
  362.                     if (y < yMin) {
  363.                         yMin = y;
  364.                     }
  365.                     if (y > yMax) {
  366.                         yMax = y;
  367.                     }
  368.                 }
  369.                 aOmega[0] = 0.5 * (yMax - yMin);
  370.             } else {
  371.                 if (c2 == 0) {
  372.                     // In some ill-conditioned cases (cf. MATH-844), the guesser
  373.                     // procedure cannot produce sensible results.
  374.                     throw new MathIllegalStateException(LocalizedFormats.ZERO_DENOMINATOR);
  375.                 }

  376.                 aOmega[0] = FastMath.sqrt(c1 / c2);
  377.                 aOmega[1] = FastMath.sqrt(c2 / c3);
  378.             }

  379.             return aOmega;
  380.         }

  381.         /**
  382.          * Estimate a first guess of the phase.
  383.          *
  384.          * @param observations Observations, sorted w.r.t. abscissa.
  385.          * @return the guessed phase.
  386.          */
  387.         private double guessPhi(WeightedObservedPoint[] observations) {
  388.             // initialize the means
  389.             double fcMean = 0;
  390.             double fsMean = 0;

  391.             double currentX = observations[0].getX();
  392.             double currentY = observations[0].getY();
  393.             for (int i = 1; i < observations.length; ++i) {
  394.                 // one step forward
  395.                 final double previousX = currentX;
  396.                 final double previousY = currentY;
  397.                 currentX = observations[i].getX();
  398.                 currentY = observations[i].getY();
  399.                 final double currentYPrime = (currentY - previousY) / (currentX - previousX);

  400.                 double omegaX = omega * currentX;
  401.                 double cosine = FastMath.cos(omegaX);
  402.                 double sine = FastMath.sin(omegaX);
  403.                 fcMean += omega * currentY * cosine - currentYPrime * sine;
  404.                 fsMean += omega * currentY * sine + currentYPrime * cosine;
  405.             }

  406.             return FastMath.atan2(-fsMean, fcMean);
  407.         }
  408.     }
  409. }