Verified Deep Learning with Lean 4

3 MNIST: 1D MLP

The pdiv we built last chapter, now applied

In Chapter 2 we defined a function \(\operatorname {pdiv}\) that captures the partial derivative of any sufficiently smooth function. We proved that three structural rules—chain, sum, product—suffice to compose new partials out of old ones. That was machinery without a target. This chapter picks the target: we’re going to compute the partial derivative of every component of a small image classifier, end-to-end, and the goal is for that classifier to come out the other side able to recognize handwritten digits.

To say the same thing more concretely: a neural network is a long chain of functions, and “training” means using the partials of those functions to nudge their parameters in a direction that reduces a loss. Every theorem in this chapter is the answer to the same question, asked about a different building block: if I jiggle this input by \(\varepsilon \), how does the output jiggle?

The Jacobian as multidimensional “how does it jiggle?”

For a scalar function \(f : \mathbb {R} \to \mathbb {R}\) the answer is a single number: the derivative. For our networks, no function is scalar-in scalar-out. The smallest layer maps an input vector to an output vector—\(n\) numbers in, \(m\) numbers out. “Jiggle the input” now means: pick a direction in \(\mathbb {R}^n\) and push slightly along it. “The output jiggles” now means: a vector in \(\mathbb {R}^m\). The full picture of how every output coordinate responds to every input coordinate is an \(m \times n\) table of numbers, one slope per (output, input) pair. That table is the Jacobian. There is nothing more to it. It is just the multidimensional generalization of “the derivative is the slope.”

The previous chapter’s \(\operatorname {pdiv}\) is the one entry of this table at a chosen \((i, j)\). The Jacobian is \(\operatorname {pdiv}\) applied across both indices at once, organized so it can be multiplied into other matrices later.

The dense layer’s Jacobian is the weight matrix

Our smallest building block is a dense (or “fully connected”) layer: \(y = Wx + b\), where \(W\) is an \(m \times n\) weight matrix, \(x\) is an \(n\)-vector, and \(b\) is an \(m\)-vector. Pick any output coordinate \(y_j\) and write it out:

\[ y_j = \sum _{k=1}^{n} W_{jk}\, x_k + b_j. \]

Now jiggle \(x_i\) by \(\varepsilon \). Of the \(n\) terms in the sum, only the \(k = i\) term notices, and it changes by \(W_{ji}\, \varepsilon \). So \(\partial y_j / \partial x_i = W_{ji}\). The Jacobian of \(y = Wx + b\) with respect to \(x\) is the weight matrix \(W\) itself. No new structure, no surprise—the Jacobian of a linear map is the linear map, written as a matrix.

We will do the same exercise for the partial with respect to \(W\), and we will get an analogous answer: the dependence is local, the entries of the Jacobian are just \(x_{i'}\, \delta _{jj'}\). These two—input-Jacobian and weight-Jacobian—are the only objects we need for a dense layer. Theorems 13 and 14 below formalize them; we have already done the substantive work.

ReLU: the piecewise case

ReLU is the function \(\mathrm{relu}(x) = \max (x, 0)\) applied coordinatewise. Its Jacobian is a diagonal matrix: each output coordinate depends only on the matching input coordinate. The diagonal entry is \(1\) where \(x_i {\gt} 0\) and \(0\) where \(x_i {\lt} 0\). At \(x_i = 0\) the function is not differentiable in the classical sense—the slope jumps from \(0\) to \(1\). For our purposes this matters less than you might fear. Theorem 15 states the Jacobian at smooth points; the codegen substitutes the standard subgradient convention at the kink and we move on. The kink is the only place in the chapter where “differentiable” becomes slightly subtle.

Softmax cross-entropy: where vectors collapse to a scalar

The last building block is the loss: softmax cross-entropy between the model’s output \(z \in \mathbb {R}^{10}\) and the true class label \(y \in \{ 0, \ldots , 9\} \). The full computation is

\[ L(z, y) = -\log \frac{e^{z_y}}{\sum _k e^{z_k}}. \]

The loss is a scalar, so its Jacobian with respect to \(z\) is a vector, not a matrix. Working through the algebra (chain rule on \(-\log \circ \mathrm{softmax}\) at the label index) produces a remarkably clean answer:

\[ \frac{\partial L}{\partial z} = \mathrm{softmax}(z) - \mathrm{onehot}(y). \]

The gradient of the loss is the difference between the model’s predicted distribution and the truth. Theorem 16 makes this formal.

From Jacobians to VJPs

Training does not multiply Jacobians together directly. The loss is a scalar, and the quantity we actually want is “how does each parameter affect that scalar?” That quantity is the Jacobian of the loss with respect to the parameters transposed and applied to the upstream gradient—a vector-Jacobian product, or VJP. Concretely: the upstream gradient is a vector \(dy\), and we want to compute the corresponding \(dx\) and \(dW\). For a dense layer the answers fall out by transposing the picture we already have:

\[ dx = W^\top dy, \qquad dW = dy \otimes x, \qquad db = dy. \]

The remaining theorems in this chapter (17 through 21) are the formalizations of these identities, plus the proof that VJPs of composed layers are themselves VJPs obtained by composing the building blocks in reverse order. That last claim is what makes a 3-layer MLP’s backward pass exactly three transposed matrix multiplies—no more, no fewer—and it is the structural fact that the rest of the book leans on.

The theorems

Theorem 13 Dense Jacobian wrt input

\(\partial (Wx + b)_j / \partial x_i = W_{ij}\). Derived from foundation rules.

Proof

Mechanical; see Proofs.pdiv_dense.

Theorem 14 Dense Jacobian wrt weight

\(\partial (Wx + b)_j / \partial W_{i'j'} = x_{i'} \delta _{jj'}\). Phase 7; derived from foundation rules over the flatten bijection.

Proof

Mechanical; see Proofs.pdiv_dense_W.

Theorem 15 ReLU Jacobian (guarded subgradient)
#

Pdiv of \(\mathrm{ReLU}\) at smooth points (\(\forall k,\ x_k \neq 0\)). Proved via local-diagonal-CLM transport: at a smooth point ReLU agrees with the diagonal indicator CLM \(\Pi _k\, (\text{if } x_k {\gt} 0 \text{ then } \mathrm{proj}_k \text{ else } 0)\) on \(\mathrm{Metric.ball}\, x\, (\min |x_k|)\) (every coordinate keeps its sign), and HasFDerivAt.congr_of_eventuallyEq transports the CLM’s self-fderiv to ReLU.

Proof

Mechanical; see Proofs.pdiv_relu.

Theorem 16 Softmax cross-entropy gradient
#

\(\partial L / \partial z = \mathrm{softmax}(z) - \mathrm{onehot}(y)\). Proved by chain rule on \(-\log \circ \mathrm{softmax\_ label}\) using HasFDerivAt.log (with \(\mathrm{softmax}(z)[\, \ell \, ] {\gt} 0\) from exp positivity) composed with the proved \(\mathrm{softmax}\) Jacobian.

Proof

Mechanical; see Proofs.softmaxCE_grad.

Theorem 17 Dense VJP
#
Proof

Mechanical; see Proofs.dense_has_vjp.

Theorem 18 Dense weight gradient is the outer product
#

\(dW = x \otimes dy\). Phase 7 promoted from vacuous rfl to theorem.

Proof
Theorem 19 Dense bias gradient is identity

\(db = dy\). Phase 7: derived, no new axiom.

Proof
Definition 20 ReLU VJP
#

noncomputable def over the canonical pdiv-derived witness; HasVJP.correct holds by rfl since \(\operatorname {pdiv}\) is a def over \(\operatorname {fderiv}\). At non-smooth points the canonical backward is \(\operatorname {fderiv}\)’s junk default of \(0\); the codegen substitutes the standard subgradient convention.

Definition 21 MLP composition VJP
#

noncomputable def over the canonical pdiv-derived witness; same shape as relu_has_vjp. Codegen routes the ReLU subgradient at the kink.

3.1 Example: MNIST MLP

The theorems above are the calculus. Here is a concrete architecture built from those pieces: a three-layer fully-connected classifier for 28\(\times \)28 MNIST digits.

Dataset overview

MNIST is the standard testbed for image-recognition learning. It’s a collection of 28\(\times \)28 grayscale images of handwritten digits 0–9: 60 000 training images and 10 000 test images, each with a label indicating which digit was drawn. The dataset has been around since 1998 — at 784 pixels per image and 10 classes, MNIST is small enough that you can train a competitive model on a laptop CPU in minutes while still having a nontrivial learning problem.

Our goal in this chapter is to correctly classify a held-out test digit based on a model trained from the 60 000 training digits. We’re going to ignore the 2D spatial structure of the image entirely for now — just flatten each 28\(\times \)28 image into a 784-dim vector and treat it as a plain supervised-learning classification problem. This is the multilayer perceptron (MLP). Chapter 4 revisits MNIST with convolutions that respect the spatial structure.

Architecture

\includegraphics[width=0.85\textwidth ]{arch/mnist_mlp}

Three dense layers stacked with ReLU non-linearities between them: \(784 \to 512 \to 512 \to 10\). First layer ingests the flattened image. The two hidden layers let the network learn nonlinear features. The final layer maps to 10-dimensional logits, one per digit class.

Code: the full training program

The entire training program in Lean is about twenty-five lines. This is the old-school SGD baseline — the exact config the Swift for TensorFlow book used in its Chapter 1 (SGD with learning rate 0.1, 12 epochs, no regularization tricks). In our repo this config is s4tfBaseline in MainAblation.lean:

-- 1
import LeanMlir

-- 2
def mnistMlp : NetSpec where
  name   := "MNIST-MLP"
  imageH := 28
  imageW := 28
  layers := [
    .dense 784 512 .relu,
    .dense 512 512 .relu,
    .dense 512  10 .identity
  ]

-- 3
def s4tfBaseline : TrainConfig where
  learningRate   := 0.1    -- old-school SGD learning rate
  batchSize      := 128
  epochs         := 12
  useAdam        := false  -- plain SGD, no momentum, no moment buffers
  weightDecay    := 0.0
  cosineDecay    := false
  warmupEpochs   := 0
  augment        := false
  labelSmoothing := 0.0

-- 4
def main (args : List String) : IO Unit :=
  mnistMlp.train s4tfBaseline (args.head?.getD "data") .mnist

Walking through the numbered sections:

1. The import. One line. LeanMlir is the framework implemented in the LeanMlir/ directory of the repo — it exposes NetSpec, TrainConfig, and the train method. Everything downstream in this chapter and the next is a specialization of this framework.

2. The architecture. NetSpec is a plain data structure: a name, input dimensions, and a list of Layer values. Our three layers are .dense 784 512 .relu, .dense 512 512 .relu, .dense 512 10 .identity. Read left-to-right: input size, output size, activation. That’s the entire model definition — no class, no forward-pass function, no @differentiable attribute. The forward pass is derivable from the layer list, and the backward pass is proved correct (see the theorems earlier in this chapter).

3. The training hyperparameters. Plain SGD at learning rate 0.1, batch size 128, 12 epochs. No Adam, no weight decay, no cosine schedule, no warmup, no augmentation, no label smoothing. This is the deliberately minimal “Chapter 1” baseline — the Swift for TensorFlow book shipped exactly this config as its reference example, and it’s a pedagogically useful starting point because every later chapter’s tweak — training-recipe (Adam, cosine, warmup, weight decay, augmentation, label smoothing) and architectural (BN, residuals) — then has a clean baseline to measure against. The repo’s MainAblation.lean ships a dozen other configs beside this one so you can run each ablation and see the delta.

4. The program entry point. One line. Call mnistMlp.train with the config, the data directory (default data/), and the dataset kind (.mnist). Everything inside .train (walked through in Chapter 2, “What’s inside .train?”) — data loading, mini-batching, forward-pass MLIR generation, IREE compilation, gradient-descent loop, evaluation loop, logging — is already formalized in the framework. The user-facing program is just “specify the network, specify the knobs, run.”

That’s the whole thing. The parts people usually have to write by hand every time (the training loop, the eval loop, the gradient- computation machinery) are hidden inside the framework because they were the same every time. The correctness of what’s inside .train is what the proof chapters prove.

Results

Build the ablation runner and invoke it with the mlp-sgd config. Output below is from a real run, captured verbatim from logs/ablation_mlp-sgd.log, on an AMD 7900 XTX (ROCm, gfx1100):

$ lake build ablation
$ IREE_BACKEND=rocm IREE_CHIP=gfx1100 \
    ./.lake/build/bin/ablation mlp-sgd
Ablation: mlp-sgd
  spec: MNIST-MLP, optimizer: SGD
  lr: 0.100000, cosine: false, wd: 0.000000
  aug: false, label_smooth: 0.000000
MNIST-MLP-mlp-sgd: 669706 params
Generating train step MLIR...
  10375 chars
Compiling vmfbs...
  forward compiled
  eval forward compiled
  train step compiled
  session loaded
  train: 60000 images (784 floats/image)
training: 468 batches/epoch, batch=128, SGD, lr=0.100000
  step 0/468: loss=2.437782 (34ms)
Epoch  1/12: loss=0.218235 lr=0.100000 (9417ms)
Epoch  2/12: loss=0.082146 lr=0.100000 (9269ms)
Epoch  3/12: loss=0.050954 lr=0.100000 (9280ms)
Epoch  4/12: loss=0.035614 lr=0.100000 (9171ms)
Epoch  5/12: loss=0.025503 lr=0.100000 (9119ms)
Epoch  6/12: loss=0.017866 lr=0.100000 (9118ms)
Epoch  7/12: loss=0.011806 lr=0.100000 (9153ms)
Epoch  8/12: loss=0.009618 lr=0.100000 (9169ms)
Epoch  9/12: loss=0.005895 lr=0.100000 (9250ms)
Epoch 10/12: loss=0.003566 lr=0.100000 (9262ms)
  val accuracy: 9795/9984 = 98.11%
Epoch 11/12: loss=0.002229 lr=0.100000 (9114ms)
Epoch 12/12: loss=0.000999 lr=0.100000 (9093ms)
  val accuracy: 9841/9984 = 98.57%
Saved params + BN stats.

Twelve epochs, about 9 seconds per epoch, \(\sim \)110 seconds total. Training loss drops from \(2.44\) at step 0 (random-initialization cross-entropy on 10 classes: \(\log 10 \approx 2.303\), plus a bit of noise) down below \(10^{-3}\) by the end. Final test accuracy lands at 98.57%, which is the going rate for a three-layer MLP on MNIST with plain SGD. The matching CPU run via the Docker image in Appendix A takes about 5 minutes and lands at the same accuracy.

\begin{tikzpicture} 
\begin{axis}[
    width=0.92\linewidth, height=6.5cm,
    xlabel={Epoch}, ylabel={Training loss},
    ymode=log,
    xmin=0, xmax=13, ymin=0.0007, ymax=0.4,
    xtick={0,2,4,6,8,10,12},
    grid=major, grid style={gray!18},
    tick label style={font=\small},
    label style={font=\small},
    every axis plot/.append style={line width=1pt, mark size=1pt},
]
\addplot[blue, mark=*, mark options={fill=blue}] coordinates {
(1,0.218235) (2,0.082146) (3,0.050954) (4,0.035614) (5,0.025503) (6,0.017866) (7,0.011806) (8,0.009618) (9,0.005895) (10,0.003566) (11,0.002229) (12,0.000999)
};
\end{axis}
\end{tikzpicture}

MNIST MLP (\(784{\to }512{\to }512{\to }10\)), SGD 0.1, 12 epochs, log-scale loss (logs/ablation_mlp-sgd.log). Per-epoch training loss falls \(\sim \)200\(\times \) from the first epoch to below \(10^{-3}\); final test accuracy 98.57%.

Notice that the emitted MLIR is only 10 375 characters — about half of what the Adam variant would emit. Plain SGD’s per-step update is a single param = param - lr * grad subtraction, while Adam has to maintain and update first- and second-moment estimate buffers and apply a bias-corrected update. Same network, same data, same backward pass — the optimizer choice alone doubles the emitted compute.

Chapter 2 (MNIST 2D CNN) runs the same s4tfBaseline config against a CNN architecture for a direct apples-to-apples comparison. Later chapters switch to the modern recipe (Adam + cosine + warmup + weight decay + label smoothing + data augmentation) once we can benchmark the deltas against this baseline.

What is actually in those “20 502 chars”

The 20 502-character line is the entire training step emitted as stablehlo MLIR — forward pass through all three dense layers plus the softmax cross-entropy loss plus the backward pass (VJP of every layer) plus the Adam optimizer update for each parameter. IREE compiles that MLIR once into a vmfb (vm flatbuffer) and then executes it per mini-batch with no further translation overhead. The 669 706 parameter count is the sum of the three dense layers: \(784 \cdot 512 + 512 = 401\, 920\) for layer 1, \(512 \cdot 512 + 512 = 262\, 656\) for layer 2, \(512 \cdot 10 + 10 = 5\, 130\) for layer 3. Subsequent chapters show the per-layer MLIR fragments this code is assembled from.