Network.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math3.ml.neuralnet;

  18. import java.io.Serializable;
  19. import java.io.ObjectInputStream;
  20. import java.util.NoSuchElementException;
  21. import java.util.List;
  22. import java.util.ArrayList;
  23. import java.util.Set;
  24. import java.util.HashSet;
  25. import java.util.Collection;
  26. import java.util.Iterator;
  27. import java.util.Comparator;
  28. import java.util.Collections;
  29. import java.util.concurrent.ConcurrentHashMap;
  30. import java.util.concurrent.atomic.AtomicLong;
  31. import org.apache.commons.math3.exception.DimensionMismatchException;
  32. import org.apache.commons.math3.exception.MathIllegalStateException;

  33. /**
  34.  * Neural network, composed of {@link Neuron} instances and the links
  35.  * between them.
  36.  *
  37.  * Although updating a neuron's state is thread-safe, modifying the
  38.  * network's topology (adding or removing links) is not.
  39.  *
  40.  * @since 3.3
  41.  */
  42. public class Network
  43.     implements Iterable<Neuron>,
  44.                Serializable {
  45.     /** Serializable. */
  46.     private static final long serialVersionUID = 20130207L;
  47.     /** Neurons. */
  48.     private final ConcurrentHashMap<Long, Neuron> neuronMap
  49.         = new ConcurrentHashMap<Long, Neuron>();
  50.     /** Next available neuron identifier. */
  51.     private final AtomicLong nextId;
  52.     /** Neuron's features set size. */
  53.     private final int featureSize;
  54.     /** Links. */
  55.     private final ConcurrentHashMap<Long, Set<Long>> linkMap
  56.         = new ConcurrentHashMap<Long, Set<Long>>();

  57.     /**
  58.      * Comparator that prescribes an order of the neurons according
  59.      * to the increasing order of their identifier.
  60.      */
  61.     public static class NeuronIdentifierComparator
  62.         implements Comparator<Neuron>,
  63.                    Serializable {
  64.         /** Version identifier. */
  65.         private static final long serialVersionUID = 20130207L;

  66.         /** {@inheritDoc} */
  67.         public int compare(Neuron a,
  68.                            Neuron b) {
  69.             final long aId = a.getIdentifier();
  70.             final long bId = b.getIdentifier();
  71.             return aId < bId ? -1 :
  72.                 aId > bId ? 1 : 0;
  73.         }
  74.     }

  75.     /**
  76.      * Constructor with restricted access, solely used for deserialization.
  77.      *
  78.      * @param nextId Next available identifier.
  79.      * @param featureSize Number of features.
  80.      * @param neuronList Neurons.
  81.      * @param neighbourIdList Links associated to each of the neurons in
  82.      * {@code neuronList}.
  83.      * @throws MathIllegalStateException if an inconsistency is detected
  84.      * (which probably means that the serialized form has been corrupted).
  85.      */
  86.     Network(long nextId,
  87.             int featureSize,
  88.             Neuron[] neuronList,
  89.             long[][] neighbourIdList) {
  90.         final int numNeurons = neuronList.length;
  91.         if (numNeurons != neighbourIdList.length) {
  92.             throw new MathIllegalStateException();
  93.         }

  94.         for (int i = 0; i < numNeurons; i++) {
  95.             final Neuron n = neuronList[i];
  96.             final long id = n.getIdentifier();
  97.             if (id >= nextId) {
  98.                 throw new MathIllegalStateException();
  99.             }
  100.             neuronMap.put(id, n);
  101.             linkMap.put(id, new HashSet<Long>());
  102.         }

  103.         for (int i = 0; i < numNeurons; i++) {
  104.             final long aId = neuronList[i].getIdentifier();
  105.             final Set<Long> aLinks = linkMap.get(aId);
  106.             for (Long bId : neighbourIdList[i]) {
  107.                 if (neuronMap.get(bId) == null) {
  108.                     throw new MathIllegalStateException();
  109.                 }
  110.                 addLinkToLinkSet(aLinks, bId);
  111.             }
  112.         }

  113.         this.nextId = new AtomicLong(nextId);
  114.         this.featureSize = featureSize;
  115.     }

  116.     /**
  117.      * @param initialIdentifier Identifier for the first neuron that
  118.      * will be added to this network.
  119.      * @param featureSize Size of the neuron's features.
  120.      */
  121.     public Network(long initialIdentifier,
  122.                    int featureSize) {
  123.         nextId = new AtomicLong(initialIdentifier);
  124.         this.featureSize = featureSize;
  125.     }

  126.     /**
  127.      * {@inheritDoc}
  128.      */
  129.     public Iterator<Neuron> iterator() {
  130.         return neuronMap.values().iterator();
  131.     }

  132.     /**
  133.      * Creates a list of the neurons, sorted in a custom order.
  134.      *
  135.      * @param comparator {@link Comparator} used for sorting the neurons.
  136.      * @return a list of neurons, sorted in the order prescribed by the
  137.      * given {@code comparator}.
  138.      * @see NeuronIdentifierComparator
  139.      */
  140.     public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) {
  141.         final List<Neuron> neurons = new ArrayList<Neuron>();
  142.         neurons.addAll(neuronMap.values());

  143.         Collections.sort(neurons, comparator);

  144.         return neurons;
  145.     }

  146.     /**
  147.      * Creates a neuron and assigns it a unique identifier.
  148.      *
  149.      * @param features Initial values for the neuron's features.
  150.      * @return the neuron's identifier.
  151.      * @throws DimensionMismatchException if the length of {@code features}
  152.      * is different from the expected size (as set by the
  153.      * {@link #Network(long,int) constructor}).
  154.      */
  155.     public long createNeuron(double[] features) {
  156.         if (features.length != featureSize) {
  157.             throw new DimensionMismatchException(features.length, featureSize);
  158.         }

  159.         final long id = createNextId();
  160.         neuronMap.put(id, new Neuron(id, features));
  161.         linkMap.put(id, new HashSet<Long>());
  162.         return id;
  163.     }

  164.     /**
  165.      * Deletes a neuron.
  166.      * Links from all neighbours to the removed neuron will also be
  167.      * {@link #deleteLink(Neuron,Neuron) deleted}.
  168.      *
  169.      * @param neuron Neuron to be removed from this network.
  170.      * @throws NoSuchElementException if {@code n} does not belong to
  171.      * this network.
  172.      */
  173.     public void deleteNeuron(Neuron neuron) {
  174.         final Collection<Neuron> neighbours = getNeighbours(neuron);

  175.         // Delete links to from neighbours.
  176.         for (Neuron n : neighbours) {
  177.             deleteLink(n, neuron);
  178.         }

  179.         // Remove neuron.
  180.         neuronMap.remove(neuron.getIdentifier());
  181.     }

  182.     /**
  183.      * Gets the size of the neurons' features set.
  184.      *
  185.      * @return the size of the features set.
  186.      */
  187.     public int getFeaturesSize() {
  188.         return featureSize;
  189.     }

  190.     /**
  191.      * Adds a link from neuron {@code a} to neuron {@code b}.
  192.      * Note: the link is not bi-directional; if a bi-directional link is
  193.      * required, an additional call must be made with {@code a} and
  194.      * {@code b} exchanged in the argument list.
  195.      *
  196.      * @param a Neuron.
  197.      * @param b Neuron.
  198.      * @throws NoSuchElementException if the neurons do not exist in the
  199.      * network.
  200.      */
  201.     public void addLink(Neuron a,
  202.                         Neuron b) {
  203.         final long aId = a.getIdentifier();
  204.         final long bId = b.getIdentifier();

  205.         // Check that the neurons belong to this network.
  206.         if (a != getNeuron(aId)) {
  207.             throw new NoSuchElementException(Long.toString(aId));
  208.         }
  209.         if (b != getNeuron(bId)) {
  210.             throw new NoSuchElementException(Long.toString(bId));
  211.         }

  212.         // Add link from "a" to "b".
  213.         addLinkToLinkSet(linkMap.get(aId), bId);
  214.     }

  215.     /**
  216.      * Adds a link to neuron {@code id} in given {@code linkSet}.
  217.      * Note: no check verifies that the identifier indeed belongs
  218.      * to this network.
  219.      *
  220.      * @param linkSet Neuron identifier.
  221.      * @param id Neuron identifier.
  222.      */
  223.     private void addLinkToLinkSet(Set<Long> linkSet,
  224.                                   long id) {
  225.         linkSet.add(id);
  226.     }

  227.     /**
  228.      * Deletes the link between neurons {@code a} and {@code b}.
  229.      *
  230.      * @param a Neuron.
  231.      * @param b Neuron.
  232.      * @throws NoSuchElementException if the neurons do not exist in the
  233.      * network.
  234.      */
  235.     public void deleteLink(Neuron a,
  236.                            Neuron b) {
  237.         final long aId = a.getIdentifier();
  238.         final long bId = b.getIdentifier();

  239.         // Check that the neurons belong to this network.
  240.         if (a != getNeuron(aId)) {
  241.             throw new NoSuchElementException(Long.toString(aId));
  242.         }
  243.         if (b != getNeuron(bId)) {
  244.             throw new NoSuchElementException(Long.toString(bId));
  245.         }

  246.         // Delete link from "a" to "b".
  247.         deleteLinkFromLinkSet(linkMap.get(aId), bId);
  248.     }

  249.     /**
  250.      * Deletes a link to neuron {@code id} in given {@code linkSet}.
  251.      * Note: no check verifies that the identifier indeed belongs
  252.      * to this network.
  253.      *
  254.      * @param linkSet Neuron identifier.
  255.      * @param id Neuron identifier.
  256.      */
  257.     private void deleteLinkFromLinkSet(Set<Long> linkSet,
  258.                                        long id) {
  259.         linkSet.remove(id);
  260.     }

  261.     /**
  262.      * Retrieves the neuron with the given (unique) {@code id}.
  263.      *
  264.      * @param id Identifier.
  265.      * @return the neuron associated with the given {@code id}.
  266.      * @throws NoSuchElementException if the neuron does not exist in the
  267.      * network.
  268.      */
  269.     public Neuron getNeuron(long id) {
  270.         final Neuron n = neuronMap.get(id);
  271.         if (n == null) {
  272.             throw new NoSuchElementException(Long.toString(id));
  273.         }
  274.         return n;
  275.     }

  276.     /**
  277.      * Retrieves the neurons in the neighbourhood of any neuron in the
  278.      * {@code neurons} list.
  279.      * @param neurons Neurons for which to retrieve the neighbours.
  280.      * @return the list of neighbours.
  281.      * @see #getNeighbours(Iterable,Iterable)
  282.      */
  283.     public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
  284.         return getNeighbours(neurons, null);
  285.     }

  286.     /**
  287.      * Retrieves the neurons in the neighbourhood of any neuron in the
  288.      * {@code neurons} list.
  289.      * The {@code exclude} list allows to retrieve the "concentric"
  290.      * neighbourhoods by removing the neurons that belong to the inner
  291.      * "circles".
  292.      *
  293.      * @param neurons Neurons for which to retrieve the neighbours.
  294.      * @param exclude Neurons to exclude from the returned list.
  295.      * Can be {@code null}.
  296.      * @return the list of neighbours.
  297.      */
  298.     public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
  299.                                             Iterable<Neuron> exclude) {
  300.         final Set<Long> idList = new HashSet<Long>();

  301.         for (Neuron n : neurons) {
  302.             idList.addAll(linkMap.get(n.getIdentifier()));
  303.         }
  304.         if (exclude != null) {
  305.             for (Neuron n : exclude) {
  306.                 idList.remove(n.getIdentifier());
  307.             }
  308.         }

  309.         final List<Neuron> neuronList = new ArrayList<Neuron>();
  310.         for (Long id : idList) {
  311.             neuronList.add(getNeuron(id));
  312.         }

  313.         return neuronList;
  314.     }

  315.     /**
  316.      * Retrieves the neighbours of the given neuron.
  317.      *
  318.      * @param neuron Neuron for which to retrieve the neighbours.
  319.      * @return the list of neighbours.
  320.      * @see #getNeighbours(Neuron,Iterable)
  321.      */
  322.     public Collection<Neuron> getNeighbours(Neuron neuron) {
  323.         return getNeighbours(neuron, null);
  324.     }

  325.     /**
  326.      * Retrieves the neighbours of the given neuron.
  327.      *
  328.      * @param neuron Neuron for which to retrieve the neighbours.
  329.      * @param exclude Neurons to exclude from the returned list.
  330.      * Can be {@code null}.
  331.      * @return the list of neighbours.
  332.      */
  333.     public Collection<Neuron> getNeighbours(Neuron neuron,
  334.                                             Iterable<Neuron> exclude) {
  335.         final Set<Long> idList = linkMap.get(neuron.getIdentifier());
  336.         if (exclude != null) {
  337.             for (Neuron n : exclude) {
  338.                 idList.remove(n.getIdentifier());
  339.             }
  340.         }

  341.         final List<Neuron> neuronList = new ArrayList<Neuron>();
  342.         for (Long id : idList) {
  343.             neuronList.add(getNeuron(id));
  344.         }

  345.         return neuronList;
  346.     }

  347.     /**
  348.      * Creates a neuron identifier.
  349.      *
  350.      * @return a value that will serve as a unique identifier.
  351.      */
  352.     private Long createNextId() {
  353.         return nextId.getAndIncrement();
  354.     }

  355.     /**
  356.      * Prevents proxy bypass.
  357.      *
  358.      * @param in Input stream.
  359.      */
  360.     private void readObject(ObjectInputStream in) {
  361.         throw new IllegalStateException();
  362.     }

  363.     /**
  364.      * Custom serialization.
  365.      *
  366.      * @return the proxy instance that will be actually serialized.
  367.      */
  368.     private Object writeReplace() {
  369.         final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]);
  370.         final long[][] neighbourIdList = new long[neuronList.length][];

  371.         for (int i = 0; i < neuronList.length; i++) {
  372.             final Collection<Neuron> neighbours = getNeighbours(neuronList[i]);
  373.             final long[] neighboursId = new long[neighbours.size()];
  374.             int count = 0;
  375.             for (Neuron n : neighbours) {
  376.                 neighboursId[count] = n.getIdentifier();
  377.                 ++count;
  378.             }
  379.             neighbourIdList[i] = neighboursId;
  380.         }

  381.         return new SerializationProxy(nextId.get(),
  382.                                       featureSize,
  383.                                       neuronList,
  384.                                       neighbourIdList);
  385.     }

  386.     /**
  387.      * Serialization.
  388.      */
  389.     private static class SerializationProxy implements Serializable {
  390.         /** Serializable. */
  391.         private static final long serialVersionUID = 20130207L;
  392.         /** Next identifier. */
  393.         private final long nextId;
  394.         /** Number of features. */
  395.         private final int featureSize;
  396.         /** Neurons. */
  397.         private final Neuron[] neuronList;
  398.         /** Links. */
  399.         private final long[][] neighbourIdList;

  400.         /**
  401.          * @param nextId Next available identifier.
  402.          * @param featureSize Number of features.
  403.          * @param neuronList Neurons.
  404.          * @param neighbourIdList Links associated to each of the neurons in
  405.          * {@code neuronList}.
  406.          */
  407.         SerializationProxy(long nextId,
  408.                            int featureSize,
  409.                            Neuron[] neuronList,
  410.                            long[][] neighbourIdList) {
  411.             this.nextId = nextId;
  412.             this.featureSize = featureSize;
  413.             this.neuronList = neuronList;
  414.             this.neighbourIdList = neighbourIdList;
  415.         }

  416.         /**
  417.          * Custom serialization.
  418.          *
  419.          * @return the {@link Network} for which this instance is the proxy.
  420.          */
  421.         private Object readResolve() {
  422.             return new Network(nextId,
  423.                                featureSize,
  424.                                neuronList,
  425.                                neighbourIdList);
  426.         }
  427.     }
  428. }