LoessInterpolator.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.analysis.interpolation;

  18. import java.io.Serializable;
  19. import java.util.Arrays;

  20. import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction;
  21. import org.apache.commons.math3.exception.NotPositiveException;
  22. import org.apache.commons.math3.exception.OutOfRangeException;
  23. import org.apache.commons.math3.exception.DimensionMismatchException;
  24. import org.apache.commons.math3.exception.NoDataException;
  25. import org.apache.commons.math3.exception.NumberIsTooSmallException;
  26. import org.apache.commons.math3.exception.NonMonotonicSequenceException;
  27. import org.apache.commons.math3.exception.NotFiniteNumberException;
  28. import org.apache.commons.math3.exception.util.LocalizedFormats;
  29. import org.apache.commons.math3.util.FastMath;
  30. import org.apache.commons.math3.util.MathUtils;
  31. import org.apache.commons.math3.util.MathArrays;

  32. /**
  33.  * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression">
  34.  * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of
  35.  * real univariate functions.
  36.  * <p/>
  37.  * For reference, see
  38.  * <a href="http://www.math.tau.ac.il/~yekutiel/MA seminar/Cleveland 1979.pdf">
  39.  * William S. Cleveland - Robust Locally Weighted Regression and Smoothing
  40.  * Scatterplots</a>
  41.  * <p/>
  42.  * This class implements both the loess method and serves as an interpolation
  43.  * adapter to it, allowing one to build a spline on the obtained loess fit.
  44.  *
  45.  * @since 2.0
  46.  */
  47. public class LoessInterpolator
  48.     implements UnivariateInterpolator, Serializable {
  49.     /** Default value of the bandwidth parameter. */
  50.     public static final double DEFAULT_BANDWIDTH = 0.3;
  51.     /** Default value of the number of robustness iterations. */
  52.     public static final int DEFAULT_ROBUSTNESS_ITERS = 2;
  53.     /**
  54.      * Default value for accuracy.
  55.      * @since 2.1
  56.      */
  57.     public static final double DEFAULT_ACCURACY = 1e-12;
  58.     /** serializable version identifier. */
  59.     private static final long serialVersionUID = 5204927143605193821L;
  60.     /**
  61.      * The bandwidth parameter: when computing the loess fit at
  62.      * a particular point, this fraction of source points closest
  63.      * to the current point is taken into account for computing
  64.      * a least-squares regression.
  65.      * <p/>
  66.      * A sensible value is usually 0.25 to 0.5.
  67.      */
  68.     private final double bandwidth;
  69.     /**
  70.      * The number of robustness iterations parameter: this many
  71.      * robustness iterations are done.
  72.      * <p/>
  73.      * A sensible value is usually 0 (just the initial fit without any
  74.      * robustness iterations) to 4.
  75.      */
  76.     private final int robustnessIters;
  77.     /**
  78.      * If the median residual at a certain robustness iteration
  79.      * is less than this amount, no more iterations are done.
  80.      */
  81.     private final double accuracy;

  82.     /**
  83.      * Constructs a new {@link LoessInterpolator}
  84.      * with a bandwidth of {@link #DEFAULT_BANDWIDTH},
  85.      * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations
  86.      * and an accuracy of {#link #DEFAULT_ACCURACY}.
  87.      * See {@link #LoessInterpolator(double, int, double)} for an explanation of
  88.      * the parameters.
  89.      */
  90.     public LoessInterpolator() {
  91.         this.bandwidth = DEFAULT_BANDWIDTH;
  92.         this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS;
  93.         this.accuracy = DEFAULT_ACCURACY;
  94.     }

  95.     /**
  96.      * Construct a new {@link LoessInterpolator}
  97.      * with given bandwidth and number of robustness iterations.
  98.      * <p>
  99.      * Calling this constructor is equivalent to calling {link {@link
  100.      * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth,
  101.      * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)}
  102.      * </p>
  103.      *
  104.      * @param bandwidth  when computing the loess fit at
  105.      * a particular point, this fraction of source points closest
  106.      * to the current point is taken into account for computing
  107.      * a least-squares regression.</br>
  108.      * A sensible value is usually 0.25 to 0.5, the default value is
  109.      * {@link #DEFAULT_BANDWIDTH}.
  110.      * @param robustnessIters This many robustness iterations are done.</br>
  111.      * A sensible value is usually 0 (just the initial fit without any
  112.      * robustness iterations) to 4, the default value is
  113.      * {@link #DEFAULT_ROBUSTNESS_ITERS}.

  114.      * @see #LoessInterpolator(double, int, double)
  115.      */
  116.     public LoessInterpolator(double bandwidth, int robustnessIters) {
  117.         this(bandwidth, robustnessIters, DEFAULT_ACCURACY);
  118.     }

  119.     /**
  120.      * Construct a new {@link LoessInterpolator}
  121.      * with given bandwidth, number of robustness iterations and accuracy.
  122.      *
  123.      * @param bandwidth  when computing the loess fit at
  124.      * a particular point, this fraction of source points closest
  125.      * to the current point is taken into account for computing
  126.      * a least-squares regression.</br>
  127.      * A sensible value is usually 0.25 to 0.5, the default value is
  128.      * {@link #DEFAULT_BANDWIDTH}.
  129.      * @param robustnessIters This many robustness iterations are done.</br>
  130.      * A sensible value is usually 0 (just the initial fit without any
  131.      * robustness iterations) to 4, the default value is
  132.      * {@link #DEFAULT_ROBUSTNESS_ITERS}.
  133.      * @param accuracy If the median residual at a certain robustness iteration
  134.      * is less than this amount, no more iterations are done.
  135.      * @throws OutOfRangeException if bandwidth does not lie in the interval [0,1].
  136.      * @throws NotPositiveException if {@code robustnessIters} is negative.
  137.      * @see #LoessInterpolator(double, int)
  138.      * @since 2.1
  139.      */
  140.     public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy)
  141.         throws OutOfRangeException,
  142.                NotPositiveException {
  143.         if (bandwidth < 0 ||
  144.             bandwidth > 1) {
  145.             throw new OutOfRangeException(LocalizedFormats.BANDWIDTH, bandwidth, 0, 1);
  146.         }
  147.         this.bandwidth = bandwidth;
  148.         if (robustnessIters < 0) {
  149.             throw new NotPositiveException(LocalizedFormats.ROBUSTNESS_ITERATIONS, robustnessIters);
  150.         }
  151.         this.robustnessIters = robustnessIters;
  152.         this.accuracy = accuracy;
  153.     }

  154.     /**
  155.      * Compute an interpolating function by performing a loess fit
  156.      * on the data at the original abscissae and then building a cubic spline
  157.      * with a
  158.      * {@link org.apache.commons.math3.analysis.interpolation.SplineInterpolator}
  159.      * on the resulting fit.
  160.      *
  161.      * @param xval the arguments for the interpolation points
  162.      * @param yval the values for the interpolation points
  163.      * @return A cubic spline built upon a loess fit to the data at the original abscissae
  164.      * @throws NonMonotonicSequenceException if {@code xval} not sorted in
  165.      * strictly increasing order.
  166.      * @throws DimensionMismatchException if {@code xval} and {@code yval} have
  167.      * different sizes.
  168.      * @throws NoDataException if {@code xval} or {@code yval} has zero size.
  169.      * @throws NotFiniteNumberException if any of the arguments and values are
  170.      * not finite real numbers.
  171.      * @throws NumberIsTooSmallException if the bandwidth is too small to
  172.      * accomodate the size of the input data (i.e. the bandwidth must be
  173.      * larger than 2/n).
  174.      */
  175.     public final PolynomialSplineFunction interpolate(final double[] xval,
  176.                                                       final double[] yval)
  177.         throws NonMonotonicSequenceException,
  178.                DimensionMismatchException,
  179.                NoDataException,
  180.                NotFiniteNumberException,
  181.                NumberIsTooSmallException {
  182.         return new SplineInterpolator().interpolate(xval, smooth(xval, yval));
  183.     }

  184.     /**
  185.      * Compute a weighted loess fit on the data at the original abscissae.
  186.      *
  187.      * @param xval Arguments for the interpolation points.
  188.      * @param yval Values for the interpolation points.
  189.      * @param weights point weights: coefficients by which the robustness weight
  190.      * of a point is multiplied.
  191.      * @return the values of the loess fit at corresponding original abscissae.
  192.      * @throws NonMonotonicSequenceException if {@code xval} not sorted in
  193.      * strictly increasing order.
  194.      * @throws DimensionMismatchException if {@code xval} and {@code yval} have
  195.      * different sizes.
  196.      * @throws NoDataException if {@code xval} or {@code yval} has zero size.
  197.      * @throws NotFiniteNumberException if any of the arguments and values are
  198.      not finite real numbers.
  199.      * @throws NumberIsTooSmallException if the bandwidth is too small to
  200.      * accomodate the size of the input data (i.e. the bandwidth must be
  201.      * larger than 2/n).
  202.      * @since 2.1
  203.      */
  204.     public final double[] smooth(final double[] xval, final double[] yval,
  205.                                  final double[] weights)
  206.         throws NonMonotonicSequenceException,
  207.                DimensionMismatchException,
  208.                NoDataException,
  209.                NotFiniteNumberException,
  210.                NumberIsTooSmallException {
  211.         if (xval.length != yval.length) {
  212.             throw new DimensionMismatchException(xval.length, yval.length);
  213.         }

  214.         final int n = xval.length;

  215.         if (n == 0) {
  216.             throw new NoDataException();
  217.         }

  218.         checkAllFiniteReal(xval);
  219.         checkAllFiniteReal(yval);
  220.         checkAllFiniteReal(weights);

  221.         MathArrays.checkOrder(xval);

  222.         if (n == 1) {
  223.             return new double[]{yval[0]};
  224.         }

  225.         if (n == 2) {
  226.             return new double[]{yval[0], yval[1]};
  227.         }

  228.         int bandwidthInPoints = (int) (bandwidth * n);

  229.         if (bandwidthInPoints < 2) {
  230.             throw new NumberIsTooSmallException(LocalizedFormats.BANDWIDTH,
  231.                                                 bandwidthInPoints, 2, true);
  232.         }

  233.         final double[] res = new double[n];

  234.         final double[] residuals = new double[n];
  235.         final double[] sortedResiduals = new double[n];

  236.         final double[] robustnessWeights = new double[n];

  237.         // Do an initial fit and 'robustnessIters' robustness iterations.
  238.         // This is equivalent to doing 'robustnessIters+1' robustness iterations
  239.         // starting with all robustness weights set to 1.
  240.         Arrays.fill(robustnessWeights, 1);

  241.         for (int iter = 0; iter <= robustnessIters; ++iter) {
  242.             final int[] bandwidthInterval = {0, bandwidthInPoints - 1};
  243.             // At each x, compute a local weighted linear regression
  244.             for (int i = 0; i < n; ++i) {
  245.                 final double x = xval[i];

  246.                 // Find out the interval of source points on which
  247.                 // a regression is to be made.
  248.                 if (i > 0) {
  249.                     updateBandwidthInterval(xval, weights, i, bandwidthInterval);
  250.                 }

  251.                 final int ileft = bandwidthInterval[0];
  252.                 final int iright = bandwidthInterval[1];

  253.                 // Compute the point of the bandwidth interval that is
  254.                 // farthest from x
  255.                 final int edge;
  256.                 if (xval[i] - xval[ileft] > xval[iright] - xval[i]) {
  257.                     edge = ileft;
  258.                 } else {
  259.                     edge = iright;
  260.                 }

  261.                 // Compute a least-squares linear fit weighted by
  262.                 // the product of robustness weights and the tricube
  263.                 // weight function.
  264.                 // See http://en.wikipedia.org/wiki/Linear_regression
  265.                 // (section "Univariate linear case")
  266.                 // and http://en.wikipedia.org/wiki/Weighted_least_squares
  267.                 // (section "Weighted least squares")
  268.                 double sumWeights = 0;
  269.                 double sumX = 0;
  270.                 double sumXSquared = 0;
  271.                 double sumY = 0;
  272.                 double sumXY = 0;
  273.                 double denom = FastMath.abs(1.0 / (xval[edge] - x));
  274.                 for (int k = ileft; k <= iright; ++k) {
  275.                     final double xk   = xval[k];
  276.                     final double yk   = yval[k];
  277.                     final double dist = (k < i) ? x - xk : xk - x;
  278.                     final double w    = tricube(dist * denom) * robustnessWeights[k] * weights[k];
  279.                     final double xkw  = xk * w;
  280.                     sumWeights += w;
  281.                     sumX += xkw;
  282.                     sumXSquared += xk * xkw;
  283.                     sumY += yk * w;
  284.                     sumXY += yk * xkw;
  285.                 }

  286.                 final double meanX = sumX / sumWeights;
  287.                 final double meanY = sumY / sumWeights;
  288.                 final double meanXY = sumXY / sumWeights;
  289.                 final double meanXSquared = sumXSquared / sumWeights;

  290.                 final double beta;
  291.                 if (FastMath.sqrt(FastMath.abs(meanXSquared - meanX * meanX)) < accuracy) {
  292.                     beta = 0;
  293.                 } else {
  294.                     beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX);
  295.                 }

  296.                 final double alpha = meanY - beta * meanX;

  297.                 res[i] = beta * x + alpha;
  298.                 residuals[i] = FastMath.abs(yval[i] - res[i]);
  299.             }

  300.             // No need to recompute the robustness weights at the last
  301.             // iteration, they won't be needed anymore
  302.             if (iter == robustnessIters) {
  303.                 break;
  304.             }

  305.             // Recompute the robustness weights.

  306.             // Find the median residual.
  307.             // An arraycopy and a sort are completely tractable here,
  308.             // because the preceding loop is a lot more expensive
  309.             System.arraycopy(residuals, 0, sortedResiduals, 0, n);
  310.             Arrays.sort(sortedResiduals);
  311.             final double medianResidual = sortedResiduals[n / 2];

  312.             if (FastMath.abs(medianResidual) < accuracy) {
  313.                 break;
  314.             }

  315.             for (int i = 0; i < n; ++i) {
  316.                 final double arg = residuals[i] / (6 * medianResidual);
  317.                 if (arg >= 1) {
  318.                     robustnessWeights[i] = 0;
  319.                 } else {
  320.                     final double w = 1 - arg * arg;
  321.                     robustnessWeights[i] = w * w;
  322.                 }
  323.             }
  324.         }

  325.         return res;
  326.     }

  327.     /**
  328.      * Compute a loess fit on the data at the original abscissae.
  329.      *
  330.      * @param xval the arguments for the interpolation points
  331.      * @param yval the values for the interpolation points
  332.      * @return values of the loess fit at corresponding original abscissae
  333.      * @throws NonMonotonicSequenceException if {@code xval} not sorted in
  334.      * strictly increasing order.
  335.      * @throws DimensionMismatchException if {@code xval} and {@code yval} have
  336.      * different sizes.
  337.      * @throws NoDataException if {@code xval} or {@code yval} has zero size.
  338.      * @throws NotFiniteNumberException if any of the arguments and values are
  339.      * not finite real numbers.
  340.      * @throws NumberIsTooSmallException if the bandwidth is too small to
  341.      * accomodate the size of the input data (i.e. the bandwidth must be
  342.      * larger than 2/n).
  343.      */
  344.     public final double[] smooth(final double[] xval, final double[] yval)
  345.         throws NonMonotonicSequenceException,
  346.                DimensionMismatchException,
  347.                NoDataException,
  348.                NotFiniteNumberException,
  349.                NumberIsTooSmallException {
  350.         if (xval.length != yval.length) {
  351.             throw new DimensionMismatchException(xval.length, yval.length);
  352.         }

  353.         final double[] unitWeights = new double[xval.length];
  354.         Arrays.fill(unitWeights, 1.0);

  355.         return smooth(xval, yval, unitWeights);
  356.     }

  357.     /**
  358.      * Given an index interval into xval that embraces a certain number of
  359.      * points closest to {@code xval[i-1]}, update the interval so that it
  360.      * embraces the same number of points closest to {@code xval[i]},
  361.      * ignoring zero weights.
  362.      *
  363.      * @param xval Arguments array.
  364.      * @param weights Weights array.
  365.      * @param i Index around which the new interval should be computed.
  366.      * @param bandwidthInterval a two-element array {left, right} such that:
  367.      * {@code (left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])}
  368.      * and
  369.      * {@code (right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])}.
  370.      * The array will be updated.
  371.      */
  372.     private static void updateBandwidthInterval(final double[] xval, final double[] weights,
  373.                                                 final int i,
  374.                                                 final int[] bandwidthInterval) {
  375.         final int left = bandwidthInterval[0];
  376.         final int right = bandwidthInterval[1];

  377.         // The right edge should be adjusted if the next point to the right
  378.         // is closer to xval[i] than the leftmost point of the current interval
  379.         int nextRight = nextNonzero(weights, right);
  380.         if (nextRight < xval.length && xval[nextRight] - xval[i] < xval[i] - xval[left]) {
  381.             int nextLeft = nextNonzero(weights, bandwidthInterval[0]);
  382.             bandwidthInterval[0] = nextLeft;
  383.             bandwidthInterval[1] = nextRight;
  384.         }
  385.     }

  386.     /**
  387.      * Return the smallest index {@code j} such that
  388.      * {@code j > i && (j == weights.length || weights[j] != 0)}.
  389.      *
  390.      * @param weights Weights array.
  391.      * @param i Index from which to start search.
  392.      * @return the smallest compliant index.
  393.      */
  394.     private static int nextNonzero(final double[] weights, final int i) {
  395.         int j = i + 1;
  396.         while(j < weights.length && weights[j] == 0) {
  397.             ++j;
  398.         }
  399.         return j;
  400.     }

  401.     /**
  402.      * Compute the
  403.      * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a>
  404.      * weight function
  405.      *
  406.      * @param x Argument.
  407.      * @return <code>(1 - |x|<sup>3</sup>)<sup>3</sup></code> for |x| &lt; 1, 0 otherwise.
  408.      */
  409.     private static double tricube(final double x) {
  410.         final double absX = FastMath.abs(x);
  411.         if (absX >= 1.0) {
  412.             return 0.0;
  413.         }
  414.         final double tmp = 1 - absX * absX * absX;
  415.         return tmp * tmp * tmp;
  416.     }

  417.     /**
  418.      * Check that all elements of an array are finite real numbers.
  419.      *
  420.      * @param values Values array.
  421.      * @throws org.apache.commons.math3.exception.NotFiniteNumberException
  422.      * if one of the values is not a finite real number.
  423.      */
  424.     private static void checkAllFiniteReal(final double[] values) {
  425.         for (int i = 0; i < values.length; i++) {
  426.             MathUtils.checkFinite(values[i]);
  427.         }
  428.     }
  429. }