Batch Normalization
Stabilizing Deep Network Training
What is Batch Normalization?
Batch Normalization (BatchNorm) normalizes the inputs to each layer during training, making the network more stable and faster to train.
The Problem
Deep networks suffer from internal covariate shift - the distribution of layer inputs changes during training, making learning slow and unstable.
How BatchNorm Works
For a mini-batch of size $m$, BatchNorm normalizes activations:
1. Compute mean and variance:
$$\mu_B = \frac{1}{m}\sum_{i=1}^{m} x_i \quad \sigma_B^2 = \frac{1}{m}\sum_{i=1}^{m}(x_i - \mu_B)^2$$
2. Normalize:
$$\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}$$
3. Scale and shift:
$$y_i = \gamma \hat{x}_i + \beta$$
$\gamma$ and $\beta$ are learnable parameters that allow the network to undo normalization if beneficial.
Key Benefits
- Faster training: Can use higher learning rates
- More stable: Reduces internal covariate shift
- Regularization: Acts as a regularizer, reducing overfitting
- Less sensitive: Less dependent on initialization
Training vs Inference
| During Training | During Inference |
|---|---|
| Uses batch statistics ($\mu_B$, $\sigma_B^2$) | Uses running statistics (population mean/variance) |
| Stochastic, changes per batch | Deterministic, fixed |
Where to Use It
- After dense layers in MLPs
- After convolutional layers in CNNs
- Before activation functions (or after - debated)
- Before the final output layer
Variants
- Layer Normalization: Normalizes across features instead of batch
- Instance Normalization: For style transfer and GANs
- Group Normalization: Hybrid of batch and layer norm