Batch normalization in an improvement on Dropout, and was originally designed to address covariate shift.
When training with mini-batches in Batch Gradient Descent, the idea is to standardize the input values for each batch, by subtracting off the mean and dividing by the standard deviation of each input dimension.
This means that the scale of the inputs to each layer remains the same, no matter how the weights in previous layers change.
- However, this somewhat complicates matters, because the computation of the weight updates will need to take into account that we are performing this transformation.
- In the modular view, batch normalization can be seen as a module that is applied to , interposed after the product with and before input to the activation function .
Batch normalization has a regularization effect similar to Weight Perturbation and Dropout. Each mini-batch of data ends up slightly perturbed, which prevents the network from exploiting very particular values of data points.
Use cases and weaknesses
- Use cases: Particularly effective in CNNs and fully-connected layers.
- Weaknesses: performance can degrade with small batch sizes because the estimate of the mean and variance becomes less accurate. BatchNorm also introduces a dependency between the examples in a mini-batch, which can be problematic for tasks that require strong independence assumptions between samples.
Formulation
Let’s think of the batch-norm layer as taking as input and producing an output as output. But instead of thinking of as an vector, we have to think about handling a mini-batch of data of size all at once, so and output will be .
Our first step is to compute the batchwise mean and standard deviation. Let the mean be an vector where:
And let be an vector where:
The basic normalized version of our data would be a matrix, element of which is
where is a very small constant to guard against division by zero. However, if we let these be our values, we really are forcing something too strong on our data—our goal was to normalize across the data batch, but not necessarily force the output values to have exactly mean and standard deviation . So, we will give the layer the “opportunity” to shift and scale the outputs by adding new weights to the layer. These weights are and , each of which is an vector. Using the weights, we define the final output to be
That’s the forward pass.
For the backward pass, we have to do two things. Given , we have to:
- Compute for backpropagation
- Compute and for gradient updates of the weights in this layer.
Schematically
It’s hard to think about these derivatives in matrix terms, so we’ll see how it works for the components. contributes to for all data points in the batch. So
Similarly, contributes to for all data points in the batch, so that
Now let’s figure out how to do backpropagation. We can start schematically:
Because dependencies exist across the batch, but not across the output units,
The next step is to note that
where if and otherwise.
We need two more small parts:
Putting everything together: