Verified Deep Learning with Lean 4

11 Bestiary of Architectures

Part 1 introduced a small set of NetSpec primitives across chapters 2–10 and proved each one’s backward pass:

  • Atomic ops: .dense, .flatten, .conv2d, .maxPool, .convBn, .globalAvgPool, plus the Activation enum (Ch 3–6).

  • .residualBlock (Ch 6, ResNet-34).

  • .invertedResidual (Ch 7, MobileNet V2).

  • .mbConv with optional Squeeze-and-Excitation (Ch 8, EfficientNet-B0).

  • .convNextStage + .convNextDownsample (Ch 9, ConvNeXt-T).

  • .patchEmbed + .transformerEncoder (Ch 10, ViT-Tiny).

Call this the primary-sequence toolbox — everything a reader has internalized after finishing Part 1. The bestiary is Part 2: a catalogue of famous architectures expressed as pure NetSpec values, no training runs, no VJP commitments, just “here’s what this architecture looks like in \(\sim \)15 lines of Lean.”

What “\(N\) new primitives” means.

Every bestiary entry leads with how many new Layer constructors it asks the reader to absorb on top of the primary-sequence toolbox.

  • Zero new primitives means the architecture composes entirely from things you already know — VGG (just \(3\times 3\) conv stacks), AlphaZero (convBn + residualBlock), Highway Networks (two parallel dense paths plus a blend that’s just product rule + additive fan-in from Ch 2). The architecture is conceptually a free read.

  • One new primitive means one bundled architectural idiom that wasn’t in Part 1 — .bottleneckBlock for deeper ResNets, .fireModule for SqueezeNet, .mambaBlock for Mamba, etc. The primitive is bundled at the right abstraction level (a whole Mamba block, a stage of Swin attention) so reading it is one chunk, not twenty sub-ops.

  • Two or three new primitives are reserved for architectures that genuinely combine multiple novel idioms — DETR (.transformerDecoder + .detrHeads), DenseNet (.denseBlock + .transitionLayer), Stable Diffusion (.unetDown + .unetUp + .transformerDecoder). At most three; if you’d need more, the architecture probably warranted its own chapter.

This count is calibrated to the reader (what’s new vs. what they’ve seen), not to the codebase (where some bestiary primitives were drafted earlier and already sit in Types.lean). Formalizing every sub-op of every modern architecture would be an asymptote we don’t approach.

11.1 Bestiary-only Layer primitives

Twenty-two new Layer constructors were added by the bestiary chapters. None are codegen-backed (MlirCodegen emits // UNSUPPORTED); the goal is pedagogical shape + parameter accounting, not training runs.

Definition 82 Inception module
#

The GoogLeNet parallel-branch module (Szegedy et al. 2014). Four branches computed in parallel: 1\(\times \)1 conv (b1 channels); 1\(\times \)1 reduce then 3\(\times \)3 (b2); 1\(\times \)1 reduce then 5\(\times \)5 (b3); 3\(\times \)3 maxPool then 1\(\times \)1 (b4). Concat along channels for b1 + b2 + b3 + b4 outputs. The 1\(\times \)1 dimension reducers on branches 2 and 3 are the paper’s trick — they make the expensive 3\(\times \)3 and 5\(\times \)5 convs operate on reduced channel counts. Signature: inceptionModule (ic b1 b2reduce b2 b3reduce b3 b4 : Nat).

Definition 83 DenseNet dense block
#

DenseNet’s bundled dense block (Huang et al. 2017). nLayers BN-ReLU-1\(\times \)1-BN-ReLU-3\(\times \)3 sub-stacks, each adding growthRate channels to the running concatenation of preceding outputs. Input has ic channels; output has ic + nLayers \(\cdot \) growthRate. Bundled because each sub-layer reads a growing concat of all preceding sub-layers, which doesn’t fit a linear NetSpec at sub-layer granularity. Signature: denseBlock (ic growthRate nLayers : Nat).

Definition 84 DenseNet transition layer
#

Inter-block downsample for DenseNet. BN + 1\(\times \)1 conv (ic \(\to \) oc) + 2\(\times \)2 average-pool stride 2. Halves spatial resolution and (typically) halves channel count via the 1\(\times \)1 conv to compress feature reuse between dense blocks. Signature: transitionLayer (ic oc : Nat).

Definition 85 ShuffleNet v1 stage
#

Grouped 1\(\times \)1 conv + channel-shuffle permutation + 3\(\times \)3 depthwise + grouped 1\(\times \)1 conv, for nUnits units (first downsampling, rest residual). The shuffle is parameter-free; grouping reduces 1\(\times \)1 cost by \(g\). Signature: shuffleBlock (ic oc groups nUnits : Nat).

Definition 86 ShuffleNet v2 stage
#

nUnits v2 units (Ma et al. 2018). Basic unit (stride 1): channel-split \([X_1, X_2]\), leave \(X_1\) alone, run \(X_2\) through \(1\times 1 \to 3\times 3\) DW \(\to 1\times 1\) at half-width, concat, then channel-shuffle. Downsample unit (stride 2): both branches see the full input; left does DW-3\(\times \)3-stride-2 \(+\) 1\(\times \)1, right does 1\(\times \)1 \(+\) DW-3\(\times \)3-stride-2 \(+\) 1\(\times \)1; concat doubles channels. v2 throws out v1’s grouped 1\(\times \)1 convs (G2) and skip-add (G4) per the paper’s practical guidelines. Signature: shuffleV2Block (ic oc nUnits : Nat).

Definition 87 MobileViT block
#

Hybrid local-conv + patch-level transformer (Mehta & Rastegari 2022). Local 3\(\times \)3 conv \(\to \) 1\(\times \)1 projection to transformer dim \(\to \) unfold into patches \(\to \) \(L\) transformer blocks across patches \(\to \) fold back \(\to \) 1\(\times \)1 projection back \(\to \) concat with input \(\to \) 3\(\times \)3 fusion. Signature: mobileVitBlock (ic dim heads mlpDim nTxBlocks : Nat). The unfold/fold operations are pdiv_reindex-style shape transformations; all the genuinely new math is already covered by the transformer proof chapter.

Definition 88 Swin Transformer stage
#

Windowed multi-head self-attention at fixed spatial resolution (Liu et al. 2021). Signature: swinStage (dim heads mlpDim windowSize nBlocks : Nat). Internal blocks alternate W-MSA and SW-MSA (shifted-window) to let information cross window boundaries.

Definition 89 Patch merging
#

Swin’s 2\(\times \)2 spatial downsample + linear channel projection (inDim \(\to \) outDim). Transformer-side analog of a stride-2 conv.

Definition 90 Darknet residual block
#

YOLOv3’s Darknet-53 residual stack (Redmon & Farhadi 2018). nBlocks residual blocks at fixed channels, each being 1\(\times \)1 conv (\(c \to c/2\)) \(+\) 3\(\times \)3 conv (\(c/2 \to c\)) \(+\) residual add. Lighter than a standard ResNet bottleneck; heavier than a ResNet-18 basic block. Signature: darknetBlock (channels nBlocks : Nat).

Definition 91 Cross-Stage Partial block
#

CSP (Wang et al. 2019), used by YOLOv4 onward. Splits input into two halves, processes one half through a stack of residual blocks, then concatenates with the untouched half and 1\(\times \)1-projects to oc. The specific inner block varies across YOLO versions (C3 in v5, C2f in v8, C3k2 in v11); this primitive approximates all three at the same abstraction level. Signature: cspBlock (ic oc nBlocks : Nat).

Definition 92 Feature Pyramid Network
#

Lin et al. 2017. Takes the four stage outputs of a CNN backbone (channels c2/c3/c4/c5 at strides \(4/8/16/32\)), projects each to target channels with a \(1\times 1\) lateral conv, merges them top-down (upsample-2\(\times \) then elementwise add), and applies a \(3\times 3\) smoothing conv at each merged level. Output: four feature maps, each target-wide, at the original spatial resolutions. Bundled because the cross-scale add doesn’t fit a linear NetSpec. Standard kit in 2-stage detectors (Mask R-CNN, Cascade R-CNN) and single-stage detectors (RetinaNet). Signature: fpnModule (c2 c3 c4 c5 target : Nat).

Definition 93 Transformer decoder (DETR-style)
#

nBlocks blocks of self-attention over nQueries learned object queries + cross-attention against an encoder output + FFN. The query embedding is part of the layer’s parameters. Signature: transformerDecoder (dim heads mlpDim nBlocks nQueries : Nat).

Definition 94 DETR prediction heads
#

Per-query class head (linear dim \(\to \) nClasses+1 with the “no object” slot) + box head (3-layer MLP to 4 scalars (cx, cy, w, h)). Signature: detrHeads (dim nClasses : Nat).

Definition 95 UNet encoder stage
#

\(2 \times \) (conv3\(\times \)3 + BN + ReLU) then maxPool-2. Saves its pre-pool activation as a skip for the matching unetUp. Signature: unetDown (ic oc : Nat).

Definition 96 UNet decoder stage
#

Transposed-conv \(2\times \) upsample + concat with matching skip + \(2 \times \) (conv3\(\times \)3 + BN + ReLU). Signature: unetUp (ic oc : Nat), where oc is both the output channel count and the expected skip width.

Definition 97 Atrous Spatial Pyramid Pooling
#

DeepLab v3+’s marquee module (Chen et al. 2018). Five parallel branches emitting oc channels each: (1) 1\(\times \)1 conv, (2–4) 3\(\times \)3 atrous convs at dilation rates 6 / 12 / 18, (5) global avg-pool + 1\(\times \)1 conv + bilinear upsample. Concatenate, then a 1\(\times \)1 fusion conv back to oc. All branches include BN + ReLU. Atrous rates widen the effective receptive field without changing param count; the pool branch supplies image-level context. Signature: asppModule (ic oc : Nat).

Definition 98 Mamba block
#

Selective state-space block in the S6 formulation (Gu & Dao 2023). Signature: mambaBlock (dim stateSize expand nBlocks : Nat). Bundles RMSNorm + linear expand + depthwise 1D conv + SiLU + selective-scan SSM + gated product + output projection, for nBlocks stacked layers. The selective scan is the novel primitive; everything else could be decomposed to existing layers if we cared to unpack the bundle.

Definition 99 WaveNet block
#

One stack of nLayers dilated causal residual blocks with doubling dilation rates \(2^0, 2^1, \ldots , 2^{\texttt{nLayers}-1}\) (van den Oord et al. 2016). Each block: dilated 2-tap causal conv \(\to \) gated activation \(\tanh (\text{filter}) \odot \sigma (\text{gate})\) \(\to \) 1\(\times \)1 project back to residualCh (residual path) + 1\(\times \)1 skip projection to skipCh. Skip outputs across blocks are summed into the final head. Signature: waveNetBlock (residualCh skipCh nLayers : Nat). Output channels are skipCh: bestiary convention picks the skip path as the "forward" output since it’s what feeds the final classifier.

Definition 100 Positional encoding
#

Sinusoidal frequency basis (Vaswani 2017, reused by NeRF 2020): \(\gamma (p) = (\sin 2^0 \pi p, \cos 2^0 \pi p, \ldots , \sin 2^{L-1} \pi p, \cos 2^{L-1} \pi p)\). Zero trainable parameters — it’s a deterministic lift of a low-dim coordinate into a high-frequency feature space where an MLP has enough wiggle room to represent sharp details. Output dim \(=\) inputDim \(\cdot 2 \cdot \) numFrequencies. Signature: positionalEncoding (inputDim numFrequencies : Nat).

Definition 101 NeRF MLP core
#

The whole NeRF network (Mildenhall et al. 2020) bundled as one primitive: 8 hidden ReLU-FC layers of hiddenDim, mid-skip concatenating \(\gamma (x)\) at layer 5, dual output heads (1-dim volume density \(\sigma \) + 3-dim RGB via a direction-conditioned branch). Under 600K parameters at the canonical config. Signature: nerfMLP (encodedPosDim encodedDirDim hiddenDim : Nat).

Definition 102 Evoformer block (AlphaFold 2)
#

Dual-representation (MSA + pair) joint update: MSA row-attention with pair bias + MSA column-attention + MSA transition + outer-product-mean (\(\to \) pair) + triangle multiplicative (outgoing, incoming) + triangle self-attention (starting node, ending node) + pair transition. Signature: evoformerBlock (msaChannels pairChannels nBlocks : Nat). The triangulation-aware operations are the key inductive bias.

Definition 103 Structure Module (AlphaFold 2)
#

Recurrent Invariant Point Attention + backbone frame update + side-chain \(\chi \)-angle prediction. Weights shared across nBlocks rounds — param count does not multiply by nBlocks. Signature: structureModule (singleChannels pairChannels nBlocks : Nat).

11.2 Bestiary entries

The entries are grouped by task domain. The first block (vision classifiers) is where Part 1’s VJP’d primitives show up at real-world scale — every layer is one you’ve already seen proved correct. Subsequent blocks step out to detection, segmentation, reinforcement learning, and the non-vision outliers (language, audio, 3D, multimodal, science).

11.2.1 Vision classifiers — Part 1’s primitives at scale

Image-classification backbones built out of conv / pool / batch-norm / residual / attention / patch-embed — the exact layer kit VJP’d in Part 1. If a chapter of Part 1 proved it, a bestiary entry below puts it to work.

LeNet (Bestiary/LeNet.lean)

Zero new primitives — just conv2d \(+\) maxPool \(+\) dense. The 1998 original (LeCun et al. 1998, IEEE 1998). Variants: LeNet-5 (61K params, the canonical CNN) and LeNet-300-100 (266K params, pure-MLP baseline). Historical importance; still the ground-truth pattern every later CNN riffs on.

def leNet5 : NetSpec where
  name   := "LeNet-5"
  imageH := 32
  imageW := 32
  layers := [
    .conv2d 1 6 5 .valid .relu,           -- 32×32 → 28×28
    .maxPool 2 2,                         -- 28×28 → 14×14
    .conv2d 6 16 5 .valid .relu,          -- 14×14 → 10×10
    .maxPool 2 2,                         -- 10×10 → 5×5
    .flatten,
    .dense (16 * 5 * 5) 120 .relu,        -- 400 → 120
    .dense 120 84 .relu,
    .dense 84 10 .identity                -- 10-class MNIST output
  ]

AlexNet (Bestiary/AlexNet.lean)

Zero new primitives — five convs, three FCs, pools. The 2012 ImageNet winner that restarted modern deep learning (Krizhevsky et al. 2012, NeurIPS 2012). Variants: AlexNet (62M, paper-exact at 60M) + tiny CIFAR fixture. \(\sim \)58M of the 62M live in the three FC layers, which is precisely why every post-2015 CNN dropped FC stacks for globalAvgPool + one final dense. LRN is omitted (replaced in the field by BatchNorm in 2015); dropout is training-time only.

Krizhevsky trained on two GTX 580 cards (3 GB each) for \(\sim \)6 days, the two-tower model is a hardware constraint, not a design choice. The winning submission combines convnets with KNN sampling and 10-crop test-time averaging.

def alexNet : NetSpec where
  name   := "AlexNet (Krizhevsky 2012)"
  imageH := 227
  imageW := 227
  layers := [
    .convBn 3 96 11 4 .same,              -- 11×11 stride 4
    .maxPool 2 2,                         -- 3×3 stride 2 in paper
    .conv2d 96 256 5 .same .relu,
    .maxPool 2 2,
    .conv2d 256 384 3 .same .relu,
    .conv2d 384 384 3 .same .relu,
    .conv2d 384 256 3 .same .relu,
    .maxPool 2 2,
    .flatten,
    .dense (6 * 6 * 256) 4096 .relu,      -- ~95% of params live here
    .dense 4096 4096 .relu,
    .dense 4096 1000 .identity
  ]

VGG (Bestiary/VGG.lean)

Zero new primitives — pure stacks of \(3\times 3\) convs + max-pool + a heavy dense head. VGG (Simonyan & Zisserman 2014, arXiv:1409.1556) was the deep-CNN era’s reference architecture for two years between AlexNet and ResNet. Its design philosophy reduced to one rule: “always use \(3\times 3\) kernels and add depth.” Two stacked \(3\times 3\)s cover the receptive field of a \(5\times 5\) with \(\sim \)30% fewer params and an extra nonlinearity in the middle; three stacked \(3\times 3\)s match a \(7\times 7\). The whole architecture is just that trick repeated.

All four paper variants:

Name

Per-stage convs

Params

VGG-11

(1, 1, 2, 2, 2)

132.9M

VGG-13

(2, 2, 2, 2, 2)

133.0M

VGG-16

(2, 2, 3, 3, 3)

138.4M

VGG-19

(2, 2, 4, 4, 4)

143.7M

All paper-exact. The big numbers come almost entirely from the FC head: the first dense layer alone (\(7\times 7\times 512 \to 4096\)) is \(\sim \)102M of the 138M for VGG-16.

def vgg16 : NetSpec where
  name   := "VGG-16"
  imageH := 224
  imageW := 224
  layers := [
    .conv2d 3 64 3 .same .relu, .conv2d 64 64 3 .same .relu,    .maxPool 2 2,
    .conv2d 64 128 3 .same .relu, .conv2d 128 128 3 .same .relu, .maxPool 2 2,
    .conv2d 128 256 3 .same .relu, .conv2d 256 256 3 .same .relu,
    .conv2d 256 256 3 .same .relu,                              .maxPool 2 2,
    .conv2d 256 512 3 .same .relu, .conv2d 512 512 3 .same .relu,
    .conv2d 512 512 3 .same .relu,                              .maxPool 2 2,
    .conv2d 512 512 3 .same .relu, .conv2d 512 512 3 .same .relu,
    .conv2d 512 512 3 .same .relu,                              .maxPool 2 2,
    .flatten,
    .dense (7 * 7 * 512) 4096 .relu,
    .dense 4096 4096 .relu,
    .dense 4096 1000 .identity
  ]

VGG-16 / VGG-19 still appear in production today as feature extractors (perceptual loss, style transfer) — those FC features turn out to encode useful image structure even after the classification head fell out of favor.

Highway Networks (Bestiary/Highway.lean)

Zero new math — the architecture is built entirely from the product rule and additive fan-in (Chapter 2’s toolkit) applied to two parallel dense / conv paths. One shape-only primitive, .highwayBlock ic oc, bundling the two paths and the gate-blend (Srivastava, Greff, Schmidhuber 2015, arXiv:1505.00387). Variants: canonical 50-layer Highway-MLP, deeper 100-layer fixture, tiny.

A two-headed network: a main path \(H(x)\) (any transformation, typically dense \(+\) ReLU) and a transform gate \(T(x)\) (its own dense \(+\) sigmoid, output in \([0, 1]\)). The output blends the two:

\[ y = T(x) \cdot H(x) + \bigl(1 - T(x)\bigr) \cdot x. \]

\(T\) learned per element gives the network a continuous knob between “pass through” (\(T \to 0\)) and “transform fully” (\(T \to 1\)). The gradient through the blend is exactly product rule + additive fan-in — both already in Chapter 2’s VJP kit — so no new derivation is needed.

def highwayMain : NetSpec where
  name   := "Highway block — main path H(x)"
  imageH := 1
  imageW := 1
  layers := [
    .dense 50 50 .relu                           -- H(x)
  ]

def highwayGate : NetSpec where
  name   := "Highway block — transform gate T(x)"
  imageH := 1
  imageW := 1
  layers := [
    .dense 50 50 .identity                       -- T(x), sigmoid applied at blend
  ]

The pedagogical hook: ResNet (Chapter 6) is the special case where \(T\) is fixed at \(1\) on the transform side and the carry path becomes a constant identity. Highway showed in 2015 that some kind of bypass made very deep networks trainable; ResNet showed six months later that the gate could be a constant and you’d save the parameters. ResNet won by simplifying away \(T\), but the underlying calculus is identical — product + additive fan-in either way.

ResNet (Bestiary/ResNet.lean)

One new primitive (.bottleneckBlock). Chapter 6 walks through the basic-block variant ResNet-34 in detail; this entry catalogs the standard ResNet family (He et al. 2015, arXiv:1512.03385). The basic-block variants (R18, R34) stack two \(3\times 3\) convs per residual unit; the bottleneck variants (R50, R101, R152) use \(1\times 1\) reduce \(\to \) \(3\times 3\) work \(\to \) \(1\times 1\) expand around a residual skip — three convs per block instead of two, but the expensive \(3\times 3\) now runs at \(1/4\) the input channels, so per-block parameter count actually drops.

Name

Block

Stage block counts

Params

ResNet-18

basic

(2, 2, 2, 2)

11.7M (paper 11.7M)

ResNet-34

basic

(3, 4, 6, 3)

21.3M (paper 21.8M)

ResNet-50

bottleneck

(3, 4, 6, 3)

25.6M (paper 25.6M)

ResNet-101

bottleneck

(3, 4, 23, 3)

44.5M (paper 44.5M)

ResNet-152

bottleneck

(3, 8, 36, 3)

60.2M (paper 60.2M)

All three share the same stem (\(7\times 7\) stride-2 conv + \(3\times 3\) stride-2 max pool), the same 4-stage layout (channels \(256 \to 512 \to 1024 \to 2048\)), and the same GAP + single-FC head. They differ only in stage block-counts.

def resNet50 : NetSpec where
  name   := "ResNet-50"
  imageH := 224
  imageW := 224
  layers := [
    .convBn 3 64 7 2 .same,
    .maxPool 3 2,
    .bottleneckBlock  64  256 3 1,             -- stage 1: 3 blocks
    .bottleneckBlock 256  512 4 2,             -- stage 2: 4
    .bottleneckBlock 512 1024 6 2,             -- stage 3: 6
    .bottleneckBlock 1024 2048 3 2,            -- stage 4: 3
    .globalAvgPool,
    .dense 2048 1000 .identity
  ]

The same residual chain rule that’s proven for ResNet-34 in Chapter 6 applies to bottleneck-block ResNets — the bottleneck just composes three convolutions instead of two inside the residual branch. No new math.

WRN (Wide ResNet) (Bestiary/WRN.lean)

Zero new primitives — same .residualBlock as Chapter 6, just widened. Zagoruyko & Komodakis 2016 (arXiv:1605.07146) take the ResNet block and turn the depth knob in (40 layers down from 1000+) and the width knob out (multiply every channel count by \(k\), typically 2–10). Result: WRN-28-10 matches ResNet-1001 on CIFAR-10/100 with \(\sim \)25% the params and half the training time. Width turned out to be a more efficient axis to scale than depth.

Name

Depth

\(k\)

Params

WRN-28-10

28

10

36.5M

WRN-40-2

40

2

2.2M

WRN-22-8

22

8

17.2M

All paper-exact. The N-k naming follows the paper: \(N\) is depth (must be \(6n+4\) for some \(n\) to balance the 3-stage design), \(k\) is the widening factor that multiplies the \(16 \to 32 \to 64\) channel sequence to \(16k \to 32k \to 64k\).

def wrn28_10 : NetSpec where
  name   := "WRN-28-10"
  imageH := 32
  imageW := 32
  layers := [
    .convBn 3 16 3 1 .same,                   -- stem
    .residualBlock 16 160 4 1,                -- stage 1: 4 blocks @ 32×32
    .residualBlock 160 320 4 2,               -- stage 2: 4 blocks @ 16×16
    .residualBlock 320 640 4 2,               -- stage 3: 4 blocks @ 8×8
    .globalAvgPool,
    .dense 640 10 .identity
  ]

The paper also adds dropout between the two convs inside each residual block, which we omit (static-graph scope — no per-step masks). The dropout-free WRN variant is also reported in the paper and lands within \(\sim \)0.3% of the dropout one on CIFAR-10.

SqueezeNet (Bestiary/SqueezeNet.lean)

One new primitive (.fireModule). AlexNet-level accuracy in 1.25M params via the fire module (Iandola et al. 2016, arXiv:1602.07360): squeeze 1\(\times \)1 conv followed by parallel expand 1\(\times \)1 + 3\(\times \)3 convs concatenated. The early efficiency-CNN family alongside MobileNet and ShuffleNet.

Name

Stem

Params

SqueezeNet 1.0

7\(\times \)7 stride 2 (96 ch)

1.25M

SqueezeNet 1.1

3\(\times \)3 stride 2 (64 ch), earlier downsample

1.24M

Both paper-exact, plus a tiny fixture. SqueezeNet 1.1 is the default for downstream use — same accuracy at \(\sim \)2\(\times \) fewer FLOPs — and is what the first book trained.

def squeezeNet1_1 : NetSpec where
  name   := "SqueezeNet 1.1"
  imageH := 224
  imageW := 224
  layers := [
    .convBn 3 64 3 2 .same,                     -- 3×3 stride-2 stem (vs 1.0's 7×7)
    .maxPool 2 2,
    .fireModule 64  16 64  64,                  -- Fire2  → 128
    .fireModule 128 16 64  64,                  -- Fire3
    .maxPool 2 2,                               -- earlier downsample
    .fireModule 128 32 128 128,                 -- Fire4  → 256
    .fireModule 256 32 128 128,                 -- Fire5
    .maxPool 2 2,
    .fireModule 256 48 192 192,                 -- Fire6  → 384
    .fireModule 384 48 192 192,                 -- Fire7
    .fireModule 384 64 256 256,                 -- Fire8  → 512
    .fireModule 512 64 256 256,                 -- Fire9
    .conv2d 512 1000 1 .same .relu,             -- 1×1 to classes
    .globalAvgPool                              -- GAP emits 1000 logits
  ]

GoogLeNet / Inception v1 (Bestiary/Inception.lean)

One new primitive (§ 82). The parallel-branch architecture that won ImageNet 2014 (Szegedy et al. 2014, arXiv:1409.4842). Three durable contributions, each still standard kit:

  • 1\(\times \)1 conv as dimension reducer — placed before the expensive 3\(\times \)3 / 5\(\times \)5 branches, it keeps parallel-branch FLOPs tractable. This is the move every later efficiency-CNN family (MobileNet, EfficientNet, ResNet bottleneck, ShuffleNet) builds on.

  • Global average pool replaces the FC head (idea cited from Network-in-Network 2013, popularized here). GoogLeNet’s classifier is one globalAvgPool \(\to \) dense, not VGG’s three-stack of 4096-wide FCs — the same head pattern Chapter 6 onward inherits.

  • Auxiliary classifiers branched off the middle of the network at training time, providing intermediate gradient signal. A precursor to deep supervision; the bestiary omits them for spec simplicity.

GoogLeNet ships at 7M parameters — \(\sim \)\(20\times \) smaller than VGG-16 at competitive accuracy on ImageNet 2014.

def googLeNet : NetSpec where
  name   := "GoogLeNet (Inception v1)"
  imageH := 224
  imageW := 224
  layers := [
    -- Stem
    .convBn 3 64 7 2 .same,
    .maxPool 2 2,                               -- 56×56
    .convBn 64 64 1 1 .same,
    .convBn 64 192 3 1 .same,
    .maxPool 2 2,                               -- 28×28
    -- Inception 3a, 3b
    .inceptionModule 192 64 96 128 16 32 32,    -- → 256
    .inceptionModule 256 128 128 192 32 96 64,  -- → 480
    .maxPool 2 2,                               -- 14×14
    -- Inception 4a..4e
    .inceptionModule 480 192 96  208 16 48  64,  -- 512
    .inceptionModule 512 160 112 224 24 64  64,  -- 512
    .inceptionModule 512 128 128 256 24 64  64,  -- 512
    .inceptionModule 512 112 144 288 32 64  64,  -- 528
    .inceptionModule 528 256 160 320 32 128 128, -- 832
    .maxPool 2 2,                                -- 7×7
    -- Inception 5a, 5b
    .inceptionModule 832 256 160 320 32 128 128, -- 832
    .inceptionModule 832 384 192 384 48 128 128, -- 1024
    .globalAvgPool,
    .dense 1024 1000 .identity
  ]

Inception v3 / v4 (Bestiary/Inception.lean)

Zero new primitives — reuses § 82. Inception v3 (Szegedy et al. 2015, arXiv:1512.00567) and v4 (2016, arXiv:1602.07261) are recipe-and-refinement papers more than architecture papers. Three contributions worth pulling out from v3:

  • Label smoothing: the cross-entropy target is \(0.9\) on the true class and \(0.1 / (\mathrm{nClasses} - 1)\) elsewhere instead of a hard one-hot. Prevents the softmax from collapsing all probability onto the argmax, which improves calibration and ImageNet accuracy. Now standard in every modern recipe (DeiT, ConvNeXt, our own TrainConfig.labelSmoothing).

  • Asymmetric factorizations: a 7\(\times \)1 followed by a 1\(\times \)7 conv replaces a single 7\(\times \)7 with the same receptive field at \(\sim \)\(1/3.5\) the FLOPs. Variant blocks include 1\(\times \)3 + 3\(\times \)1 pairs.

  • BN-Aux on the auxiliary classifier branch — the auxiliary classifier of v1 was rebranded as a regularizer rather than a gradient-injection mechanism.

Inception v4 layers more refinement on top: a wider initial stem (stem-A / stem-B / stem-C blocks), three Inception module variants specialized per spatial resolution, and a residual variant (Inception-ResNet) tested in the same paper. Variants here: Inception-v3 (23.8M, paper 23M), Inception-v4 (33M, paper 42M — a bit low because our unified .inceptionModule primitive approximates v3/v4’s richer module catalog).

Xception (Bestiary/Xception.lean)

One new primitive (.separableConv). “Extreme Inception” (Chollet 2017, arXiv:1610.02357): every conv is a depthwise-separable conv. The design choice that made MobileNet possible a year later. Variants: Xception (21.9M, paper-exact at 22M), tiny fixture. Residual skips around each block-of-three-sep-convs are implicit in the linear NetSpec.

def xception : NetSpec where
  name   := "Xception"
  imageH := 299
  imageW := 299
  layers := [
    -- Entry flow
    .convBn 3 32 3 2 .same,
    .convBn 32 64 3 1 .same,
    .separableConv 64 128 1, .separableConv 128 128 1,
    .maxPool 2 2,
    .separableConv 128 256 1, .separableConv 256 256 1,
    .maxPool 2 2,
    .separableConv 256 728 1, .separableConv 728 728 1,
    .maxPool 2 2,
    -- Middle flow: 8× (3 sep-convs + implicit residual), linearized
    .separableConv 728 728 1, .separableConv 728 728 1, .separableConv 728 728 1,
    .separableConv 728 728 1, .separableConv 728 728 1, .separableConv 728 728 1,
    .separableConv 728 728 1, .separableConv 728 728 1, .separableConv 728 728 1,
    .separableConv 728 728 1, .separableConv 728 728 1, .separableConv 728 728 1,
    .separableConv 728 728 1, .separableConv 728 728 1, .separableConv 728 728 1,
    .separableConv 728 728 1, .separableConv 728 728 1, .separableConv 728 728 1,
    .separableConv 728 728 1, .separableConv 728 728 1, .separableConv 728 728 1,
    .separableConv 728 728 1, .separableConv 728 728 1, .separableConv 728 728 1,
    -- Exit flow
    .separableConv 728 728  1,
    .separableConv 728 1024 1,
    .maxPool 2 2,
    .separableConv 1024 1536 1,
    .separableConv 1536 2048 1,
    .globalAvgPool,
    .dense 2048 1000 .identity
  ]

DenseNet (Bestiary/DenseNet.lean)

Two new primitives (.denseBlock, .transitionLayer). DenseNet (Huang et al. 2017, arXiv:1608.06993) takes the ResNet bypass idea to its concatenative extreme: every layer in a “dense block” reads the channel-wise concatenation of all preceding layers’ outputs. Each layer adds only growth_rate new channels (32 in the paper), so the per-layer parameter count stays small while feature reuse compounds.

Name

Block counts

Params

DenseNet-121

(6, 12, 24, 16)

7.0M (paper 6.8M)

DenseNet-169

(6, 12, 32, 32)

12.5M (paper 12.5M)

DenseNet-201

(6, 12, 48, 32)

18.1M (paper 18.6M)

All within \(\sim \)3% of the paper. Plus a tiny fixture for testing.

The dense connectivity pattern doesn’t fit a linear NetSpec at the sub-layer granularity (each sub-layer reads a growing concatenation of its predecessors), so the block is bundled as a shape-only primitive Layer.denseBlock ic growth_rate n_layers alongside Layer.transitionLayer ic oc (BN + 1\(\times \)1 conv + 2\(\times \)2 avg-pool, halves channels and spatial between blocks). Same modeling trick we use for mambaBlock, swinStage, evoformerBlock — the architectural novelty is inside the bundled op.

def denseNet121 : NetSpec where
  name   := "DenseNet-121"
  imageH := 224
  imageW := 224
  layers := [
    .convBn 3 64 7 2 .same,                       -- stem 7×7 stride 2
    .maxPool 3 2,                                  -- 3×3 stride 2 → 56×56
    .denseBlock 64 32 6,                           -- → 256 channels
    .transitionLayer 256 128,                      -- → 128 @ 28×28
    .denseBlock 128 32 12,                         -- → 512 channels
    .transitionLayer 512 256,                      -- → 256 @ 14×14
    .denseBlock 256 32 24,                         -- → 1024 channels
    .transitionLayer 1024 512,                     -- → 512 @ 7×7
    .denseBlock 512 32 16,                         -- → 1024 channels
    .globalAvgPool,
    .dense 1024 10 .identity
  ]

The pedagogical hook: DenseNet sits next to ResNet historically (2017 vs 2015) but architecturally generalizes the bypass primitive — ResNet’s \(y = f(x) + x\) becomes \(y_i = f_i([y_0, \ldots , y_{i-1}])\). The cost: activation memory grows quadratically in block depth from the concatenation. Modern CNN architectures (EfficientNet, ConvNeXt) took the per-layer-parameter-efficiency lesson but dropped the all-pairs concatenation in favor of standard residual paths.

ShuffleNet (Bestiary/ShuffleNet.lean)

One new primitive (§ 85). Grouped 1\(\times \)1 convs + channel shuffle (Zhang et al. 2017, arXiv:1707.01083). Variants: ShuffleNet 0.5\(\times \) / 1.0\(\times \) / 2.0\(\times \) at \(g=3\), plus a tiny fixture.

def shuffleNet1x : NetSpec where
  name   := "ShuffleNet 1.0× (g=3)"
  imageH := 224
  imageW := 224
  layers := [
    .convBn 3 24 3 2 .same,                     -- stem
    .maxPool 2 2,
    .shuffleBlock 24  240 3 4,                  -- stage 2
    .shuffleBlock 240 480 3 8,                  -- stage 3
    .shuffleBlock 480 960 3 4,                  -- stage 4
    .globalAvgPool,
    .dense 960 1000 .identity
  ]

ShuffleNet v2 (Bestiary/ShuffleNetV2.lean)

One new primitive (§ 86). The efficient-CNN paper (Ma et al. 2018, arXiv:1807.11164) that called out FLOPs as a bad latency proxy: measured memory-access cost (MAC) directly and derived four practical guidelines (equal channel widths, avoid grouped convs, avoid fragmentation, avoid element-wise ops). v2’s architecture throws out everything in v1 that violated those rules — no grouped 1\(\times \)1 convs, no skip-add — in favor of channel-split + identity + concat + channel-shuffle. Variants: 0.5\(\times \) (1.37M, paper 1.4M), 1.0\(\times \) (2.29M, paper 2.3M), 1.5\(\times \) (3.52M, paper 3.5M), 2.0\(\times \) (7.41M, paper-exact), tiny fixture. All widths within 2% of paper.

def shuffleV2_1_0 : NetSpec where
  name   := "ShuffleNet v2 1.0×"
  imageH := 224
  imageW := 224
  layers := [
    .convBn 3 24 3 2 .same,                     -- stem
    .maxPool 2 2,
    .shuffleV2Block  24 116 4,                  -- stage 2
    .shuffleV2Block 116 232 8,                  -- stage 3
    .shuffleV2Block 232 464 4,                  -- stage 4
    .conv2d 464 1024 1 .same .relu,             -- 1×1 head conv
    .globalAvgPool,
    .dense 1024 1000 .identity
  ]

MobileViT (Bestiary/MobileViT.lean)

One new primitive (§ 87); reuses existing invertedResidual for the MV2 stages. Hybrid mobile backbone (Mehta & Rastegari 2022, arXiv:2110.02178): MobileNet V2 body with MobileViT blocks replacing some of the deeper inverted-residual stages. Variants: MobileViT-S (5.6M, paper-exact), XS (2.3M), XXS (1.3M), tiny fixture.

def mobileViTS : NetSpec where
  name   := "MobileViT-S"
  imageH := 256
  imageW := 256
  layers := [
    .convBn 3 16 3 2 .same,                     -- stem
    .invertedResidual 16 32 4 1 1,              -- stage 1: MV2
    .invertedResidual 32 64 4 2 3,              -- stage 2: 3× MV2 + downsample
    -- Stage 3: MV2 downsample + MobileViT (d=144, L=2)
    .invertedResidual 64 96 4 2 1,
    .mobileVitBlock 96 144 4 288 2,
    -- Stage 4: MV2 downsample + MobileViT (d=192, L=4)
    .invertedResidual 96 128 4 2 1,
    .mobileVitBlock 128 192 4 384 4,
    -- Stage 5: MV2 downsample + MobileViT (d=240, L=3)
    .invertedResidual 128 160 4 2 1,
    .mobileVitBlock 160 240 4 480 3,
    .conv2d 160 640 1 .same .relu,              -- 1×1 expansion
    .globalAvgPool,
    .dense 640 1000 .identity
  ]

Swin Transformer (Bestiary/SwinT.lean)

Two new primitives (§ 88, § 89). Variants: Swin-T / -S / -B / tiny. Swin-T lands at 28M params matching the paper (Liu et al. 2021, arXiv:2103.14030) exactly.

def swinT : NetSpec where
  name   := "Swin-T"
  imageH := 224
  imageW := 224
  layers := [
    .patchEmbed 3 96 4 3136,                    -- 56×56 patches, dim 96
    .swinStage 96 3 384 7 2,                    -- stage 1: 2 blocks
    .patchMerging 96 192,                       -- 56×56×96 → 28×28×192
    .swinStage 192 6 768 7 2,                   -- stage 2: 2 blocks
    .patchMerging 192 384,                      -- 28×28×192 → 14×14×384
    .swinStage 384 12 1536 7 6,                 -- stage 3: 6 blocks
    .patchMerging 384 768,                      -- 14×14×384 → 7×7×768
    .swinStage 768 24 3072 7 2,                 -- stage 4: 2 blocks
    .globalAvgPool,
    .dense 768 1000 .identity
  ]

11.2.2 Object detection

Localize and classify. Detection heads are where the linear NetSpec shape starts to creak: multi-scale FPN outputs need a graph, not a list. The bestiary entries show the single-scale view and defer the multi-head refactor to the limitations discussion.

YOLO on VOC persons, in passing.

[TODO: actually get this working — the objectness loss still needs the focal rebalance described in the third bullet below before the rendered boxes are clean. The trajectory and numbers in this paragraph are the target, not a measured result yet; see planning/yolo_v5.md for the current state.] The demos/MainYolov1VocTrainBootstrap.lean trainer is a single-class person detector: a ResNet-34 backbone (bootstrapped from the repo’s ImageNet ResNet-34 checkpoint, jax_r34_imagenet.bin, \(\sim \)21M backbone params) with a convolutional detection head, trained on Pascal VOC 2007 filtered to the person class (\(224 \times 224\) RGB \(\to \) a \(7 \times 7\) prediction grid). After training, it draws boxes that track people on held-out validation images (see demos/README.md for the rendered grid). Three things change relative to the image classifiers of Part 2 and the segmentation nets above:

  • The label is a variable-length set of boxes, baked onto a grid. A classifier compares one logit vector to one class; a segmentation net compares per-pixel logits to a per-pixel mask. Detection has a variable number of objects per image, each a (class, box). YOLO fixes the shape by tiling the image into a \(7 \times 7\) grid and making the cell containing an object’s center responsible for it: the target tensor carries, per cell, the box center/size, an objectness \(=1\), and a class slot — zero everywhere else. The loss is multi-part (coordinate regression + objectness + class), masked to the object cells, riding the lossKind := .yolov1Masked path in TrainConfig — no new VJP machinery, just the Part 2 squared-error and cross-entropy backwards fired on slices of the grid.

  • The head is convolutional, and that is load-bearing. The paper flattens the final feature map into two fully-connected layers (shown in the catalog entry below). This demo uses a \(1 \times 1\) convolution as the head instead (\(512 \to \) per-cell channels over the \(7 \times 7\) map). The reason is both pedagogical and practical: an FC head must learn the spatial correspondence “these features \(\to \) a box there” through a dense layer, and on a small dataset it collapses to predicting the average box in the image center; the conv head predicts each cell from its own feature column with shared weights, so localization is structural rather than learned. (This is precisely the v1\(\to \)v2 change.)

  • Class and foreground/background imbalance shape the loss. Two imbalances bite. The first is why this demo is single-class: on full 20-class VOC the network collapses to predicting the dominant “person” label everywhere, so we narrow to person-only and let the box head carry the demo. The second is intrinsic — most of the \(49\) cells are empty, so a naive objectness term drives every prediction to “no object”; a focal-weighted objectness loss down-weights the easy background cells so the few foreground cells survive. Inference then adds a post-processing step the classifiers never needed: decode each cell’s box (cell-relative \(\to \) absolute), score it (objectness \(\times \) class), threshold, and run non-max suppression to collapse overlapping boxes (scripts/yolo_render.py).

The catalog entry below keeps the paper-exact YOLOv1 shape (\(448\) input, FC head, all 20 classes) for the historical record; the working demo deliberately narrows to the simplest thing that detects — one class, one scale, conv head — and defers anchors and the multi-scale FPN heads (which don’t fit the linear NetSpec) to the limitations discussion.

YOLO v1/v3/v5/v8/v11 (Bestiary/YOLO.lean)

v1 (Redmon et al. 2016, arXiv:1506.02640) uses zero new primitives (conv2d + maxPool + flatten + dense; the YOLO-ness lives in the loss and output reshape, not the architecture). v3 adds § 90 for the Darknet-53 body. v5/v8/v11 use § 91 for their CSPDarknet backbones. Variants: YOLOv1 (271M, paper-exact), fast, tiny, YOLOv3 (40M backbone), YOLOv5s/m (3M/8M single-scale), YOLOv8n/s (0.7M/2.9M single-scale), YOLOv11n/m (backbone only). Multi-scale FPN detection heads don’t linearize — same skip-connection problem as UNet; all entries show a single-scale view.

def yolo : NetSpec where
  name   := "YOLOv1 (Redmon 2016)"
  imageH := 448
  imageW := 448
  layers := [
    .conv2d 3 64 7 .same .relu,
    .maxPool 2 2,
    .conv2d 64 192 3 .same .relu,
    .maxPool 2 2,
    -- Block 3: 1×1 reduce + 3×3 expand
    .conv2d 192 128 1 .same .relu,
    .conv2d 128 256 3 .same .relu,
    .conv2d 256 256 1 .same .relu,
    .conv2d 256 512 3 .same .relu,
    .maxPool 2 2,
    -- Block 4: four 1×1+3×3 pairs, then 1×1 + 3×3
    .conv2d 512 256 1 .same .relu, .conv2d 256 512 3 .same .relu,
    .conv2d 512 256 1 .same .relu, .conv2d 256 512 3 .same .relu,
    .conv2d 512 256 1 .same .relu, .conv2d 256 512 3 .same .relu,
    .conv2d 512 256 1 .same .relu, .conv2d 256 512 3 .same .relu,
    .conv2d 512 512 1 .same .relu,
    .conv2d 512 1024 3 .same .relu,
    .maxPool 2 2,
    -- Block 5: two 1×1+3×3 pairs + two more 3×3 convs
    .conv2d 1024 512 1 .same .relu, .conv2d 512 1024 3 .same .relu,
    .conv2d 1024 512 1 .same .relu, .conv2d 512 1024 3 .same .relu,
    .conv2d 1024 1024 3 .same .relu,
    .conv2d 1024 1024 3 .same .relu,
    -- Block 6: two 3×3 at 7×7
    .conv2d 1024 1024 3 .same .relu,
    .conv2d 1024 1024 3 .same .relu,
    -- Head: flatten + 2 FC → reshape to (7, 7, 2·5 + 20) = 1470
    .flatten,
    .dense (7 * 7 * 1024) 4096 .relu,
    .dense 4096 (7 * 7 * (2 * 5 + 20)) .identity
  ]

Mask R-CNN (Bestiary/MaskRCNN.lean)

Two new primitives (.bottleneckBlock + § 92). The canonical two-stage detector + instance segmentation reference (He et al. 2017, arXiv:1703.06870). Five architectural pieces: ResNet-FPN backbone, RPN for anchor-box proposals, ROI-Align for per-proposal feature extraction (an orchestration step, not a layer), a 2-layer FC box head for classification + bbox regression, and a 4-conv + transposed-conv mask head for per-class \(28 \times 28\) masks. Shown as separate NetSpecs per head (SAM-style decomposition). Param totals: backbone+FPN 45.8M, box head 14.3M, mask head 2.6M, RPN 0.6M — \(\sim \)63M in total, matching paper. The FPN cross-scale add is the one thing that demanded a new bundled primitive; everything else reuses existing ResNet / conv / dense primitives. DETR is its end-to-end transformer-era cousin; Mask2Former is the DETR-style instance-segmentation successor.

def maskRCNNBackboneR101 : NetSpec where
  name   := "Mask R-CNN backbone (ResNet-101-FPN)"
  imageH := 800
  imageW := 800
  layers := [
    -- ResNet-101 stem + body (C2..C5)
    .convBn 3 64 7 2 .same,
    .maxPool 3 2,
    .bottleneckBlock 64   256  3  1,            -- C2
    .bottleneckBlock 256  512  4  2,            -- C3
    .bottleneckBlock 512  1024 23 2,            -- C4
    .bottleneckBlock 1024 2048 3  2,            -- C5
    -- FPN: combines C2–C5 into a 4-level feature pyramid at 256 channels
    .fpnModule 256 512 1024 2048 256
  ]

DETR (Bestiary/DETR.lean)

Two new primitives (§ 93, § 94; Carion et al. 2020, arXiv:2005.12872); reuses existing bottleneckBlock (ResNet backbone), patchEmbed (absorbs the DETR 1\(\times \)1 channel reduce), and transformerEncoder. Variants: DETR-R50 (41M, paper-exact), DETR-R101 (60M), tiny.

def detrR50 : NetSpec where
  name   := "DETR-R50"
  imageH := 800
  imageW := 800
  layers := [
    -- ResNet-50 backbone (stem + 4 stages)
    .convBn 3 64 7 2 .same,
    .maxPool 2 2,
    .bottleneckBlock 64 256 3 1,                -- C2
    .bottleneckBlock 256 512 4 2,               -- C3
    .bottleneckBlock 512 1024 6 2,              -- C4
    .bottleneckBlock 1024 2048 3 2,             -- C5
    -- Channel projection (2048→256) + flatten spatial + pos embed
    .patchEmbed 2048 256 1 625,                 -- 25×25 = 625 tokens
    .transformerEncoder 256 8 2048 6,           -- 6 encoder blocks, 8 heads
    .transformerDecoder 256 8 2048 6 100,       -- 6 decoder blocks, 100 queries
    .detrHeads 256 91                           -- class (→92) + box (→4)
  ]

11.2.3 Semantic segmentation

Pixel-level labeling. Symmetric encoder/decoder with skip connections is the recurring pattern; UNet is the canonical instance, and the same shape later became the diffusion-model backbone.

UNet on Pets, in passing.

The demos/MainUnetPetsTrain.lean trainer in this repo is a working encoder-decoder UNet at 7.85M params on Oxford-IIIT Pets (\(224 \times 224\) RGB \(\to \) 3-class trimap: animal, background, boundary). After \(\sim \)50 epochs the predicted masks track ground-truth shapes closely on held-out validation images (see demos/README.md for the rendered grid). The two codegen primitives that make it work — bilinear upsample (forward + VJP) and channel concat — are exposed at the Layer level as .unetDown and .unetUp, which the canonical bestiary UNet below also uses. Three things change relative to the image classifiers of Part 2:

  • The label is a tensor, not a scalar. Classifiers emit one logit vector per image (shape \(B \times K\)) and compare it to one ground-truth class. Segmentation networks emit one logit vector per pixel (shape \(B \times K \times H \times W\), with \(H \times W\) matching the input) and compare against a shape-identical ground-truth mask. Cross-entropy is computed per pixel, averaged spatially. This rides the same useSeg flag in TrainConfig that the NLP subsection’s TinyGPT uses for per-token loss, applied across the spatial axes instead of the token axis — no new VJP machinery; the per-pixel softmax-CE backward is the same one Ch 3 derived, fired \(HW\) times in parallel.

  • The network has to give spatial resolution back. A classifier downsamples aggressively (input \(224^2\), bottleneck \(7^2\), GAP, \(K\)-way logits) and discards spatial structure on purpose. A segmentation network downsamples and then upsamples, with skip connections from each encoder stage spliced into the matching decoder stage so the output recovers pixel-level detail the bottleneck threw away. That symmetry is exactly what .unetUp’s body bundles — bilinear upsample + channel concat with the saved skip + two convs — and it is the only architectural piece this section’s backbones add to the Part 2 kit.

  • Inference is one forward pass; post-processing differs. Segmentation doesn’t autoregress; one network call produces every pixel label simultaneously. What changes is what you do with the logits: argmax per pixel and render as a colored mask. The pets-predict executable is the smallest example of that path — same vmfb the trainer produced, plus a small C colormap routine.

The encoder-decoder-with-skips shape is not specific to segmentation either: it later became the diffusion-model backbone (DDPM, Stable Diffusion, in “Image generation” below), and the same two primitives carry that subsection’s generative pipeline. The layer kit is shared; only the loss and the inference procedure change.

UNet (Bestiary/UNet.lean)

Two new primitives (§ 95, § 96). Skip connections are implicit — the \(i\)-th unetUp from the bottom pairs with the \(i\)-th unetDown from the top. Variants: original (1-channel, 2-class), RGB, small, tiny. Original lands at 31M params matching Ronneberger (Ronneberger et al. 2015, arXiv:1505.04597).

def unet : NetSpec where
  name   := "UNet (original, grayscale → 2-class)"
  imageH := 512
  imageW := 512
  layers := [
    .unetDown 1   64,                           -- encoder stage 1
    .unetDown 64  128,                          -- encoder stage 2
    .unetDown 128 256,                          -- encoder stage 3
    .unetDown 256 512,                          -- encoder stage 4
    .convBn 512 1024 3 1 .same,                 -- bottleneck part 1
    .convBn 1024 1024 3 1 .same,                -- bottleneck part 2
    .unetUp 1024 512,                           -- decoder (skip: encoder 4)
    .unetUp 512 256,                            --         (skip: encoder 3)
    .unetUp 256 128,                            --         (skip: encoder 2)
    .unetUp 128 64,                             --         (skip: encoder 1)
    .conv2d 64 2 1 .same .identity              -- 1×1 conv to 2 classes
  ]

DeepLab v3+ (Bestiary/DeepLabV3Plus.lean)

One new primitive (§ 97). The pre-transformer segmentation workhorse (Chen et al. 2018, arXiv:1802.02611) — still deployed widely in remote sensing, medical imaging, and autonomous-driving perception pipelines. Two ideas: atrous (dilated) convolutions in the backbone’s last stage (param-count-free receptive-field expansion) + an ASPP module for dense multi-scale context at the deepest feature resolution. The “+” in v3+ adds a lightweight decoder that upsamples the ASPP output 4\(\times \) and concatenates a low-level skip from backbone stage 2. Variants: ResNet-101 backbone (59M, paper \(\sim \)63M), MobileNet v2 backbone (5.7M, paper \(\sim \)6M mobile variant), tinyDeepLab fixture. The ASPP skip-to-stage-2 in the decoder doesn’t linearize cleanly; same hack as UNet / WaveNet use. SegFormer (next entry) argues “do ASPP’s multi-scale context via a transformer pyramid”; different mechanism, same goal.

def deeplabv3plusResnet101 : NetSpec where
  name   := "DeepLab v3+ (ResNet-101 backbone)"
  imageH := 513
  imageW := 513
  layers := [
    .convBn 3 64 7 2 .same,                     -- stem
    .maxPool 3 2,
    -- ResNet-101 body; stage 4 uses stride 1 (atrous in real impl)
    .bottleneckBlock 64   256  3  1,
    .bottleneckBlock 256  512  4  2,
    .bottleneckBlock 512  1024 23 2,
    .bottleneckBlock 1024 2048 3  1,
    .asppModule 2048 256,                       -- ASPP: 5 branches + fusion
    -- Decoder (skip from stage 2 doesn't linearize; prose notes it)
    .conv2d 256 256 3 .same .relu,
    .conv2d 256 256 3 .same .relu,
    .conv2d 256 21 1 .same .identity            -- Pascal VOC: 21 classes
  ]

SegFormer (Bestiary/SegFormer.lean)

One new primitive (§ 89). Semantic segmentation (Xie et al. 2021, arXiv:2105.15203) via a hierarchical transformer pyramid (MiT encoder: 4 stages of .transformerEncoder glued by .patchMerging) and a lightweight MLP decoder (a handful of .dense calls). The decoder stays trivially small across all B0–B5 encoder sizes — the design argument of the paper is precisely that a good pretrained transformer feature pyramid makes the segmentation head cheap. Variants: MiT-B0 encoder (2.6M, paper 3.7M), B2 (19M, paper 25M), B5 (61M, paper 82M), shared MLP decoder (1M single-scale approx), tiny fixture. Uniform \(\sim \)25% undercount across all encoder sizes because real MiT uses depthwise convs inside the FFN and overlapping patch embeddings at inter-stage transitions; our .transformerEncoder and .patchMerging approximate the shape without those details. Comparison to DeepLab v3+: SegFormer’s decoder is \(\sim \)3M params across all sizes, DeepLab’s ASPP module is \(\sim \)15M and needs per-receptive-field tuning.

def segformerB2 : NetSpec where
  name   := "SegFormer MiT-B2 encoder"
  imageH := 224
  imageW := 224
  layers := [
    -- Stage 1: 224 → 56, 64 channels, 3 blocks
    .patchEmbed 3 64 4 (56 * 56),
    .transformerEncoder 64 1 256 3,
    -- Stage 2: 56 → 28, 128 channels, 4 blocks
    .patchMerging 64 128,
    .transformerEncoder 128 2 512 4,
    -- Stage 3: 28 → 14, 320 channels, 6 blocks
    .patchMerging 128 320,
    .transformerEncoder 320 5 1280 6,
    -- Stage 4: 14 → 7, 512 channels, 3 blocks
    .patchMerging 320 512,
    .transformerEncoder 512 8 2048 3
  ]

SAM (Bestiary/SAM.lean)

One new primitive (§ 93). Promptable segmentation (Kirillov et al. 2023, arXiv:2304.02643): a ViT image encoder runs once per image, a tiny prompt encoder tokenizes clicks / boxes / masks, a lightweight transformer mask decoder cross-attends between image tokens, prompt tokens, and a handful of learned output queries. Image encoder is .patchEmbed + .transformerEncoder (the ViT kit); mask decoder is a small .transformerDecoder (from DETR, 4 queries, 2 blocks). Variants: SAM ViT-B encoder (88M, paper total 91M), ViT-L (307M, paper 308M), ViT-H (635M, paper 636M), shared 3.3M mask decoder, tinySAM fixture. The image encoder accounts for \(\sim \)99% of each variant’s parameter budget, which is why EfficientSAM (Xiong et al. 2023) focused its distillation there.

def samEncoderH : NetSpec where
  name   := "SAM ViT-H image encoder"
  imageH := 1024
  imageW := 1024
  layers := [
    -- 16×16 patches on 1024×1024 → 64×64 = 4096 tokens
    .patchEmbed 3 1280 16 4096,
    .transformerEncoder 1280 16 5120 32          -- 32 blocks, dim 1280
  ]

11.2.4 Image generation

Networks whose output is a novel image. Two families covered here: variational autoencoders and GANs. Backbones overlap heavily with segmentation (UNet again). Diffusion models, also generative, get their own subsection further down.

VAE (Bestiary/VAE.lean)

One new primitive (§ 96). The classical variational autoencoder (Kingma & Welling 2013, arXiv:1312.6114): encoder outputs \((\mu , \log \sigma ^2)\) (represented as a single tensor with doubled final width), the reparameterization trick \(z = \mu + \sigma \odot \epsilon \) samples \(z\) in training code, the decoder reconstructs. KL divergence between the learned latent distribution and a standard normal is the regularizer. Variants: MNIST MLP VAE (20-dim latent, textbook example), CIFAR conv VAE (\(4 \times 4 \times 4\) spatial latent, SD-style), tiny fixture. All shown as encoder + decoder pairs. Training-code details (the actual sampling step, the KL loss) live outside the NetSpec. The same architectural template scales up to Stable Diffusion’s VAE (see StableDiffusion.lean); 2013’s 20-dim MNIST latent and 2022’s \(64 \times 64 \times 4\) SD latent are the same idea at different scales. VQ-VAE (discrete-codebook variant) and \(\beta \)-VAE (scaled KL) are mentioned in the prose notes.

def mnistVAEEncoder : NetSpec where
  name   := "MNIST VAE encoder (MLP, 20-dim latent)"
  imageH := 28
  imageW := 28
  layers := [
    .flatten,
    .dense 784 400 .relu,
    -- Output is 2×20 = 40: first 20 are μ, last 20 are log σ².
    .dense 400 40 .identity
  ]

DCGAN (Bestiary/DCGAN.lean)

Zero new primitives. Deep Convolutional GAN (Radford et al. 2015, arXiv:1511.06434) — the paper that made GAN training reliably work. Eight design guidelines (strided convs instead of pooling, BN everywhere except \(G\)’s output / \(D\)’s input, no hidden FC layers, ReLU in \(G\) / LeakyReLU in \(D\), Adam with specific hyperparams) that became the default for every GAN paper since. Three NetSpecs: noise projector (dense \(100 \to 4 \times 4 \times 1024\), 1.65M), generator convs (11M, paper \(\sim \)12M), discriminator (11M, paper-exact). Transposed convs in \(G\) are approximated by standard convs of matching kernel and channels (same params, spatial doubling is forward-pass-only); same hack appears in VAE and Stable Diffusion entries. GAN training dynamics (mode collapse, equilibrium stability) are all training-procedure concerns living outside the NetSpec.

def dcganGenerator : NetSpec where
  name   := "DCGAN generator (64×64 RGB)"
  imageH := 4                                     -- starting spatial (post-projection)
  imageW := 4
  layers := [
    -- Each .convBn stands in for a 4×4 transposed conv doubling spatial;
    -- same params (ic × oc × k²), upsampling is a forward-pass detail.
    .convBn 1024 512 4 1 .same,                  --  4×4 →  8×8
    .convBn 512  256 4 1 .same,                  --  8×8 → 16×16
    .convBn 256  128 4 1 .same,                  -- 16×16 → 32×32
    .convBn 128   64 4 1 .same,                  -- 32×32 → 64×64
    .conv2d   64   3 4 .same .identity           -- RGB output (tanh in real)
  ]

Pix2Pix (Bestiary/Pix2Pix.lean)

Two new primitives (§ 95, § 96). The paired-data ancestor of CycleGAN, from the same lab 9 months earlier (Isola et al. 2017, arXiv:1611.07004). UNet generator (8 levels, \(\sim \)70M approx vs paper \(\sim \)54M — our .unetDown / .unetUp use 2 convs per level where Pix2Pix uses 1 strided conv, so we overcount) + PatchGAN discriminator (identical to CycleGAN’s, 2.8M). Trained with GAN loss + L1 reconstruction — the L1 term is a direct supervision signal that exists only because pairs exist. When you don’t have pairs you fall back to CycleGAN’s cycle-consistency trick. Hardware context: 2016–2017, 54M UNet on \(256 \times 256\) images meant batch size 1 on a GTX 1080 Ti, which is how InstanceNorm became the default normalizer in this lineage — it was what fit.

def pix2pixGenerator : NetSpec where
  name   := "Pix2Pix generator (8-level UNet, 256×256 RGB)"
  imageH := 256
  imageW := 256
  layers := [
    -- Encoder: 8 × unetDown, channels 64 → 128 → 256 → 512 → 512 ×4
    .unetDown 3    64, .unetDown 64  128,
    .unetDown 128 256, .unetDown 256 512,
    .unetDown 512 512, .unetDown 512 512,
    .unetDown 512 512, .unetDown 512 512,         -- bottleneck at 1×1×512
    -- Decoder: 8 × unetUp mirroring the encoder
    .unetUp 512 512, .unetUp 512 512,
    .unetUp 512 512, .unetUp 512 512,
    .unetUp 512 512, .unetUp 512 256,
    .unetUp 256 128, .unetUp 128  64,
    .conv2d 64 3 1 .same .identity                -- 1×1 to RGB (tanh in real)
  ]

CycleGAN (Bestiary/CycleGAN.lean)

Zero new primitives. Unpaired image translation (Zhu et al. 2017, arXiv:1703.10593): two generators \(G : X \to Y\) and \(F : Y \to X\) plus two PatchGAN discriminators, trained with adversarial loss + cycle consistency loss \(\| F(G(x)) - x\| _1\). The cycle constraint is what makes unpaired training work — without it, \(G\) could mode-collapse every \(x\) to one target. Generator is a Johnson-style ResNet (Johnson et al. 2016, arXiv:1603.08155): convs down + 9 .residualBlocks at 256 channels + convs up (11.4M, paper \(\sim \)11M). Discriminator is a PatchGAN: 5 strided convs with a \(70 \times 70\) receptive field, outputs an \(N \times N\) grid of patch real/fake logits (2.8M, paper \(\sim \)2.8M). Four-network pattern shown as two specs (\(G\) and \(D\)); the other \(F\) and \(D_X\) are architecturally identical copies. The “one clever loss” does the work — the architecture is quite ordinary.

def cycleganGenerator : NetSpec where
  name   := "CycleGAN generator (9-block ResNet, 256×256 RGB)"
  imageH := 256
  imageW := 256
  layers := [
    .convBn 3 64 7 1 .same,                       -- 7×7 stem
    .convBn 64  128 3 2 .same,                    -- 256 → 128
    .convBn 128 256 3 2 .same,                    -- 128 →  64
    .residualBlock 256 256 9 1,                   -- 9 res blocks (bottleneck)
    .convBn 256 128 3 1 .same,                    -- upsample stand-in
    .convBn 128  64 3 1 .same,
    .conv2d 64 3 7 .same .identity                -- 7×7 to RGB
  ]

11.2.5 Reinforcement learning

Two-headed (policy + value) networks wrapped in a self-play + MCTS outer loop. The architectural side is a stack of residual CNN blocks; the complexity lives in the outer loop, not the network.

AlphaGo (Bestiary/AlphaGo.lean)

Zero new primitives. The original 2016 Lee Sedol system (Silver et al. 2016, Nature). Three separate networks: a 13-layer conv policy network (\(\sim \)3.9M) trained first on 30M human games then fine-tuned via self-play, a 13-layer conv + FC-head value network (\(\sim \)4M), and a shallow linear rollout policy used inside MCTS for fast tree playouts. Input is 48 hand-crafted Go feature planes (liberties, ladder patterns, 3\(\times \)3 stone arrangements, etc.). AlphaGo Zero (next entry) throws all of this out — raw board only, one two-headed net, self-play only — and plays better (5185 Elo vs 3140). The one-sentence lesson is written into every follow-up paper: features the network can learn on its own, it will.

def alphaGoPolicyNet : NetSpec where
  name   := "AlphaGo policy network"
  imageH := 19
  imageW := 19
  layers := [
    -- 13 conv layers: one 5×5 stem + twelve 3×3 at 192 channels
    .convBn 48 192 5 1 .same,                     -- 48 hand-crafted planes in
    .convBn 192 192 3 1 .same, .convBn 192 192 3 1 .same,
    .convBn 192 192 3 1 .same, .convBn 192 192 3 1 .same,
    .convBn 192 192 3 1 .same, .convBn 192 192 3 1 .same,
    .convBn 192 192 3 1 .same, .convBn 192 192 3 1 .same,
    .convBn 192 192 3 1 .same, .convBn 192 192 3 1 .same,
    .convBn 192 192 3 1 .same,
    -- 1×1 conv to per-position move logit (reshape + pass scalar downstream)
    .conv2d 192 1 1 .same .identity
  ]

AlphaZero (Bestiary/AlphaZero.lean)

Zero new primitives — convBn + residualBlock + conv2d + dense. Two-headed (policy + value) network (Silver et al. 2018, arXiv:1712.01815), expressed as two separate NetSpec values sharing the body in prose. Variants: AlphaGo Zero (Go), AlphaZero chess, tiny fixture.

def alphaGoZeroPolicy : NetSpec where
  name   := "AlphaGo Zero (policy head)"
  imageH := 19
  imageW := 19
  layers := [
    .convBn 17 256 3 1 .same,                     -- 17 raw-history planes → 256
    .residualBlock 256 256 19 1,                  -- 19 residual blocks (shared body)
    .conv2d 256 2 1 .same .identity,              -- policy head: 2 filters
    .flatten,
    .dense (2 * 19 * 19) 362 .identity            -- 361 board moves + pass
  ]

The value head shares the first two layers (body) and diverges at the head — one filter, two-layer MLP down to a single scalar (tanh is applied downstream; Activation doesn’t include it):

def alphaGoZeroValue : NetSpec where
  name   := "AlphaGo Zero (value head)"
  imageH := 19
  imageW := 19
  layers := [
    .convBn 17 256 3 1 .same,                     -- shared body
    .residualBlock 256 256 19 1,                  -- shared body
    .conv2d 256 1 1 .same .identity,              -- value head: 1 filter
    .flatten,
    .dense (1 * 19 * 19) 256 .relu,
    .dense 256 1 .identity                        -- → scalar (tanh applied downstream)
  ]

The chess and tiny variants follow the same shape with different (planes, blocks, board) parameters — 119/40/8 for chess, 17/3/9 for the tiny fixture.

MuZero (Bestiary/MuZero.lean)

Zero new primitives — three ResNet-style networks (representation, dynamics, prediction) reusing convBn + residualBlock + dense. The architectural novelty is the three-network factoring, not any single layer type (Schrittwieser et al. 2020, arXiv:1911.08265). Five NetSpec values per variant (rep, dyn body, dyn reward head, pred policy, pred value). Variants: Go, Atari (representation only), tiny.

def muZeroGoRepresentation : NetSpec where
  name   := "MuZero Go — representation h"
  imageH := 19
  imageW := 19
  layers := [
    .convBn 17 256 3 1 .same,                     -- observation → hidden
    .residualBlock 256 256 16 1                   -- 16 ResBlocks (AlphaZero-shaped body)
  ]

The dynamics network \(g\) takes the hidden state plus a one-channel action plane and predicts the next hidden state. A separate reward head branches off the same trunk and emits a scalar:

def muZeroGoDynamics : NetSpec where
  name   := "MuZero Go — dynamics g (next-state path)"
  imageH := 19
  imageW := 19
  layers := [
    .convBn 257 256 3 1 .same,                    -- 256 hidden + 1 action plane
    .residualBlock 256 256 16 1                   -- → next hidden (256, 19, 19)
  ]

def muZeroGoDynamicsReward : NetSpec where
  name   := "MuZero Go — dynamics g (reward head)"
  imageH := 19
  imageW := 19
  layers := [
    .conv2d 256 1 1 .same .identity,
    .flatten,
    .dense (1 * 19 * 19) 128 .relu,
    .dense 128 1 .identity                        -- scalar reward
  ]

The prediction network \(f\) is a small two-headed body — same shape as AlphaZero’s policy/value heads, with two fewer ResBlocks because the trunk has already done most of the work in \(h\):

def muZeroGoPredictionPolicy : NetSpec where
  name   := "MuZero Go — prediction f (policy head)"
  imageH := 19
  imageW := 19
  layers := [
    .residualBlock 256 256 2 1,                   -- small head body
    .conv2d 256 2 1 .same .identity,
    .flatten,
    .dense (2 * 19 * 19) 362 .identity            -- 361 moves + pass
  ]

def muZeroGoPredictionValue : NetSpec where
  name   := "MuZero Go — prediction f (value head)"
  imageH := 19
  imageW := 19
  layers := [
    .residualBlock 256 256 2 1,
    .conv2d 256 1 1 .same .identity,
    .flatten,
    .dense (1 * 19 * 19) 256 .relu,
    .dense 256 1 .identity                        -- scalar value
  ]

11.2.6 Natural language processing

The architectures here all consume a sequence of token IDs and emit either a sequence of next-token logits (decoder-only: GPT) or a single classification head (encoder-only: BERT). At the layer level they reduce to one of three shapes: a stack of .transformerEncoder (with or without a causal mask), a state-space stack (.mambaBlock), or a conv/attention hybrid (.separableConv plus .transformerEncoder, QANet). What’s different from the vision side of the bestiary isn’t the layer kit — the kit is the same set of primitives Ch 10 introduced — it’s everything around the layers: the data pipeline, the loss, and the inference loop.

TinyGPT, in passing.

The demos/MainTinyGptShakespeare.lean trainer in this repo is a working char-level transformer at 212K params (\(T{=}64\), \(D{=}64\), 4 layers, 2 heads). It trains in \(\sim \)11 minutes on gfx1100 for 10K Adam steps, hits 1.45 nats/char on Karpathy’s tinyshakespeare, and sampling after that produces recognizable Shakespeare cadence with proper character names (KING RICHARD, ROMEO, POMPEY) and multi-line dialog. Three things change relative to the image trainers shown throughout Part 2, and they capture what NLP training is:

  • Input is a vocabulary index, not a pixel. The first layer of every NLP spec is .tokenPositionEmbed (one-hot \(\to \) learnable token embedding \(+\) learnable position embedding). Image stacks open with .conv2d or .patchEmbed acting on \(H \times W \times 3\) floats; the NLP stack starts with vocabulary IDs flowing through a lookup table. Replace patches with tokens and the rest of a ViT is shape-identical to a GPT.

  • Loss is per-token, averaged over the sequence. Image classifiers compute one cross-entropy at the end of the network: pool to a vector, dense to logits, compare to one ground-truth class. Language models compute cross-entropy at every position — each of the \(T\) tokens predicts the next one, and the per-token losses are averaged. Our trainer reuses the same cross-entropy path through a flag (useSeg in TrainConfig) that switches the loss from “one vector versus one class” to “\(T\) vectors versus \(T\) classes.” No new VJP machinery; the per-token softmax-CE backward is the same one Ch 3 derived, applied \(T\) times in parallel.

  • Inference is autoregressive. The training loop is a single forward pass per batch, just like image classification, but inference is a \(T\)-step loop: the network predicts token \(t{+}1\) from tokens \([0..t]\), the sampled token is appended to the context, and the network runs again for \(t{+}2\). Our tinygpt-shakespeare sample executable implements exactly this loop — \(\sim \)30 lines of plain Lean, no proof machinery, calling the same compiled vmfb the trainer used. The architectural cost is the causal mask: during training, attention at position \(t\) must not see positions \(t{+}1\) onward, or the model would trivially predict the answer it can see. We expose this via the causalMask flag on .transformerEncoder; the rest of the kit is shared with ViT.

That’s the entire delta. Same Adam optimizer, same VJP machinery, same compile-to-IREE pipeline. The bestiary entries below all reduce to either “GPT shape” (decoder-only, causal mask, autoregressive sampling) or “BERT shape” (encoder-only, bidirectional attention, classification head); the architectural differences between them are smaller than the chapter-to-chapter deltas in Part 2.

Mamba (Bestiary/Mamba.lean)

One new primitive (§ 98). Variants: Mamba-130M / 370M / 790M / tiny, matching Gu & Dao’s (2023, arXiv:2312.00752) param counts within \(\sim \)5%.

def mamba130M : NetSpec where
  name   := "Mamba-130M"
  imageH := 2048                                  -- context length (L tokens)
  imageW := 1
  layers := [
    -- dim=768, state=16, expand=2, 24 blocks
    .mambaBlock 768 16 2 24,
    .dense 768 50280 .identity                    -- LM head (GPT-NeoX vocab)
  ]

BERT / RoBERTa (Bestiary/BERT.lean)

Zero new primitives — uses .transformerEncoder, same kit as ViT / DETR, plus .dense vocab\(\to \)dim standing in for the token-embedding table (faithful param count, shape semantics cheat since a linear NetSpec can’t express the \(L \to L \times D\) lookup). Variants: BERT-base (109M, paper 110M), BERT-large (335M, paper 340M), RoBERTa-base (124M, paper 125M), RoBERTa-large (355M, paper-exact), tinyBERT fixture. The architectural lesson is that RoBERTa (Liu et al. 2019, arXiv:1907.11692) = BERT (Devlin et al. 2018, arXiv:1810.04805); all RoBERTa gains came from training procedure (dynamic masking, more data, bigger batches, dropped NSP) — none of which lives in the NetSpec.

def bertBase : NetSpec where
  name   := "BERT-base"
  imageH := 512                                   -- max context length
  imageW := 1
  layers := [
    -- Token-embedding lookup approximated as dense (vocab → D)
    .dense 30522 768 .identity,
    -- 12 encoder blocks, post-norm, GELU, mlpDim = 4·D
    .transformerEncoder 768 12 3072 12,
    -- [CLS] pooler (tanh applied downstream)
    .dense 768 768 .identity
  ]

GPT-1 / GPT-2 (Bestiary/GPT.lean)

Zero new primitives. GPT-1 (Radford et al. 2018, OpenAI TR) and GPT-2 (Radford et al. 2019, OpenAI TR) are decoder-only counterparts of BERT. Same .transformerEncoder, now read as a stack of decoder blocks with a causal attention mask (a training-time detail, not a parameter). No pooler; GPT-2 uses pre-norm instead of BERT’s post-norm (zero-parameter swap). Weight tying: the LM head reuses the token-embedding matrix, so our .dense vocab\(\to \)D stand-in already pays for both input and output sides. Variants: GPT-1 (116M, paper 117M), GPT-2 small (123M, paper 124M), medium (353M), large (772M, paper 774M), XL (1.56B, paper 1.5B), tinyGPT fixture. Reference implementation: Karpathy’s nanoGPT (\(\sim \)300 lines of PyTorch) targets GPT-2 small and is the canonical mental model.

def gpt2Small : NetSpec where
  name   := "GPT-2 small"                         -- the nanoGPT target
  imageH := 1024                                  -- context length
  imageW := 1
  layers := [
    -- Tied token embedding (also serves as LM head projection)
    .dense 50257 768 .identity,
    -- 12 decoder blocks, pre-norm (mask is attention-time, not a param)
    .transformerEncoder 768 12 3072 12
  ]

QANet (Bestiary/QANet.lean)

One new primitive (.separableConv). The 2018 reading-comprehension architecture (Yu et al. 2018, arXiv:1804.09541) that killed the BiLSTM for SQuAD-style tasks. Core contribution: an encoder block combining 4 depthwise- separable convs (local context) with a self-attention + FFN transformer block (global context) — the “conv + attention hybrid” shape MobileViT and ConvNeXt rediscovered in 2022. Expressed with primitives we already have: 4 .separableConv 128 128 1 calls + 1 .transformerEncoder 128 8 512 1 per block. Per-block count: \(\sim \)270K; the paper’s model encoder stack (7 blocks) lands at \(\sim \)1.9M, repeated 3 times in the full architecture. The BiDAF- style context-query attention and character/word embedding tables are omitted from the spec; described in prose. QANet’s headline number was 3–4\(\times \) training speedup over BiLSTM-based competitors, a clean example of hardware (parallelization of convs vs. sequential RNN roll-outs) forcing an architectural choice. BERT landing 7 months later ended the SQuAD-as-benchmark era, but QANet’s hybrid shape lived on — showed up repeatedly in vision 4 years later.

def qanetEncoderBlock : NetSpec where
  name   := "QANet encoder block (4× sep-conv + self-attn + FFN)"
  imageH := 400                                   -- representative context length
  imageW := 1
  layers := [
    -- 4 depthwise-separable convs (local context, kernel 7 in paper)
    .separableConv 128 128 1,
    .separableConv 128 128 1,
    .separableConv 128 128 1,
    .separableConv 128 128 1,
    -- Self-attention + FFN + 2 LayerNorms (global context)
    .transformerEncoder 128 8 512 1
  ]

Nyströmformer (Bestiary/Nystromformer.lean)

Zero new primitives. An efficient-attention transformer (Xiong et al. 2021, arXiv:2102.03902) that replaces the \(O(n^2)\) softmax attention with an \(O(n)\) Nyström-approximation computed via \(m\) landmark tokens, \(m \ll n\). Crucially, this is a compute change, not a parameter change: the \(W_Q / W_K / W_V / W_O\) projections and the FFN are identical to a standard transformer. Our .transformerEncoder spec is therefore identical to BERT’s at each scale — base lands at 108M (vs BERT-base 109M), large at 333M (vs BERT-large 335M). The entry’s pedagogical value is prose-level: most of the 2020–2022 efficient-attention literature (Linformer, Performer, Longformer, BigBird, Reformer, FlashAttention) is parameter-identical to vanilla attention, and the bestiary can’t differentiate them at the layer level. Nyströmformer is worth highlighting because the specific trick — a 1928 result from numerical linear algebra, routed through kernel-method literature in the early 2000s, finally surfacing in transformers — is a genuinely long arc for one math idea.

def nystromformerBase : NetSpec where
  name   := "Nyströmformer base (BERT-base-shaped, O(n) attention)"
  imageH := 4096                                  -- long context
  imageW := 1
  layers := [
    -- Architecturally identical to BERT-base; Nyström approximation
    -- lives inside the attention kernel, not the NetSpec.
    .dense 30522 768 .identity,
    .transformerEncoder 768 12 3072 12
  ]

11.2.7 Diffusion

Generative models that predict the noise added to a clean image, with inference iterating the reverse (noise in, image out). The architecture is a UNet — often with cross-attention bolted in for text conditioning. The novelty lives in the noise schedule and the sampling loop, not in layer design.

DDPM on MNIST, in passing.

The demos/MainMnistDdpmTrain.lean trainer (with CIFAR-10, sinusoidal-time-embedding, and attention variants alongside it) is a working denoising diffusion model: a timestep-conditioned UNet trained to predict the Gaussian noise added to an image. The mnist-ddpm-sample executable then runs the reverse loop — a 50-step DDIM schedule subsampled from the \(T = 1000\) training schedule — starting from pure noise and producing MNIST digits (see demos/README.md for the rendered samples). The backbone is literally the segmentation UNet (§ 95, § 96) of the subsection above; three things change relative to the classifiers, segmentation nets, and detectors so far:

  • The target is the noise, and it is self-supervised — there are no labels. Every model so far compares against a human-provided label: a class, a per-pixel mask, a set of boxes. Diffusion manufactures its own target: sample a random timestep \(t\), sample Gaussian noise \(\varepsilon \), add it to a clean image per the noise schedule, and train the network to predict \(\varepsilon \). The loss is a plain per-pixel MSE between predicted and actual noise (lossKind := .floatTargetMse) — the same squared-error backward from Part 2, but against a target generated on the fly rather than read from disk.

  • The backbone is the segmentation UNet, conditioned on the timestep. One network has to denoise at every level \(t \in [0, T)\), so \(t\) is fed in: the tiny MNIST demo concatenates it as an extra input channel (hence the leading .unetDown 2 16 — one image channel plus one time channel), and the CIFAR sinusoidal variant embeds it (sinusoidal \(\to \) dense \(\to \) dense, reusing the § 100 primitive NeRF introduced). The encoder-decoder-with-skips shape is exactly the segmentation kit — diffusion adds only the conditioning, not new layers.

  • Inference is an iterative sampling loop, not one forward pass. Classifiers, segmentation, and detection each produce their output in a single network call. Diffusion starts from pure Gaussian noise and iterates the reverse process — predict the noise, step toward a cleaner image, repeat — here 50 DDIM steps. The *-ddpm-sample executables run that loop on the same vmfb the trainer produced and write a PPM; the novelty is entirely in the schedule and the loop, not in any layer.

The shape carries all the way up: the bestiary DDPM entry below is the same UNet at CIFAR scale, and Stable Diffusion is the same UNet run on VAE latents with text cross-attention spliced in. Encoder-decoder-with-skips plus a sampling loop is the whole generative pipeline.

DDPM (Bestiary/Diffusion.lean)

Three new primitives (§ 95, § 96, § 100). The denoiser in Ho et al.’s DDPM (2020, arXiv:2006.11239) is a UNet — literally the same § 95 and § 96 Ronneberger shipped in 2015. Everything “diffusion” lives in the training loop: a forward noise schedule adds Gaussian noise over \(T\) steps, a reverse process learns to predict the added noise, sampling iterates the reverse from pure noise back to a clean image. None of that is a layer. The timestep conditioning MLP (sinusoidal \(\to \) dense \(\to \) dense) is shown as a standalone NetSpec that reuses positionalEncoding from NeRF. Variants: CIFAR config (32x32 backbone), 256x256 high-res config, tiny fixture, timestep embed. Our simplified backbone undercounts vs paper (paper DDPM adds residual blocks, GroupNorm, attention-at-low-res, and per-block time-embedding projection on top of the UNet); the spec’s value is showing the architectural shape, not an exact param match. The lesson is the same one CLIP and NeRF taught: the novelty lives in the training procedure, not in layer design.

def ddpmCifar : NetSpec where
  name   := "DDPM (CIFAR-10, backbone approx)"
  imageH := 32
  imageW := 32
  layers := [
    -- Encoder: 32 → 16 → 8 → 4
    .unetDown 3   128,                            -- 32 → 16
    .unetDown 128 256,                            -- 16 →  8
    .unetDown 256 256,                            --  8 →  4
    .convBn 256 256 3 1 .same,                    -- bottleneck
    .convBn 256 256 3 1 .same,
    -- Decoder: 4 → 8 → 16 → 32
    .unetUp 256 256, .unetUp 256 256, .unetUp 256 128,
    .conv2d 128 3 1 .same .identity               -- predict added noise (3-ch)
  ]

Stable Diffusion (Bestiary/StableDiffusion.lean)

Three new primitives (§ 95, § 96, § 93). The paper that made generative image models consumer-reachable (Rombach et al. 2022, arXiv:2112.10752). Two architectural moves over DDPM, each individually small: (a) latent diffusion — run the diffusion process on 64\(\times \)64\(\times \)4 VAE latents instead of 512\(\times \)512 pixels, cutting spatial work \(\sim \)64\(\times \); (b) text conditioning via cross-attention — at each interior UNet resolution, insert a spatial transformer block that cross-attends from image tokens to CLIP text embeddings. Shown as six separate NetSpecs: VAE encoder, VAE decoder, CLIP text encoder (123M, matches CLIP ViT-L/14 exactly), UNet backbone (202M backbone approx; real SD 1.5 UNet is 865M, missing \(\sim \)650M is the interleaved cross-attention), an explicit spatial-transformer-block spec (uses .transformerDecoder with nQueries = 0 — same primitive Whisper’s decoder uses, same mechanism as DETR’s decoder applied at image-feature resolution), and a tiny end-to-end fixture. Three components are pretrained and frozen during SD’s main training (VAE + text encoder) or constrained-frozen (CLIP); only the UNet trains. SDXL scales the UNet to 2.6B; SD 3 switches the UNet for a DiT transformer. The latent-diffusion plus text-conditioning template is the same.

def sdUNet15 : NetSpec where
  name   := "SD 1.5 UNet denoiser (backbone approx)"
  imageH := 64                                    -- latent resolution, not pixel
  imageW := 64
  layers := [
    .conv2d 4 320 3 .same .identity,              -- stem (latent → channels)
    -- Encoder: 64 → 32 → 16 → 8
    .unetDown 320  640,
    .unetDown 640  1280,
    .unetDown 1280 1280,
    -- Bottleneck at 8×8 (real SD inserts a spatial transformer here)
    .convBn 1280 1280 3 1 .same,
    .convBn 1280 1280 3 1 .same,
    -- Decoder: 8 → 16 → 32 → 64
    .unetUp 1280 1280,
    .unetUp 1280 1280,
    .unetUp 1280 640,
    .conv2d 640 4 3 .same .identity               -- predict noise (4-ch latent)
  ]

11.2.8 Beyond vision

Architectures where the task domain is neither a 2D image nor a token sequence: audio, 3D scene reconstruction, multimodal embedding, scientific. Several of these (NeRF, CLIP) have essentially no architectural novelty — the interesting work lives in the data, loss, or training procedure, and the bestiary entry exists to make that point.

WaveNet (Bestiary/WaveNet.lean)

One new primitive (§ 99). Dilated causal convolutions for raw audio sample prediction (van den Oord et al. 2016, arXiv:1609.03499): exponential receptive field, linear parameter growth. The foundation of neural TTS and PixelCNN. Variants: single stack (0.4M), 3-stack (paper setup, with a NetSpec-linearity approximation), music (single stack, 4.1M), tiny. One honest limitation: the residual-vs-skip dual-output pattern doesn’t linearize cleanly, so the 3-stack variant uses a simplified channel-flow approximation.

def waveNet : NetSpec where
  name   := "WaveNet (single stack, speech)"
  imageH := 16000                                 -- 1 second @ 16 kHz
  imageW := 1
  layers := [
    -- Input embedding: 256 mu-law bins → 32 residual channels
    .conv2d 256 32 1 .same .identity,
    -- One stack of 10 dilated residual blocks (dilations 2⁰..2⁹)
    .waveNetBlock 32 512 10,
    -- Output head: skip-sum → 1×1 → 1×1, 256-way categorical
    .conv2d 512 256 1 .same .relu,
    .conv2d 256 256 1 .same .identity
  ]

Whisper (Bestiary/Whisper.lean)

One new primitive (§ 93). Encoder-decoder transformer over log-mel spectrograms (Radford et al. 2022, arXiv:2212.04356). Encoder is .transformerEncoder on the 1500 post-stem audio tokens; decoder is .transformerDecoder (from DETR) with nQueries = 0 — a clean seq2seq decoder with self-attn + cross-attn + FFN, text tokens coming from a separate .dense vocab\(\to \)D stand-in tied to the LM head. Variants: tiny (7M enc, paper 39M total), base (19M enc, paper 74M), small (85M enc + 153M dec+emb = 238M vs paper 244M), medium (302M enc, paper 769M), large (629M enc, paper 1.55B total), plus a shared decoder spec and tinyWhisper fixture. Encoder and decoder are the same size per variant; adding them together recovers paper totals within 2–3%. The multitask interface (swap a prefix token to change language or switch between transcribe / translate) is prompt-engineering, not architecture — Whisper’s architectural novelty is essentially zero.

def whisperSmall : NetSpec where
  name   := "Whisper small encoder"
  imageH := 1500                                  -- post-stem audio tokens (3000 / 2)
  imageW := 1
  layers := [
    -- 12 encoder blocks, dim 768, 12 heads, mlp 3072
    .transformerEncoder 768 12 3072 12
  ]

NeRF (Bestiary/NeRF.lean)

Two new primitives (§ 100, § 101). The "it’s literally just an MLP" paper (Mildenhall et al. 2020, arXiv:2003.08934). Under 600K params at canonical config; the architectural novelty is nonexistent. What makes NeRF work is the positional encoding + the volumetric-rendering loss — both outside the network, not layers in it. Variants: canonical (593K), fast (167K hidden=128), tiny fixture.

def nerf : NetSpec where
  name   := "NeRF (canonical)"
  imageH := 1
  imageW := 1
  layers := [
    -- Positional encoding of (x,y,z): 3 coords × 2 × 10 freqs = 60-dim
    .positionalEncoding 3 10,
    -- 8-layer MLP with mid-skip + dual density/RGB heads.
    -- encodedDirDim = 2·2·4 = 16 (2D direction, 4 frequencies)
    .nerfMLP 60 16 256
  ]

CLIP (Bestiary/CLIP.lean)

Zero new primitives. Two textbook encoders glued together by a contrastive loss (Radford et al. 2021, arXiv:2103.00020): a ResNet-50 or ViT-B for images, a 12-layer causal transformer for text, each with a final linear projection to a shared embedding space. Variants: CLIP-RN50, CLIP-ViT-B/32 (151M, paper-exact), CLIP-ViT-L/14 (427M, paper-exact), tiny fixture. The architectural lesson from CLIP is identical to NeRF’s: the novelty lives in data and objective, not in layer design.

def clipViTB32ImageEncoder : NetSpec where
  name   := "CLIP-ViT-B/32 (image encoder)"
  imageH := 224
  imageW := 224
  layers := [
    -- ViT-B/32: patch size 32 → 49 patches, dim 768, 12 blocks, 12 heads
    .patchEmbed 3 768 32 49,
    .transformerEncoder 768 12 3072 12,
    -- [CLS] token → shared 512-dim embedding space
    .dense 768 512 .identity
  ]

LLaVA / LLaVA-1.5 (Bestiary/LLaVA.lean)

Zero new primitives. The cleanest exhibit of the modern vision-language pattern (Liu et al. 2023, arXiv:2304.08485; LLaVA-1.5: arXiv:2310.03744): frozen CLIP ViT encoder + small MLP projector + (mostly) frozen LLaMA/Vicuna language model. Trained in two stages — projector-only pretrain, then joint instruction fine- tune. Shown as separate NetSpecs per component: vision encoder (303M, matching CLIP ViT-L/14), LLaVA-1 single-linear projector (4M), LLaVA-1.5 two-layer MLP projector (21M), LLM backbones at 7B and 13B. The LLM specs undercount real LLaMA by \(\sim \)23% because our .transformerEncoder uses a standard 2-projection FFN while LLaMA uses SwiGLU (3 projections); depth / width / heads still match. Key ratio: the projector is 0.3% of total LLaVA-1.5 7B parameters — the entire interesting work of the model lives in  21M of trainable bolt-on between two pretrained backbones. BLIP-2, Flamingo, and every modern VLM demo generalize this template with fancier adapters (Q-Former, Perceiver resampler, gated xattn-dense), but the frozen-backbone-plus-adapter shape is the same.

def llava15Projector : NetSpec where
  name   := "LLaVA-1.5 projector (2-layer MLP)"
  imageH := 576                                   -- visual token count (24×24)
  imageW := 1
  layers := [
    -- The entire LLaVA architectural contribution: map CLIP's 1024-dim
    -- image tokens into the LLM's 4096-dim embedding space. ~21M params
    -- — roughly 0.3% of total LLaVA-1.5 7B. GELU between is zero-param.
    .dense 1024 4096 .identity,
    .dense 4096 4096 .identity
  ]

AlphaFold 2 Evoformer (Bestiary/Evoformer.lean)

Two new primitives (§ 102, § 103). Dual-representation (MSA + pair) doesn’t fit a linear NetSpec cleanly; the bestiary shows the two bundled primitives and notes the limitation (Jumper et al. 2021, Nature). Variants: full (76M), mini, tiny.

def alphaFold2 : NetSpec where
  name   := "AlphaFold 2 (Evoformer + StructureModule)"
  imageH := 384                                   -- max residues (N_res)
  imageW := 1
  layers := [
    -- 48 blocks of dual-representation processing (MSA + pair).
    -- Each block bundles: MSA row-attn (w/ pair bias), MSA col-attn,
    -- MSA transition, outer-product mean → pair, triangle-mul updates
    -- (out + in), triangle self-attention (start + end), pair transition.
    .evoformerBlock 256 128 48,
    -- Structure Module: 8 recurrent IPA rounds (weights shared).
    .structureModule 384 128 8
  ]

The two primitives are conceptually separate networks chained back to back; the linear NetSpec flattens that.

Evoformer (48 blocks) is the perception/refinement stage. It maintains a dual representation: a multiple-sequence-alignment tensor \(\mathrm{MSA}\! \in \! \mathbb {R}^{s \times r \times c_m}\) (where \(s\) is the number of homologs found by the database search and \(r\) is the number of residues) and a pair tensor \(\mathrm{Pair}\! \in \! \mathbb {R}^{r \times r \times c_z}\) encoding predicted residue-residue relationships. Each block updates both via seven stacked operations (row attention with pair bias, column attention, MSA transition, outer-product-mean into pair, two triangle multiplicative updates, two triangle attentions, pair transition). Forty-eight rounds of that turn raw evolutionary signal into a self-consistent geometry-implied pair representation.

Structure Module (8 recurrent rounds, shared weights) is the geometry stage. It consumes the final pair tensor + a per-residue “single” embedding and emits 3D backbone frames plus side-chain \(\chi \) angles. Each round runs Invariant Point Attention (IPA) over the residues, updates each residue’s local frame relative to its neighbors, and feeds back into the next round. The shared-weights recurrence is unusual — most networks unroll different parameters per layer — and reflects that each round is the same fixed-point iteration on the protein’s 3D structure.

The bundled-primitive shortcut means Lean sees one NetSpec with two layers; the actual graph is a much larger DAG with the dual representation and the cross-stage handoff. § 102 and § 103 document the bundling explicitly.