Neuron.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.ml.neuralnet;

  18. import java.io.Serializable;
  19. import java.io.ObjectInputStream;
  20. import java.util.concurrent.atomic.AtomicReference;
  21. import org.apache.commons.math3.util.Precision;
  22. import org.apache.commons.math3.exception.DimensionMismatchException;


  23. /**
  24.  * Describes a neuron element of a neural network.
  25.  *
  26.  * This class aims to be thread-safe.
  27.  *
  28.  * @since 3.3
  29.  */
  30. public class Neuron implements Serializable {
  31.     /** Serializable. */
  32.     private static final long serialVersionUID = 20130207L;
  33.     /** Identifier. */
  34.     private final long identifier;
  35.     /** Length of the feature set. */
  36.     private final int size;
  37.     /** Neuron data. */
  38.     private final AtomicReference<double[]> features;

  39.     /**
  40.      * Creates a neuron.
  41.      * The size of the feature set is fixed to the length of the given
  42.      * argument.
  43.      * <br/>
  44.      * Constructor is package-private: Neurons must be
  45.      * {@link Network#createNeuron(double[]) created} by the network
  46.      * instance to which they will belong.
  47.      *
  48.      * @param identifier Identifier (assigned by the {@link Network}).
  49.      * @param features Initial values of the feature set.
  50.      */
  51.     Neuron(long identifier,
  52.            double[] features) {
  53.         this.identifier = identifier;
  54.         this.size = features.length;
  55.         this.features = new AtomicReference<double[]>(features.clone());
  56.     }

  57.     /**
  58.      * Gets the neuron's identifier.
  59.      *
  60.      * @return the identifier.
  61.      */
  62.     public long getIdentifier() {
  63.         return identifier;
  64.     }

  65.     /**
  66.      * Gets the length of the feature set.
  67.      *
  68.      * @return the number of features.
  69.      */
  70.     public int getSize() {
  71.         return size;
  72.     }

  73.     /**
  74.      * Gets the neuron's features.
  75.      *
  76.      * @return a copy of the neuron's features.
  77.      */
  78.     public double[] getFeatures() {
  79.         return features.get().clone();
  80.     }

  81.     /**
  82.      * Tries to atomically update the neuron's features.
  83.      * Update will be performed only if the expected values match the
  84.      * current values.<br/>
  85.      * In effect, when concurrent threads call this method, the state
  86.      * could be modified by one, so that it does not correspond to the
  87.      * the state assumed by another.
  88.      * Typically, a caller {@link #getFeatures() retrieves the current state},
  89.      * and uses it to compute the new state.
  90.      * During this computation, another thread might have done the same
  91.      * thing, and updated the state: If the current thread were to proceed
  92.      * with its own update, it would overwrite the new state (which might
  93.      * already have been used by yet other threads).
  94.      * To prevent this, the method does not perform the update when a
  95.      * concurrent modification has been detected, and returns {@code false}.
  96.      * When this happens, the caller should fetch the new current state,
  97.      * redo its computation, and call this method again.
  98.      *
  99.      * @param expect Current values of the features, as assumed by the caller.
  100.      * Update will never succeed if the contents of this array does not match
  101.      * the values returned by {@link #getFeatures()}.
  102.      * @param update Features's new values.
  103.      * @return {@code true} if the update was successful, {@code false}
  104.      * otherwise.
  105.      * @throws DimensionMismatchException if the length of {@code update} is
  106.      * not the same as specified in the {@link #Neuron(long,double[])
  107.      * constructor}.
  108.      */
  109.     public boolean compareAndSetFeatures(double[] expect,
  110.                                          double[] update) {
  111.         if (update.length != size) {
  112.             throw new DimensionMismatchException(update.length, size);
  113.         }

  114.         // Get the internal reference. Note that this must not be a copy;
  115.         // otherwise the "compareAndSet" below will always fail.
  116.         final double[] current = features.get();
  117.         if (!containSameValues(current, expect)) {
  118.             // Some other thread already modified the state.
  119.             return false;
  120.         }

  121.         if (features.compareAndSet(current, update.clone())) {
  122.             // The current thread could atomically update the state.
  123.             return true;
  124.         } else {
  125.             // Some other thread came first.
  126.             return false;
  127.         }
  128.     }

  129.     /**
  130.      * Checks whether the contents of both arrays is the same.
  131.      *
  132.      * @param current Current values.
  133.      * @param expect Expected values.
  134.      * @throws DimensionMismatchException if the length of {@code expected}
  135.      * is not the same as specified in the {@link #Neuron(long,double[])
  136.      * constructor}.
  137.      * @return {@code true} if the arrays contain the same values.
  138.      */
  139.     private boolean containSameValues(double[] current,
  140.                                       double[] expect) {
  141.         if (expect.length != size) {
  142.             throw new DimensionMismatchException(expect.length, size);
  143.         }

  144.         for (int i = 0; i < size; i++) {
  145.             if (!Precision.equals(current[i], expect[i])) {
  146.                 return false;
  147.             }
  148.         }
  149.         return true;
  150.     }

  151.     /**
  152.      * Prevents proxy bypass.
  153.      *
  154.      * @param in Input stream.
  155.      */
  156.     private void readObject(ObjectInputStream in) {
  157.         throw new IllegalStateException();
  158.     }

  159.     /**
  160.      * Custom serialization.
  161.      *
  162.      * @return the proxy instance that will be actually serialized.
  163.      */
  164.     private Object writeReplace() {
  165.         return new SerializationProxy(identifier,
  166.                                       features.get());
  167.     }

  168.     /**
  169.      * Serialization.
  170.      */
  171.     private static class SerializationProxy implements Serializable {
  172.         /** Serializable. */
  173.         private static final long serialVersionUID = 20130207L;
  174.         /** Features. */
  175.         private final double[] features;
  176.         /** Identifier. */
  177.         private final long identifier;

  178.         /**
  179.          * @param identifier Identifier.
  180.          * @param features Features.
  181.          */
  182.         SerializationProxy(long identifier,
  183.                            double[] features) {
  184.             this.identifier = identifier;
  185.             this.features = features;
  186.         }

  187.         /**
  188.          * Custom serialization.
  189.          *
  190.          * @return the {@link Neuron} for which this instance is the proxy.
  191.          */
  192.         private Object readResolve() {
  193.             return new Neuron(identifier,
  194.                               features);
  195.         }
  196.     }
  197. }