Verified Deep Learning with Lean 4

2 You Are Here

A VJP (vector-Jacobian product) is the backward function for a layer: takes an upstream gradient, returns the downstream one. Every backward function in this book is a VJP built by composing the foundation here — the partial-derivative function \(\operatorname {pdiv}\) (defined via Mathlib’s \(\operatorname {fderiv}\)), its three structural rules (chain, sum, product), and three VJP record types (\(\mathsf{HasVJP}\), \(\mathsf{HasVJPMat}\), \(\mathsf{HasVJP3}\), one per tensor rank) that bundle a backward function with its correctness claim.

This is the technically hardest chapter in the book, by design. Going from nothing to a complete trained model — the \(\operatorname {pdiv}\) calculus, the VJP framework, a forward pass, a loss, a backward pass, an SGD step, all the way down to the GPU — is the entire machinery, built in a single chapter. We do it on a trivial network (one matrix multiply: the MNIST linear classifier), because the network is not the point — building the whole machinery is. Once that machinery exists, every later chapter is just adding a layer: one new primitive dropped into the same forward / loss / backward / optimize loop.

The next twelve theorems all reduce to a single definition (\(\operatorname {pdiv}\)) plus chain rule, sum rule, and product rule. Before we drop you into the deep end, here’s the shape of what’s coming:

\begin{tikzpicture} [
  node distance=1.2cm,
  every node/.style={font=\small, align=center},
  group/.style={rectangle, draw, fill=blue!5, rounded corners,
                inner sep=4pt, minimum width=5cm, minimum height=1.5cm},
  root/.style={rectangle, draw, fill=red!10, rounded corners,
               inner sep=4pt, minimum width=4cm, minimum height=0.8cm},
  payoff/.style={rectangle, draw, fill=green!10, rounded corners,
                 inner sep=4pt, minimum width=6.5cm, minimum height=1cm},
  arrow/.style={->, >=stealth, thick, gray}
]
\node[root] (pdiv) at (0, 5) {\(\pdiv\) \\ \footnotesize defined via Mathlib's \(\fderiv\)};

\node[group] (foundation) at (0, 3) {\textbf{Foundation rules} \\ chain $\cdot$ sum $\cdot$ product \\ identity $\cdot$ const $\cdot$ reindex \\ finite-sum};

\node[group] (vjp) at (0, 0.5) {\textbf{VJP transpose} \\ \texttt{vjp\_comp} $\cdot$ biPath \\ elemwiseProduct $\cdot$ identity};

\node[payoff] (payoff) at (0, -2) {Used by the MNIST linear classifier here \\ (proven + trained); then reused by every \\ architecture in Chapters 3--10.};

\draw[arrow] (pdiv) -- (foundation);
\draw[arrow] (foundation) -- (vjp);
\draw[arrow] (vjp) -- (payoff);
\end{tikzpicture}

Read top-to-bottom, this is the order things get proved. Read bottom-to-top, this is what every theorem in Chapters 3–9 unfolds to. (Ch 10 attention adds matrix-level machinery on top; see §10.1.) The full clickable dependency graph is in the blueprint web view.

Definition 1 Partial derivative
#

The partial derivative function. For \(f : \mathbb {R}^{m} \to \mathbb {R}^{n}\), \(\operatorname {pdiv}\, f\, x\, i\, j\) is the \((i, j)\)-entry of the Jacobian at \(x\), defined as \(\operatorname {fderiv}_{\mathbb {R}}\, f\, x\, (\mathbf{e}_i)\, j\) — the \(j\)-th coordinate of Mathlib’s Fréchet derivative applied to the \(i\)-th standard basis vector.

Theorem 2 Chain rule
#

\(\operatorname {pdiv}(g \circ f)\, x\, i\, k = \sum _j \operatorname {pdiv}f\, x\, i\, j \cdot \operatorname {pdiv}g\, (f\, x)\, j\, k\), conditional on \(f\) differentiable at \(x\) and \(g\) differentiable at \(f(x)\).

Proof

Mechanical; see Proofs.pdiv_comp.

Theorem 3 Sum rule
#

Linearity of the derivative; conditional on both summands differentiable at \(x\).

Proof

Mechanical; see Proofs.pdiv_add.

Theorem 4 Product rule
#

Pointwise Leibniz rule for elementwise products; conditional on both factors differentiable at \(x\).

Proof

Mechanical; see Proofs.pdiv_mul.

Theorem 5 Identity Jacobian
#

\(\operatorname {pdiv}(\mathrm{id})\, x\, i\, j = \delta _{ij}\).

Proof

Mechanical; see Proofs.pdiv_id.

Theorem 6 Constant has zero Jacobian
#
Proof

Mechanical; see Proofs.pdiv_const.

Theorem 7 Gather / reindex Jacobian
#

Covers permutations, reshapes, slicing. Generalizes pdiv_id.

Proof

Mechanical; see Proofs.pdiv_reindex.

Theorem 8 Finite-sum rule
#

Linearity extended to arbitrary finite sums by induction.

Proof

Mechanical; see Proofs.pdiv_finset_sum.

Theorem 9 VJP chain rule
#

Given \(\mathsf{HasVJP}\, f\) and \(\mathsf{HasVJP}\, g\), get \(\mathsf{HasVJP}\, (g \circ f)\).

Proof

Mechanical; see Proofs.vjp_comp.

Theorem 10 Additive fan-in VJP
#

VJP of \(f + g\). Used for residual connections.

Proof

Mechanical; see Proofs.biPath_has_vjp.

Theorem 11 Multiplicative fan-in VJP
#

VJP of elementwise product. Used for Squeeze-and-Excitation.

Proof
Theorem 12 Identity VJP
#
Proof

Mechanical; see Proofs.identity_has_vjp.

2.1 Example: MNIST linear classifier

The smallest network that learns: a single dense layer mapping 784-dim images to 10-dim logits, trained with softmax cross-entropy and plain SGD. About 7,850 parameters. \(\sim \)92% test accuracy in seconds.

The forward pass and loss, in math.

For an input image \(x \in \mathbb {R}^{784}\) and true class index \(t \in \{ 0, \dots , 9\} \):

\begin{align*} z & = W x + b & & \text{(dense layer; $W \in \mathbb {R}^{10 \times 784}$, $b \in \mathbb {R}^{10}$)} \\ \hat{y}_i & = \mathrm{softmax}(z)_i = \frac{e^{z_i}}{\sum _j e^{z_j}} & & \text{(class probabilities)} \\ \ell (W, b;\, x, t) & = -\log \hat{y}_t & & \text{(cross-entropy on the true class)} \end{align*}

The trainable parameters are \(W\) and \(b\) — 7,850 floats total.

The same thing, in NetSpec.

def mnistLinear : NetSpec where
  name   := "MNIST-Linear"
  imageH := 28
  imageW := 28
  layers := [.dense 784 10 .identity]

The .identity activation means the layer outputs raw logits \(z = W x + b\); softmax + cross-entropy are added by .train since this is a classification task. The four-line NetSpec above describes exactly the function on the previous page.

Training program and run command.

def main (args : List String) : IO Unit :=
  mnistLinear.train mnistLinearConfig (args.head?.getD "data") .mnist

In our repo this is the standalone trainer mnist-linear-train (MainMnistLinearTrain.lean), matching the per-chapter trainer pattern every later chapter uses. Build and run it with

lake build mnist-linear-train
./.lake/build/bin/mnist-linear-train

The gradient, in math.

Backpropagation through this network produces a single outer product:

\[ \nabla _W \ell = (\hat{y} - e_t) \otimes x, \qquad \nabla _b \ell = \hat{y} - e_t \]

where \(e_t \in \mathbb {R}^{10}\) is the one-hot vector for the true class. Ch 2’s theorems (chain rule + identity Jacobian/VJP) plus the Dense Jacobian and softmax-CE gradient (formalized in Ch 3, both themselves proved from this chapter’s foundation rules) are what guarantees these formulas are correct; linear-sgd’s codegen emits exactly these as fused MLIR.

Results.

Captured verbatim from runs/2026-05-05-linear-sgd/linear-sgd.log on an AMD 7900 XTX (ROCm, gfx1100):

$ ./.lake/build/bin/mnist-linear-train
MNIST-Linear: 7850 params
training: 468 batches/epoch, batch=128, SGD, lr=0.100000
  step 0/468: loss=2.301343 (11ms)
Epoch  1/12: loss=0.381019 lr=0.100000 (932ms)
Epoch  2/12: loss=0.300518 lr=0.100000 (969ms)
Epoch  3/12: loss=0.289204 lr=0.100000 (934ms)
Epoch  4/12: loss=0.282829 lr=0.100000 (964ms)
Epoch  5/12: loss=0.277735 lr=0.100000 (995ms)
Epoch  6/12: loss=0.273606 lr=0.100000 (977ms)
Epoch  7/12: loss=0.271623 lr=0.100000 (961ms)
Epoch  8/12: loss=0.268908 lr=0.100000 (930ms)
Epoch  9/12: loss=0.265967 lr=0.100000 (969ms)
Epoch 10/12: loss=0.265967 lr=0.100000 (969ms)
  val accuracy: 9219/9984 = 92.34%
Epoch 11/12: loss=0.265719 lr=0.100000 (949ms)
Epoch 12/12: loss=0.263504 lr=0.100000 (967ms)
  val accuracy: 9175/9984 = 91.90%
Saved params.

Twelve epochs, about 1 second per epoch, \(\sim \)12 seconds total. Loss drops from \(\log 10 \approx 2.30\) at step 0 (random init for 10 classes) to \(0.26\) by epoch 12. Final test accuracy 91.90%, which is the going rate for a linear classifier on MNIST (LeCun’s original 1998 benchmark put a comparable model at \(\sim \)92%). Adding hidden layers and ReLU (Ch 3) bumps this to 98.57% on the same dataset and the same recipe; that delta of \(\sim \)6.7 points is the value of non-linearity on this task.

Where this goes.

Every architecture in Part 2 extends this template the same way. Ch 3 stacks dense layers and inserts ReLU between them (one new operator VJP); Ch 4 swaps the dense forward for conv2d (one new operator VJP); Ch 5 adds BatchNorm (one new operator VJP); Ch 6 adds the residual skip (additive fan-in, Theorem 10, no new operator); Ch 7 factors standard conv into depthwise plus pointwise (one new operator VJP, recombined through the chain rule); Ch 8 adds the SE channel-attention block (elementwise product plus GAP plus dense, one new VJP composed from pieces we already have); Ch 9 replaces ReLU with GELU (one new operator); Ch 10 adds attention (new matrix-level machinery plus softmax-VJP chains). The structural rules from this chapter never change; later chapters add operator-specific theorems and compose them through the chain rule we just proved. The training loop that runs all of them is the same hundred lines of Lean, walked through next.

2.2 What’s inside .train?

The section above treated mnistLinear.train as a black box. You specified the network, you specified the hyperparameters, training happened. That’s deliberately the user-facing interface, but there’s no magic: the training loop is real code, \(\sim \)100 lines of Lean in LeanMlir/Train.lean. Every chapter in this book uses the same loop. The only thing that varies chapter to chapter is the NetSpec and TrainConfig values handed to it; never the loop itself.

The training loop, the way a math book would write it.

Algorithm: Mini-batch SGD with Adam optimizer
Input:  spec (architecture f_θ), cfg (hyperparameters),
        dataset D = (X_train, y_train, X_val, y_val)
Output: trained parameters θ

// 1. Load training data
(X_train, y_train) ← load(D)
B ← cfg.batchSize, E ← cfg.epochs, α ← cfg.learningRate

// 2. Initialize parameters + Adam moment buffers
θ ← spec.heInit()
m, v ← 0                                  // Adam 1st and 2nd moment

// 3. Epoch loop: shuffle, schedule LR, train, log
for epoch = 1 to E:
    (X, y) ← shuffle(X_train, y_train)
    α_t ← schedule(α, epoch)              // cosine + warmup

    // 4. Batch loop: forward + loss + backward + optimizer
    for each mini-batch (x, t) of size B:
        ŷ ← f_θ(x)                         // forward
        L ← ℓ(ŷ, t)                        // loss
        g ← ∇_θ L                          // backward — Ch 2's VJP machinery
        (θ, m, v) ← Adam(θ, m, v, g, α_t)  // optimizer step
    log(epoch, mean_loss)

    // 5. Validation every 10 epochs
    if epoch ≡ 0 mod 10:
        eval(θ, X_val, y_val)

// 6. Save trained parameters
save(θ)
return θ

The only line that needs proof machinery is \(g \leftarrow \nabla _\theta L\) — that’s exactly what Ch 2’s theorems just established. Everything else is data plumbing (loaders, epoch counters, the SGD/Adam update). For mnistLinear above, that gradient is a single outer product; for ResNet-34 it would be the chain rule applied through 34 layers; the loop is the same in either case. Each later chapter swaps in a different \(f_\theta \); the loop never changes.

The same loop, in Lean.

The Lean realization of the algorithm above is one hundred lines in LeanMlir/Train.lean. The // 1 through // 6 comments below match the same labels in the pseudocode — read them in parallel.

def runTraining (spec : NetSpec) (cfg : TrainConfig)
    (ds : DatasetKind) (dataDir : String)
    (sess : IreeSession) : IO Unit := do
  -- 1. Load training data
  let batchN := cfg.batchSize
  let dio    := datasetIO ds
  let (trainImg, trainLbl, nTrain) ← dio.loadTrain dataDir

  -- 2. Initialize parameters + Adam moment buffers
  let mut p ← spec.heInitParams                        -- He-init weights
  let mut m ← F32.const (F32.size p).toUSize 0.0        -- Adam 1st moment
  let mut v ← F32.const (F32.size p).toUSize 0.0        -- Adam 2nd moment

  let bpE := nTrain / batchN
  let nP  := spec.totalParams
  let mut globalStep : Nat := 0

  -- 3. Epoch loop: shuffle, schedule LR, train, log
  for epoch in [:cfg.epochs] do
    let (sImg, sLbl) ← F32.shuffle trainImg trainLbl
                         nTrain.toUSize dio.trainPixels.toUSize
                         (epoch + 42).toUSize

    let lr : Float :=                                   -- cosine + warmup
      if epoch < cfg.warmupEpochs then
        cfg.learningRate * (epoch.toFloat + 1.0)
          / cfg.warmupEpochs.toFloat
      else if cfg.cosineDecay then
        cfg.learningRate * 0.5 * (1.0 + Float.cos (
          3.14159265 * (epoch.toFloat - cfg.warmupEpochs.toFloat)
                     / (cfg.epochs.toFloat - cfg.warmupEpochs.toFloat)))
      else cfg.learningRate

    -- 4. Batch loop: forward + loss + backward + optimizer in ONE call
    let mut epochLoss : Float := 0.0
    for bi in [:bpE] do
      globalStep := globalStep + 1
      let xba := F32.sliceImages sImg (bi * batchN) batchN dio.trainPixels
      let yb  := F32.sliceLabels sLbl (bi * batchN) batchN
      let packed := (p.append m).append v

      let out ← IreeSession.trainStepAdamF32 sess spec.trainFnName
                  packed spec.shapesBA xba (spec.xShape batchN) yb
                  lr globalStep.toFloat spec.bnShapesBA batchN.toUSize

      epochLoss := epochLoss + F32.extractLoss out (3 * nP)
      p := F32.slice out 0           nP                 -- updated params
      m := F32.slice out nP          nP                 -- updated m
      v := F32.slice out (2 * nP)    nP                 -- updated v

    IO.eprintln s!"Epoch {epoch+1}/{cfg.epochs}: " ++
                s!"loss={epochLoss / bpE.toFloat} lr={lr}"

    -- 5. Validation every 10 epochs: forward-only vmfb over val set
    if (epoch + 1) % 10 == 0 || epoch + 1 == cfg.epochs then
      let evalSess ← IreeSession.create
                       s!"{spec.buildPrefix}_fwd_eval.vmfb"
      let (valImg, valLbl, nVal) ← dio.loadVal dataDir
      let mut correct : Nat := 0
      for bi in [:nVal / batchN] do
        let xba := F32.sliceImages valImg (bi * batchN) batchN dio.valPixels
        let logits ← IreeSession.forwardF32 evalSess spec.evalFnName
                       p spec.evalShapesBA xba (spec.xShape batchN)
                       batchN.toUSize spec.numClasses.toUSize
        for i in [:batchN] do
          let pred  := F32.argmax10 logits (i * spec.numClasses).toUSize
          let label := (F32.sliceLabels valLbl (bi * batchN) batchN)
                         .data[i * 4]!.toNat
          if pred.toNat == label then correct := correct + 1
      let acc := correct.toFloat / nVal.toFloat * 100.0
      IO.eprintln s!"  val accuracy: {correct}/{nVal} = {acc}%"

  -- 6. Save trained parameters
  IO.FS.writeBinFile s!"{spec.buildPrefix}_params.bin" p
  IO.eprintln "Saved params."

Walking through the numbered sections:

1. Load training data. datasetIO ds returns a per-dataset I/O helper (MNIST, CIFAR-10, Imagenette) that knows how to mmap the on-disk binary files into F32Array buffers. trainImg holds the flattened pixel data; trainLbl holds integer class labels. Nothing ML-specific yet — this is just “put the data somewhere the GPU can reach.”

2. Initialize parameters and optimizer state. He initialization (spec.heInitParams) produces a random-but-scaled weight vector. Adam’s first and second moment buffers start at zero. Everything is a flat F32Array; reshape logic happens per-layer inside the compiled vmfb, not here. Three buffers — p, m, v — plus a running step counter.

3. Epoch loop. Shuffle the data once per epoch with a deterministic seed (epoch + 42) so runs are reproducible. Compute the learning rate: linear warmup for the first warmupEpochs, then cosine decay (if enabled) over the rest of training. If neither warmup nor cosine is on — the s4tfBaseline case we used in the first example — lr is just cfg.learningRate constant. Warmup handles the transformer-style fragile-early-gradient problem; cosine is the standard don’t-overfit-at-the-end schedule.

4. Batch loop — the core of training. For each mini-batch: slice the current epoch’s images and labels, concatenate (params, m, v) into one flat buffer, and call IreeSession.trainStepAdamF32. That one call does everything: forward pass through every layer, cross-entropy loss, backward pass (the VJP of every layer you’ve proved in earlier sections), Adam update of parameters and moment buffers. Whether that update is Adam or plain SGD + momentum is baked into the vmfb at codegen time by cfg.useAdam (step 1); the call name stays trainStepAdamF32 either way. All of it executes as one pre-compiled stablehlo vmfb on the GPU. The framework unpacks the updated (p, m, v) from the output buffer and loops.

The entire Lean \(\to \) MLIR \(\to \) IREE pipeline’s value proposition lives in this one line. There is no Python. There is no graph construction per-step. There is no autograd interpreter. The training step was compiled once, at startup, and every subsequent step is a single dispatched vmfb call. That’s why training at batch 128 runs at \(\sim \)20 ms per step on a consumer GPU (7900 XTX), \(\sim \)80 ms on a 4060 Ti.

5. Validation. Every 10 epochs (and always at the end) we swap in the forward-only eval vmfb, which uses BN’s running statistics (not the per-batch estimates training uses). Loop over the val set, forward-pass each batch, argmax the logits, compare to labels, count correct. Print accuracy. The eval vmfb is a separate compiled artifact because BN behaves differently at inference and we don’t want that branching inside the hot training loop.

6. Save. Write the parameter buffer to disk as a raw .bin blob so it can be loaded later for inference, fine-tuning, or cross-run comparison.

That’s the whole function. Seventy-odd lines of Lean doing what most Python deep-learning tutorials present as an elaborate black box. The reason it’s short is that everything heavyweight — forward pass, loss, backward pass, optimizer update — is one compiled vmfb executing on the GPU. Lean’s role here isn’t to compute gradients; it’s to specify what the training step is, prove that specification is mathematically correct (the theorems earlier in this chapter and in subsequent chapters), and invoke the IREE-compiled implementation. Three separate concerns, cleanly factored.

Every chapter after this one hands a different NetSpec to this same loop. Only the NetSpec changes; the loop is literally the same hundred lines each time. That’s the framework pitch in concrete form: once the loop is written (and proved correct one layer at a time), every architecture in the book and every architecture in the bestiary runs through it with no further infrastructure code. And it all runs on your hardware: Appendix A walks you through the toolchain setup, and if you’re impatient, there’s a Docker image that gets you from zero to training MNIST in one command.

2.3 How this book is organized

Part 1: The framework (Chapters 2–10). One new primitive per chapter, proved correct as we go. The first half pins down the composition rules — chain, additive fan-in, product — on dense layers, convolution, and normalization. The second half climbs the ladder to ViT.

  • Chapter 2: You Are Here — Chain rule, sum rule, product rule for partial derivatives; every backward in the book is a VJP (vector-Jacobian product) composed from these. The book’s hardest chapter — it builds the whole pipeline (foundation calculus \(\to \) VJP framework \(\to \) forward/loss/backward/SGD \(\to \) GPU), used by the MNIST linear classifier, proven and trained. Everything after is just adding a layer.

  • Chapter 3: MNIST: 1D MLP — ReLU and hidden layers on top of Chapter 2’s dense layer: the first multi-layer backward pass.

  • Chapter 4: MNIST: 2D CNN — Convolution, max pooling. The backward pass of conv is another conv with a reversed kernel.

  • Chapter 5: CIFAR with BatchNorm — The first dense Jacobian with a clean closed form. Three terms, one cancellation, and you never have to derive it again.

  • Chapter 6: ResNet-34 — Residual connections. Gradients from two paths add at the input.

  • Chapter 7: MobileNetV2 — Depthwise convolution. Same as regular conv, but diagonal in the channel dimension.

  • Chapter 8: EfficientNet — Squeeze-and-excitation. The product rule: main path times gate, gradients from both paths add.

  • Chapter 9: ConvNeXt — LayerNorm, Batchnorm in disguise. GELU activation function. Data augmentation strategies.

  • Chapter 10: Vision Transformer — Self-attention. Softmax in the middle of the network, three-way fan-in at the input. The capstone.

Part 2: The bestiary (Chapter 11). A catalog of \(\sim \)37 additional architectures, from U-Net to Mamba to diffusion models, each decomposed into the framework’s primitives. Organized by the question you’re asking: “how do I do detection?”, “how do I do generation?”, etc.

Appendices. Toolchain setup (A) and the complete proof framework as a reference (B).

Foundation: Mathlib’s fderiv

Earlier drafts of this book axiomatized the entire calculus foundation — chain rule, sum rule, product rule — as eight opaque facts. The current foundation flips that: \(\operatorname {pdiv}\) is defined in terms of Mathlib’s Fréchet derivative \(\operatorname {fderiv}\), and the structural rules above are theorems proved from Mathlib’s API. Each carries a \(\mathsf{Differentiable}\) hypothesis, threaded through every downstream chapter.

Every later-chapter axiom has been pruned the same way. Where earlier drafts axiomatized \(\texttt{conv2d}\) as an opaque function with a stated Jacobian, the current version defines it concretely and proves the weight, bias, and input gradient lemmas from the foundation. The same is true for multi-head SDPA, patch embedding, the BN inverse-stddev broadcast, GELU, the softmax Jacobian, and every transformer-level composition chain. \(\texttt{\# print axioms vit\_ full\_ has\_ vjp}\) returns only Lean core (propext, Classical.choice, Quot.sound); every other dependency is a Mathlib theorem or a derivation in this project. See Appendix B for the full elimination history.

The cost is mild: \(\mathsf{Differentiable}\) hypotheses propagate. The benefit is that every claim downstream of this chapter is either a definition Lean unfolds or a theorem typechecked against Mathlib — no thumb on the scale.

What about non-smooth points?

A reader who has met ReLU will have spotted a gap. ReLU is \(\max (0, x)\), which has no derivative at \(x=0\); absolute value flunks the same test for the same reason; maxPool over a window has no derivative when two entries tie. Mathlib’s \(\operatorname {fderiv}\) correctly declines to return a value at these points, but ML frameworks have to return something, so they all pick a convention. PyTorch and JAX both use \(0\) as ReLU’s subgradient at the kink and break maxPool ties by the lowest index. Those choices aren’t theorems — they’re conventions that make backprop computable everywhere.

We make those conventions formal. For a non-smooth \(f\), we don’t try to prove a closed-form \(\mathsf{HasVJP}\) from \(\operatorname {fderiv}\) (which would fail at the singular set). Instead we define the \(\mathsf{HasVJP}\) directly, with backward function \(\mathrm{backward}(x, dy)_i := \sum _j \operatorname {pdiv}\, f\, x\, i\, j \cdot dy_j\). The correct field then holds by rfl — by definition, that’s exactly what backward computes. Wherever \(f\) is smooth, \(\operatorname {pdiv}\) agrees with \(\operatorname {fderiv}\) (proved); at the non-smooth points, \(\operatorname {pdiv}\) takes the standard convention by definition. We call this the canonical pdiv-derived witness pattern, and it shows up by name in the next few chapters for ReLU, maxPool, and the depthwise input VJP.

The codegen then substitutes the same subgradient/argmax convention every framework already uses. The benefit isn’t a different numerical result — it’s a paper trail that says “yes, we know exactly which choice we made at each non-smooth point, and the typechecker has verified the algebra around it.”

Why VJPs, not Jacobians?

If you’ve taken multivariable calculus, the natural object to track through a network is the Jacobian. For \(f : \mathbb {R}^{n} \to \mathbb {R}^{m}\), that’s the \(m \times n\) matrix of partial derivatives \(J_{i,j} = \operatorname {pdiv}\, f\, x\, i\, j\). The chain rule for a composition \(h = g \circ f\) is matrix multiplication: \(J_h = J_g \cdot J_f\). So if your network is \(L\) layers stacked, the Jacobian of the output with respect to the input is the matrix product \(J_L \cdot J_{L-1} \cdots J_1\).

This is correct and totally impractical. For a 512-dimensional hidden layer, each \(J_k\) is a \(512 \times 512\) matrix; multiplying two of them is \(\sim 1.3 \times 10^8\) flops. ResNet-34 has 34 of these layers, ViT has more. You’d be materializing intermediate matrices that are bigger than the network itself, and most of whose entries you never need.

The fix is to notice what you actually want. Training minimizes a scalar loss \(\mathcal{L}\), so the only output gradient you ever need is \(\nabla \mathcal{L} \in \mathbb {R}^{m}\) — a single vector. Apply that vector to the chain on the left:

\[ (\nabla \mathcal{L})^T \cdot J_L \cdot J_{L-1} \cdots J_1. \]

Read this right-to-left. Start with one \(m\)-dimensional vector. Multiply on the right by \(J_L\), get a vector. Multiply by \(J_{L-1}\), get a vector. Continue. At every step you carry a vector — never a matrix — and its dimension matches the layer’s input or output shape.

The single step

\[ v \mapsto v^T \cdot J_f \]

is the vector–Jacobian product, or VJP. It takes the gradient flowing in from above and produces the gradient flowing out below. The cost is \(O(mn)\) per layer instead of \(O(d^3)\) for a matrix–matrix product, and you never form \(J_f\) at all — you just need a function that, given \(v\), returns \(v^T J_f\). That function is the backward function; together with \(f\) itself, it implements one layer’s contribution to backpropagation.

This is what reverse-mode automatic differentiation does, mechanically, for any program. The same right-to-left vector-pass story — forward to compute activations, backward to compute gradients, both linear in the number of layers.

What \(\mathsf{HasVJP}\) packages. A \(\mathsf{HasVJP}\) record for \(f\) holds two things:

  • a backward function \(\mathbb {R}^{m} \to \mathbb {R}^{n}\);

  • a correct field — a proof that \(\mathrm{backward}(v)_j = \sum _i v_i \cdot \operatorname {pdiv}\, f\, x\, i\, j\), i.e. that the function really does compute \(v^T J_f\).

In Lean that’s exactly three lines:

structure HasVJP (f : Vec m → Vec n) where
  backward : Vec m → Vec n → Vec m
  correct  : ∀ x dy i, backward x dy i = ∑ j, pdiv f x i j * dy j

A structure with two fields: a function, and a proof obligation about that function. Constructing a \(\mathsf{HasVJP}\) for some \(f\) means writing both — you don’t get the record without the proof.

The correctness field is the whole point: the backward function is the kernel that ships into the codegen, and the proof is the guarantee that what it computes matches what its name claims. Every later chapter proves a closed-form \(\mathsf{HasVJP}\) for a layer family (dense, conv, BN, attention, \(\ldots \)). Unlike autograd, which records each forward op on a tape and replays its derivative at backward time, we ship the closed form: codegen reads off the backward field and emits a single fused StableHLO kernel — no tape, no replay.

Theorem and definition budget per chapter

If you’re planning how much of the book to read, here’s what each chapter proves about gradients: its VJP correctness theorems — the \(\texttt{*\_ has\_ vjp*\_ correct}\) results, each asserting \(\texttt{backward}\, x\, dy\, i = \sum _j \operatorname {pdiv}f\, x\, i\, j \cdot dy_j\) (a layer’s backward equals the \(\operatorname {pdiv}\)-Jacobian contracted against the cotangent). Across Chapters 2–10 there are 34 VJP correctness theorems, one citable contract per layer, operator, and whole-network architecture. The foundation calculus rules, differentiability lemmas, and concrete forward/witness definitions that support them are additional machinery, not counted here. Even Chapter 2’s toolkit is used straight away by a proven model — the MNIST linear classifier — so no chapter is pure scaffolding.

Chapter 2 (You Are Here) does the structural heavy lifting: 11 theorems that compose Mathlib’s \(\operatorname {fderiv}\) into the VJP toolkit every later chapter will reuse — the chain rule, additive and multiplicative fan-in, identity, reindexing, finite-sum. Plus \(\operatorname {pdiv}\) itself as a definition over \(\operatorname {fderiv}\). The matrix-level extensions (matmul, transpose, scalar-scale, row-wise lifting, 3D chain rule) live at the start of Chapter 10, where attention first uses them. The chapter then puts the kit to work in its first proven model: the 4 VJP contracts of the MNIST linear classifier — \(\texttt{mnistLinear\_ has\_ vjp\_ correct}\) (the whole model, a single dense layer), the dense weight- and bias-gradient theorems, and the softmax cross-entropy gradient. It is the degenerate-simplest whole-model VJP, and the ops it needs are introduced right here, where the demo lives.

Chapter 3 (MNIST: 1D) adds depth and nonlinearity on top of Chapter 2’s dense layer — 4 VJP contracts. The guarded ReLU Jacobian \(\operatorname {pdiv}\_ \texttt{relu}\) is proved via local-diagonal-CLM transport; \(\texttt{relu\_ has\_ vjp}\) and \(\texttt{mlp\_ has\_ vjp}\) are noncomputable defs over the canonical pdiv-derived witness (HasVJP.correct by rfl), with the smooth-point \(\texttt{\_ at}\) variants carrying the real chain-rule proof.

Chapter 4 (MNIST: 2D) contributes 3 VJP contracts. \(\texttt{conv2d}\) and \(\texttt{maxPool2}\) are concrete forward defs; the conv2d input-side VJP and maxPool2’s VJP — plus its smooth-point \(\texttt{\_ at}\) variant, whose correct field is a real proof rather than rfl — are proved from foundation, with the conv weight- and bias-gradient theorems alongside.

Chapter 5 (CIFAR: BN) contributes 1 VJP contract (\(\texttt{bn\_ input\_ grad}\)), assembled from a chain of supporting Jacobians: the inverse-stddev broadcast Jacobian and its smoothness lemma under \(\varepsilon {\gt} 0\), plus the BN affine and centering Jacobians, all proved from foundation.

Chapter 6 (ResNet-34) contributes 6 VJP contracts and no new calculus: the residual additive-fan-in VJP (plus its projected and smooth-point _at variants), the global-average-pool VJP, and — new in this revision — \(\texttt{cnn\_ has\_ vjp\_ at}\), the whole-network fold composing stem, residual blocks, pool, and dense head end to end. The scaling claim in miniature: an entire architecture’s backward, built from Chapter 2’s kit.

Chapter 7 (MobileNetV2) contributes 2 VJP contracts: the depthwise input VJP (same recipe as Ch 4 with one fewer summation level — no cross-channel mixing) and the whole-network \(\texttt{mobilenetv2\_ has\_ vjp\_ at}\) over inverted-residual blocks, which introduces the two-sided-kink relu6 as a new smooth-point activation.

Chapter 8 (EfficientNet) contributes 4 VJP contracts: swish/SiLU (the MBConv activation), the squeeze-and-excitation block (\(\texttt{seBlock\_ has\_ vjp}\) — global-avg-pool plus two 1\(\times \)1 convs plus the elementwise product, every piece already proved elsewhere), the smooth sigmoid gate activation, and the whole-network \(\texttt{efficientnet\_ has\_ vjp\_ at}\) over MBConv blocks with a full-spatial SE gate.

Chapter 9 (ConvNeXt) contributes 4 VJP contracts: GELU (Jacobian via \(\operatorname {fderiv}\) and Real.differentiable_tanh), LayerNorm (reusing the BN proof template), the elementwise layer-scale, and the whole-network \(\texttt{convnext\_ has\_ vjp\_ at}\) — the all-smooth architecture, discharged with no kink hypotheses, only \(\varepsilon {\gt} 0\) positivity. (Swish/SiLU lives in Chapter 8, where EfficientNet uses it; ConvNeXt’s activation is GELU.)

Chapter 10 (ViT) contributes 6 VJP contracts (the three SDPA cotangents, MHSA, the transformer block, and — new in this revision — the whole-network \(\texttt{vit\_ full\_ has\_ vjp\_ correct}\) wrapper; the softmax cross-entropy gradient now lives in Chapter 2 with the linear classifier), resting on the largest supporting-theorem budget in the book and the only chapter that needs the matrix-level extensions of Chapter 2’s foundation rules. The chapter opens with 16 matrix-machinery theorems (matmul Jacobians and VJPs in both factor positions, transpose, scalar-scale, row-wise lifting, matrix and 3D chain rules) before any attention proof appears. Then the 20 attention theorems: Phase 3 proved multi-head SDPA via column-stacking (lifting the single-head SDPA backward over the head axis using \(\operatorname {pdivMat}\_ \texttt{colIndep}\) + \(\texttt{colSlabwise\_ has\_ vjp\_ mat}\)). Phase 6 closed patch embedding by de-opaquing the forward (Phase 6a) and proving the input VJP via spatial rearrangement (Phase 6b). Softmax Jacobian, row-wise softmax smoothness, and seven transformer-level chains (sublayers, block, tower, ViT body) all factor through the matrix machinery this chapter just assembled.

Chapter

VJP contracts

New primitive / result

2 (Foundation)

4

mnistLinear (linear classifier), dense weight/bias-grad, softmax-CE

3 (MNIST: 1D)

4

ReLU, MLP (\(+\) smooth-point _at)

4 (MNIST: 2D)

3

Conv, MaxPool (\(+\) _at)

5 (CIFAR: BN)

1

BatchNorm input-gradient

6 (ResNet-34)

6

Residual (\(+\)proj, \(+\)_at), global-avg-pool, whole-net cnn

7 (MobileNetV2)

2

Depthwise, whole-net mobilenetv2

8 (EfficientNet)

4

swish, SE (mult. fan-in), sigmoid, whole-net efficientnet

9 (ConvNeXt)

4

GELU, LayerNorm, layer-scale, whole-net convnext

10 (ViT)

6

matrix kit, Attention (SDPA/MHSA/block), whole-net vit_full

 

34

 

The shift from earlier drafts: every architecture now ends in a single whole-network VJP contract — \(\texttt{cnn\_ has\_ vjp\_ at\_ correct}\), \(\texttt{mobilenetv2\_ has\_ vjp\_ at\_ correct}\), \(\texttt{convnext\_ has\_ vjp\_ at\_ correct}\), \(\texttt{efficientnet\_ has\_ vjp\_ at\_ correct}\), \(\texttt{vit\_ full\_ has\_ vjp\_ correct}\) — the entire forward pass’s backward, machine-checked end to end, not just the per-layer pieces. (The underlying forward and witness defs, the foundation \(\operatorname {pdiv}\) rules, and the differentiability lemmas are the supporting machinery these 33 contracts rest on; they are not counted in the table.)

One observation worth pulling out: every chapter is pure composition. Every primitive in every chapter either inherits from Mathlib’s \(\operatorname {fderiv}\) or composes over previously-proved theorems. The two non-smooth shortcuts that earlier drafts axiomatized (\(\texttt{relu\_ has\_ vjp}\) and \(\texttt{maxPool2\_ has\_ vjp3}\)) are now noncomputable defs over the canonical pdiv-derived witness; the codegen substitutes the standard subgradient/argmax convention at non-smooth points. The trust kernel for the entire backward pass is just Lean core, Mathlib, and the codegen contract.

!
\begin{tikzpicture} [
  >={Stealth[length=2.5mm]},
  every node/.style={font=\sffamily\footnotesize},
  group/.style={
    draw, rounded corners=2pt, fill=blue!6,
    align=center, inner sep=4pt, minimum width=4.0cm, minimum height=1cm
  },
  primitive/.style={
    draw=orange!60!black, rounded corners=2pt, fill=orange!10,
    align=center, inner sep=4pt, minimum width=4.0cm, minimum height=0.95cm
  },
  composed/.style={
    draw=purple!60!black, rounded corners=2pt, fill=purple!8,
    align=center, inner sep=4pt, minimum width=4.0cm, minimum height=0.85cm
  },
  final/.style={
    draw=green!50!black, rounded corners=2pt, fill=green!12,
    align=center, inner sep=4pt, minimum width=4.0cm, minimum height=0.85cm,
    very thick
  },
  arr/.style={->, thick, gray!65, shorten >=1pt},
  shared/.style={->, thick, gray!50, dashed, shorten >=1pt}
]

% Top row: Ch 2
\node[group] (found) at (-1.5, 0) {
  \textbf{Ch~2: foundation calculus}\\
  \texttt{pdiv\_comp/add/mul/id} \\
  \texttt{pdiv\_finset\_sum}
};
\node[group] (matkit) at (8.0, 0) {
  \textbf{Ch~10: matrix kit} \\
  \texttt{pdivMat\_comp} \\
  \texttt{matmul\_left\_const} \\
  \texttt{rowwise\_has\_vjp\_mat}
};

% Left pillar: ResNet-34
\node[primitive] (mlp) at (-9.0, -2.3) {
  \textbf{Ch~3: MLP} \\
  \texttt{dense\_weight\_grad\_correct} \\
  \texttt{relu\_has\_vjp\_correct}
};
\node[primitive] (cnn) at (-9.0, -3.7) {
  \textbf{Ch~4: CNN} \\
  \texttt{conv2d\_has\_vjp3} \\
  \texttt{maxPool2\_has\_vjp3\_correct}
};
\node[primitive] (bn)  at (-9.0, -5.1) {
  \textbf{Ch~5: BatchNorm} \\
  \texttt{pdiv\_bnNormalize} \\
  \texttt{pdiv\_bnAffine}
};
\node[primitive] (res) at (-9.0, -6.5) {
  \textbf{Ch~6: residual skip} \\
  \texttt{residual\_has\_vjp\_correct}
};
\node[final] (r34) at (-9.0, -8.0) {
  \textbf{Full ResNet-34} \\
  (NetSpec composition)
};

% Middle pillar: EfficientNet (shifted south to clear shared-arrow lines)
\node[primitive] (dw) at (-2.5, -4.8) {
  \textbf{Ch~7: depthwise conv} \\
  \texttt{depthwise\_has\_vjp3\_correct}
};
\node[primitive] (se) at (-2.5, -6.5) {
  \textbf{Ch~8: SE block} \\
  \texttt{seBlock\_has\_vjp\_correct}
};
\node[final] (enet) at (-2.5, -9.4) {
  \textbf{Full EfficientNet-B0} \\
  (NetSpec composition)
};

% Right pillar: ViT
\node[primitive] (lngelu) at (3.5, -2.3) {
  \textbf{Ch~9: LayerNorm \& GELU} \\
  \texttt{layerNorm\_has\_vjp\_correct} \\
  \texttt{gelu\_has\_vjp\_correct}
};
\node[primitive] (attn) at (8.0, -2.3) {
  \textbf{Ch~10: attention primitives} \\
  \texttt{pdiv\_softmax} \\
  \texttt{sdpa\_back\_\{Q,K,V\}\_correct}
};
\node[composed] (mhsa) at (8.0, -4.5) {
  \textbf{Ch~10: multi-head attention} \\
  \texttt{mhsa\_has\_vjp\_mat\_correct}
};
\node[composed] (block) at (5.75, -6.4) {
  \textbf{Ch~10: transformer block} \\
  \texttt{transformerBlock\_has\_vjp\_mat\_correct}
};
\node[final] (body) at (5.75, -8.0) {
  \textbf{Ch~10: ViT body} \\
  \texttt{vit\_body\_has\_vjp\_mat}
};
\node[final] (vit) at (5.75, -9.4) {
  \textbf{Full ViT} \\
  \texttt{vit\_full\_has\_vjp}
};

% Foundation -> R34 primitives. Arrow to Ch 4 routes via cnn.east so it
% sweeps below Ch 3 instead of through it.
\draw[arr] (found.south west) to[out=-170, in=85] (mlp.north east);
\draw[arr] (found.south west) to[out=-150, in=20] (cnn.east);
\draw[arr] (found.south west) to[out=-130, in=20] (bn.east);
\draw[arr] (found.south west) to[out=-115, in=20] (res.east);
% Foundation -> ENet primitives
\draw[arr] (found.south) to[out=-100, in=70] (dw.north east);
\draw[arr] (found.south) to[out=-80, in=60] (se.north east);
% Foundation -> ViT primitives
\draw[arr] (found.south east) to[out=-30, in=180] (lngelu.west);
\draw[arr] (found.south east) to[out=-20, in=170] (attn.north west);
% Matrix kit -> ViT side
\draw[arr] (matkit.south) to[out=-100, in=40] (attn.north);
\draw[arr] (matkit.south east) to[out=-60, in=20] (mhsa.east);

% R34 down-flow
\draw[arr] (mlp) -- (cnn);
\draw[arr] (cnn) -- (bn);
\draw[arr] (bn)  -- (res);
\draw[arr] (res) -- (r34);

% ENet path
\draw[arr] (dw) -- (se);
\draw[arr] (se) -- (enet);
\draw[shared] (cnn.east) to[out=-10, in=130] (enet.north west);
\draw[shared] (bn.east)  to[out=-10, in=140] (enet.west);

% ViT down-flow
\draw[arr] (attn) -- (mhsa);
\draw[arr] (mhsa.south) to[out=-100, in=20] (block.north east);
\draw[arr] (lngelu.south) to[out=-90, in=160] (block.north west);
\draw[arr] (block) -- (body);
\draw[arr] (body)  -- (vit);
\end{tikzpicture}
Figure 2.1 Three architecture spines, all the way down to Ch 2 foundation. ResNet-34 (left) needs only the foundation calculus and Chs 3–6 primitives. EfficientNet-B0 (middle) adds Ch 7’s depthwise conv and Ch 8’s SE block; dashed arrows mark Chs 4 (CNN) and 5 (BN), shared with R34. ViT (right) is the only path that needs the matrix kit; its longer chain runs through MHSA and the transformer block before bundling into the full ViT VJP. Every node is a Lean theorem.

Roadmap: skip to your target architecture

Every theorem and definition in the book carries a \uses{} annotation, so the dependency graph from “which target architecture do I care about?” back to which foundation pieces you need is explicit — we get it for free from doing the proofs in Lean. Chapter 2 itself is short enough that the right answer for every target is “just read all of it” (11 theorems + 1 definition, mostly chain rule and fan-in variants). The interesting subsetting decision is which later chapters to read, and whether you need Chapter 10’s matrix-machinery section.

Target: MLP (Ch 3). Ch 2 + Ch 3. The MLP backward pass uses chain rule, additive fan-in, identity, and the Dense Jacobians (which are themselves built from this chapter’s finite-sum + const + reindex).

Target: CNN, with or without BN (Ch 45). Ch 2 + Ch 35. CNN adds Conv2d and MaxPool VJPs proved from this chapter’s foundation; BN adds the inverse-stddev smoothness chain.

Target: ResNet-34 (Ch 6). Ch 2–6. Adds one theorem on top of CNN+BN: the residual additive-fan-in VJP. This is a great stopping point. ResNet-34 was state of the art for years and still routinely gets 90% on Imagenette; its bigger brother ResNet-50 is the standard benchmark (🐐) for production image classification on consumer hardware. If your destination is a real image classifier that works, Ch 2–6 is a complete reading list; Chs 710 are strictly skippable.

Target: MobileNetV2 (Ch 7). ResNet subset \(+\) Ch 7. Adds depthwise conv VJPs (same recipe as Ch 4 with one fewer summation level).

Target: EfficientNet (Ch 8). MobileNet subset \(+\) Ch 8. Adds the SE block VJP (squeeze-and-excitation as elementwise product over global pool + 1\(\times \)1 convs).

Target: ConvNeXt (Ch 9). EfficientNet subset \(+\) Ch 9. Adds GELU and LayerNorm as new operator primitives; both proved from Ch 2 foundation rules.

Target: Attention / ViT (Ch 10). Everything above plus Ch 10. This is the only chapter that needs matrix-level machinery (matmul Jacobians and VJPs in both factor positions, transpose, scalar-scale, row-wise lifting, 3D extensions); they’re gathered in §10.1 at the start of Ch 10, just before the attention proofs that use them. Every earlier chapter routes around the matrix kit entirely.

If you’re unsure: read Ch 2 carefully (it’s short), pick an architecture target that fits your interest, follow the chapter chain to it. You can always extend the read list later as new architectures catch your eye.

For readers of the first book

If you’re coming to this book from Convolutional Neural Networks with Swift for TensorFlow, you already know most of the architectures and the forward-pass intuition: VGG, ResNet, MobileNet, and EfficientNet are all familiar, and you’ve built and trained them at least once.

Chapters 9 and 10 will be new material even on the forward-pass side. I shipped the final draft of the first book in July 2020; in late August the grapevine started whispering that something attention-shaped was coming, and the ViT paper landed on arXiv that October. Everything in this book that looks like a transformer — LayerNorm, GELU, multi-head attention — postdates the first book and will be genuinely new for you. Otherwise, this book is the same architectures brought under the additional scrutiny of a formal backward pass.

Concretely, that means:

  • The architecture introductions in chapters 3-8 will feel redundant. The primary sequence is a trimmed version of the first-book progression, kept so the chapters stay consistent. Head straight for the *_has_vjp theorems and the \uses{} annotations that pin down the backward pass.

  • Chapter 2 is the new material for you. The tensor-calculus chapter is what the first book didn’t have — where the gradients get formalized. If you only read one chapter thoroughly, make it this one.

  • Training recipe upgraded. The first book’s codegen trained with SGD (momentum 0.9, learning rate 0.002) and no other tricks — one recipe held constant across architectures for consistency. This book adopts the modern recipe — Adam, cosine decay, linear warmup, weight decay, label smoothing, and basic data augmentation — which is how the accuracy numbers in this book’s results table were achieved. The full recipe and each component’s contribution are detailed in Chapter 6 (ResNet), where the ablation comparisons live.

  • The codegen is new. The first book’s Swift-for-TensorFlow pipeline is replaced with Lean 4 \(\to \) StableHLO MLIR \(\to \) IREE \(\to \) GPU. Mostly doesn’t affect the math; it changes where the gradients are computed (at codegen time in Lean, not at runtime by a framework).

Practical reading style: work through Chapter 2 slowly, then skim the rest at the pace of “forward pass I know \(\to \) new *_has_vjp theorem \(\to \) see what theorems it cites \(\to \) move on.” The dependency DAG is your friend.

I have endeavored to read as many of your papers as possible. Part 2 is deliberately patterned on your practice.