NelderMeadSimplex.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;

  18. import java.util.Comparator;

  19. import org.apache.commons.math3.optim.PointValuePair;
  20. import org.apache.commons.math3.analysis.MultivariateFunction;

  21. /**
  22.  * This class implements the Nelder-Mead simplex algorithm.
  23.  *
  24.  * @since 3.0
  25.  */
  26. public class NelderMeadSimplex extends AbstractSimplex {
  27.     /** Default value for {@link #rho}: {@value}. */
  28.     private static final double DEFAULT_RHO = 1;
  29.     /** Default value for {@link #khi}: {@value}. */
  30.     private static final double DEFAULT_KHI = 2;
  31.     /** Default value for {@link #gamma}: {@value}. */
  32.     private static final double DEFAULT_GAMMA = 0.5;
  33.     /** Default value for {@link #sigma}: {@value}. */
  34.     private static final double DEFAULT_SIGMA = 0.5;
  35.     /** Reflection coefficient. */
  36.     private final double rho;
  37.     /** Expansion coefficient. */
  38.     private final double khi;
  39.     /** Contraction coefficient. */
  40.     private final double gamma;
  41.     /** Shrinkage coefficient. */
  42.     private final double sigma;

  43.     /**
  44.      * Build a Nelder-Mead simplex with default coefficients.
  45.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  46.      * for both gamma and sigma.
  47.      *
  48.      * @param n Dimension of the simplex.
  49.      */
  50.     public NelderMeadSimplex(final int n) {
  51.         this(n, 1d);
  52.     }

  53.     /**
  54.      * Build a Nelder-Mead simplex with default coefficients.
  55.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  56.      * for both gamma and sigma.
  57.      *
  58.      * @param n Dimension of the simplex.
  59.      * @param sideLength Length of the sides of the default (hypercube)
  60.      * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
  61.      */
  62.     public NelderMeadSimplex(final int n, double sideLength) {
  63.         this(n, sideLength,
  64.              DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
  65.     }

  66.     /**
  67.      * Build a Nelder-Mead simplex with specified coefficients.
  68.      *
  69.      * @param n Dimension of the simplex. See
  70.      * {@link AbstractSimplex#AbstractSimplex(int,double)}.
  71.      * @param sideLength Length of the sides of the default (hypercube)
  72.      * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
  73.      * @param rho Reflection coefficient.
  74.      * @param khi Expansion coefficient.
  75.      * @param gamma Contraction coefficient.
  76.      * @param sigma Shrinkage coefficient.
  77.      */
  78.     public NelderMeadSimplex(final int n, double sideLength,
  79.                              final double rho, final double khi,
  80.                              final double gamma, final double sigma) {
  81.         super(n, sideLength);

  82.         this.rho = rho;
  83.         this.khi = khi;
  84.         this.gamma = gamma;
  85.         this.sigma = sigma;
  86.     }

  87.     /**
  88.      * Build a Nelder-Mead simplex with specified coefficients.
  89.      *
  90.      * @param n Dimension of the simplex. See
  91.      * {@link AbstractSimplex#AbstractSimplex(int)}.
  92.      * @param rho Reflection coefficient.
  93.      * @param khi Expansion coefficient.
  94.      * @param gamma Contraction coefficient.
  95.      * @param sigma Shrinkage coefficient.
  96.      */
  97.     public NelderMeadSimplex(final int n,
  98.                              final double rho, final double khi,
  99.                              final double gamma, final double sigma) {
  100.         this(n, 1d, rho, khi, gamma, sigma);
  101.     }

  102.     /**
  103.      * Build a Nelder-Mead simplex with default coefficients.
  104.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  105.      * for both gamma and sigma.
  106.      *
  107.      * @param steps Steps along the canonical axes representing box edges.
  108.      * They may be negative but not zero. See
  109.      */
  110.     public NelderMeadSimplex(final double[] steps) {
  111.         this(steps, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
  112.     }

  113.     /**
  114.      * Build a Nelder-Mead simplex with specified coefficients.
  115.      *
  116.      * @param steps Steps along the canonical axes representing box edges.
  117.      * They may be negative but not zero. See
  118.      * {@link AbstractSimplex#AbstractSimplex(double[])}.
  119.      * @param rho Reflection coefficient.
  120.      * @param khi Expansion coefficient.
  121.      * @param gamma Contraction coefficient.
  122.      * @param sigma Shrinkage coefficient.
  123.      * @throws IllegalArgumentException if one of the steps is zero.
  124.      */
  125.     public NelderMeadSimplex(final double[] steps,
  126.                              final double rho, final double khi,
  127.                              final double gamma, final double sigma) {
  128.         super(steps);

  129.         this.rho = rho;
  130.         this.khi = khi;
  131.         this.gamma = gamma;
  132.         this.sigma = sigma;
  133.     }

  134.     /**
  135.      * Build a Nelder-Mead simplex with default coefficients.
  136.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  137.      * for both gamma and sigma.
  138.      *
  139.      * @param referenceSimplex Reference simplex. See
  140.      * {@link AbstractSimplex#AbstractSimplex(double[][])}.
  141.      */
  142.     public NelderMeadSimplex(final double[][] referenceSimplex) {
  143.         this(referenceSimplex, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
  144.     }

  145.     /**
  146.      * Build a Nelder-Mead simplex with specified coefficients.
  147.      *
  148.      * @param referenceSimplex Reference simplex. See
  149.      * {@link AbstractSimplex#AbstractSimplex(double[][])}.
  150.      * @param rho Reflection coefficient.
  151.      * @param khi Expansion coefficient.
  152.      * @param gamma Contraction coefficient.
  153.      * @param sigma Shrinkage coefficient.
  154.      * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
  155.      * if the reference simplex does not contain at least one point.
  156.      * @throws org.apache.commons.math3.exception.DimensionMismatchException
  157.      * if there is a dimension mismatch in the reference simplex.
  158.      */
  159.     public NelderMeadSimplex(final double[][] referenceSimplex,
  160.                              final double rho, final double khi,
  161.                              final double gamma, final double sigma) {
  162.         super(referenceSimplex);

  163.         this.rho = rho;
  164.         this.khi = khi;
  165.         this.gamma = gamma;
  166.         this.sigma = sigma;
  167.     }

  168.     /** {@inheritDoc} */
  169.     @Override
  170.     public void iterate(final MultivariateFunction evaluationFunction,
  171.                         final Comparator<PointValuePair> comparator) {
  172.         // The simplex has n + 1 points if dimension is n.
  173.         final int n = getDimension();

  174.         // Interesting values.
  175.         final PointValuePair best = getPoint(0);
  176.         final PointValuePair secondBest = getPoint(n - 1);
  177.         final PointValuePair worst = getPoint(n);
  178.         final double[] xWorst = worst.getPointRef();

  179.         // Compute the centroid of the best vertices (dismissing the worst
  180.         // point at index n).
  181.         final double[] centroid = new double[n];
  182.         for (int i = 0; i < n; i++) {
  183.             final double[] x = getPoint(i).getPointRef();
  184.             for (int j = 0; j < n; j++) {
  185.                 centroid[j] += x[j];
  186.             }
  187.         }
  188.         final double scaling = 1.0 / n;
  189.         for (int j = 0; j < n; j++) {
  190.             centroid[j] *= scaling;
  191.         }

  192.         // compute the reflection point
  193.         final double[] xR = new double[n];
  194.         for (int j = 0; j < n; j++) {
  195.             xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
  196.         }
  197.         final PointValuePair reflected
  198.             = new PointValuePair(xR, evaluationFunction.value(xR), false);

  199.         if (comparator.compare(best, reflected) <= 0 &&
  200.             comparator.compare(reflected, secondBest) < 0) {
  201.             // Accept the reflected point.
  202.             replaceWorstPoint(reflected, comparator);
  203.         } else if (comparator.compare(reflected, best) < 0) {
  204.             // Compute the expansion point.
  205.             final double[] xE = new double[n];
  206.             for (int j = 0; j < n; j++) {
  207.                 xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
  208.             }
  209.             final PointValuePair expanded
  210.                 = new PointValuePair(xE, evaluationFunction.value(xE), false);

  211.             if (comparator.compare(expanded, reflected) < 0) {
  212.                 // Accept the expansion point.
  213.                 replaceWorstPoint(expanded, comparator);
  214.             } else {
  215.                 // Accept the reflected point.
  216.                 replaceWorstPoint(reflected, comparator);
  217.             }
  218.         } else {
  219.             if (comparator.compare(reflected, worst) < 0) {
  220.                 // Perform an outside contraction.
  221.                 final double[] xC = new double[n];
  222.                 for (int j = 0; j < n; j++) {
  223.                     xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
  224.                 }
  225.                 final PointValuePair outContracted
  226.                     = new PointValuePair(xC, evaluationFunction.value(xC), false);
  227.                 if (comparator.compare(outContracted, reflected) <= 0) {
  228.                     // Accept the contraction point.
  229.                     replaceWorstPoint(outContracted, comparator);
  230.                     return;
  231.                 }
  232.             } else {
  233.                 // Perform an inside contraction.
  234.                 final double[] xC = new double[n];
  235.                 for (int j = 0; j < n; j++) {
  236.                     xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
  237.                 }
  238.                 final PointValuePair inContracted
  239.                     = new PointValuePair(xC, evaluationFunction.value(xC), false);

  240.                 if (comparator.compare(inContracted, worst) < 0) {
  241.                     // Accept the contraction point.
  242.                     replaceWorstPoint(inContracted, comparator);
  243.                     return;
  244.                 }
  245.             }

  246.             // Perform a shrink.
  247.             final double[] xSmallest = getPoint(0).getPointRef();
  248.             for (int i = 1; i <= n; i++) {
  249.                 final double[] x = getPoint(i).getPoint();
  250.                 for (int j = 0; j < n; j++) {
  251.                     x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
  252.                 }
  253.                 setPoint(i, new PointValuePair(x, Double.NaN, false));
  254.             }
  255.             evaluate(evaluationFunction, comparator);
  256.         }
  257.     }
  258. }