path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))BatchNorm Folding
Overview
BatchNorm Folding is an optimization technique that merges batch normalization layers into preceding convolutional layers. This eliminates the batch norm computation entirely while maintaining mathematically equivalent results.
Why Fold BatchNorm?
| Aspect | Before Folding | After Folding |
|---|---|---|
| Layers | Conv → BN → ReLU | Conv → ReLU |
| Parameters | Conv weights + BN params | Modified Conv weights only |
| Inference Speed | Slower (extra ops) | Faster (no BN overhead) |
| Accuracy | Baseline | Identical (mathematically equivalent) |
How It Works
During inference, batch normalization applies a linear transformation: \[y = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta\]
Since convolution is also linear, we can fold BN parameters into the conv weights: \[W_{new} = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} \cdot W_{old}\] \[b_{new} = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} \cdot (b_{old} - \mu) + \beta\]
The result is identical outputs with fewer operations.
1. Setup and Training
First, let’s train a model with batch normalization layers.
Load Data
Train the Model
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn.fit_one_cycle(5)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.604740 | 0.685939 | 0.685386 | 00:02 |
| 1 | 0.565022 | 0.724329 | 0.694858 | 00:02 |
| 2 | 0.512418 | 0.516759 | 0.736807 | 00:02 |
| 3 | 0.445161 | 0.466733 | 0.763193 | 00:02 |
| 4 | 0.362070 | 0.433802 | 0.792963 | 00:02 |
2. Fold BatchNorm Layers
Use BN_Folder to fold all batch normalization layers into their preceding convolutions:
bn = BN_Folder()
new_model = bn.fold(learn.model)The batch norm layers have been replaced by Identity layers, and the convolution weights have been modified to incorporate the batch norm transformation.
new_modelResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
(bn1): Identity()
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): Identity()
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): Identity()
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): Identity()
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): Identity()
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): Identity()
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): Identity()
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
(1): Identity()
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): Identity()
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): Identity()
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): Identity()
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): Identity()
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
(1): Identity()
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): Identity()
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): Identity()
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): Identity()
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): Identity()
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
(1): Identity()
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): Identity()
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): Identity()
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=2, bias=True)
)
3. Comparing Results
Parameter Count
The folded model has fewer parameters (BN parameters are absorbed into conv weights):
count_parameters(learn.model)11177538
count_parameters(new_model)11172738
Inference Speed
The folded model is faster because batch norm operations are eliminated:
x,y = dls.one_batch()learn.model(x[0][None].cuda())1.19 ms ± 4.31 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
new_model(x[0][None].cuda())768 μs ± 1.79 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Accuracy Verification
Most importantly, the folded model produces identical results to the original:
new_learn = Learner(dls, new_model, metrics=accuracy)new_learn.validate()[0.4338044822216034, 0.792963445186615]
Summary
| Metric | Original | Folded | Improvement |
|---|---|---|---|
| Parameters | 11,177,538 | 11,172,738 | ~5K fewer |
| Inference (single image) | 1.19 ms | 0.77 ms | ~35% faster |
| Accuracy | Baseline | Identical | No change |
When to Use BN Folding
| Scenario | Recommendation |
|---|---|
| Inference/deployment | ✅ Always fold - free speedup |
| Before quantization | ✅ Fold first - cleaner quantization |
| During training | ❌ Don’t fold - BN helps training |
| Models without BN | N/A - Nothing to fold |
See Also
- Quantize Callback - Apply quantization after folding for maximum compression
- ONNX Exporter - Export folded models to ONNX for deployment
- Pruner - Combine with pruning for smaller, faster models