Designing a Neural Network in Java From a Programmer's Perspective
Learn an approach to programming a neural network using Java in a simple and understandable way so that the code can be reused for various experiments.
Join the DZone community and get the full member experience.
Join For FreeAccording to Wikipedia:
Artificial neural networks (ANNs) or connectionist systems are computing systems inspired by the biological neural networks that constitute animal brains. Such systems learn (progressively improve performance) to do tasks by considering examples, generally without task-specific programming.
Designing a neural network in Java or any other programming language requires an understanding of the structure and functionality of artificial neural networks.
Artificial neural networks perform tasks such as pattern recognition, learning from data, and forecasting trends just like an expert can, as opposed to the conventional algorithmic approach that requires a set of steps to be performed to achieve the defined goal. An ANN can learn how to solve some task by itself because of its highly interconnected network structure.
The artificial neuron has a similar structure to the neurons of the human brain. A natural neuron is composed of a nucleus, dendrites, and axons. The axon extends itself into several branches to form synapses with other neurons' dendrites.
So far, we have identified the structure of a neuron and the network of connected neurons. Another important aspect is the processing or calculations associated with a neural network respectively with a single neuron. Natural neurons are signal processors — they receive micro signals in the dendrites that can trigger a signal in the axon. There is a threshold potential that, when reached, fires the axon and propagates the signal to the other neurons. Therefore, we can then think of an artificial neuron as having a signal collector in the inputs and an activation unit in the output that can trigger a signal that will be forwarded to other neurons similar to the one shown on the picture:
Furthermore, the connections between neurons have associated weights that can modify the signals, thus influencing on the neuron's output. Since the weights are internal to the neural network and influence its outputs, they can be considered to be the neural network's internal knowledge. Adjusting weights that characterize the neuron's connections to other neurons or to the external world will reflect the neural networks capabilities.
As stated by Bioinfo Publications:
The artificial neuron receives one or more inputs (representing dendrites) and sums them up to produce an output/ activation (representing a neuron's axon). Usually, the sums of each node are weighted and the sum is passed through an activation function or transfer function.
This component adds nonlinearity to neural network processing, which is needed because the natural neuron has nonlinear behaviors. In some special cases, it can be a linear function.
A standard computer chip circuit can be seen as a digital network of activation functions that can be "ON" (1) or "OFF" (0), depending on input. This is similar to the behavior of the linear perceptron in neural networks. However, it is the nonlinear activation function that allows such networks to compute nontrivial problems using only a small number of nodes. Examples of popular activation functions used are Sigmoid, hyperbolic tangent, hard limiting threshold, and purely linear.
Translating this knowledge into Java code, we will have a neuron class as follows:
import java.util.ArrayList;
import java.util.List;
import edu.neuralnet.core.activation.ActivationFunction;
import edu.neuralnet.core.input.InputSummingFunction;
/**
* Represents a neuron model comprised of: </br>
* <ul>
* <li>Summing part - input summing function</li>
* <li>Activation function</li>
* <li>Input connections</li>
* <li>Output connections</li>
* </ul>
*/
public class Neuron {
/**
* Neuron's identifier
*/
private String id;
/**
* Collection of neuron's input connections (connections to this neuron)
*/
protected List < Connection > inputConnections;
/**
* Collection of neuron's output connections (connections from this to other
* neurons)
*/
protected List < Connection > outputConnections;
/**
* Input summing function for this neuron
*/
protected InputSummingFunction inputSummingFunction;
/**
* Activation function for this neuron
*/
protected ActivationFunction activationFunction;
/**
* Default constructor
*/
public Neuron() {
this.inputConnections = new ArrayList < > ();
this.outputConnections = new ArrayList < > ();
}
/**
* Calculates the neuron's output
*/
public double calculateOutput() {
double totalInput = inputSummingFunction.getOutput(inputConnections);
return activationFunction.getOutput(totalInput);
}
...
}
The neuron has input and output connections, input the sum and activation function, but where are the input weights? They are contained in the connection itself as follows:
/**
* Represents a connection between two neurons an the associated weight.
*/
public class NeuronsConnection {
/**
* From neuron for this connection (source neuron). This connection is
* output connection for from neuron.
*/
protected Neuron fromNeuron;
/**
* To neuron for this connection (target, destination neuron) This
* connection is input connection for to neuron.
*/
protected Neuron toNeuron;
/**
* Connection weight
*/
protected double weight;
/**
* Creates a new connection between specified neurons with random weight.
*
* @param fromNeuron
* neuron to connect from
* @param toNeuron
* neuron to connect to
*/
public NeuronsConnection(Neuron fromNeuron, Neuron toNeuron) {
this.fromNeuron = fromNeuron;
this.toNeuron = toNeuron;
this.weight = Math.random();
}
/**
* Creates a new connection to specified neuron with specified weight object
*
* @param fromNeuron
* neuron to connect from
* @param toNeuron
* neuron to connect to
* @param weight
* weight for this connection
*/
public NeuronsConnection(Neuron fromNeuron, Neuron toNeuron, double weight) {
this(fromNeuron, toNeuron);
this.weight = weight;
}
/**
* Returns weight for this connection
*
* @return weight for this connection
*/
public double getWeight() {
return weight;
}
/**
* Set the weight of the connection.
*
* @param weight
* The new weight of the connection to be set
*/
public void setWeight(double weight) {
this.weight = weight;
}
/**
* Returns input of this connection - the activation function result
* calculated in the input neuron of this connection.
*
* @return input received through this connection
*/
public double getInput() {
return fromNeuron.calculateOutput();
}
/**
* Returns the weighted input of this connection
*
* @return weighted input of the connection
*/
public double getWeightedInput() {
return fromNeuron.calculateOutput() * weight;
}
/**
* Gets from neuron for this connection
*
* @return from neuron for this connection
*/
public Neuron getFromNeuron() {
return fromNeuron;
}
/**
* Gets to neuron for this connection
*
* @return neuron to set as to neuron
*/
public Neuron getToNeuron() {
return toNeuron;
}
...
}
The connection object provides the weights and is responsible for calculating weighted inputs.
The input sum and activation functions are defined as interfaces in order to be able to replace the calculation strategies of the neuron:
import java.util.List;
import edu.neuralnet.core.Connection;
/**
* Represents the inputs summing part of a neuron also called signal collector.
*/
public interface InputSummingFunction {
/**
* Performs calculations based on the output values of the input neurons.
*
* @param inputConnections
* neuron's input connections
* @return total input for the neuron having the input connections
*/
double collectOutput(List<Connection> inputConnections);
}
And respectively an implementation:
import java.util.List;
import edu.neuralnet.core.Connection;
/**
* Calculates the weighted sums of the input neurons' outputs.
*/
public final class WeightedSumFunction implements InputSummingFunction {
/**
* {@inheritDoc}
*/
@Override
public double collectOutput(List<Connection> inputConnections) {
double weightedSum = 0d;
for (Connection connection : inputConnections) {
weightedSum += connection.getWeightedInput();
}
return weightedSum;
}
}
For the activation function, the interface can be defined as follows:
/**
* Neural networks activation function interface.
*/
public interface ActivationFunction {
/**
* Performs calculation based on the sum of input neurons output.
*
* @param summedInput
* neuron's sum of outputs respectively inputs for the connected
* neuron
*
* @return Output's calculation based on the sum of inputs
*/
double calculateOutput(double summedInput);
}
The last aspect that needs attention before starting to write code is neural network layers. Neural networks can be composed of several linked layers, forming the so-called multilayer networks. The neural layers can be divided into three classes:
Input layer
Hidden layer
Output layer
In practice, an additional neural layer adds another level of abstraction of the outside stimuli, enhancing the neural network's capacity to represent more complex knowledge.
A layer class can be defined as a list of neurons having their connections:
import java.util.ArrayList;
import java.util.List;
/**
* Neural networks can be composed of several linked layers, forming the
* so-called multilayer networks. A layer can be defined as a set of neurons
* comprising a single neural net's layer.
*/
public class NeuralNetLayer {
/**
* Layer's identifier
*/
private String id;
/**
* Collection of neurons in this layer
*/
protected List<Neuron> neurons;
/**
* Creates an empty layer with an id.
* @param id
* layer's identifier
*/
public NeuralNetLayer(String id) {
this.id = id;
neurons = new ArrayList<>();
}
/**
* Creates a layer with a list of neurons and an id.
*
* @param id
* layer's identifier
* @param neurons
* list of neurons to be added to the layer
*/
public NeuralNetLayer(String id, List<Neuron> neurons) {
this.id = id;
this.neurons = neurons;
}
...
}
And finally, a simple neural net created in Java with layers of neurons:
/**
* Represents an artificial neural network with layers containing neurons.
*/
public class NeuralNet {
/**
* Neural network id
*/
private String id;
/**
* Neural network input layer
*/
private NeuralNetLayer inputLayer;
/**
* Neural network hidden layers
*/
private List<NeuralNetLayer> hiddenLayers;
/**
* Neural network output layer
*/
private NeuralNetLayer outputLayer;
/**
* Constructs a neural net with all layers present.
*
* @param id
* Neural network id to be set
* @param inputLayer
* Neural network input layer to be set
* @param hiddenLayers
* Neural network hidden layers to be set
* @param outputLayer
* Neural network output layer to be set
*/
public NeuralNet(String id, NeuralNetLayer inputLayer, List<NeuralNetLayer> hiddenLayers,
NeuralNetLayer outputLayer) {
this.id = id;
this.inputLayer = inputLayer;
this.hiddenLayers = hiddenLayers;
this.outputLayer = outputLayer;
}
/**
* Constructs a neural net without hidden layers.
*
* @param id
* Neural network id to be set
* @param inputLayer
* Neural network input layer to be set
* @param outputLayer
* Neural network output layer to be set
*/
public NeuralNet(String id, NeuralNetLayer inputLayer, NeuralNetLayer outputLayer) {
this.id = id;
this.inputLayer = inputLayer;
this.outputLayer = outputLayer;
}
...
}
What we achieved is a structural definition of a Java-based neural network with layers, neurons, and connections. We also talked a bit about the activation functions and defined an interface for them. For simplicity, we omitted the implementation of various activation functions and the basics of learning neural networks. These two topics will be presented in subsequent articles of this series.
Opinions expressed by DZone contributors are their own.
Comments