path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))Prune Callback
Use the pruner in fastai Callback system
Overview
Structured Pruning removes entire filters, channels, or layers from neural networks, resulting in genuinely smaller and faster models. Unlike sparsification (which zeros individual weights), pruning physically removes parameters.
Why Use Structured Pruning?
| Approach | What’s Removed | Model Size | Speed Benefit | Hardware |
|---|---|---|---|---|
| Sparsification | Individual weights | Same | Requires sparse support | Specialized |
| Structured Pruning | Entire filters | Smaller | Immediate | Standard |
Key Benefits
- Real speedup - Fewer parameters = faster inference on any hardware
- Smaller models - Reduced memory footprint for deployment
- Gradual pruning - Remove filters progressively during training
- Flexible targeting - Global or local pruning strategies
1. Setup and Baseline
First, train a baseline ResNet-18 to establish expected performance:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(1)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.573910 | 0.346901 | 0.848444 | 00:02 |
base_macs, base_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))2. Training with PruneCallback
Now let’s train with gradual filter pruning. We’ll remove 40% of filters using a one-cycle schedule:
Configuration: - pruning_ratio=40 - Remove 40% of filters - context='global' - Remove least important filters from anywhere in the network - criteria=large_final - Keep filters with largest final weights - schedule=one_cycle - Gradually increase pruning following one-cycle pattern
pr_cb = PruneCallback(pruning_ratio=0.4, context='global', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(10, cbs=pr_cb)Ignoring output layer: 1.8
Total ignored layers: 1
| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.350107 | 0.277946 | 0.875507 | 00:02 |
| 1 | 0.270526 | 0.309414 | 0.881597 | 00:03 |
| 2 | 0.247778 | 0.240875 | 0.903924 | 00:03 |
| 3 | 0.224332 | 0.608088 | 0.708390 | 00:03 |
| 4 | 0.193209 | 0.221060 | 0.897835 | 00:03 |
| 5 | 0.249345 | 0.259771 | 0.895805 | 00:04 |
| 6 | 0.266264 | 0.265805 | 0.890392 | 00:04 |
| 7 | 0.234256 | 0.263015 | 0.888363 | 00:02 |
| 8 | 0.224429 | 0.255041 | 0.890392 | 00:02 |
| 9 | 0.196133 | 0.255395 | 0.892422 | 00:03 |
Sparsity at the end of epoch 0: 0.39%
Sparsity at the end of epoch 1: 1.54%
Sparsity at the end of epoch 2: 5.60%
Sparsity at the end of epoch 3: 15.91%
Sparsity at the end of epoch 4: 29.13%
Sparsity at the end of epoch 5: 36.64%
Sparsity at the end of epoch 6: 39.12%
Sparsity at the end of epoch 7: 39.79%
Sparsity at the end of epoch 8: 39.96%
Sparsity at the end of epoch 9: 40.00%
pruned_macs, pruned_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))3. Measuring Compression
The pruned model has fewer parameters and requires less compute:
print(f'The pruned model has {pruned_macs/base_macs:.2f} the compute of original model')The pruned model has 0.63 the compute of original model
print(f'The pruned model has {pruned_params/base_params:.2f} the parameters of original model')The pruned model has 0.18 the parameters of original model
Summary
| Metric | Original | Pruned (40%) | Improvement |
|---|---|---|---|
| Parameters | 100% | ~18% | 5.5x smaller |
| Compute (MACs) | 100% | ~63% | 1.6x fewer ops |
| Accuracy | Baseline | ~1% drop | Minimal impact |
Parameter Reference
| Parameter | Description | Example |
|---|---|---|
pruning_ratio |
Percentage of filters to remove | 40 |
context |
Pruning scope | 'global' (whole model) or 'local' (per-layer) |
criteria |
Importance measure | large_final, magnitude, taylor |
schedule |
How pruning increases over training | one_cycle, cos, linear |
See Also
- Pruner - Lower-level API for one-shot pruning
- Sparsifier - For unstructured sparsification
- Schedules - Available pruning schedules
- Criteria - Filter importance measures
- YOLO Pruning Tutorial - Pruning detection models