This is a guide to implementing a neural network and training with gradient descent.
Background
We will start with an example identifying handwritten digits using the MNIST dataset. This problem has 10 classes (digits ). Each image in the database is pixels, meaning each input is a linearlized vector. The output of the CNN will be a vector, where each feature represents the probability of the input being a digit .
For now, let’s suppose our network has one fully-connected layer with 10 neurons, one neuron per class or digit. Each neuron has a weight for each input and a bias term.
A neuron is a single node within a neural network layer. Neurons, also known as perceptrons, take in and process a set of data and output a set of data. More information on the neuron (perceptron) can be found here.
Because our first layer is a fully-connected layer, each neuron in this layer takes in all the input data, making 784 connections and weights. More on different neural network layers here.
Definitions
Let’s define variables for our network:
- is the input data, a vector.
- are the input data labels, a vector.
- is the probability output of the network, a vector. Each value in corresponds to a value in .
- are the weights per neuron, a matrix per neuron. For all 10 neurons, is a matrix.
- are the bias of a neuron. For all 10 neurons, is a vector.
- is the output of the fully-connected layer, a vector. stands for logits.
Now we can define an output per neuron in the fully connected layer:
where is the index of the current neuron.
We can turn our logits into probabilities for each class:
This is simply the exponential of the current logit divided by the sum of all logit exponentials. This is known as the softmax function.
The softmax function is guarenteed to output a probability distribution (), and is popular for determining the best class in a classification problem for convolutional neural networks.
To train our model, we want to define a loss function for the difference between and , our predicted and actual values:
where if the input is class , and otherwise.
This is known as the cross-entropy loss function. We can use this loss function to compute error, which is defined as .
Cross-entropy loss measures how well the predicted probability distribution matches the actual distribution. This loss function minimizes the amounts of information needed to represent the truth distribution versus our predicted distribution. When the amount of information needed is similar, the loss is low and both distributions are similar. There are other loss functions available, but cross-entropy most is popular for classification problems.
Training
To train our model, we calculate the loss function for every different training example, also known as an epoch. We repeat this for many epochs until hte loss over all training examples is minimized.
A training example is a single input and output pair. This is used to update the weights and biases of the network. An epoch is a single pass through the entire dataset. This is used to update the weights and biases of the network.
The most common strategy for minimizing the loss function is gradient descent. For each training example, we will use backpropagation to update weights and biases via a learning rate.
Gradient descent is an optimization algorithm to minimize a function by iteratively moving in the direction of steepest descent. Steepest descents are calculated by the gradient of the function at the current point.
The learning rate is one of the neural network’s hyperparameters. It determines how far each step of gradient descent should go.
Backpropagation is a method to calculate the gradient of the loss function with respect to the weights and biases of the network. Backpropagation is used with gradient descent to update the weights and biases.
The weights and biases are updated as follows:
where is the scalar learning rate.
In order to calculate the partial derivatives, we need to deduce and in terms of and .
The derivatives are as follows:
We skip much of the calculation here, but the derivatives are derived from the chain rule, using backpropagation. More extensive derivation walkthroughs can be found here.