BatchNorm Folding

Fold your BatchNorm layers

Overview

BatchNorm Folding is an optimization technique that merges batch normalization layers into preceding convolutional layers. This eliminates the batch norm computation entirely while maintaining mathematically equivalent results.

Why Fold BatchNorm?

Aspect Before Folding After Folding
Layers Conv → BN → ReLU Conv → ReLU
Parameters Conv weights + BN params Modified Conv weights only
Inference Speed Slower (extra ops) Faster (no BN overhead)
Accuracy Baseline Identical (mathematically equivalent)

How It Works

During inference, batch normalization applies a linear transformation: \[y = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta\]

Since convolution is also linear, we can fold BN parameters into the conv weights: \[W_{new} = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} \cdot W_{old}\] \[b_{new} = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} \cdot (b_{old} - \mu) + \beta\]

The result is identical outputs with fewer operations.

1. Setup and Training

First, let’s train a model with batch normalization layers.

Load Data

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

Train the Model

learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 0.604740 0.685939 0.685386 00:02
1 0.565022 0.724329 0.694858 00:02
2 0.512418 0.516759 0.736807 00:02
3 0.445161 0.466733 0.763193 00:02
4 0.362070 0.433802 0.792963 00:02

2. Fold BatchNorm Layers

Use BN_Folder to fold all batch normalization layers into their preceding convolutions:

bn = BN_Folder()
new_model = bn.fold(learn.model)

The batch norm layers have been replaced by Identity layers, and the convolution weights have been modified to incorporate the batch norm transformation.

new_model
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (bn1): Identity()
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
        (1): Identity()
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
        (1): Identity()
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
        (1): Identity()
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=2, bias=True)
)

3. Comparing Results

Parameter Count

The folded model has fewer parameters (BN parameters are absorbed into conv weights):

count_parameters(learn.model)
11177538
count_parameters(new_model)
11172738

Inference Speed

The folded model is faster because batch norm operations are eliminated:

x,y = dls.one_batch()
learn.model(x[0][None].cuda())
1.19 ms ± 4.31 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
new_model(x[0][None].cuda())
768 μs ± 1.79 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Accuracy Verification

Most importantly, the folded model produces identical results to the original:

new_learn = Learner(dls, new_model, metrics=accuracy)
new_learn.validate()
[0.4338044822216034, 0.792963445186615]

Summary

Metric Original Folded Improvement
Parameters 11,177,538 11,172,738 ~5K fewer
Inference (single image) 1.19 ms 0.77 ms ~35% faster
Accuracy Baseline Identical No change

When to Use BN Folding

Scenario Recommendation
Inference/deployment ✅ Always fold - free speedup
Before quantization ✅ Fold first - cleaner quantization
During training ❌ Don’t fold - BN helps training
Models without BN N/A - Nothing to fold

See Also

  • Quantize Callback - Apply quantization after folding for maximum compression
  • ONNX Exporter - Export folded models to ONNX for deployment
  • Pruner - Combine with pruning for smaller, faster models