import torch
import torch.nn as nn
from torchvision.models import resnet18
from fasterai.prune.pruner import Pruner
from fasterai.core.criteria import large_final, randomPruner Tutorial
Overview
The Pruner class performs structured pruning - physically removing entire filters and channels from your neural network. Unlike sparsification (which zeros weights but keeps the architecture), structured pruning creates a genuinely smaller model that runs faster on standard hardware.
Sparsifier vs Pruner
| Aspect | Sparsifier | Pruner |
|---|---|---|
| What it removes | Individual weights → zeros | Entire filters → gone |
| Architecture | Unchanged (same shapes) | Smaller (fewer channels) |
| Speedup | Needs sparse hardware | Immediate on any hardware |
| Use case | Research, sparse accelerators | Production deployment |
When to use Pruner: - You need a smaller model file - You want faster inference without special hardware - You’re deploying to edge devices or mobile
1. Basic Pruning
Let’s start with a ResNet18 and prune 30% of its filters:
model = resnet18(weights=None)
print('Before pruning:')
print(f' conv1: {model.conv1}')
print(f' layer1[0].conv1: {model.layer1[0].conv1}')
print(f' Parameters: {sum(p.numel() for p in model.parameters()):,}')Before pruning:
conv1: Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layer1[0].conv1: Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Parameters: 11,689,512
pruner = Pruner(
model,
pruning_ratio=30, # Remove 30% of filters
context='local', # Prune each layer independently
criteria=large_final # Keep filters with largest weights
)
pruner.prune_model()
print('\nAfter pruning:')
print(f' conv1: {model.conv1}')
print(f' layer1[0].conv1: {model.layer1[0].conv1}')
params_after = sum(p.numel() for p in model.parameters())
print(f' Parameters: {params_after:,}')
print(f' Reduction: {100*(1 - params_after/11689512):.1f}%')Ignoring output layer: fc
Total ignored layers: 1
After pruning:
conv1: Conv2d(3, 44, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layer1[0].conv1: Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Parameters: 5,820,556
Reduction: 50.2%
Notice the channel counts changed: Conv2d(3, 64, ...) became Conv2d(3, 44, ...). The model is genuinely smaller!
Key point: The Pruner automatically handles layer dependencies. When you remove output channels from one layer, it removes the corresponding input channels from the next layer.
2. Local vs Global Pruning
The context parameter controls how filters are selected for pruning:
| Context | Behavior | Best for |
|---|---|---|
'local' |
Each layer loses same % of filters | Uniform compression |
'global' |
Compare importance across all layers | Maximum accuracy retention |
# Local: each layer pruned independently to 50%
model_local = resnet18(weights=None)
pruner = Pruner(model_local, 50, 'local', large_final)
pruner.prune_model()
print('\nLocal pruning (each layer loses 50%):')
print(f' layer1[0].conv1: Conv2d({model_local.layer1[0].conv1.in_channels}, {model_local.layer1[0].conv1.out_channels}, ...)')
print(f' layer2[0].conv1: Conv2d({model_local.layer2[0].conv1.in_channels}, {model_local.layer2[0].conv1.out_channels}, ...)')
print(f' layer3[0].conv1: Conv2d({model_local.layer3[0].conv1.in_channels}, {model_local.layer3[0].conv1.out_channels}, ...)')
print(f' layer4[0].conv1: Conv2d({model_local.layer4[0].conv1.in_channels}, {model_local.layer4[0].conv1.out_channels}, ...)')Ignoring output layer: fc
Total ignored layers: 1
Local pruning (each layer loses 50%):
layer1[0].conv1: Conv2d(32, 32, ...)
layer2[0].conv1: Conv2d(32, 64, ...)
layer3[0].conv1: Conv2d(64, 128, ...)
layer4[0].conv1: Conv2d(128, 256, ...)
# Global: least important filters across entire network
model_global = resnet18(weights=None)
pruner = Pruner(model_global, 50, 'global', large_final)
pruner.prune_model()
print('\nGlobal pruning (importance compared across layers):')
print(f' layer1[0].conv1: Conv2d({model_global.layer1[0].conv1.in_channels}, {model_global.layer1[0].conv1.out_channels}, ...)')
print(f' layer2[0].conv1: Conv2d({model_global.layer2[0].conv1.in_channels}, {model_global.layer2[0].conv1.out_channels}, ...)')
print(f' layer3[0].conv1: Conv2d({model_global.layer3[0].conv1.in_channels}, {model_global.layer3[0].conv1.out_channels}, ...)')
print(f' layer4[0].conv1: Conv2d({model_global.layer4[0].conv1.in_channels}, {model_global.layer4[0].conv1.out_channels}, ...)')Ignoring output layer: fc
Total ignored layers: 1
Global pruning (importance compared across layers):
layer1[0].conv1: Conv2d(64, 64, ...)
layer2[0].conv1: Conv2d(64, 128, ...)
layer3[0].conv1: Conv2d(128, 69, ...)
layer4[0].conv1: Conv2d(256, 512, ...)
With global pruning, early layers often keep more filters (they’re more important) while later layers with redundant features get pruned more aggressively.
3. Iterative Pruning
For high compression ratios, iterative pruning works better than one-shot. The model gradually adapts to having fewer parameters:
model = resnet18(weights=None)
params_orig = sum(p.numel() for p in model.parameters())
# Iterative pruning: 5 steps to reach 50% pruning
pruner = Pruner(
model,
pruning_ratio=50,
context='local',
criteria=large_final,
iterative_steps=5 # Spread pruning over 5 steps
)
print('Iterative pruning (5 steps to reach 50%):')
for i in range(5):
pruner.prune_model()
params = sum(p.numel() for p in model.parameters())
print(f' Step {i+1}: {params:,} params ({100*(1-params/params_orig):.1f}% reduction)')Ignoring output layer: fc
Total ignored layers: 1
Iterative pruning (5 steps to reach 50%):
Step 1: 9,481,588 params (18.9% reduction)
Step 2: 7,534,380 params (35.5% reduction)
Step 3: 5,820,556 params (50.2% reduction)
Step 4: 4,318,898 params (63.1% reduction)
Step 5: 3,055,880 params (73.9% reduction)
In practice: When using PruneCallback during training, iterative pruning happens automatically - the model is pruned a little bit after each batch, allowing it to recover between steps.
4. Per-Layer Pruning Ratios
Different layers have different sensitivity to pruning. You can specify custom ratios using a dictionary:
model = resnet18(weights=None)
# Conservative on early layers, aggressive on later layers
per_layer_ratios = {
'layer1.0.conv1': 20, 'layer1.0.conv2': 20, # 20% pruning
'layer2.0.conv1': 40, 'layer2.0.conv2': 40, # 40% pruning
'layer3.0.conv1': 60, 'layer3.0.conv2': 60, # 60% pruning
'layer4.0.conv1': 80, 'layer4.0.conv2': 80, # 80% pruning
}
pruner = Pruner(model, per_layer_ratios, 'local', large_final)
pruner.prune_model()
print('\nPer-layer pruning results:')
print(f' layer1.0.conv1: {model.layer1[0].conv1.out_channels} channels (20% pruned from 64)')
print(f' layer2.0.conv1: {model.layer2[0].conv1.out_channels} channels (40% pruned from 128)')
print(f' layer3.0.conv1: {model.layer3[0].conv1.out_channels} channels (60% pruned from 256)')
print(f' layer4.0.conv1: {model.layer4[0].conv1.out_channels} channels (80% pruned from 512)')Ignoring output layer: fc
Total ignored layers: 1
Using per-layer pruning with 8 layer-specific ratios
Per-layer pruning results:
layer1.0.conv1: 51 channels (20% pruned from 64)
layer2.0.conv1: 76 channels (40% pruned from 128)
layer3.0.conv1: 102 channels (60% pruned from 256)
layer4.0.conv1: 102 channels (80% pruned from 512)
Tip: Use sensitivity analysis to determine which layers can tolerate more pruning. See the Sensitivity Tutorial for details.
5. Verifying the Pruned Model
After pruning, the model remains fully functional - it just has fewer parameters:
model = resnet18(weights=None)
model.eval()
# Prune 50%
pruner = Pruner(model, 50, 'global', large_final)
pruner.prune_model()
# Verify forward pass works
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
output = model(x)
print('\nForward pass verification:')
print(f' Input shape: {x.shape}')
print(f' Output shape: {output.shape}')
print(' Model works correctly after pruning!')Ignoring output layer: fc
Total ignored layers: 1
Forward pass verification:
Input shape: torch.Size([1, 3, 224, 224])
Output shape: torch.Size([1, 1000])
Model works correctly after pruning!
6. Importance Criteria
The criteria parameter determines how filter importance is calculated:
| Criteria | Method | Best for |
|---|---|---|
large_final |
Keep filters with largest L1 norm | General use, most common |
small_final |
Keep filters with smallest L1 norm | Unusual, for experimentation |
random |
Random selection | Baseline comparison |
results = {}
for name, criteria in [('large_final', large_final), ('random', random)]:
model = resnet18(weights=None)
pruner = Pruner(model, 30, 'local', criteria)
pruner.prune_model()
results[name] = sum(p.numel() for p in model.parameters())
print('\nSame pruning ratio, different criteria:')
for name, params in results.items():
print(f' {name}: {params:,} parameters')
print('\nNote: Parameter counts are similar, but accuracy differs!')
print('large_final preserves important filters, random does not.')Ignoring output layer: fc
Total ignored layers: 1
Ignoring output layer: fc
Total ignored layers: 1
Same pruning ratio, different criteria:
large_final: 5,820,556 parameters
random: 5,820,556 parameters
Note: Parameter counts are similar, but accuracy differs!
large_final preserves important filters, random does not.
Summary
| Feature | Description |
|---|---|
| Structured pruning | Removes entire filters, creating genuinely smaller models |
| Local context | Each layer pruned by same percentage |
| Global context | Compare importance across all layers |
| Iterative pruning | Gradual pruning for better accuracy retention |
| Per-layer ratios | Dictionary of custom ratios per layer |
| Auto dependency | Handles layer connections automatically |
Typical Workflow
# 1. One-shot pruning for quick experiments
pruner = Pruner(model, 30, 'local', large_final)
pruner.prune_model()
# 2. During training with PruneCallback (recommended)
cb = PruneCallback(pruning_ratio=50, schedule=agp, context='global', criteria=large_final)
learn.fit(10, cbs=[cb])See Also
- PruneCallback Tutorial - Apply pruning during fastai training
- Sparsifier Tutorial - Unstructured pruning alternative
- Criteria - Importance measures for filter selection
- Schedules - Control pruning progression