Verified Deep Learning with Lean 4

B On Verification

The proofs in this book compile with zero sorrys. Every VJP correctness theorem — dense layers, convolution, batch normalization, residual connections, depthwise convolution, squeeze-and-excitation, layer normalization, and self-attention — is machine-checked by Lean’s type system. If it builds, it’s correct.

But “it builds” is only the first of three kinds of certainty this appendix relies on, and they are not the same kind. Proven (deductive): in Lean, against three axioms, each operator’s gradient is its exact reverse-mode derivative — and an independent re-checker confirms no project axiom slipped into a proof, so “proven” means proven-clean. By construction (structural): the StableHLO the GPU runs is not written but printed from the same datatype the proofs reason about, with a bridge theorem pinning it to the proven formula, so the code cannot drift from the math — and this now covers whole training steps (forward, backward, update), for ViT-Tiny and ConvNeXt-T at the production trainers’ exact signatures. Cross-checked (empirical): because IREE cannot itself be proven, and float32 only conditionally — its rounding now carries a theorem layer whose interface constants are empirically anchored (§ B.7) — three independent oracles — a finite-difference probe, a parallel JAX pipeline, and a CPU/GPU/NumPy comparator — must agree on the gradient before the hardware is believed. Each kind catches what the others structurally cannot; the five mechanisms this appendix walks through are these three kinds of certainty, with the empirical one split among its witnesses.

B.1 Trust kernel

The Verified VJP Proofs suite proves 48 VJP correctness theorems — one per layer, operator, and whole-network architecture, each asserting backward \(= \sum _j \operatorname {pdiv}f\, x\, i\, j \cdot dy_j\), now including the full-architecture closes of all five Part I networks — on top of the foundation \(\operatorname {pdiv}\) calculus and the differentiability/forward/witness machinery they rest on, plus 22 architecture definitions in the bestiary, across 54 Lean proof files, with zero project axioms. The axiom audit (tests/AuditAxioms.lean) re-checks 421 theorems in all — the VJP contracts plus the forward-graph faithfulness, render, and cotangent-chain results — every one closing under the same three core axioms. The most recent additions are the \(\mathbb {R}\to \texttt{float32}\) bridge and the descent results (§ B.7): 38 theorems — rounding budgets, the inexact-descent lemma, the linear loss’s Lipschitz constants — all hypothesis-style, so the audit still closes under the same three axioms with zero project axioms.

What landed in the final rounds to retire the last of the project axioms earlier drafts had shipped:

  • Phase 1–2 (input-side conv VJPs). The \(\texttt{conv2d}\) and \(\texttt{depthwiseConv2d}\) input VJPs are now theorems proved via pdiv_finset_sum \(\times \) 3 + pdiv_const_mul_pi_pad_eval per summand + \(\sum _{c, k_h, k_w}\) collapse.

  • Phase 3 (multi-head SDPA). A new \(\operatorname {pdivMat}\_ \texttt{colIndep}\) theorem and \(\texttt{colSlabwise\_ has\_ vjp\_ mat}\) framework lift the proved single-head SDPA backward over the head axis. Both \(\texttt{mhsa\_ has\_ vjp\_ mat}\) and its \(\mathsf{Differentiable}\) sibling are now theorems.

  • Phase 6a/6b (patch embedding). De-opaqued the \(\texttt{patchEmbed\_ flat}\) forward into a concrete def (conv-with-stride + CLS prepend + positional embed), then proved its input VJP via spatial-rearrangement bridge to \(\texttt{conv2d\_ has\_ vjp3}\).

  • Phase 7 (the final four). \(\operatorname {pdiv}\_ \texttt{relu}\) proved via local-diagonal-CLM transport: at a smooth point ReLU agrees with the diagonal indicator CLM on a neighborhood (every coordinate keeps its sign), and HasFDerivAt.congr_of_eventuallyEq transports the CLM’s self-fderiv to ReLU. The remaining three (\(\texttt{relu\_ has\_ vjp}\), \(\texttt{mlp\_ has\_ vjp}\), and \(\texttt{maxPool2\_ has\_ vjp3}\)) became noncomputable defs over the canonical pdiv-derived witness; HasVJP.correct holds by rfl since \(\operatorname {pdiv}\) is a def over \(\operatorname {fderiv}\). At non-smooth points the canonical backward is \(\operatorname {fderiv}\)’s junk default of \(0\); the codegen substitutes the standard subgradient /argmax convention — see “Codegen trust boundary” in LeanMlir/Proofs/README.md.

Pure-Mathlib closure on every theorem. \(\texttt{\# print axioms vit\_ full\_ has\_ vjp}\) shows only Lean core (propext, Classical.choice, Quot.sound); nothing project-level beneath those.

B.2 Finite-difference gradient checks

The script LeanMlir/Proofs/check_jacobians.py runs 25 finite-difference gradient checks. For each, it perturbs the input by \(\varepsilon \), compares the claimed VJP against the centered difference \((f(x + \varepsilon ) - f(x - \varepsilon )) / 2\varepsilon \), and asserts agreement within tolerance (typical max-error \(\sim 10^{-11}\) at \(\varepsilon = 10^{-5}\) in float64).

Every FD check is a belt-and-suspenders pass over a proved Jacobian theorem. The proofs already establish that each formula equals \(\operatorname {fderiv}\) at the relevant points; the FD checks confirm the formulas-as-written agree numerically with what the function actually does. Coverage spans every closed-form Jacobian we use downstream: \(\operatorname {pdiv}\_ \texttt{dense}\) and its weight/bias companions, \(\operatorname {pdiv}\_ \texttt{relu}\) (at smooth points), softmax cross-entropy, all four BN pieces (bnNormalize, bnCentered, bnIstdBroadcast, bnAffine), the conv2d and depthwise input/weight/bias VJPs, the maxPool2 input VJP, the softmax Jacobian, the three single-head SDPA Q/K/V backwards, GELU, the bundled multi-head SDPA reduction (per-head sdpa_back stacked over the head axis), the patch-embed input VJP, and the full-network MLP composition.

What the FD pass catches that the symbolic proof can’t: typos between the proof and the prose that uses the same Jacobian. If a chapter narrates one formula but the Lean theorem states another, the proof still type-checks (Lean only verifies its own statement), but the FD test runs against the formula the prose published and would diverge.

FD is cheap, easy to reason about, and tight enough for spot-checking formulas, but it can’t probe what the compiled code actually computes on the GPU — only what the formula says in Python — and it struggles at non-smooth points where the limit definition breaks down. For those gaps and for end-to-end pipeline verification, the third layer takes over.

B.3 The JAX parallel pipeline

A separate Lean \(\to \) JAX \(\to \) XLA pipeline (jax/Jax/Codegen.lean, \(\sim 1100\) lines) produces an idiomatic JAX training script from the same NetSpec the primary Lean \(\to \) StableHLO MLIR \(\to \) IREE pipeline consumes. XLA is the compiler JAX uses to produce GPU code (the same backend that powers TensorFlow and other frameworks); IREE is its Lean-side counterpart. The two stacks are independent end to end — different codegen, different runtime, different kernels — which is why agreement between them is a meaningful cross-check. Running both from identical initial parameters on identical batches gives us a pair of trace files, one per stack, that should agree modulo float32 rounding if both pipelines compute the same math.

The agreement is very tight. For the MNIST MLP, step-1 losses agree to \(\sim 2 \times 10^{-7}\) — float32 ULP — across the JAX-CPU-vs-IREE-ROCm comparison, and phase-3 IREE output is bit-identical across AMD and NVIDIA hardware at step 1. For the MNIST CNN with batch norm, step-1 agrees to \(\sim 10^{-4}\), looser because variance reductions over \(\sim 100\)k-element tensors amplify cross-compiler reduction-tree differences — both pipelines do correct math; they just sum it in different orders. Full results are committed to the repo as reproducible JSON-Lines traces; see traces/CROSS_BACKEND_RESULTS.md.

Layered on top of the end-to-end trace diff is a per-axiom differential test in tests/vjp_oracle/, which compares each Lean-proved backward pass against JAX’s value_and_grad autodiff on a minimal one-step training run. Nine cases — dense, dense+ReLU, conv, conv+BN, conv+maxPool, residual, depthwise, SE, attention — each agree with JAX autodiff at 1–2 ULP of step-2 loss. Any future hand-derived VJP added to the Lean proof base can be validated by a one-step comparison against JAX, catching algebraic errors that FD would miss (sign flips, wrong contraction axis, swapped indices).

B.4 Independent kernel re-check (comparator)

tests/comparator/ wires the project up to leanprover/comparator, the Lean community’s trustworthy-judge tool for projects that claim zero project axioms. It runs 49 theorems — the foundation rules, every chapter’s headline Jacobian, the public *_has_vjp_correct wrappers, and three smooth-point pointwise variants (relu_has_vjp_at_correct, mlp_has_vjp_at_correct, maxPool2_has_vjp_at3_correct) whose underlying .correct field is a real proof rather than rfl — through Lean’s kernel typechecker independently of the elaborator, with an axiom-allowlist of exactly \(\{ \texttt{propext}, \texttt{Quot.sound}, \texttt{Classical.choice}\} \). Any project axiom in the transitive closure of any verified theorem would fail the run.

This is what closes the gap between “the elaborator accepted my proofs” (which lake build confirms) and “the kernel agrees, audited from a separate process” (which comparator confirms). The 49-theorem coverage is illustrative rather than exhaustive — the same recipe scales to any subset of the proof suite, and the same allowlist applies because every theorem in the project closes the same way.

B.5 Verified code generation

The four layers above answer whether the mathematics is right — the hand-derived VJPs, their proofs, and the elaborator that accepted them. They say nothing about a different gap. The GPU never runs the Lean functions; it runs a block of emitted stablehlo that IREE compiles once per training program. A code generator can be handed a correct proof and still emit an operation that contracts the wrong axis, transposes backwards, or silently drops a term — and the only thing traditionally linking the proof to the emitted string is a code comment. The “MLIR: operator” section in each chapter closes that gap for one operator; this section describes the machinery they share.

Three pieces. The link from a proof to the GPU is built from three components:

  • Denoted IR and a bridge theorem. The emitted backward (and forward) is not a string but a Lean datatype — a small abstract syntax tree (Back, Fwd, Back3) — carrying a denotation \([\! [\cdot ]\! ]\) valued in the proofs’ own Vec and Tensor3 types. A bridge theorem then proves that this denotation equals the proven derivative: \([\! [\, \text{emitted graph}\, ]\! ] = (\text{proven VJP}).\mathrm{backward}\). Examples are conv_back_bridge, bn_back_bridge, and relu_back_bridge; each closes under the same three axioms as the rest of the project.

  • A computable printer. A small printer walks that same IR and emits one stablehlo operation per node — dotGeneral \(W\) becomes stablehlo.dot_general, a selectPos node becomes compare GT 0 plus select, and so on. The emitted text is the printout of the IR, by construction.

  • An execution oracle. A Python harness regenerates the .mlir from the printer, compiles it with IREE for both the llvm-cpu and rocm backends, runs it, and diffs the result against an independent NumPy reference. The CPU run is the correctness gate; the GPU run (ROCm on a Radeon RX 7900 XTX, gfx1100) confirms the proof-backed graph also executes on real hardware, matching the reference to roughly \(10^{-6}\).

Proven versus trusted. The resulting claim is exact and bounded:

  • Proven (Lean, three axioms): the IR’s denotation equals the proven derivative.

  • By construction: the emitted StableHLO is the rendering of exactly that IR.

  • Trusted (validated numerically, not proven): that the StableHLO text faithfully denotes \([\! [\cdot ]\! ]\) — a property that would need a formal semantics of StableHLO text, which does not yet exist; that IREE lowers the text correctly; and that float32 approximates \(\mathbb {R}\).

So the gradient the GPU computes is, by a machine-checked theorem, the network’s exact reverse-mode derivative over \(\mathbb {R}\) — up to one printer, IREE, and floating point. Where a conventional generator leaves thousands of lines of string-building linked to the proof by nothing at all, the unproven surface here is a single printer, tested end to end.

Why it is tractable. The reason a deep ResNet or a transformer block comes under this scheme as a focused engineering build, rather than a research project, is the order things were proved in. The VJP library is per-operator and generic — proved once, over abstract dimensions, before any code generation existed — and the whole-network VJPs compose those per-operator lemmas through the chain rule. Code generation therefore carries no new proof obligation: every architecture’s operators were proven once, and the emitter reuses them, adding only the printer and the numerical check. Build the mathematical foundation per-operator and generic, and the code generation becomes mechanical.

The one conditional. For the smooth operators — BatchNorm and LayerNorm (given \(\varepsilon {\gt} 0\)), GELU, swish, sigmoid, softmax, and attention — the bridge is unconditional, holding at every input. For the kinked operators — ReLU, ReLU6, and max-pool — it holds only at a smooth point, where no pre-activation sits exactly on the kink (zero for ReLU, \(0\) or \(6\) for ReLU6, an argmax tie for max-pool); the equality is permitted to fail precisely on that measure-zero set, and nowhere else. That set is the one irreducible boundary the smooth-point caveats in the chapter sections refer to.

B.6 Inside a bridge theorem

The “denoted IR and a bridge theorem” above is worth seeing concretely, because it is the step that does the real work — the place where a string of emitted code becomes a proposition Lean can check. Take the backward pass. It is represented not as text but as a value of an inductive type, a small abstract syntax tree whose constructors are exactly the StableHLO operations a backward uses:

inductive Back (inp : Nat) : Nat -> Type where
  | cotangent  : Back inp inp            -- the input dy
  | dotGeneral (A : Mat m n) : Back inp n -> Back inp m   -- matmul
  | selectPos  (x : Vec n)   : Back inp n -> Back inp n   -- ReLU mask
  -- plus scale, sub, add, sumBroadcast, scaleConst (BN, residuals)

A Back value is a closed description of one backward graph. The dense backward, for instance, is just dotGeneral W cotangent — feed the incoming cotangent into a single matrix multiply. Two functions are defined on this type, and everything rests on the gap between them being closed by a theorem.

The denotation. The first function, Back.denote, interprets a graph into the proofs’ own Vec type — it says what the graph means mathematically:

\[ [\! [\texttt{cotangent}]\! ]\, dy = dy, \qquad [\! [\texttt{dotGeneral}\, A\, e]\! ]\, dy = \texttt{Mat.mulVec}\, A\, ([\! [e]\! ]\, dy), \]
\[ [\! [\texttt{selectPos}\, x\, e]\! ]\, dy = \bigl(i \mapsto \text{if } x_i {\gt} 0 \text{ then } [\! [e]\! ]\, dy\, i \text{ else } 0\bigr), \]

and likewise for the remaining constructors — scale, sub, sumBroadcast, scaleConst, add — which are the pieces that assemble BatchNorm’s three-term backward and the residual fan-in. So a Back value denotes a concrete \(\texttt{Vec} \to \texttt{Vec}\) function, living in the same world as the VJP theorems.

The bridge, as an equation. A bridge theorem states that this denotation equals the proven derivative. For the dense layer:

\[ \texttt{dense\_ back\_ bridge}:\quad [\! [\texttt{emitDenseBack}\, W]\! ]\, dy = (\texttt{dense\_ has\_ vjp}\, W\, b).\mathrm{backward}\, x\, dy. \]

Its proof is one word, rfl: both sides reduce to the same term, Mat.mulVec \(W\, dy\), so the denotation of the emitted graph and the proven backward are definitionally identical. That base case pins the plumbing. The ReLU bridge is the first with real content:

\[ \texttt{relu\_ back\_ bridge}\ \ (h_{\text{smooth}} : \forall k,\ x_k \ne 0): \quad [\! [\texttt{emitReluBack}\, x]\! ]\, dy\, i = (\texttt{relu\_ has\_ vjp}\, n).\mathrm{backward}\, x\, dy\, i. \]

The proof unfolds the denotation to \(\text{if } x_i {\gt} 0 \text{ then } dy_i \text{ else } 0\) and shows it equals the canonical ReLU subgradient — but only under the hypothesis \(h_{\text{smooth}}\) that no coordinate sits on the kink. That hypothesis is not a technicality. It names exactly the measure-zero set where the emitted compare GT 0 disagrees with the true derivative, and the theorem is written to permit failure precisely there and nowhere else. The convolution bridge of Chapter 4 is the same shape with a harder proof: convBackDenote unfolds to a forward conv2d of the reversed-and-transposed kernel, discharged by expansion at the concrete tensor shape.

Composing the per-operator bridges. Whole-network backwards are assembled by substitution. Back.subst plugs one graph into another’s cotangent leaf, and a chain-rule lemma proves the denotation composes:

\[ \texttt{denote\_ subst}:\quad [\! [\, e[g/\texttt{cotangent}]\, ]\! ]\, dz = [\! [e]\! ]\, ([\! [g]\! ]\, dz), \]

proved by induction over the graph — the IR-level analogue of the vjp_comp that builds whole-network VJPs from per-layer ones. So the per-operator bridges compose into a whole-network bridge the same way the VJP theorems compose into a whole-network VJP: the MLP’s mlp_whole_bridge is this one substitution, chaining its five per-operator bridges.

Why the datatype is the point. The Back value is the pivot between two arrows. The printer walks it and emits one stablehlo operation per constructor (dotGeneral \(\mapsto \) stablehlo.dot_general; selectPos \(\mapsto \) compare GT 0 plus select) — that arrow is the trusted one, producing the text in each chapter’s listing. The denotation interprets the same value into the proofs’ Vec type, and the bridge proves that equals the derivative — that arrow is machine-checked. Because both arrows start from one concrete datatype rather than from a string, “the emitted code computes the proven gradient” is a theorem about a value, not a hope about a comment.

B.7 Float32, and whether it still trains

Every theorem in this book is over exact reals; the GPU computes in binary32. Until recently that gap lived entirely in the empirical layer — the oracles agree to 1–2 ULP, and you were asked to find that persuasive. For the Tier-1 nets it is now a theorem chain (LeanMlir/Proofs/FloatBridge.lean, SgdDescent.lean, SgdDescentLinear.lean), and the way it avoids new axioms is the part worth explaining.

The model is a hypothesis, not an axiom.

A FloatModel is any rounding operator \(\mathrm{rnd}\) with relative error \(u\): \(|\mathrm{rnd}(x) - x| \le u\, |x|\). Binary32 round-to-nearest satisfies this with \(u = 2^{-24}\) on the normal range (subnormals are open); the exact-arithmetic model (\(\mathrm{rnd} = \mathrm{id}\), \(u = 0\)) shows the interface is inhabited and collapses every budget to zero. Nothing about IEEE-754 is postulated — the theorems are conditional on the standard model, the same way the ReLU theorems are conditional on being off the kink, and the axiom audit is untouched. Two further design choices are forced by the hardware: the dot-product budgets are stated in the classical compounded form valid for every summation association, because IREE tiles and reorders reductions freely; and \(\exp \) enters as a hypothesis (\(|\widehat{\exp }(t) - e^t| \le e_{\exp }\, e^t\)) because GPU transcendentals have no IEEE specification — \(e_{\exp }\) is precisely the constant the VJP oracle (§ B.3) measures, so the deductive and empirical layers meet at a named interface instead of a hand-wave.

The chain.

Four links, each in the three-axiom audit. Forward: mlp_float_close_uniform budgets the rounded \(784{\to }512{\to }512{\to }10\) forward against the exact one, from coordinatewise magnitude bounds alone. Backward: mlp_{w2,w1,w0,b2,b1,b0}_step_float_close budget every rounded SGD parameter entry against \(\theta - \mathrm{lr}\cdot (a_i c_j)\) — entry for entry the emitWeightGrad quantities the render closes (§ B.5) prove equal to the \(\operatorname {pdiv}\)-Jacobian contractions, so the float step chains to the proven gradient. The loss head: softmax_ce_cot_close budgets the rounded softmax-minus-onehot cotangent against the certified \(\partial (\mathrm{crossEntropy})/\partial (\mathrm{logits})\). Descent: sgd_descends proves an \(\eta \)-accurate gradient step still decreases the loss, and for the linear classifier linear_sgd_descends discharges its smoothness hypothesis with the explicit constant \(2a^2/(1-2aD)\) — no Hessian: the softmax ratio sandwich that powers the float budgets turns out to be the Lipschitz engine too.

The kink, quantitatively.

Over \(\mathbb {R}\), ReLU forced the hypotheses \(x_k \ne 0\). In float the same op inverts its role twice. The forward mask is exact — compare-and-select rounds nothing, so the op that causes all the \(\mathbb {R}\)-side conditions is the free op here. But the backward mask reads the rounded pre-activation, so the hypotheses return with a number in them: \(\mathit{ez} {\lt} |z_i|\) — the accumulated rounding error must not flip a sign (reluMask_close). A qualitative side condition became a checkable margin.

Measured against proven.

The numeric capstones are instantiated at the trained magnitudes of a real 12-epoch, 97.8% GPU run (\(|W| \le 3/5\), covering the measured \(0.52\); He initialization already exceeds prettier bounds in its tails). An f32/f64 twin of that run (scripts/margin_probe.py, per-step coupled to match the single-step theorems) measures what the theorems bound:

quantity

worst-case theorem

measured

logit drift

\(\le 5100\)  (mnist_mlp_float_budget)

\(1.6\cdot 10^{-5}\)

cotangent

\(\le 21/1000\)  (mnist_cot_budget)

\(2.2\cdot 10^{-6}\)

\(W_2\) SGD step

\(\le 5/4\)  (mnist_w2_step_float_budget)

\(7.5\cdot 10^{-9}\)

ReLU mask flips

\(0\) under margins

\(\mathbf{0}\, /\, 29.5\mathrm{M}\)

Read the gap honestly, in both directions. The worst-case bounds hold with up to \(10^{8}\) to spare because worst-case composition compounds magnitude bounds the way no real activation pattern does — that gap is the quantitative argument for a-posteriori certificates past toy depth, now visible in a theorem rather than asserted. And the flip count is zero across 29.5 million measured pre-activations: the margin hypotheses are not a technicality the proofs hide behind; they are what training actually looks like.

B.8 The five layers, as three kinds of certainty

The five mechanisms are three kinds of certainty about a single claim — that the GPU computed the network’s exact gradient — with the empirical kind split among its witnesses. Each kind catches what the others structurally cannot.

Proven (deductive): the math is right.

  • Formal proofs fail if the math is wrong. Every theorem reduces to Mathlib’s \(\operatorname {fderiv}\) and Lean core; the trust kernel is exactly propext, Classical.choice, Quot.sound, plus what Mathlib proves about \(\operatorname {fderiv}\). Type-checking can’t catch the prose narrating a different formula than the theorem states, though.

  • The comparator fails if the elaborator accepted a term the kernel wouldn’t, or if any project axiom snuck into a proof’s transitive closure — the independent re-check that makes “proven” mean proven-clean, since \(\# \)print axioms and \(\texttt{lake build}\) share an elaborator and cannot audit themselves.

By construction (structural): the code is the math.

  • The bridge theorems (§ B.5) fail if the emitted StableHLO denotes anything other than the proven derivative. The graph is printed from the same datatype the bridge reasons about, so it cannot drift: “the code computes the proven gradient” is a theorem about a value, not a hope about a comment.

Cross-checked (empirical): the hardware agrees.

  • FD gradient checks catch prose-vs-theorem drift and any algebraic typo in a closed-form Jacobian. FD can’t spot compiler bugs or non-smooth-point tiebreak disagreements.

  • The JAX-parallel pipeline fails if either compiler has a codegen bug, if the runtime drifts, or if a hand-derived VJP disagrees with JAX autodiff.

  • The IREE-versus-NumPy oracle fails if the compiled graph computes the wrong thing on real hardware — the one check that reaches past the proofs to IREE’s lowering — which cannot be proven — and float32, whose rounding now can be, conditionally (§ B.7).

  • The margin probe (scripts/margin_probe.py) fails if a real training run violates the float theorems’ hypotheses — a flipped ReLU mask, a logit drift past \(\delta \) — the check that keeps the conditional theorems honest about actual training rather than hypothetical nets.

Each kind fails differently — a deductive proof is blind to a miscompile, an empirical diff is blind to a measure-zero kink, a faithful bridge is blind to a wrong formula it renders perfectly — and overlaying them is the guarantee. Verified code generation straddles two of these: its bridge is structural, its oracle empirical. The float32 bridge straddles the other pair: its budgets are deductive, its interface constants (\(u\), \(e_{\exp }\), \(\delta \)) empirical.