Verified Deep Learning with Lean 4
Formal Backpropagation from MLP to Attention, via MLIR

7 MobileNetV2

A standard conv does two jobs at once

A standard 3\(\times \)3 convolution from \(C_\text {in}\) channels to \(C_\text {out}\) channels computes, for every output cell:

\[ y_{c_\text {out},\, i,\, j} \; =\; \sum _{c_\text {in},\, k_h,\, k_w} W_{c_\text {out},\, c_\text {in},\, k_h,\, k_w} \cdot x_{c_\text {in},\, i+k_h,\, j+k_w}. \]

That sum is doing two things. The \(k_h, k_w\) part is spatial mixing: it combines a pixel with its neighbors. The \(c_\text {in}\) part is cross-channel mixing: it combines the \(C_\text {in}\) channels into each of \(C_\text {out}\) new channels. Both happen inside the same kernel and at the same cost \(C_\text {in} \times C_\text {out} \times k^2\) weights per layer.

Howard et al. 2017 (MobileNetV1, then V2 in arXiv:1801.04381) asked: what if we factored those two responsibilities and only paid the full cost of one of them? The spatial mixing is locally information-rich and benefits from a real \(3 \times 3\) neighborhood. The channel mixing can be approximated by a much cheaper operation. If we separate them, we get parameter efficiency without losing much expressive power.

Depthwise: one kernel per channel, no cross-channel sum

A depthwise convolution does the spatial half on its own, per channel. Each output channel comes from one input channel, with its own \(k \times k\) kernel:

\[ y_{c,\, i,\, j} \; =\; \sum _{k_h,\, k_w} W_{c,\, k_h,\, k_w} \cdot x_{c,\, i+k_h,\, j+k_w}. \]

Only one summation level (over the kernel positions), not two. No mixing between channels at all—channel \(c\) of the output only ever looks at channel \(c\) of the input. The weight tensor is shape \([C, k, k]\) instead of \([C_\text {out}, C_\text {in}, k, k]\):

\[ \text{depthwise weights: } C \times 3 \times 3 \quad \text{vs}\quad \text{standard weights: } C_\text {out} \times C_\text {in} \times 3 \times 3. \]

For a typical \(C_\text {in} = C_\text {out} = 64\) layer, that’s 576 weights vs 36,864 weights—a 64\(\times \) reduction. The Jacobian from Chapter 4 simplifies in step: one fewer \(\Sigma \) level means the depthwise input-VJP is a slightly simpler reversed-kernel convolution, the weight VJP is a slightly simpler transpose-trick, and the bias VJP is the same per-channel sum. Theorems 37 through 39 prove the depthwise VJPs.

The depthwise-separable sandwich

A depthwise conv alone never lets channels see each other. That’s a problem—the whole point of stacking convolutions is to build hierarchical features that combine information across channels. The fix is to follow each depthwise with a pointwise \(1 \times 1\) convolution, which is purely cross-channel mixing (no spatial structure—each output pixel is a learned linear combination of the same input pixel’s channels). That two-step combination,

\[ \text{depthwise}\ 3 \times 3 \quad \longrightarrow \quad \text{pointwise}\ 1 \times 1, \]

is called a depthwise-separable convolution, and it’s what replaces the standard \(3 \times 3\) conv in MobileNet and everywhere downstream. A standard \(3 \times 3\) over 64\(\to \)64 costs 36,864 weights; the separable version costs 576 (depthwise) + 4,096 (pointwise) = 4,672. Eight times fewer.

Inverted residual: depthwise at the wide point

MobileNetV2 wraps the separable block in one more idea. ResNet’s bottleneck block is wide \(\to \) narrow \(\to \) wide (project down for the expensive spatial conv, project back up for the residual). MobileNetV2 flips it: narrow \(\to \) wide \(\to \) narrow, with the depthwise spatial conv at the wide expansion. The intuition is that depthwise is so cheap that you can afford to make it wider; the narrow bottleneck channels carry the actual information across the residual.

Concretely, an inverted-residual block with input channels \(C_\text {in}\), output channels \(C_\text {out}\), and expansion ratio \(t\):

  1. \(1 \times 1\) conv \(C_\text {in} \to t \cdot C_\text {in}\) (the expand)

  2. \(3 \times 3\) depthwise at \(t \cdot C_\text {in}\) channels (the spatial mix at the wide point)

  3. \(1 \times 1\) conv \(t \cdot C_\text {in} \to C_\text {out}\) (the project, narrowing back down)

  4. Residual skip when \(C_\text {in} = C_\text {out}\) and stride equals 1

This is the .invertedResidual ic oc t stride n primitive in the spec—one Layer constructor that emits all three convs, the BN+ReLU between them, and the conditional residual add. The VJP for the whole block composes the depthwise VJP from this chapter with the standard conv2d VJPs from Chapter 4, the BN VJP from Chapter 5, and the additive fan-in for the residual from Chapter 6. Once again, no new math beyond the foundation rules—just the depthwise primitive plus composition.

Stacking the blocks

MobileNetV2 stacks 17 inverted-residual blocks across seven stages (at depths 1, 2, 3, 4, 3, 3, 1) with widths progressing from 16 to 320 channels, then a final \(1 \times 1\) to 1280 channels, global average pool, and a 1280-to-10 dense head—the same stem-stages-head template as ResNet-34. Total: 2.24M parameters, 9.5\(\times \) fewer than ResNet-34, at a modest accuracy cost (87% vs 90% on Imagenette in the example below). For mobile and embedded deployment that trade is the whole reason MobileNet exists.

The theorems

Definition 36 Depthwise conv forward
#

Concrete per-channel cross-correlation (Phase 7).

Theorem 37 Depthwise input VJP
#

Phase 2 (Apr 2026) proof using the same recipe as conv2d_has_vjp3 with one fewer \(\Sigma \) level (no cross-channel mixing in depthwise).

Proof

Mechanical; see Proofs.depthwise_has_vjp3.

Theorem 38 Depthwise weight VJP

Phase 7: reuses \(\mathsf{HasVJP3}\) directly since DepthwiseKernel is definitionally \(\mathsf{Tensor3}\). Derived from foundation rules.

Proof
Theorem 39 Depthwise bias VJP

Phase 9. Derived from foundation rules.

Proof

7.1 Example: MobileNet V2 on Imagenette

Depthwise convolution on its own is rarely used directly — you almost always see it wrapped in a depthwise-separable sandwich: \(1 \times 1\) expand, \(3 \times 3\) depthwise, \(1 \times 1\) project, with a residual connection around the whole thing. That sandwich is the inverted-residual block that defines MobileNet V2, and it’s how depthwise convolutions got their fame.

The point of depthwise is parameter efficiency. A normal \(3 \times 3\) conv over 64 channels uses \(3 \times 3 \times 64 \times 64 = 36864\) weights. A depthwise version uses \(3 \times 3 \times 64 = 576\) weights. 64\(\times \) fewer. The depthwise-separable block recovers the missing cross-channel expressivity with the surrounding \(1 \times 1\) convs, at a combined cost far below a full \(3 \times 3\). MobileNet V2 is what you get when you stack 17 of these blocks and call it a network.

The architecture

-- 1
import LeanMlir

-- 2
def mobilenetV2 : NetSpec where
  name   := "MobileNet-v2"
  imageH := 224
  imageW := 224
  layers := [
    .convBn 3 32 3 2 .same,                    -- stem 224→112
    .invertedResidual  32  16 1 1 1,           -- 112, expand 1×
    .invertedResidual  16  24 6 2 2,           -- 112→56, t=6
    .invertedResidual  24  32 6 2 3,           -- 56→28, t=6
    .invertedResidual  32  64 6 2 4,           -- 28→14, t=6
    .invertedResidual  64  96 6 1 3,           -- 14, t=6
    .invertedResidual  96 160 6 2 3,           -- 14→7, t=6
    .invertedResidual 160 320 6 1 1,           -- 7, t=6
    .convBn 320 1280 1 1 .same,                -- 1×1 to 1280
    .globalAvgPool,
    .dense 1280 10 .identity
  ]

-- 3
def mobilenetV2Config : TrainConfig where
  learningRate   := 0.001
  batchSize      := 32
  epochs         := 80
  useAdam        := true
  weightDecay    := 0.0001
  cosineDecay    := true
  warmupEpochs   := 3
  augment        := true
  labelSmoothing := 0.1

-- 4
def main (args : List String) : IO Unit :=
  mobilenetV2.train mobilenetV2Config (args.head?.getD "data/imagenette")

The .invertedResidual primitive takes (ic, oc, expandRatio, stride, nBlocks). Internally each block is: \(1 \times 1\) conv to ic*expandRatio channels, \(3 \times 3\) depthwise at that width, \(1 \times 1\) project to oc, plus a residual skip when ic == oc and stride == 1. The VJP is § 37 for the depthwise conv composed with the existing dense / BN / biPath theorems we already proved. Training config is identical to ResNet-34’s (Appendix A): Adam + cosine + warmup + wd + augmentation + label smoothing, the same production recipe.

Results

$ IREE_BACKEND=rocm IREE_CHIP=gfx1100 \
    ./.lake/build/bin/mobilenet-v2-train
MobileNet-v2: 2236682 params
Generating train step MLIR...
  741020 chars
Compiling vmfbs...
  forward compiled
  eval forward (fixed BN) compiled
  compiled
  session loaded
  train: 9469 images (256×256)
  2236682 params + m + v (25 MB)
training: 295 batches/epoch, batch=32, Adam, lr=0.001000,
          cosine, label_smooth=0.1, wd=1e-4
  BN layers: 52, BN stat floats: 34112
  step 0/295: loss=2.327938 (832ms)
Epoch  1/80: loss=1.968284 lr=0.000333 (243839ms)
Epoch  2/80: loss=(dropping) lr=0.000667 ...
...
Epoch 79/80: loss=0.532640 lr=0.000002 (244123ms)
Epoch 80/80: loss=0.532729 lr=0.000000 (244070ms)
  val accuracy (running BN): 3400/3904 = 87.09%

(Full log in logs/mnv2_train.log.)

Final val accuracy 87.09% — 3.2 points below ResNet-34’s 90.29% on the same dataset and training config, at 9.5\(\times \) fewer parameters (2.24M vs 21.29M). That’s the depthwise-separable trade: you give up a couple of accuracy points to get an order-of- magnitude reduction in parameter count. For mobile and embedded deployment, that’s a deal you want.

A few observations worth calling out:

- Per-step time is faster, not slower: 830 ms vs ResNet-34’s 1400 ms. Even though the network has more layers (17 inverted-residual blocks \(\times \) 3 sub-layers each = 51+ internal conv layers, vs ResNet-34’s 34), the depthwise convs are so cheap that overall throughput improves. Fewer parameters and faster training, at a modest accuracy cost.

- MLIR is actually bigger, not smaller: 741 020 chars vs ResNet-34’s 517 912. That’s counterintuitive — the network has fewer params but emits more StableHLO. The reason: each depthwise-separable block has three separate convs (expand, depthwise, project) plus their BN ops, so there are more operations even though each operation has fewer weights. IREE emits per-op, so the char count tracks op count, not param count.

- 52 BN layers, up from ResNet-34’s 36. Same reason — more internal conv layers means more BN-after-conv pairs. Each one still proves its VJP via § 34.

- 9.5\(\times \) fewer parameters, 5.4 hours total training vs ResNet-34’s 9.5 hours. The speedup comes entirely from the smaller per-step time; the epoch count is the same.

7.2 MLIR: Depthwise Convolution

What is already proven. A depthwise convolution applies one \(k\times k\) kernel per channel with no cross-channel sum, so its reverse-mode derivative (§ 37, depthwise_has_vjp3) is a reversed-kernel convolution minus the channel transpose the full-conv VJP of Chapter 4 needs: with no \(\sum _{c_\text {in}}\) to take the adjoint of, there is no in/out channel axis to swap. The inverted-residual block composes this with the pointwise convs, BatchNorm, and relu6; mobilenetv2_has_vjp_at chains the whole network through vjp_comp_at.

The gap, and how we close it. The emitted graph is denoted and shown equal to depthwise_has_vjp3’s backward. Here is what the printer emits for the depthwise input gradient (two channels, \(3\times 3\)):

func.func @dw_back(%dy: tensor<1x2x4x4xf32>, %W: tensor<2x3x3xf32>)
    -> tensor<1x2x4x4xf32> {
  %We = stablehlo.reshape %W
          : (tensor<2x3x3xf32>) -> tensor<2x1x3x3xf32>
  %Wr = stablehlo.reverse %We, dims = [2, 3] : tensor<2x1x3x3xf32>
  %dx = stablehlo.convolution(%dy, %Wr)
      dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1],
      window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
                lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
      {batch_group_count = 1 : i64, feature_group_count = 2 : i64}
      : (tensor<1x2x4x4xf32>, tensor<2x1x3x3xf32>)
        -> tensor<1x2x4x4xf32>
  return %dx : tensor<1x2x4x4xf32>
}

Read it against Chapter 4’s conv_back: that one began with a transpose %W, dims = [1, 0, 2, 3] to swap the in- and out-channel axes before reversing; this one has no transpose. The reshape only adds the singleton input-channel axis the grouped form expects (\([c,k,k] \to [c,1,k,k]\)), the reverse flips the two spatial axes, and feature_group_count = 2 — equal to the channel count — tells the convolution to run each channel through its own filter, no mixing. That single attribute is the depthwise structure; the bridge theorem is the claim that this graph computes depthwise_has_vjp3’s backward.

Caveats.

  • The depthwise convolution bridge is unconditional — it is linear.

  • relu6 is a smooth-point bridge — MobileNetV2’s clamp to \([0,6]\) has no derivative where a pre-activation sits exactly at \(0\) or \(6\), and the bridge may fail only on that measure-zero set.

  • Representative scale (two channels at \(4\times 4\)).

7.3 ImageNet recipe

Why the phase-2 trainer? The ImageNet runs in this book use the phase-2 (Lean\(\to \)JAX) trainer: at this scale its job is to validate the framework’s logic end to end. Whether the phase-3 verified-IREE codegen can reach ImageNet under its own codegen rules is an open question — phase-2 is how we establish these baselines in the meantime.

The Imagenette run above trained on \(\sim \)9.5K images. The same backbone scales to full 1000-class ImageNet exactly the way ResNet-34 did in Chapter 6: the inverted-residual stack is unchanged, only the head widens to 1000 and the dataset and schedule grow. The spec mirrors jax/MainMobilenetV2Imagenet.lean:

-- Same inverted-residual backbone, 1000 classes instead of 10.
def mobilenetV2Imagenet : NetSpec where
  name   := "MobileNetV2 (ImageNet, bf16)"
  imageH := 224
  imageW := 224
  layers := [
    .convBn 3 32 3 2 .same,
    .invertedResidual  32  16 1 1 1,
    .invertedResidual  16  24 6 2 2,
    .invertedResidual  24  32 6 2 3,
    .invertedResidual  32  64 6 2 4,
    .invertedResidual  64  96 6 1 3,
    .invertedResidual  96 160 6 2 3,
    .invertedResidual 160 320 6 1 1,
    .convBn 320 1280 1 1 .same,
    .globalAvgPool,
    .dense 1280 1000 .identity      -- 1000-class head
  ]

-- SGD recipe, MobileNet-tuned: smaller weight decay, bf16 conv on.
def mobilenetV2ImagenetConfig : TrainConfig where
  learningRate := 0.1
  batchSize    := 256
  epochs       := 90               -- the real run; validate at 30 first
  useAdam      := false             -- paper: SGD + momentum
  momentum     := 0.9
  weightDecay  := 4e-5              -- smaller than R34's 1e-4
  cosineDecay    := true
  warmupEpochs   := 5
  augment        := true            -- random-resized-crop + flip
  labelSmoothing := 0.1
  bf16           := true
  bf16Conv       := true            -- reaches the inverted-residual blocks

The one knob specific to MobileNet is bf16Conv. ResNet’s convolutions cast to bfloat16 the instant the flag is on, but MobileNet’s compute lives inside inverted-residual blocks — a \(1\times 1\) expansion, a depthwise \(3\times 3\), a \(1\times 1\) projection — whose convolutions originally ran in fp32 regardless. Routing them through the same convdt cast as the plain convolutions is what lets bf16 reach the bulk of the network: the payoff is the whole inverted-residual block running \(\sim 2\times \) faster on cuDNN, where the \(1\times 1\)s (which are really matmuls) love bfloat16 and the depthwise \(3\times 3\) is a wash. The weight decay also drops to \(4\times 10^{-5}\) — MobileNet’s depthwise filters are tiny (\(9\) weights per channel), and the \(10^{-4}\) that suits ResNet over-regularizes them.

Compute budget. On the four RTX 4060 Ti (CUDA, bf16, batch 256 across the four GPUs), steady-state throughput is \(\sim \)106 ms per step, about \(9.1\) minutes per epoch — so the 30-epoch validation tier is a \(\sim \)4.5-hour run and the full 90-epoch schedule projects to roughly 14 wall-clock hours. Curiously, that is slower per epoch than the \(9.5\times \)-larger ResNet-34 at \(\sim \)10.5 min: depthwise convolution is memory-bound and low-arithmetic-intensity, so MobileNet’s order-of-magnitude FLOP saving does not translate into a proportional wall-clock win on this hardware.

GPU

Precision

Per epoch

90 epochs

Val top-1

Val top-5

4\(\times \) 4060 Ti (CUDA)

bf16

\(\sim \)9.1 min

\(\sim \)14 hr

\(\mathbf{68.33\% }\)

\(\mathbf{88.17\% }\)

The full 90-epoch run completed at \(\mathbf{68.33\% }\) top-1 / \(\mathbf{88.17\% }\) top-5 on the full 50,000-image validation split — a few points under the \(\sim \)71% paper number, the gap attributable to the missing RMSProp optimizer [TODO: wire RMSProp] and original longer schedule, not the precision or pipeline. The SGD peak LR of \(0.1\) took off cleanly (no collapse), so the \(0.05\) fallback was never needed. The validation curve has the usual shape — a fast warmup climb, a long middle, and a final lift as the cosine schedule anneals to zero:

\begin{tikzpicture} 
\begin{axis}[
    width=0.92\linewidth, height=6.5cm,
    xlabel={Epoch}, ylabel={Validation accuracy (\%)},
    xmin=0, xmax=91, ymin=0, ymax=95,
    xtick={0,15,30,45,60,75,90}, ytick={20,40,60,80},
    legend pos=south east, legend cell align={left},
    grid=major, grid style={gray!18},
    tick label style={font=\small}, label style={font=\small},
    every axis plot/.append style={line width=1pt, mark size=1pt},
]
\addplot[blue, mark=*, mark options={fill=blue}] coordinates {
(1,5.26) (2,18.66) (3,28.21) (4,34.35) (5,38.94) (6,43.01) (10,50.42) (11,50.84) (12,51.74) (13,52.23) (16,53.62) (17,53.55) (18,54.72) (19,54.48) (23,55.12) (24,56.02) (25,55.95) (26,55.95) (37,58.19) (38,57.67) (42,59.10) (43,59.18) (44,59.27) (45,59.06) (49,60.27) (50,60.75) (51,60.64) (52,60.90) (55,61.72) (56,61.68) (57,62.15) (58,62.01) (61,62.96) (62,63.03) (63,63.29) (64,63.40) (68,64.49) (69,64.74) (70,65.12) (71,65.24) (74,66.23) (75,66.41) (76,66.30) (77,66.65) (78,67.02) (82,67.68) (83,67.94) (84,68.01) (85,68.03) (86,68.23) (87,68.31) (88,68.34) (89,68.35) (90,68.39)
};
\addlegendentry{top-1}
\addplot[orange, mark=*, mark options={fill=orange}] coordinates {
(1,15.91) (2,40.03) (3,53.26) (4,60.68) (5,65.31) (6,69.60) (10,75.79) (11,76.22) (12,76.99) (13,77.39) (16,78.53) (17,78.22) (18,79.20) (19,79.00) (23,79.77) (24,80.10) (25,80.37) (26,80.47) (37,81.79) (38,81.64) (42,82.37) (43,82.35) (44,82.53) (45,82.52) (49,83.24) (50,83.58) (51,83.54) (52,83.68) (55,84.19) (56,84.26) (57,84.62) (58,84.40) (61,85.09) (62,85.08) (63,85.20) (64,85.38) (68,85.88) (69,86.04) (70,86.30) (71,86.33) (74,86.98) (75,87.11) (76,87.13) (77,87.25) (78,87.39) (82,87.75) (83,87.86) (84,87.90) (85,87.95) (86,88.06) (87,88.10) (88,88.18) (89,88.16) (90,88.15)
};
\addlegendentry{top-5}
\end{axis}
\end{tikzpicture}

MobileNetV2 / ImageNet-1k validation accuracy per epoch (bf16, 4\(\times \) 4060 Ti).

The general pattern — deeper + narrower + per-channel conv + surrounding \(1 \times 1\) expansion — carries forward to MobileNet V3 and the EfficientNet family. All of those add their own small modifications on top (SE blocks, Swish activations, compound scaling) but the core depthwise-separable scaffolding is exactly what this chapter formalized.