001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.ml.neuralnet; 019 020import java.io.Serializable; 021import java.io.ObjectInputStream; 022import java.util.NoSuchElementException; 023import java.util.List; 024import java.util.ArrayList; 025import java.util.Set; 026import java.util.HashSet; 027import java.util.Collection; 028import java.util.Iterator; 029import java.util.Comparator; 030import java.util.Collections; 031import java.util.concurrent.ConcurrentHashMap; 032import java.util.concurrent.atomic.AtomicLong; 033import org.apache.commons.math3.exception.DimensionMismatchException; 034import org.apache.commons.math3.exception.MathIllegalStateException; 035 036/** 037 * Neural network, composed of {@link Neuron} instances and the links 038 * between them. 039 * 040 * Although updating a neuron's state is thread-safe, modifying the 041 * network's topology (adding or removing links) is not. 042 * 043 * @since 3.3 044 */ 045public class Network 046 implements Iterable<Neuron>, 047 Serializable { 048 /** Serializable. */ 049 private static final long serialVersionUID = 20130207L; 050 /** Neurons. */ 051 private final ConcurrentHashMap<Long, Neuron> neuronMap 052 = new ConcurrentHashMap<Long, Neuron>(); 053 /** Next available neuron identifier. */ 054 private final AtomicLong nextId; 055 /** Neuron's features set size. */ 056 private final int featureSize; 057 /** Links. */ 058 private final ConcurrentHashMap<Long, Set<Long>> linkMap 059 = new ConcurrentHashMap<Long, Set<Long>>(); 060 061 /** 062 * Comparator that prescribes an order of the neurons according 063 * to the increasing order of their identifier. 064 */ 065 public static class NeuronIdentifierComparator 066 implements Comparator<Neuron>, 067 Serializable { 068 /** Version identifier. */ 069 private static final long serialVersionUID = 20130207L; 070 071 /** {@inheritDoc} */ 072 public int compare(Neuron a, 073 Neuron b) { 074 final long aId = a.getIdentifier(); 075 final long bId = b.getIdentifier(); 076 return aId < bId ? -1 : 077 aId > bId ? 1 : 0; 078 } 079 } 080 081 /** 082 * Constructor with restricted access, solely used for deserialization. 083 * 084 * @param nextId Next available identifier. 085 * @param featureSize Number of features. 086 * @param neuronList Neurons. 087 * @param neighbourIdList Links associated to each of the neurons in 088 * {@code neuronList}. 089 * @throws MathIllegalStateException if an inconsistency is detected 090 * (which probably means that the serialized form has been corrupted). 091 */ 092 Network(long nextId, 093 int featureSize, 094 Neuron[] neuronList, 095 long[][] neighbourIdList) { 096 final int numNeurons = neuronList.length; 097 if (numNeurons != neighbourIdList.length) { 098 throw new MathIllegalStateException(); 099 } 100 101 for (int i = 0; i < numNeurons; i++) { 102 final Neuron n = neuronList[i]; 103 final long id = n.getIdentifier(); 104 if (id >= nextId) { 105 throw new MathIllegalStateException(); 106 } 107 neuronMap.put(id, n); 108 linkMap.put(id, new HashSet<Long>()); 109 } 110 111 for (int i = 0; i < numNeurons; i++) { 112 final long aId = neuronList[i].getIdentifier(); 113 final Set<Long> aLinks = linkMap.get(aId); 114 for (Long bId : neighbourIdList[i]) { 115 if (neuronMap.get(bId) == null) { 116 throw new MathIllegalStateException(); 117 } 118 addLinkToLinkSet(aLinks, bId); 119 } 120 } 121 122 this.nextId = new AtomicLong(nextId); 123 this.featureSize = featureSize; 124 } 125 126 /** 127 * @param initialIdentifier Identifier for the first neuron that 128 * will be added to this network. 129 * @param featureSize Size of the neuron's features. 130 */ 131 public Network(long initialIdentifier, 132 int featureSize) { 133 nextId = new AtomicLong(initialIdentifier); 134 this.featureSize = featureSize; 135 } 136 137 /** 138 * {@inheritDoc} 139 */ 140 public Iterator<Neuron> iterator() { 141 return neuronMap.values().iterator(); 142 } 143 144 /** 145 * Creates a list of the neurons, sorted in a custom order. 146 * 147 * @param comparator {@link Comparator} used for sorting the neurons. 148 * @return a list of neurons, sorted in the order prescribed by the 149 * given {@code comparator}. 150 * @see NeuronIdentifierComparator 151 */ 152 public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) { 153 final List<Neuron> neurons = new ArrayList<Neuron>(); 154 neurons.addAll(neuronMap.values()); 155 156 Collections.sort(neurons, comparator); 157 158 return neurons; 159 } 160 161 /** 162 * Creates a neuron and assigns it a unique identifier. 163 * 164 * @param features Initial values for the neuron's features. 165 * @return the neuron's identifier. 166 * @throws DimensionMismatchException if the length of {@code features} 167 * is different from the expected size (as set by the 168 * {@link #Network(long,int) constructor}). 169 */ 170 public long createNeuron(double[] features) { 171 if (features.length != featureSize) { 172 throw new DimensionMismatchException(features.length, featureSize); 173 } 174 175 final long id = createNextId(); 176 neuronMap.put(id, new Neuron(id, features)); 177 linkMap.put(id, new HashSet<Long>()); 178 return id; 179 } 180 181 /** 182 * Deletes a neuron. 183 * Links from all neighbours to the removed neuron will also be 184 * {@link #deleteLink(Neuron,Neuron) deleted}. 185 * 186 * @param neuron Neuron to be removed from this network. 187 * @throws NoSuchElementException if {@code n} does not belong to 188 * this network. 189 */ 190 public void deleteNeuron(Neuron neuron) { 191 final Collection<Neuron> neighbours = getNeighbours(neuron); 192 193 // Delete links to from neighbours. 194 for (Neuron n : neighbours) { 195 deleteLink(n, neuron); 196 } 197 198 // Remove neuron. 199 neuronMap.remove(neuron.getIdentifier()); 200 } 201 202 /** 203 * Gets the size of the neurons' features set. 204 * 205 * @return the size of the features set. 206 */ 207 public int getFeaturesSize() { 208 return featureSize; 209 } 210 211 /** 212 * Adds a link from neuron {@code a} to neuron {@code b}. 213 * Note: the link is not bi-directional; if a bi-directional link is 214 * required, an additional call must be made with {@code a} and 215 * {@code b} exchanged in the argument list. 216 * 217 * @param a Neuron. 218 * @param b Neuron. 219 * @throws NoSuchElementException if the neurons do not exist in the 220 * network. 221 */ 222 public void addLink(Neuron a, 223 Neuron b) { 224 final long aId = a.getIdentifier(); 225 final long bId = b.getIdentifier(); 226 227 // Check that the neurons belong to this network. 228 if (a != getNeuron(aId)) { 229 throw new NoSuchElementException(Long.toString(aId)); 230 } 231 if (b != getNeuron(bId)) { 232 throw new NoSuchElementException(Long.toString(bId)); 233 } 234 235 // Add link from "a" to "b". 236 addLinkToLinkSet(linkMap.get(aId), bId); 237 } 238 239 /** 240 * Adds a link to neuron {@code id} in given {@code linkSet}. 241 * Note: no check verifies that the identifier indeed belongs 242 * to this network. 243 * 244 * @param linkSet Neuron identifier. 245 * @param id Neuron identifier. 246 */ 247 private void addLinkToLinkSet(Set<Long> linkSet, 248 long id) { 249 linkSet.add(id); 250 } 251 252 /** 253 * Deletes the link between neurons {@code a} and {@code b}. 254 * 255 * @param a Neuron. 256 * @param b Neuron. 257 * @throws NoSuchElementException if the neurons do not exist in the 258 * network. 259 */ 260 public void deleteLink(Neuron a, 261 Neuron b) { 262 final long aId = a.getIdentifier(); 263 final long bId = b.getIdentifier(); 264 265 // Check that the neurons belong to this network. 266 if (a != getNeuron(aId)) { 267 throw new NoSuchElementException(Long.toString(aId)); 268 } 269 if (b != getNeuron(bId)) { 270 throw new NoSuchElementException(Long.toString(bId)); 271 } 272 273 // Delete link from "a" to "b". 274 deleteLinkFromLinkSet(linkMap.get(aId), bId); 275 } 276 277 /** 278 * Deletes a link to neuron {@code id} in given {@code linkSet}. 279 * Note: no check verifies that the identifier indeed belongs 280 * to this network. 281 * 282 * @param linkSet Neuron identifier. 283 * @param id Neuron identifier. 284 */ 285 private void deleteLinkFromLinkSet(Set<Long> linkSet, 286 long id) { 287 linkSet.remove(id); 288 } 289 290 /** 291 * Retrieves the neuron with the given (unique) {@code id}. 292 * 293 * @param id Identifier. 294 * @return the neuron associated with the given {@code id}. 295 * @throws NoSuchElementException if the neuron does not exist in the 296 * network. 297 */ 298 public Neuron getNeuron(long id) { 299 final Neuron n = neuronMap.get(id); 300 if (n == null) { 301 throw new NoSuchElementException(Long.toString(id)); 302 } 303 return n; 304 } 305 306 /** 307 * Retrieves the neurons in the neighbourhood of any neuron in the 308 * {@code neurons} list. 309 * @param neurons Neurons for which to retrieve the neighbours. 310 * @return the list of neighbours. 311 * @see #getNeighbours(Iterable,Iterable) 312 */ 313 public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) { 314 return getNeighbours(neurons, null); 315 } 316 317 /** 318 * Retrieves the neurons in the neighbourhood of any neuron in the 319 * {@code neurons} list. 320 * The {@code exclude} list allows to retrieve the "concentric" 321 * neighbourhoods by removing the neurons that belong to the inner 322 * "circles". 323 * 324 * @param neurons Neurons for which to retrieve the neighbours. 325 * @param exclude Neurons to exclude from the returned list. 326 * Can be {@code null}. 327 * @return the list of neighbours. 328 */ 329 public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons, 330 Iterable<Neuron> exclude) { 331 final Set<Long> idList = new HashSet<Long>(); 332 333 for (Neuron n : neurons) { 334 idList.addAll(linkMap.get(n.getIdentifier())); 335 } 336 if (exclude != null) { 337 for (Neuron n : exclude) { 338 idList.remove(n.getIdentifier()); 339 } 340 } 341 342 final List<Neuron> neuronList = new ArrayList<Neuron>(); 343 for (Long id : idList) { 344 neuronList.add(getNeuron(id)); 345 } 346 347 return neuronList; 348 } 349 350 /** 351 * Retrieves the neighbours of the given neuron. 352 * 353 * @param neuron Neuron for which to retrieve the neighbours. 354 * @return the list of neighbours. 355 * @see #getNeighbours(Neuron,Iterable) 356 */ 357 public Collection<Neuron> getNeighbours(Neuron neuron) { 358 return getNeighbours(neuron, null); 359 } 360 361 /** 362 * Retrieves the neighbours of the given neuron. 363 * 364 * @param neuron Neuron for which to retrieve the neighbours. 365 * @param exclude Neurons to exclude from the returned list. 366 * Can be {@code null}. 367 * @return the list of neighbours. 368 */ 369 public Collection<Neuron> getNeighbours(Neuron neuron, 370 Iterable<Neuron> exclude) { 371 final Set<Long> idList = linkMap.get(neuron.getIdentifier()); 372 if (exclude != null) { 373 for (Neuron n : exclude) { 374 idList.remove(n.getIdentifier()); 375 } 376 } 377 378 final List<Neuron> neuronList = new ArrayList<Neuron>(); 379 for (Long id : idList) { 380 neuronList.add(getNeuron(id)); 381 } 382 383 return neuronList; 384 } 385 386 /** 387 * Creates a neuron identifier. 388 * 389 * @return a value that will serve as a unique identifier. 390 */ 391 private Long createNextId() { 392 return nextId.getAndIncrement(); 393 } 394 395 /** 396 * Prevents proxy bypass. 397 * 398 * @param in Input stream. 399 */ 400 private void readObject(ObjectInputStream in) { 401 throw new IllegalStateException(); 402 } 403 404 /** 405 * Custom serialization. 406 * 407 * @return the proxy instance that will be actually serialized. 408 */ 409 private Object writeReplace() { 410 final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]); 411 final long[][] neighbourIdList = new long[neuronList.length][]; 412 413 for (int i = 0; i < neuronList.length; i++) { 414 final Collection<Neuron> neighbours = getNeighbours(neuronList[i]); 415 final long[] neighboursId = new long[neighbours.size()]; 416 int count = 0; 417 for (Neuron n : neighbours) { 418 neighboursId[count] = n.getIdentifier(); 419 ++count; 420 } 421 neighbourIdList[i] = neighboursId; 422 } 423 424 return new SerializationProxy(nextId.get(), 425 featureSize, 426 neuronList, 427 neighbourIdList); 428 } 429 430 /** 431 * Serialization. 432 */ 433 private static class SerializationProxy implements Serializable { 434 /** Serializable. */ 435 private static final long serialVersionUID = 20130207L; 436 /** Next identifier. */ 437 private final long nextId; 438 /** Number of features. */ 439 private final int featureSize; 440 /** Neurons. */ 441 private final Neuron[] neuronList; 442 /** Links. */ 443 private final long[][] neighbourIdList; 444 445 /** 446 * @param nextId Next available identifier. 447 * @param featureSize Number of features. 448 * @param neuronList Neurons. 449 * @param neighbourIdList Links associated to each of the neurons in 450 * {@code neuronList}. 451 */ 452 SerializationProxy(long nextId, 453 int featureSize, 454 Neuron[] neuronList, 455 long[][] neighbourIdList) { 456 this.nextId = nextId; 457 this.featureSize = featureSize; 458 this.neuronList = neuronList; 459 this.neighbourIdList = neighbourIdList; 460 } 461 462 /** 463 * Custom serialization. 464 * 465 * @return the {@link Network} for which this instance is the proxy. 466 */ 467 private Object readResolve() { 468 return new Network(nextId, 469 featureSize, 470 neuronList, 471 neighbourIdList); 472 } 473 } 474}