Network.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.apache.commons.math3.ml.neuralnet;
- import java.io.Serializable;
- import java.io.ObjectInputStream;
- import java.util.NoSuchElementException;
- import java.util.List;
- import java.util.ArrayList;
- import java.util.Set;
- import java.util.HashSet;
- import java.util.Collection;
- import java.util.Iterator;
- import java.util.Comparator;
- import java.util.Collections;
- import java.util.concurrent.ConcurrentHashMap;
- import java.util.concurrent.atomic.AtomicLong;
- import org.apache.commons.math3.exception.DimensionMismatchException;
- import org.apache.commons.math3.exception.MathIllegalStateException;
- /**
- * Neural network, composed of {@link Neuron} instances and the links
- * between them.
- *
- * Although updating a neuron's state is thread-safe, modifying the
- * network's topology (adding or removing links) is not.
- *
- * @since 3.3
- */
- public class Network
- implements Iterable<Neuron>,
- Serializable {
- /** Serializable. */
- private static final long serialVersionUID = 20130207L;
- /** Neurons. */
- private final ConcurrentHashMap<Long, Neuron> neuronMap
- = new ConcurrentHashMap<Long, Neuron>();
- /** Next available neuron identifier. */
- private final AtomicLong nextId;
- /** Neuron's features set size. */
- private final int featureSize;
- /** Links. */
- private final ConcurrentHashMap<Long, Set<Long>> linkMap
- = new ConcurrentHashMap<Long, Set<Long>>();
- /**
- * Comparator that prescribes an order of the neurons according
- * to the increasing order of their identifier.
- */
- public static class NeuronIdentifierComparator
- implements Comparator<Neuron>,
- Serializable {
- /** Version identifier. */
- private static final long serialVersionUID = 20130207L;
- /** {@inheritDoc} */
- public int compare(Neuron a,
- Neuron b) {
- final long aId = a.getIdentifier();
- final long bId = b.getIdentifier();
- return aId < bId ? -1 :
- aId > bId ? 1 : 0;
- }
- }
- /**
- * Constructor with restricted access, solely used for deserialization.
- *
- * @param nextId Next available identifier.
- * @param featureSize Number of features.
- * @param neuronList Neurons.
- * @param neighbourIdList Links associated to each of the neurons in
- * {@code neuronList}.
- * @throws MathIllegalStateException if an inconsistency is detected
- * (which probably means that the serialized form has been corrupted).
- */
- Network(long nextId,
- int featureSize,
- Neuron[] neuronList,
- long[][] neighbourIdList) {
- final int numNeurons = neuronList.length;
- if (numNeurons != neighbourIdList.length) {
- throw new MathIllegalStateException();
- }
- for (int i = 0; i < numNeurons; i++) {
- final Neuron n = neuronList[i];
- final long id = n.getIdentifier();
- if (id >= nextId) {
- throw new MathIllegalStateException();
- }
- neuronMap.put(id, n);
- linkMap.put(id, new HashSet<Long>());
- }
- for (int i = 0; i < numNeurons; i++) {
- final long aId = neuronList[i].getIdentifier();
- final Set<Long> aLinks = linkMap.get(aId);
- for (Long bId : neighbourIdList[i]) {
- if (neuronMap.get(bId) == null) {
- throw new MathIllegalStateException();
- }
- addLinkToLinkSet(aLinks, bId);
- }
- }
- this.nextId = new AtomicLong(nextId);
- this.featureSize = featureSize;
- }
- /**
- * @param initialIdentifier Identifier for the first neuron that
- * will be added to this network.
- * @param featureSize Size of the neuron's features.
- */
- public Network(long initialIdentifier,
- int featureSize) {
- nextId = new AtomicLong(initialIdentifier);
- this.featureSize = featureSize;
- }
- /**
- * {@inheritDoc}
- */
- public Iterator<Neuron> iterator() {
- return neuronMap.values().iterator();
- }
- /**
- * Creates a list of the neurons, sorted in a custom order.
- *
- * @param comparator {@link Comparator} used for sorting the neurons.
- * @return a list of neurons, sorted in the order prescribed by the
- * given {@code comparator}.
- * @see NeuronIdentifierComparator
- */
- public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) {
- final List<Neuron> neurons = new ArrayList<Neuron>();
- neurons.addAll(neuronMap.values());
- Collections.sort(neurons, comparator);
- return neurons;
- }
- /**
- * Creates a neuron and assigns it a unique identifier.
- *
- * @param features Initial values for the neuron's features.
- * @return the neuron's identifier.
- * @throws DimensionMismatchException if the length of {@code features}
- * is different from the expected size (as set by the
- * {@link #Network(long,int) constructor}).
- */
- public long createNeuron(double[] features) {
- if (features.length != featureSize) {
- throw new DimensionMismatchException(features.length, featureSize);
- }
- final long id = createNextId();
- neuronMap.put(id, new Neuron(id, features));
- linkMap.put(id, new HashSet<Long>());
- return id;
- }
- /**
- * Deletes a neuron.
- * Links from all neighbours to the removed neuron will also be
- * {@link #deleteLink(Neuron,Neuron) deleted}.
- *
- * @param neuron Neuron to be removed from this network.
- * @throws NoSuchElementException if {@code n} does not belong to
- * this network.
- */
- public void deleteNeuron(Neuron neuron) {
- final Collection<Neuron> neighbours = getNeighbours(neuron);
- // Delete links to from neighbours.
- for (Neuron n : neighbours) {
- deleteLink(n, neuron);
- }
- // Remove neuron.
- neuronMap.remove(neuron.getIdentifier());
- }
- /**
- * Gets the size of the neurons' features set.
- *
- * @return the size of the features set.
- */
- public int getFeaturesSize() {
- return featureSize;
- }
- /**
- * Adds a link from neuron {@code a} to neuron {@code b}.
- * Note: the link is not bi-directional; if a bi-directional link is
- * required, an additional call must be made with {@code a} and
- * {@code b} exchanged in the argument list.
- *
- * @param a Neuron.
- * @param b Neuron.
- * @throws NoSuchElementException if the neurons do not exist in the
- * network.
- */
- public void addLink(Neuron a,
- Neuron b) {
- final long aId = a.getIdentifier();
- final long bId = b.getIdentifier();
- // Check that the neurons belong to this network.
- if (a != getNeuron(aId)) {
- throw new NoSuchElementException(Long.toString(aId));
- }
- if (b != getNeuron(bId)) {
- throw new NoSuchElementException(Long.toString(bId));
- }
- // Add link from "a" to "b".
- addLinkToLinkSet(linkMap.get(aId), bId);
- }
- /**
- * Adds a link to neuron {@code id} in given {@code linkSet}.
- * Note: no check verifies that the identifier indeed belongs
- * to this network.
- *
- * @param linkSet Neuron identifier.
- * @param id Neuron identifier.
- */
- private void addLinkToLinkSet(Set<Long> linkSet,
- long id) {
- linkSet.add(id);
- }
- /**
- * Deletes the link between neurons {@code a} and {@code b}.
- *
- * @param a Neuron.
- * @param b Neuron.
- * @throws NoSuchElementException if the neurons do not exist in the
- * network.
- */
- public void deleteLink(Neuron a,
- Neuron b) {
- final long aId = a.getIdentifier();
- final long bId = b.getIdentifier();
- // Check that the neurons belong to this network.
- if (a != getNeuron(aId)) {
- throw new NoSuchElementException(Long.toString(aId));
- }
- if (b != getNeuron(bId)) {
- throw new NoSuchElementException(Long.toString(bId));
- }
- // Delete link from "a" to "b".
- deleteLinkFromLinkSet(linkMap.get(aId), bId);
- }
- /**
- * Deletes a link to neuron {@code id} in given {@code linkSet}.
- * Note: no check verifies that the identifier indeed belongs
- * to this network.
- *
- * @param linkSet Neuron identifier.
- * @param id Neuron identifier.
- */
- private void deleteLinkFromLinkSet(Set<Long> linkSet,
- long id) {
- linkSet.remove(id);
- }
- /**
- * Retrieves the neuron with the given (unique) {@code id}.
- *
- * @param id Identifier.
- * @return the neuron associated with the given {@code id}.
- * @throws NoSuchElementException if the neuron does not exist in the
- * network.
- */
- public Neuron getNeuron(long id) {
- final Neuron n = neuronMap.get(id);
- if (n == null) {
- throw new NoSuchElementException(Long.toString(id));
- }
- return n;
- }
- /**
- * Retrieves the neurons in the neighbourhood of any neuron in the
- * {@code neurons} list.
- * @param neurons Neurons for which to retrieve the neighbours.
- * @return the list of neighbours.
- * @see #getNeighbours(Iterable,Iterable)
- */
- public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
- return getNeighbours(neurons, null);
- }
- /**
- * Retrieves the neurons in the neighbourhood of any neuron in the
- * {@code neurons} list.
- * The {@code exclude} list allows to retrieve the "concentric"
- * neighbourhoods by removing the neurons that belong to the inner
- * "circles".
- *
- * @param neurons Neurons for which to retrieve the neighbours.
- * @param exclude Neurons to exclude from the returned list.
- * Can be {@code null}.
- * @return the list of neighbours.
- */
- public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
- Iterable<Neuron> exclude) {
- final Set<Long> idList = new HashSet<Long>();
- for (Neuron n : neurons) {
- idList.addAll(linkMap.get(n.getIdentifier()));
- }
- if (exclude != null) {
- for (Neuron n : exclude) {
- idList.remove(n.getIdentifier());
- }
- }
- final List<Neuron> neuronList = new ArrayList<Neuron>();
- for (Long id : idList) {
- neuronList.add(getNeuron(id));
- }
- return neuronList;
- }
- /**
- * Retrieves the neighbours of the given neuron.
- *
- * @param neuron Neuron for which to retrieve the neighbours.
- * @return the list of neighbours.
- * @see #getNeighbours(Neuron,Iterable)
- */
- public Collection<Neuron> getNeighbours(Neuron neuron) {
- return getNeighbours(neuron, null);
- }
- /**
- * Retrieves the neighbours of the given neuron.
- *
- * @param neuron Neuron for which to retrieve the neighbours.
- * @param exclude Neurons to exclude from the returned list.
- * Can be {@code null}.
- * @return the list of neighbours.
- */
- public Collection<Neuron> getNeighbours(Neuron neuron,
- Iterable<Neuron> exclude) {
- final Set<Long> idList = linkMap.get(neuron.getIdentifier());
- if (exclude != null) {
- for (Neuron n : exclude) {
- idList.remove(n.getIdentifier());
- }
- }
- final List<Neuron> neuronList = new ArrayList<Neuron>();
- for (Long id : idList) {
- neuronList.add(getNeuron(id));
- }
- return neuronList;
- }
- /**
- * Creates a neuron identifier.
- *
- * @return a value that will serve as a unique identifier.
- */
- private Long createNextId() {
- return nextId.getAndIncrement();
- }
- /**
- * Prevents proxy bypass.
- *
- * @param in Input stream.
- */
- private void readObject(ObjectInputStream in) {
- throw new IllegalStateException();
- }
- /**
- * Custom serialization.
- *
- * @return the proxy instance that will be actually serialized.
- */
- private Object writeReplace() {
- final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]);
- final long[][] neighbourIdList = new long[neuronList.length][];
- for (int i = 0; i < neuronList.length; i++) {
- final Collection<Neuron> neighbours = getNeighbours(neuronList[i]);
- final long[] neighboursId = new long[neighbours.size()];
- int count = 0;
- for (Neuron n : neighbours) {
- neighboursId[count] = n.getIdentifier();
- ++count;
- }
- neighbourIdList[i] = neighboursId;
- }
- return new SerializationProxy(nextId.get(),
- featureSize,
- neuronList,
- neighbourIdList);
- }
- /**
- * Serialization.
- */
- private static class SerializationProxy implements Serializable {
- /** Serializable. */
- private static final long serialVersionUID = 20130207L;
- /** Next identifier. */
- private final long nextId;
- /** Number of features. */
- private final int featureSize;
- /** Neurons. */
- private final Neuron[] neuronList;
- /** Links. */
- private final long[][] neighbourIdList;
- /**
- * @param nextId Next available identifier.
- * @param featureSize Number of features.
- * @param neuronList Neurons.
- * @param neighbourIdList Links associated to each of the neurons in
- * {@code neuronList}.
- */
- SerializationProxy(long nextId,
- int featureSize,
- Neuron[] neuronList,
- long[][] neighbourIdList) {
- this.nextId = nextId;
- this.featureSize = featureSize;
- this.neuronList = neuronList;
- this.neighbourIdList = neighbourIdList;
- }
- /**
- * Custom serialization.
- *
- * @return the {@link Network} for which this instance is the proxy.
- */
- private Object readResolve() {
- return new Network(nextId,
- featureSize,
- neuronList,
- neighbourIdList);
- }
- }
- }