Verified Deep Learning with Lean 4

B On Verification

The proofs in this book compile with zero sorrys. Every VJP correctness theorem — dense layers, convolution, batch normalization, residual connections, depthwise convolution, squeeze-and-excitation, layer normalization, and self-attention — is machine-checked by Lean’s type system. If it builds, it’s correct.

But “it builds” is only one of four independent checks we rely on. Composition proofs catch math errors, finite-difference checks confirm each closed-form Jacobian formula in the proof matches the function it claims to differentiate, a parallel JAX pipeline catches codegen and implementation errors, and an independent kernel re-check (comparator) audits the transitive axiom closure of every theorem outside Lean’s elaborator. Each layer fails in a different way, so overlaying them is the defense. This appendix walks through all four.

B.1 Trust kernel

The Verified VJP Proofs suite contains 81 numbered items across the proof chapters (72 theorems + 9 definitions), plus 22 architecture definitions in the bestiary (103 total) across 9 Lean files, with zero project axioms. Earlier drafts shipped 30 project axioms; the progression: 30 (early) \(\to \) 10 (post diff-threading) \(\to \) 0 (post column-slab + Phase 7 cleanup, Apr 2026).

What landed in the final rounds to retire the last of them:

  • Phase 1–2 (input-side conv VJPs). The \(\texttt{conv2d}\) and \(\texttt{depthwiseConv2d}\) input VJPs are now theorems proved via pdiv_finset_sum \(\times \) 3 + pdiv_const_mul_pi_pad_eval per summand + \(\sum _{c, k_h, k_w}\) collapse.

  • Phase 3 (multi-head SDPA). A new \(\operatorname {pdivMat}\_ \texttt{colIndep}\) theorem and \(\texttt{colSlabwise\_ has\_ vjp\_ mat}\) framework lift the proved single-head SDPA backward over the head axis. Both \(\texttt{mhsa\_ has\_ vjp\_ mat}\) and its \(\mathsf{Differentiable}\) sibling are now theorems.

  • Phase 6a/6b (patch embedding). De-opaqued the \(\texttt{patchEmbed\_ flat}\) forward into a concrete def (conv-with-stride + CLS prepend + positional embed), then proved its input VJP via spatial-rearrangement bridge to \(\texttt{conv2d\_ has\_ vjp3}\).

  • Phase 7 (the final four). \(\operatorname {pdiv}\_ \texttt{relu}\) proved via local-diagonal-CLM transport: at a smooth point ReLU agrees with the diagonal indicator CLM on a neighborhood (every coordinate keeps its sign), and HasFDerivAt.congr_of_eventuallyEq transports the CLM’s self-fderiv to ReLU. The remaining three (\(\texttt{relu\_ has\_ vjp}\), \(\texttt{mlp\_ has\_ vjp}\), and \(\texttt{maxPool2\_ has\_ vjp3}\)) became noncomputable defs over the canonical pdiv-derived witness; HasVJP.correct holds by rfl since \(\operatorname {pdiv}\) is a def over \(\operatorname {fderiv}\). At non-smooth points the canonical backward is \(\operatorname {fderiv}\)’s junk default of \(0\); the codegen substitutes the standard subgradient /argmax convention — see “Codegen trust boundary” in LeanMlir/Proofs/README.md.

Pure-Mathlib closure on every theorem. \(\texttt{\# print axioms vit\_ full\_ has\_ vjp}\) shows only Lean core (propext, Classical.choice, Quot.sound); nothing project-level beneath those.

B.2 Finite-difference gradient checks

The script LeanMlir/Proofs/check_axioms.py (the name predates the axiom-count going to zero) runs 25 finite-difference gradient checks. For each, it perturbs the input by \(\varepsilon \), compares the claimed VJP against the centered difference \((f(x + \varepsilon ) - f(x - \varepsilon )) / 2\varepsilon \), and asserts agreement within tolerance (typical max-error \(\sim 10^{-11}\) at \(\varepsilon = 10^{-5}\) in float64).

With zero project axioms, every FD check is now a belt-and-suspenders pass over a proved Jacobian theorem. The proofs already establish that each formula equals \(\operatorname {fderiv}\) at the relevant points; the FD checks confirm the formulas-as-written agree numerically with what the function actually does. Coverage spans every closed-form Jacobian we use downstream: \(\operatorname {pdiv}\_ \texttt{dense}\) and its weight/bias companions, \(\operatorname {pdiv}\_ \texttt{relu}\) (at smooth points), softmax cross-entropy, all four BN pieces (bnNormalize, bnCentered, bnIstdBroadcast, bnAffine), the conv2d and depthwise input/weight/bias VJPs, the maxPool2 input VJP, the softmax Jacobian, the three single-head SDPA Q/K/V backwards, GELU, the bundled multi-head SDPA reduction (per-head sdpa_back stacked over the head axis), the patch-embed input VJP, and the full-network MLP composition.

What the FD pass catches that the symbolic proof can’t: typos between the proof and the prose that uses the same Jacobian. If a chapter narrates one formula but the Lean theorem states another, the proof still type-checks (Lean only verifies its own statement), but the FD test runs against the formula the prose published and would diverge.

FD is cheap, easy to reason about, and tight enough for spot-checking formulas, but it can’t probe what the compiled code actually computes on the GPU — only what the formula says in Python — and it struggles at non-smooth points where the limit definition breaks down. For those gaps and for end-to-end pipeline verification, the third layer takes over.

B.3 The JAX parallel pipeline

A separate Lean \(\to \) JAX \(\to \) XLA pipeline (jax/Jax/Codegen.lean, \(\sim 1100\) lines) produces an idiomatic JAX training script from the same NetSpec the primary Lean \(\to \) StableHLO MLIR \(\to \) IREE pipeline consumes. XLA is the compiler JAX uses to produce GPU code (the same backend that powers TensorFlow and other frameworks); IREE is its Lean-side counterpart. The two stacks are independent end to end — different codegen, different runtime, different kernels — which is why agreement between them is a meaningful cross-check. Running both from identical initial parameters on identical batches gives us a pair of trace files, one per stack, that should agree modulo float32 rounding if both pipelines compute the same math.

The agreement is very tight. For the MNIST MLP, step-1 losses agree to \(\sim 2 \times 10^{-7}\) — float32 ULP — across the JAX-CPU-vs-IREE-ROCm comparison, and phase-3 IREE output is bit-identical across AMD and NVIDIA hardware at step 1. For the MNIST CNN with batch norm, step-1 agrees to \(\sim 10^{-4}\), looser because variance reductions over \(\sim 100\)k-element tensors amplify cross-compiler reduction-tree differences — both pipelines do correct math; they just sum it in different orders. Full results are committed to the repo as reproducible JSON-Lines traces; see traces/CROSS_BACKEND_RESULTS.md.

Layered on top of the end-to-end trace diff is a per-axiom differential test in tests/vjp_oracle/, which compares each Lean-proved backward pass against JAX’s value_and_grad autodiff on a minimal one-step training run. Nine cases — dense, dense+ReLU, conv, conv+BN, conv+maxPool, residual, depthwise, SE, attention — each agree with JAX autodiff at 1–2 ULP of step-2 loss. Any future hand-derived VJP added to the Lean proof base can be validated by a one-step comparison against JAX, catching algebraic errors that FD would miss (sign flips, wrong contraction axis, swapped indices).

B.4 Independent kernel re-check (comparator)

tests/comparator/ wires the project up to leanprover/comparator, the Lean community’s trustworthy-judge tool for projects that claim zero project axioms. It runs 38 theorems — the foundation rules, every chapter’s headline Jacobian, and the public *_has_vjp_correct wrappers — through Lean’s kernel typechecker independently of the elaborator, with an axiom-allowlist of exactly \(\{ \texttt{propext}, \texttt{Quot.sound}, \texttt{Classical.choice}\} \). Any project axiom in the transitive closure of any verified theorem would fail the run.

This is what closes the gap between “the elaborator accepted my proofs” (which lake build confirms) and “the kernel agrees, audited from a separate process” (which comparator confirms). The 38-theorem coverage is illustrative rather than exhaustive — the same recipe scales to any subset of the proof suite, and the same allowlist applies because every theorem in the project closes the same way.

B.5 The four layers, summarized

Together, the four layers catch four different kinds of bugs:

  • Formal proofs fail if the math is wrong. With zero project axioms, every theorem reduces to Mathlib’s \(\operatorname {fderiv}\) and Lean core; the trust kernel is exactly propext, Classical.choice, Quot.sound, plus what Mathlib proves about \(\operatorname {fderiv}\). Type-checking can’t catch the prose narrating a different formula than the theorem states, though.

  • FD gradient checks catch that prose-vs-theorem drift, plus any algebraic typo in a closed-form Jacobian. FD can’t spot compiler bugs or non-smooth-point tiebreak disagreements.

  • The JAX-parallel pipeline fails if either compiler has a codegen bug, if the runtime drifts, or if the hand-derived VJP disagrees with JAX autodiff — the places formal proofs and FD can’t see.

  • Comparator fails if the elaborator accepted a term the kernel wouldn’t, or if any project axiom snuck into a proof’s transitive closure — the places \(\# \)print axioms and \(\texttt{lake build}\) share an elaborator with and therefore can’t independently audit.

Each layer fails differently. Overlaying them is the guarantee.