Prune Callback

Use the pruner in fastai Callback system

Overview

Structured Pruning removes entire filters, channels, or layers from neural networks, resulting in genuinely smaller and faster models. Unlike sparsification (which zeros individual weights), pruning physically removes parameters.

Why Use Structured Pruning?

Approach What’s Removed Model Size Speed Benefit Hardware
Sparsification Individual weights Same Requires sparse support Specialized
Structured Pruning Entire filters Smaller Immediate Standard

Key Benefits

  • Real speedup - Fewer parameters = faster inference on any hardware
  • Smaller models - Reduced memory footprint for deployment
  • Gradual pruning - Remove filters progressively during training
  • Flexible targeting - Global or local pruning strategies

1. Setup and Baseline

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

First, train a baseline ResNet-18 to establish expected performance:

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(1)
epoch train_loss valid_loss accuracy time
0 0.573910 0.346901 0.848444 00:02
base_macs, base_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))

2. Training with PruneCallback

Now let’s train with gradual filter pruning. We’ll remove 40% of filters using a one-cycle schedule:

Configuration: - pruning_ratio=40 - Remove 40% of filters - context='global' - Remove least important filters from anywhere in the network - criteria=large_final - Keep filters with largest final weights - schedule=one_cycle - Gradually increase pruning following one-cycle pattern

pr_cb = PruneCallback(pruning_ratio=0.4, context='global', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(10, cbs=pr_cb)
Ignoring output layer: 1.8
Total ignored layers: 1
epoch train_loss valid_loss accuracy time
0 0.350107 0.277946 0.875507 00:02
1 0.270526 0.309414 0.881597 00:03
2 0.247778 0.240875 0.903924 00:03
3 0.224332 0.608088 0.708390 00:03
4 0.193209 0.221060 0.897835 00:03
5 0.249345 0.259771 0.895805 00:04
6 0.266264 0.265805 0.890392 00:04
7 0.234256 0.263015 0.888363 00:02
8 0.224429 0.255041 0.890392 00:02
9 0.196133 0.255395 0.892422 00:03
Sparsity at the end of epoch 0: 0.39%
Sparsity at the end of epoch 1: 1.54%
Sparsity at the end of epoch 2: 5.60%
Sparsity at the end of epoch 3: 15.91%
Sparsity at the end of epoch 4: 29.13%
Sparsity at the end of epoch 5: 36.64%
Sparsity at the end of epoch 6: 39.12%
Sparsity at the end of epoch 7: 39.79%
Sparsity at the end of epoch 8: 39.96%
Sparsity at the end of epoch 9: 40.00%
pruned_macs, pruned_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))

3. Measuring Compression

The pruned model has fewer parameters and requires less compute:

print(f'The pruned model has {pruned_macs/base_macs:.2f} the compute of original model')
The pruned model has 0.63 the compute of original model
print(f'The pruned model has {pruned_params/base_params:.2f} the parameters of original model')
The pruned model has 0.18 the parameters of original model

Summary

Metric Original Pruned (40%) Improvement
Parameters 100% ~18% 5.5x smaller
Compute (MACs) 100% ~63% 1.6x fewer ops
Accuracy Baseline ~1% drop Minimal impact

Parameter Reference

Parameter Description Example
pruning_ratio Percentage of filters to remove 40
context Pruning scope 'global' (whole model) or 'local' (per-layer)
criteria Importance measure large_final, magnitude, taylor
schedule How pruning increases over training one_cycle, cos, linear

See Also