Batch Normalisation Explanation

To help readers understand Batch Normalization, this article is going to answer the following questions.

  • What advantages are brought by Batch Normalization

  • What does Batch Normalization do

  • How to conduct Batch Normalization during training

The biggest benefit brought by Batch Normalization is mitigating gradient vanishing and overfitting. As the neural network goes deep, data distribution could easily fall into the nonlinear region of the activation function. For example, if we pick up sigmoid as the activation function, which is usually the case for binary classification problems, falling into the nonlinear region means that the corresponding gradient from this part is 0 and because of the chain rule used in backpropagation, the final gradient computed for this data point is 0, as a result, the parameters are not going to be updated, which is so-called gradient vanishing and this makes the model converge slowly. The most straightforward way to resolve this is to make sure that the input data to the activation function always fall into the linear region. And this is EXACTLY what Batch Normalisation does!

Batch Normalisation first computes the mean and standard deviation for the mini-batch. And then normalize the input data distribution to the activation function with the mean and standard deviation so that the final distribution is a gaussian distribution with the mean value equals to 0 and the standard deviation equals to 1.

Why does this stop gradient vanishing from happening? Let's recall what the sigmoid function looks like.

After normalisation, most of the data points are around 0. And for the sigmoid function, if the input value is around 0 then it is in the linear region and we could have a good gradient for parameters update.

Then readers may have another question, if we force activation function input in the linear region for every layer, then no matter how deep the neural network is, it will collapse into one single layer, which makes the model expressiveness decrease a lot and we are losing the advantage of having a deep neural network. Thus, to address this and keep the non-linearity we want in the model, here comes the second step in Batch Normalisation. After normalisation, scale and shift operations are added on top of the distribution with zero mean and one as standard deviation (y = scale * x + shift). As a result, each neuron adds two parameters for scaling and shifting, which are learned through training. This means moving the normalised data distribution from the standard gaussian distribution to the left or right by shifting or stretching the original distribution a little bit through scaling, which is equivalent to moving the activation layer input value from the linear region towards the nonlinear region a little bit. The other benefit that comes along with scale and shift operations is helping address overfitting. Since randomness or noise is introduced to the input data by scaling and shifting, which means that the original data distribution is invisible to the model and thus model will not overfit it.

To summarize, let's use an illustration to show how Batch Normalisation is conducted during training.

  • First, we collect output values after the linear operations: y=Wx+b, where x is the output value from the last layer's activation function.

  • Second, compute the mean and standard deviation for the current batch on every neuron.

  • Third, we normalise y from the first step and then do scale and shift adjustments on top. And let's call the output value from the third step z.

  • z is fed to the activation function of the current layer.