5 CIFAR with BatchNorm
This is structurally the hardest chapter in the book. BN’s inverse-stddev term \(1/\sqrt{\sigma ^2 + \varepsilon }\) has a gradient that blows up as \(\sigma ^2 \to 0\); proving the backward pass exists and is bounded requires ContinuousLinearMap-based real-analysis machinery from Mathlib that none of the previous chapters needed. If you’re new to formal math, skim the proofs and trust them — the takeaways are concrete:
BN’s gradient has a closed-form 3-term formula (Theorem 32).
The formula needs \(\varepsilon {\gt} 0\) to stay bounded (Theorem 31).
BN’s whole purpose is to let deeper networks train — the example in §5.1 demonstrates that directly.
The proofs themselves use Mathlib’s HasFDerivAt.sqrt, (hasDerivAt_inv).comp_hasFDerivAt, and a centering CLM chained through the chain rule from Ch 2. They’re correct (the Lean kernel checks them) and they’re available in Proofs/BatchNorm.lean for the curious; you do not need to understand them line-by-line to use BN as a layer or to follow the rest of the book. Ch 6 (ResNet-34) is the easiest chapter in the book and follows immediately — this is a localized difficulty spike, not the new normal.
What BN actually is
BatchNorm (Ioffe & Szegedy, 2015) takes a batch of activations, \(x\), and does three things in sequence. Each is one line of code; together they are the layer.
Center. Compute the batch mean \(\mu = \frac{1}{n} \sum _k x_k\) and subtract it from every sample: \(x - \mu \). Output has mean zero.
Normalize. Compute the batch variance \(\sigma ^2\), add a small \(\varepsilon \) for numerical safety, and divide: \(\hat{x} = (x - \mu )/\sqrt{\sigma ^2 + \varepsilon }\). Output has unit variance.
Affine. Scale and shift with learnable per-channel parameters \(\gamma \) and \(\beta \): \(\mathrm{bn}(x) = \gamma \hat{x} + \beta \).
The chapter’s first six theorems are the Jacobian of each step and the VJPs that fall out of them. Theorem 34 closes the loop by composing them. The forward is three steps, so the chain rule from Chapter 2 gives us three Jacobians to multiply, and the famous “BN three-term backward” is exactly that product written out.
The centering term has an indirect path
Here is the first place a careful reader can trip. When you jiggle a single input \(x_i\), you do not only jiggle \(x_i\)—you also jiggle the batch mean \(\mu \), and \(\mu \) appears in every sample’s centered value \(x_k - \mu \). So jiggling \(x_i\) by \(\varepsilon \) shifts every centered value by \(-\varepsilon /n\) plus the direct \(\varepsilon \) on the \(i\)th sample.
The \(\delta _{ij}\) is the direct effect; the \(-1/n\) is the indirect effect through \(\mu \). Theorem 29 states this formally. Forgetting the \(-1/n\) is the most common hand-derivation mistake on BN, and is exactly what the formal proof prevents.
The inverse-stddev term is the hard part
The normalize step divides by \(\sqrt{\sigma ^2 + \varepsilon }\), where \(\sigma ^2 = \frac{1}{n} \sum _k (x_k - \mu )^2\) is itself a function of every input. Jiggle \(x_i\) by \(\varepsilon \) and \(\sigma ^2\) changes, which means the divisor changes, which means every output \(\hat{x}_j\) changes, not just \(\hat{x}_i\).
Working through the chain rule (\(x \mapsto x^2 \mapsto \text{mean} \mapsto \sqrt{\cdot + \varepsilon } \mapsto 1/\cdot \)) gives
That \(\mathrm{istd}^3\) is what makes the BN backward expensive: the gradient of one sample depends on every sample’s centered value, scaled by the cube of the inverse standard deviation. Theorem 30 states this.
Notice what happens to the formula as \(\sigma ^2 \to 0\): \(\mathrm{istd} \to 1/\sqrt{\varepsilon }\), bounded; without \(\varepsilon \) the gradient diverges. Theorem 31 is the formal statement that \(\varepsilon {\gt} 0\) is sufficient to make this term differentiable. This is the one place in the book where the math actually requires real-analysis machinery beyond chain-sum-product; everything else in the framework reduces to those three rules.
The famous three-term backward
Compose the three forward steps and apply the product rule on \(\hat{x} = (x - \mu ) \cdot \mathrm{istd}\). The cross-terms collapse (the centered sum \(\sum _k (x_k - \mu ) = 0\) is what saves us) and what falls out is a one-line backward:
Three terms inside the parentheses, one per forward step’s indirect effect:
\(n\, dy\)—the direct effect, every sample’s upstream gradient.
\(-\sum _k dy_k\)—the centering correction, subtracts the total upstream gradient because shifting the mean shifts every sample.
\(-\hat{x} \cdot \sum _k (dy_k\, \hat{x}_k)\)—the normalization correction, subtracts the projection of the upstream gradient onto \(\hat{x}\), because rescaling by the inverse stddev couples every sample’s gradient through the shared divisor.
Theorem 32 formalizes this. Every production BN implementation (PyTorch, JAX, TensorFlow, custom CUDA kernels) computes this exact expression. The value of the formal proof is not that we discovered the formula—it has been known since 2015—but that the Lean kernel mechanically verifies we are computing the right thing at every training step.
The affine step is just dense-with-broadcasting
The third step, \(\gamma \hat{x} + \beta \), is structurally a dense layer applied per-channel. Its Jacobian \(\partial (\gamma v + \beta )/\partial v_i = \gamma \delta _{ij}\) is exactly the dense-Jacobian computation from Chapter 3, just lifted to a tensor shape and broadcast across spatial dimensions. Theorem 28 and Theorem 33 are essentially corollaries of the dense theorems; the new content here is zero.
Putting it together
Theorem 34 is the composition: \(\mathrm{bn} = \mathrm{affine} \circ \mathrm{normalize}\), with VJPs chained via the same \(\mathrm{vjp\_ comp}\) rule from Chapter 2. The 3-term backward is the centerpiece; affine is a corollary; composition is a one-line proof. The structural story of the chapter is: one new Mathlib-level analytic dependency (sqrt and recip differentiability), one famous formula, and the rest is the same chain rule we already had.
The theorems
\(\partial (\gamma v + \beta ) / \partial v_i = \gamma \delta _{ij}\). Derived from foundation rules.
Mechanical; see Proofs.pdiv_bnAffine.
\(\partial (x_j - \mu ) / \partial x_i = \delta _{ij} - 1/n\). Derived from foundation rules.
Mechanical; see Proofs.pdiv_bnCentered.
\(\partial \mathrm{istd}/\partial x_i = -\mathrm{istd}^3 \cdot (x_i - \mu )/n\). Proved via the centering \(\mathrm{ContinuousLinearMap}\), HasFDerivAt.sqrt (under \(\mathrm{bnVar} + \varepsilon {\gt} 0\)), and (hasDerivAt_inv).comp_hasFDerivAt; centered sum collapses by \(\sum _k (x_k - \mu ) = 0\).
Mechanical; see Proofs.pdiv_bnIstdBroadcast.
\(\operatorname {bnIstdBroadcast}\) is \(\mathsf{Differentiable}\) when \(\varepsilon {\gt} 0\). Proved via Differentiable.sqrt and Differentiable.inv over \(\mathrm{bnVar} + \varepsilon {\gt} 0\); captures the sqrt/recip smoothness of \(1/\sqrt{\sigma ^2 + \varepsilon }\) required by pdiv_mul inside pdiv_bnNormalize.
Mechanical; see Proofs.bnIstdBroadcast_diff.
The consolidated 3-term backward: factor \(\mathrm{bnXhat}\) as \((x - \mu ) \cdot \mathrm{istd}\), apply product rule, collapse.
Mechanical; see Proofs.bnNormalize_has_vjp.
Mechanical; see Proofs.bnAffine_has_vjp.
\(\mathrm{bn} = \mathrm{affine} \circ \mathrm{normalize}\).
Mechanical; see Proofs.bn_has_vjp.
5.1 Example: the BN lift on CIFAR
The previous two chapters trained MNIST classifiers with SGD at learning rate 0.1, no regularization, no tricks. That configuration works on MNIST. The moment you swap in CIFAR-10 — color \(32 \times 32\) images, harder learning problem, a real test of whether the training actually does anything — the same config fails completely unless you add BatchNorm.
Here’s the demo. Two CIFAR CNNs, same architecture, same SGD 0.1 training config. One has BN. The other doesn’t. Same s4tfBaseline we’ve been using.
The two specs, differing by one keyword per layer
Without BN:
def cifarCnnNoBn : NetSpec where
name := "CIFAR-CNN-noBN"
imageH := 32
imageW := 32
layers := [
.conv2d 3 32 3 .same .relu,
.conv2d 32 32 3 .same .relu,
.maxPool 2 2,
.conv2d 32 64 3 .same .relu,
.conv2d 64 64 3 .same .relu,
.maxPool 2 2,
.flatten,
.dense 4096 512 .relu,
.dense 512 512 .relu,
.dense 512 10 .identity
]
With BN:
def cifarCnnBn : NetSpec where
name := "CIFAR-CNN-BN"
imageH := 32
imageW := 32
layers := [
.convBn 3 32 3 1 .same,
.convBn 32 32 3 1 .same,
.maxPool 2 2,
.convBn 32 64 3 1 .same,
.convBn 64 64 3 1 .same,
.maxPool 2 2,
.flatten,
.dense 4096 512 .relu,
.dense 512 512 .relu,
.dense 512 10 .identity
]
The diff: four .conv2d \(\cdot \) \(\cdot \) 3 .same .relu layers become four .convBn \(\cdot \) \(\cdot \) 3 1 .same layers. That’s it. The dense head, the max-pool, and the training config are all identical.
Training: one learns, one doesn’t
Both runs use s4tfBaseline (SGD 0.1, 30 epochs). Output captured from logs/ablation_cifar-nobn-sgd.log and logs/ablation_cifar-bn-sgd.log:
# cifar-nobn-sgd — same config, no BN Epoch 1/30: loss=2.318085 lr=0.100000 Epoch 2/30: loss=2.305088 lr=0.100000 Epoch 3/30: loss=2.301315 lr=0.100000 Epoch 4/30: loss=2.304873 lr=0.100000 Epoch 5/30: loss=2.304378 lr=0.100000 Epoch 6/30: loss=2.304410 lr=0.100000 Epoch 7/30: loss=2.304158 lr=0.100000 ... Epoch 30/30: loss=~2.304 lr=0.100000 val accuracy: 998/9984 = 10.00% ← random-guess on 10 classes # cifar-bn-sgd — same config, add BN Epoch 1/30: loss=2.360751 lr=0.100000 Epoch 2/30: loss=1.618361 lr=0.100000 Epoch 3/30: loss=1.485531 lr=0.100000 Epoch 4/30: loss=1.391280 lr=0.100000 Epoch 5/30: loss=1.295481 lr=0.100000 Epoch 6/30: loss=1.191307 lr=0.100000 Epoch 7/30: loss=1.110645 lr=0.100000 ... Epoch 30/30: loss=~0.35 lr=0.100000 val accuracy: 6147/9984 = 61.57%
The no-BN run is not diverging. It’s frozen. Loss 2.302 is \(\log (10)\) — the cross-entropy of uniform random prediction on 10 classes. The network’s output distribution is essentially unchanged after 30 epochs. Gradients at learning rate 0.1 are too noisy for the optimizer to make coherent progress; every step nudges parameters into meaningless neighborhoods. The model is computationally a random classifier from step zero to the end of training.
CIFAR-10, 4-conv arch, SGD 0.1, 30 epochs (BN vs no-BN; logs/ablation_cifar-{bn,nobn}-sgd.log). The no-BN run sits at \(\log 10 \approx 2.30\) — random-guess loss — for the entire run.
The BN run’s loss halves by epoch 2 and keeps dropping. Same optimizer, same learning rate, same dataset, same architecture plus the BN theorems we proved in this chapter. The three-term formula from § 32 is what makes this lift possible — BN standardizes each layer’s input distribution (\(\mu = 0\), \(\sigma = 1\)) and its backward pass returns a gradient that accounts for how that standardization affected every other sample in the batch. The result: gradients are well- conditioned, SGD steps are meaningful, training proceeds.
Where the BN lift actually lives
Why does normalizing help at all? Each layer is tuned for the distribution of its inputs — but those inputs are the outputs of every layer below, which are themselves changing every step. So each layer chases a moving target, and at a high learning rate the target moves fast: activations drift to scales where the nonlinearity saturates or gradients blow up, and the step that helped one layer wrecks the next. BN removes the moving part — by pinning every layer’s input to mean-zero, unit-variance before the learnable \(\gamma , \beta \) get a say, each layer always sees a well-conditioned input no matter what the layers below just did. Ioffe & Szegedy framed this as reducing “internal covariate shift”; later analysis (Santurkar et al., 2018) argued the sharper reason is that normalization smooths the loss landscape — gradients stay predictable from one step to the next — which is exactly what lets you take the bigger steps the rest of this section measures. Either way: BN decouples a layer’s job from the chaos below it, and that decoupling is worth an order of magnitude of learning rate.
If you rerun the no-BN CIFAR config at a smaller learning rate — SGD at 0.02 instead of 0.1, for example — it trains fine and hits 72.5% (logs/ablation_cifar-nobn-sgd002.log). At the same smaller rate, the BN version hits 72.9% (cifar-bn-sgd002). So BN isn’t magically giving you more model capacity — both specs top out around the same accuracy when you tune the learning rate.
What BN is doing is giving you roughly an order of magnitude of learning-rate headroom for free. The lr=0.1 setting that diverges without BN trains fine with it. That matters because the lr-0.1 regime trains faster per epoch and reaches the same final accuracy in fewer passes over the data. Over a multi-day ImageNet training run that difference is the difference between experimentally tractable and not.
This is also why every post-2015 image architecture has BN baked in by default. Before BN, picking the right lr schedule was a dark art; after BN, you just set lr to 0.1 and let the network figure itself out. The proof suite we built in this chapter is what mechanically guarantees the BN pass is computing the correct VJP at every step — so the empirical result (“BN lets you train with a 5\(\times \) bigger learning rate”) rests on a machine-checked mathematical claim, not on folklore.
The next chapter (§ 6) adds residual connections and the same mechanical approach: prove that the VJP of a skip connection is additive fan-in, compose with BN and conv, and the rest of the ResNet family falls out without introducing any new math.