Understanding Backpropagation
The Algorithm that Powers Neural Network Training
1. The Chain Rule: Foundation of Backpropagation
Backpropagation is simply the chain rule from calculus applied to neural networks. If we have a composite function \( y = f(g(x)) \), the chain rule tells us:
In neural networks, we have many layers of composition: \( L = f_n(f_{n-1}(...f_2(f_1(x)))) \). To compute how much each weight contributes to the final loss, we apply the chain rule recursively.
2. Forward Pass: Computing Activations
Let's consider a simple 3-layer network:
$$ z^{(1)} = W^{(1)} x + b^{(1)} $$ $$ a^{(1)} = \sigma(z^{(1)}) $$
Layer 2 (Hidden → Output):
$$ z^{(2)} = W^{(2)} a^{(1)} + b^{(2)} $$ $$ a^{(2)} = \sigma(z^{(2)}) $$
Loss:
$$ L = \frac{1}{2}(a^{(2)} - y)^2 $$
3. Backward Pass: Computing Gradients
We compute gradients from output to input:
$$ \frac{\partial L}{\partial a^{(2)}} = a^{(2)} - y $$
Step 2: Gradient w.r.t pre-activation
$$ \frac{\partial L}{\partial z^{(2)}} = \frac{\partial L}{\partial a^{(2)}} \cdot \sigma'(z^{(2)}) $$
Step 3: Gradient w.r.t weights (Layer 2)
$$ \frac{\partial L}{\partial W^{(2)}} = \frac{\partial L}{\partial z^{(2)}} \cdot (a^{(1)})^T $$
Step 4: Propagate to previous layer
$$ \frac{\partial L}{\partial a^{(1)}} = (W^{(2)})^T \cdot \frac{\partial L}{\partial z^{(2)}} $$
Step 5: Continue backward...
$$ \frac{\partial L}{\partial z^{(1)}} = \frac{\partial L}{\partial a^{(1)}} \cdot \sigma'(z^{(1)}) $$ $$ \frac{\partial L}{\partial W^{(1)}} = \frac{\partial L}{\partial z^{(1)}} \cdot x^T $$
4. Numerical Example
Setup: 2 inputs, 2 hidden neurons, 1 output
Input: \( x = [0.5, 0.3] \)
Target: \( y = 0.8 \)
Weights: \( W^{(1)} = \begin{bmatrix} 0.2 & 0.5 \\ 0.3 & 0.4 \end{bmatrix}, W^{(2)} = \begin{bmatrix} 0.6 & 0.7 \end{bmatrix} \)
Biases: \( b^{(1)} = [0.1, 0.2], b^{(2)} = 0.1 \)
Activation: Sigmoid \( \sigma(z) = \frac{1}{1+e^{-z}} \), Derivative: \( \sigma'(z) = \sigma(z)(1-\sigma(z)) \)
Forward Pass
\( z_1^{(1)} = 0.2(0.5) + 0.5(0.3) + 0.1 = 0.1 + 0.15 + 0.1 = 0.35 \)
\( z_2^{(1)} = 0.3(0.5) + 0.4(0.3) + 0.2 = 0.15 + 0.12 + 0.2 = 0.47 \)
\( a_1^{(1)} = \sigma(0.35) \approx 0.587 \)
\( a_2^{(1)} = \sigma(0.47) \approx 0.615 \)
Layer 2:
\( z^{(2)} = 0.6(0.587) + 0.7(0.615) + 0.1 = 0.352 + 0.431 + 0.1 = 0.883 \)
\( a^{(2)} = \sigma(0.883) \approx 0.708 \)
Loss:
\( L = \frac{1}{2}(0.708 - 0.8)^2 = \frac{1}{2}(-0.092)^2 \approx 0.00423 \)
Backward Pass
\( \frac{\partial L}{\partial a^{(2)}} = 0.708 - 0.8 = -0.092 \)
\( \sigma'(z^{(2)}) = 0.708(1-0.708) \approx 0.207 \)
\( \frac{\partial L}{\partial z^{(2)}} = -0.092 \times 0.207 \approx -0.019 \)
Weights W(2):
\( \frac{\partial L}{\partial W_1^{(2)}} = -0.019 \times 0.587 \approx -0.011 \)
\( \frac{\partial L}{\partial W_2^{(2)}} = -0.019 \times 0.615 \approx -0.012 \)
\( \frac{\partial L}{\partial b^{(2)}} = -0.019 \)
Hidden Layer Gradients:
\( \frac{\partial L}{\partial a_1^{(1)}} = 0.6 \times (-0.019) = -0.0114 \)
\( \frac{\partial L}{\partial a_2^{(1)}} = 0.7 \times (-0.019) = -0.0133 \)
\( \sigma'(z_1^{(1)}) = 0.587(1-0.587) \approx 0.242 \)
\( \sigma'(z_2^{(1)}) = 0.615(1-0.615) \approx 0.237 \)
\( \frac{\partial L}{\partial z_1^{(1)}} = -0.0114 \times 0.242 \approx -0.00276 \)
\( \frac{\partial L}{\partial z_2^{(1)}} = -0.0133 \times 0.237 \approx -0.00315 \)
Input Weights:
\( \frac{\partial L}{\partial W_{11}^{(1)}} = -0.00276 \times 0.5 = -0.00138 \)
\( \frac{\partial L}{\partial W_{12}^{(1)}} = -0.00276 \times 0.3 = -0.00083 \)
(and so on for all weights...)
5. Interactive Computational Graph
Click "Forward Pass" to see activations flow forward, then "Backward Pass" to see gradients propagate backward.
Computation Log
6. Key Insights
- Computing all gradients in one backward pass is \( O(n) \) where n is the number of parameters
- Alternative finite difference would be \( O(n^2) \) - compute loss n times!
- We reuse intermediate computations from the forward pass
- The chain rule allows us to decompose complex derivatives into simple ones