NelderMeadSimplex.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.optimization.direct;

  18. import java.util.Comparator;

  19. import org.apache.commons.math3.optimization.PointValuePair;
  20. import org.apache.commons.math3.analysis.MultivariateFunction;

  21. /**
  22.  * This class implements the Nelder-Mead simplex algorithm.
  23.  *
  24.  * @deprecated As of 3.1 (to be removed in 4.0).
  25.  * @since 3.0
  26.  */
  27. @Deprecated
  28. public class NelderMeadSimplex extends AbstractSimplex {
  29.     /** Default value for {@link #rho}: {@value}. */
  30.     private static final double DEFAULT_RHO = 1;
  31.     /** Default value for {@link #khi}: {@value}. */
  32.     private static final double DEFAULT_KHI = 2;
  33.     /** Default value for {@link #gamma}: {@value}. */
  34.     private static final double DEFAULT_GAMMA = 0.5;
  35.     /** Default value for {@link #sigma}: {@value}. */
  36.     private static final double DEFAULT_SIGMA = 0.5;
  37.     /** Reflection coefficient. */
  38.     private final double rho;
  39.     /** Expansion coefficient. */
  40.     private final double khi;
  41.     /** Contraction coefficient. */
  42.     private final double gamma;
  43.     /** Shrinkage coefficient. */
  44.     private final double sigma;

  45.     /**
  46.      * Build a Nelder-Mead simplex with default coefficients.
  47.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  48.      * for both gamma and sigma.
  49.      *
  50.      * @param n Dimension of the simplex.
  51.      */
  52.     public NelderMeadSimplex(final int n) {
  53.         this(n, 1d);
  54.     }

  55.     /**
  56.      * Build a Nelder-Mead simplex with default coefficients.
  57.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  58.      * for both gamma and sigma.
  59.      *
  60.      * @param n Dimension of the simplex.
  61.      * @param sideLength Length of the sides of the default (hypercube)
  62.      * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
  63.      */
  64.     public NelderMeadSimplex(final int n, double sideLength) {
  65.         this(n, sideLength,
  66.              DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
  67.     }

  68.     /**
  69.      * Build a Nelder-Mead simplex with specified coefficients.
  70.      *
  71.      * @param n Dimension of the simplex. See
  72.      * {@link AbstractSimplex#AbstractSimplex(int,double)}.
  73.      * @param sideLength Length of the sides of the default (hypercube)
  74.      * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
  75.      * @param rho Reflection coefficient.
  76.      * @param khi Expansion coefficient.
  77.      * @param gamma Contraction coefficient.
  78.      * @param sigma Shrinkage coefficient.
  79.      */
  80.     public NelderMeadSimplex(final int n, double sideLength,
  81.                              final double rho, final double khi,
  82.                              final double gamma, final double sigma) {
  83.         super(n, sideLength);

  84.         this.rho = rho;
  85.         this.khi = khi;
  86.         this.gamma = gamma;
  87.         this.sigma = sigma;
  88.     }

  89.     /**
  90.      * Build a Nelder-Mead simplex with specified coefficients.
  91.      *
  92.      * @param n Dimension of the simplex. See
  93.      * {@link AbstractSimplex#AbstractSimplex(int)}.
  94.      * @param rho Reflection coefficient.
  95.      * @param khi Expansion coefficient.
  96.      * @param gamma Contraction coefficient.
  97.      * @param sigma Shrinkage coefficient.
  98.      */
  99.     public NelderMeadSimplex(final int n,
  100.                              final double rho, final double khi,
  101.                              final double gamma, final double sigma) {
  102.         this(n, 1d, rho, khi, gamma, sigma);
  103.     }

  104.     /**
  105.      * Build a Nelder-Mead simplex with default coefficients.
  106.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  107.      * for both gamma and sigma.
  108.      *
  109.      * @param steps Steps along the canonical axes representing box edges.
  110.      * They may be negative but not zero. See
  111.      */
  112.     public NelderMeadSimplex(final double[] steps) {
  113.         this(steps, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
  114.     }

  115.     /**
  116.      * Build a Nelder-Mead simplex with specified coefficients.
  117.      *
  118.      * @param steps Steps along the canonical axes representing box edges.
  119.      * They may be negative but not zero. See
  120.      * {@link AbstractSimplex#AbstractSimplex(double[])}.
  121.      * @param rho Reflection coefficient.
  122.      * @param khi Expansion coefficient.
  123.      * @param gamma Contraction coefficient.
  124.      * @param sigma Shrinkage coefficient.
  125.      * @throws IllegalArgumentException if one of the steps is zero.
  126.      */
  127.     public NelderMeadSimplex(final double[] steps,
  128.                              final double rho, final double khi,
  129.                              final double gamma, final double sigma) {
  130.         super(steps);

  131.         this.rho = rho;
  132.         this.khi = khi;
  133.         this.gamma = gamma;
  134.         this.sigma = sigma;
  135.     }

  136.     /**
  137.      * Build a Nelder-Mead simplex with default coefficients.
  138.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  139.      * for both gamma and sigma.
  140.      *
  141.      * @param referenceSimplex Reference simplex. See
  142.      * {@link AbstractSimplex#AbstractSimplex(double[][])}.
  143.      */
  144.     public NelderMeadSimplex(final double[][] referenceSimplex) {
  145.         this(referenceSimplex, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
  146.     }

  147.     /**
  148.      * Build a Nelder-Mead simplex with specified coefficients.
  149.      *
  150.      * @param referenceSimplex Reference simplex. See
  151.      * {@link AbstractSimplex#AbstractSimplex(double[][])}.
  152.      * @param rho Reflection coefficient.
  153.      * @param khi Expansion coefficient.
  154.      * @param gamma Contraction coefficient.
  155.      * @param sigma Shrinkage coefficient.
  156.      * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
  157.      * if the reference simplex does not contain at least one point.
  158.      * @throws org.apache.commons.math3.exception.DimensionMismatchException
  159.      * if there is a dimension mismatch in the reference simplex.
  160.      */
  161.     public NelderMeadSimplex(final double[][] referenceSimplex,
  162.                              final double rho, final double khi,
  163.                              final double gamma, final double sigma) {
  164.         super(referenceSimplex);

  165.         this.rho = rho;
  166.         this.khi = khi;
  167.         this.gamma = gamma;
  168.         this.sigma = sigma;
  169.     }

  170.     /** {@inheritDoc} */
  171.     @Override
  172.     public void iterate(final MultivariateFunction evaluationFunction,
  173.                         final Comparator<PointValuePair> comparator) {
  174.         // The simplex has n + 1 points if dimension is n.
  175.         final int n = getDimension();

  176.         // Interesting values.
  177.         final PointValuePair best = getPoint(0);
  178.         final PointValuePair secondBest = getPoint(n - 1);
  179.         final PointValuePair worst = getPoint(n);
  180.         final double[] xWorst = worst.getPointRef();

  181.         // Compute the centroid of the best vertices (dismissing the worst
  182.         // point at index n).
  183.         final double[] centroid = new double[n];
  184.         for (int i = 0; i < n; i++) {
  185.             final double[] x = getPoint(i).getPointRef();
  186.             for (int j = 0; j < n; j++) {
  187.                 centroid[j] += x[j];
  188.             }
  189.         }
  190.         final double scaling = 1.0 / n;
  191.         for (int j = 0; j < n; j++) {
  192.             centroid[j] *= scaling;
  193.         }

  194.         // compute the reflection point
  195.         final double[] xR = new double[n];
  196.         for (int j = 0; j < n; j++) {
  197.             xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
  198.         }
  199.         final PointValuePair reflected
  200.             = new PointValuePair(xR, evaluationFunction.value(xR), false);

  201.         if (comparator.compare(best, reflected) <= 0 &&
  202.             comparator.compare(reflected, secondBest) < 0) {
  203.             // Accept the reflected point.
  204.             replaceWorstPoint(reflected, comparator);
  205.         } else if (comparator.compare(reflected, best) < 0) {
  206.             // Compute the expansion point.
  207.             final double[] xE = new double[n];
  208.             for (int j = 0; j < n; j++) {
  209.                 xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
  210.             }
  211.             final PointValuePair expanded
  212.                 = new PointValuePair(xE, evaluationFunction.value(xE), false);

  213.             if (comparator.compare(expanded, reflected) < 0) {
  214.                 // Accept the expansion point.
  215.                 replaceWorstPoint(expanded, comparator);
  216.             } else {
  217.                 // Accept the reflected point.
  218.                 replaceWorstPoint(reflected, comparator);
  219.             }
  220.         } else {
  221.             if (comparator.compare(reflected, worst) < 0) {
  222.                 // Perform an outside contraction.
  223.                 final double[] xC = new double[n];
  224.                 for (int j = 0; j < n; j++) {
  225.                     xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
  226.                 }
  227.                 final PointValuePair outContracted
  228.                     = new PointValuePair(xC, evaluationFunction.value(xC), false);
  229.                 if (comparator.compare(outContracted, reflected) <= 0) {
  230.                     // Accept the contraction point.
  231.                     replaceWorstPoint(outContracted, comparator);
  232.                     return;
  233.                 }
  234.             } else {
  235.                 // Perform an inside contraction.
  236.                 final double[] xC = new double[n];
  237.                 for (int j = 0; j < n; j++) {
  238.                     xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
  239.                 }
  240.                 final PointValuePair inContracted
  241.                     = new PointValuePair(xC, evaluationFunction.value(xC), false);

  242.                 if (comparator.compare(inContracted, worst) < 0) {
  243.                     // Accept the contraction point.
  244.                     replaceWorstPoint(inContracted, comparator);
  245.                     return;
  246.                 }
  247.             }

  248.             // Perform a shrink.
  249.             final double[] xSmallest = getPoint(0).getPointRef();
  250.             for (int i = 1; i <= n; i++) {
  251.                 final double[] x = getPoint(i).getPoint();
  252.                 for (int j = 0; j < n; j++) {
  253.                     x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
  254.                 }
  255.                 setPoint(i, new PointValuePair(x, Double.NaN, false));
  256.             }
  257.             evaluate(evaluationFunction, comparator);
  258.         }
  259.     }
  260. }