4 MNIST: 2D CNN
The MLP wastes information; we can do better
Chapter 3’s MLP recognized digits by flattening every 28\(\times \)28 image into a 784-dim vector and feeding it through three dense layers. It worked. But it worked by ignoring something the data was telling us. Pixel \((5, 5)\) and pixel \((5, 6)\) are neighbors: they were next to each other on the page where the digit was drawn. Pixel \((5, 5)\) and pixel \((23, 14)\) are not. The dense layer has no idea. It treats every input coordinate as an independent feature; its weight matrix has no notion that some pairs of pixels are geometrically close. A trained MLP eventually learns to compensate, but it has to learn it from the data, with the data, by paying the full price for each unrelated feature pair.
A handwritten “5” drawn in the top-left of the image and the same “5” drawn in the bottom-right are the same digit. The MLP sees them as completely unrelated vectors. It can be trained to classify both, but each is its own learning problem; nothing it learned about one transfers to the other for free.
Convolutional networks fix this by building three priors directly into the architecture—priors that the data has but the MLP threw away.
Three properties baked into a convolution
Locality. Each output value depends on only a small neighborhood of inputs. A convolution with a \(3 \times 3\) kernel computes each output pixel as a weighted sum of its 9 neighbors. Faraway pixels do not enter directly; the network only learns to combine them later, by stacking convolutions and letting receptive fields grow.
Translation invariance. The same weights are applied at every spatial position. A feature detector that recognizes the bottom-left curve of a “5” uses the same parameters whether the curve appears at row 5 or row 25. The architecture imposes the symmetry; the data does not have to teach it.
Parameter sharing. A \(3 \times 3 \times 1 \to 32\) convolution has \(3 \times 3 \times 1 \times 32 + 32 = 320\) parameters total, regardless of image size. The matching fully-connected layer on a 28\(\times \)28 image would have \(28^2 \times 28^2 \times 32 \approx 20\) million. The convolution gets the same expressive power with four orders of magnitude fewer parameters because it reuses them.
Concretely, a convolution computes
Each output cell is a small dot product between the kernel \(W\) and a window of the input. The sum is over a tiny \(3 \times 3\) spatial region and over input channels—usually a handful of terms, not the 784 a dense layer touches.
The Jacobian is sparse and structured
The Jacobian of a convolution looks like a giant matrix, but it isn’t a generic giant matrix—it’s mostly zeros. Specifically, \(\partial y_{c_{\text{out}}, i, j} / \partial x_{c_{\text{in}}, i', j'}\) is zero unless \((i', j')\) falls inside the kernel window centered at \((i, j)\). When it does, the entry is just the corresponding weight \(W_{c_{\text{out}}, c_{\text{in}}, k_h, k_w}\).
So a convolution is a linear map—like a dense layer—but its Jacobian is highly structured: each row has only \(k^2 \times C_{\text{in}}\) nonzero entries, and those entries are reused across many rows because the same kernel is used at every spatial position. The forward picture from Chapter 3 “the Jacobian of a linear map is the linear map” still holds. The Jacobian is just written in a different basis: instead of one big dense matrix it is a list of kernel weights plus a rule for where each one lives.
The backward of a convolution is another convolution
This is the part that surprises a first reader, and it is the single most important fact in the chapter. We want \(dx\)—how a small jiggle in each input pixel changes the loss—given an upstream \(dy\). Plugging into the Jacobian above and grinding through the algebra, the answer is:
That is itself a convolution. Specifically, it is a convolution of the upstream gradient \(dy\) against the same kernel \(W\), but with the spatial axes reversed and the input/output channel axes swapped—a “transposed convolution”. The backward of a CNN block reuses the same kind of operation as the forward; no new primitive is needed.
The weight gradient \(dW\) similarly works out to a convolution, this time of the saved input \(x\) against the upstream gradient \(dy\) (the “transpose trick”). The bias gradient \(db\) is just \(dy\) summed over spatial positions per channel. Theorems 23, 24, and 25 formalize each of these.
Max pooling: the winner gets the gradient
The other new primitive in this chapter is max pooling. A \(2 \times 2\) max pool collapses each \(2 \times 2\) block of the input into a single output value: the max of the four. It has no parameters; its purpose is to halve the spatial resolution while keeping the strongest local response.
The Jacobian of \(\max \) is degenerate. At a generic point, only one of the four inputs is the maximum, and the output’s slope is \(1\) with respect to that one input and \(0\) with respect to the other three. “Jiggle the max” moves the output; “jiggle a non-max” does nothing. So the backward of max pool routes the upstream gradient back to whichever input was the winner in each window, and zeros out the rest. At ties the codegen breaks deterministically (same convention as ReLU’s kink). Definition 27 states this formally.
The rest of the network is Chapter 3
After two convolutions and a max pool, the feature map is \(14 \times 14 \times 32\). At that point we flatten the \(14 \cdot 14 \cdot 32 = 6272\)-dim feature vector and feed it into the same three-dense-layer head from Chapter 3, ending in 10-class logits. The backward through the dense head and the softmax-CE loss is exactly the calculation we did in the previous chapter—we do not re-prove it here. The composition rule from Chapter 2 guarantees that the full network’s VJP is just the per-layer VJPs strung together in reverse order: softmax-CE backward, dense backward \(\times 3\), flatten backward, max-pool backward, conv2d backward \(\times 2\). Each piece is its own theorem; together they form a working trainer for MNIST that respects the image’s 2-D structure.
The theorems
Concrete \(\sum _{c,kh,kw}\) cross-correlation with SAME padding (Phase 7). Codegen emits stablehlo.convolution; proofs reason about it via the explicit definition.
Reversed-kernel convolution formula. Phase 1 (Apr 2026) proof via pdiv_finset_sum \(\times 3\) + pdiv_const_mul_pi_pad_eval per-summand + \(\sum _{c,k_h,k_w}\) collapse.
Mechanical; see Proofs.conv2d_has_vjp3.
Phase 7: the transpose-trick formula via Kernel4.flatten. Derived from foundation rules.
Mechanical; see Proofs.conv2d_weight_grad_has_vjp.
Phase 9: sum cotangent over spatial dims per channel. Derived from foundation rules.
Mechanical; see Proofs.conv2d_bias_grad_has_vjp.
Concrete four-way max over 2x2 windows (Phase 7).
noncomputable def over the canonical pdiv-derived witness. HasVJP.correct holds by rfl; codegen substitutes the standard argmax routing convention at tiebreaks.
4.1 Example: MNIST 2D CNN
The MLP in Chapter 3 crushed MNIST by flattening the image into a 784-dim vector and throwing dense layers at it. That works because MNIST is small, but it also throws away all the spatial structure: pixel (5,5) and pixel (5,6) are neighbors, but the MLP sees no more relationship between them than it does between pixel (5,5) and pixel (23,14). Convolutions fix that. This chapter does MNIST again, the same way the Swift for TensorFlow book’s Chapter 2 did it: two plain \(3 \times 3\) convs to pull out local features, a max-pool to collapse the spatial grid, then the same dense head you already know.
Architecture
Two convolutions lift the 1-channel grayscale input to 32 channels, maxPool halves the spatial resolution, then the flattened \(14 \times 14 \times 32 = 6272\) feature vector goes through three dense layers into 10-class logits. No batch norm yet — BN shows up in Chapter 5.
Code: the full training program
Same format as the MLP chapter. The only thing that changes from Chapter 3’s training program is the NetSpec — hyperparameters stay exactly the same s4tfBaseline config.
-- 1
import LeanMlir
-- 2
def mnistCnnNoBn : NetSpec where
name := "MNIST-CNN-noBN"
imageH := 28
imageW := 28
layers := [
.conv2d 1 32 3 .same .relu,
.conv2d 32 32 3 .same .relu,
.maxPool 2 2,
.flatten,
.dense 6272 512 .relu,
.dense 512 512 .relu,
.dense 512 10 .identity
]
-- 3 (identical to the MLP chapter's s4tfBaseline)
def s4tfBaseline : TrainConfig where
learningRate := 0.1
batchSize := 128
epochs := 12
useAdam := false
weightDecay := 0.0
cosineDecay := false
warmupEpochs := 0
augment := false
labelSmoothing := 0.0
-- 4
def main (args : List String) : IO Unit :=
mnistCnnNoBn.train s4tfBaseline (args.head?.getD "data") .mnist
Walking through the numbered sections:
1. The import. Same as Chapter 3.
2. The architecture. Seven layers: two .conv2d layers at 32 channels, one .maxPool, .flatten, three .dense layers down to 10. Each .conv2d takes in-channels, out-channels, kernel size, padding, activation — five parameters, no classes. Same data-first style as the MLP spec.
The .conv2d backward pass is the theorem from earlier in this chapter (§ 23): the input VJP is another convolution, using the kernel reversed along both spatial axes. The max-pool backward (§ 27) routes the gradient to the argmax position in each \(2 \times 2\) window. The dense backward is already proved in Chapter 3 (§ 17). Adding convolution to the model introduces no new training-loop code — only new layer types, which the framework already knows how to compose.
3. The training hyperparameters. Literally the same s4tfBaseline config as the MLP chapter — SGD 0.1, 12 epochs, no regularization, no augmentation. That’s the framework’s pitch made concrete: swap the architecture, keep the trainer. The ablation framework gives you access to the same variant configs (adamOnly, fullRecipe, etc.) without any more changes on your side.
4. The program entry point. One line, same as the MLP chapter. mnistCnnNoBn.train s4tfBaseline replaces mnistMlp.train s4tfBaseline. That’s the whole diff.
Results
Build and run via the ablation framework:
$ IREE_BACKEND=rocm IREE_CHIP=gfx1100 \
./.lake/build/bin/ablation cnn-nobn-sgd
Ablation: cnn-nobn-sgd
spec: MNIST-CNN-noBN, optimizer: SGD
lr: 0.100000, cosine: false, wd: 0.000000
aug: false, label_smooth: 0.000000
MNIST-CNN-noBN-cnn-nobn-sgd: 3489130 params
Generating train step MLIR...
21884 chars
Compiling vmfbs...
forward compiled
eval forward compiled
train step compiled
session loaded
train: 60000 images (784 floats/image)
training: 468 batches/epoch, batch=128, SGD, lr=0.100000
step 0/468: loss=2.436... (75ms)
Epoch 1/12: loss=0.196587 lr=0.100000 (39327ms)
Epoch 2/12: loss=0.043851 lr=0.100000 (39243ms)
...
Epoch 12/12: loss=(small) lr=0.100000 (~39000ms)
val accuracy: 9898/9984 = 98.98%
(Full log in logs/ablation_cnn-nobn-sgd.log.) Twelve epochs, about 40 seconds per epoch on the 7900 XTX, \(\sim \)8 minutes total. Final test accuracy: 98.98%, up from the MLP’s 98.57% on the same data, same optimizer, same number of epochs. The delta is 0.4 percentage points — not huge, but consistent and free: the convolution just finds spatial features the MLP couldn’t see.
MNIST, same SGD 0.1 / 12-epoch recipe, log-scale training loss (logs/ablation_{mlp,cnn-nobn}-sgd.log). The CNN converges faster early (half the MLP’s loss by epoch 2); both bottom out near \(10^{-3}\) train loss, but the CNN generalizes better — 98.98% vs 98.57% test.
Why this network is still 99% dense-head
At 3,489,130 total parameters, the no-BN CNN is actually \(\sim \)5.2\(\times \) bigger than the MLP, not smaller. You might have expected the opposite — convolutions are supposed to be parameter-efficient. They are, but the head is what’s eating the budget: the flatten of a \(14 \times 14 \times 32\) feature map produces a 6272-dim vector, and the first dense layer (6272 \(\to \) 512) alone is 3,211,776 parameters. The entire conv+maxPool backbone is 9,568 parameters — 0.27% of the total.
This is the same “fat FC head” pattern you’ll see in AlexNet, VGG, and YOLOv1: a compact convolutional backbone feeding a massive fully-connected classifier. Historically it drove practitioners crazy — the conv layers were the interesting new machinery, but all the weight budget lived in the part nobody had done anything new with. Every post-2015 vision architecture responded by replacing the dense head with globalAvgPool followed by a single small dense, which is why ResNet-18 has \(\sim \)11M parameters for 1000-class ImageNet but Chapter 4’s 10-class MNIST CNN has 3.5M. We’ll see the switch firsthand starting with ResNet in Chapter 6.
The MLIR character count also tells a story: the CNN emits 21,884 chars vs the MLP’s 10,375 — about 2.1\(\times \) bigger. Adding convolutions and maxPool roughly doubles the emitted compute, not 3.5\(\times \) the way our other (BN-equipped) CNN does. Every .convBn adds batch-norm normalize and affine ops; the plain .conv2d used here keeps the StableHLO much tighter.