Pruner Tutorial

Structured pruning to create smaller, faster models

Overview

The Pruner class performs structured pruning - physically removing entire filters and channels from your neural network. Unlike sparsification (which zeros weights but keeps the architecture), structured pruning creates a genuinely smaller model that runs faster on standard hardware.

Sparsifier vs Pruner

Aspect Sparsifier Pruner
What it removes Individual weights → zeros Entire filters → gone
Architecture Unchanged (same shapes) Smaller (fewer channels)
Speedup Needs sparse hardware Immediate on any hardware
Use case Research, sparse accelerators Production deployment

When to use Pruner: - You need a smaller model file - You want faster inference without special hardware - You’re deploying to edge devices or mobile

import torch
import torch.nn as nn
from torchvision.models import resnet18
from fasterai.prune.pruner import Pruner
from fasterai.core.criteria import large_final, random

1. Basic Pruning

Let’s start with a ResNet18 and prune 30% of its filters:

model = resnet18(weights=None)

print('Before pruning:')
print(f'  conv1: {model.conv1}')
print(f'  layer1[0].conv1: {model.layer1[0].conv1}')
print(f'  Parameters: {sum(p.numel() for p in model.parameters()):,}')
Before pruning:
  conv1: Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  layer1[0].conv1: Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  Parameters: 11,689,512
pruner = Pruner(
    model, 
    pruning_ratio=30,      # Remove 30% of filters
    context='local',       # Prune each layer independently
    criteria=large_final   # Keep filters with largest weights
)
pruner.prune_model()

print('\nAfter pruning:')
print(f'  conv1: {model.conv1}')
print(f'  layer1[0].conv1: {model.layer1[0].conv1}')
params_after = sum(p.numel() for p in model.parameters())
print(f'  Parameters: {params_after:,}')
print(f'  Reduction: {100*(1 - params_after/11689512):.1f}%')
Ignoring output layer: fc
Total ignored layers: 1

After pruning:
  conv1: Conv2d(3, 44, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  layer1[0].conv1: Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  Parameters: 5,820,556
  Reduction: 50.2%

Notice the channel counts changed: Conv2d(3, 64, ...) became Conv2d(3, 44, ...). The model is genuinely smaller!

Key point: The Pruner automatically handles layer dependencies. When you remove output channels from one layer, it removes the corresponding input channels from the next layer.

2. Local vs Global Pruning

The context parameter controls how filters are selected for pruning:

Context Behavior Best for
'local' Each layer loses same % of filters Uniform compression
'global' Compare importance across all layers Maximum accuracy retention
# Local: each layer pruned independently to 50%
model_local = resnet18(weights=None)
pruner = Pruner(model_local, 50, 'local', large_final)
pruner.prune_model()

print('\nLocal pruning (each layer loses 50%):')
print(f'  layer1[0].conv1: Conv2d({model_local.layer1[0].conv1.in_channels}, {model_local.layer1[0].conv1.out_channels}, ...)')
print(f'  layer2[0].conv1: Conv2d({model_local.layer2[0].conv1.in_channels}, {model_local.layer2[0].conv1.out_channels}, ...)')
print(f'  layer3[0].conv1: Conv2d({model_local.layer3[0].conv1.in_channels}, {model_local.layer3[0].conv1.out_channels}, ...)')
print(f'  layer4[0].conv1: Conv2d({model_local.layer4[0].conv1.in_channels}, {model_local.layer4[0].conv1.out_channels}, ...)')
Ignoring output layer: fc
Total ignored layers: 1

Local pruning (each layer loses 50%):
  layer1[0].conv1: Conv2d(32, 32, ...)
  layer2[0].conv1: Conv2d(32, 64, ...)
  layer3[0].conv1: Conv2d(64, 128, ...)
  layer4[0].conv1: Conv2d(128, 256, ...)
# Global: least important filters across entire network
model_global = resnet18(weights=None)
pruner = Pruner(model_global, 50, 'global', large_final)
pruner.prune_model()

print('\nGlobal pruning (importance compared across layers):')
print(f'  layer1[0].conv1: Conv2d({model_global.layer1[0].conv1.in_channels}, {model_global.layer1[0].conv1.out_channels}, ...)')
print(f'  layer2[0].conv1: Conv2d({model_global.layer2[0].conv1.in_channels}, {model_global.layer2[0].conv1.out_channels}, ...)')
print(f'  layer3[0].conv1: Conv2d({model_global.layer3[0].conv1.in_channels}, {model_global.layer3[0].conv1.out_channels}, ...)')
print(f'  layer4[0].conv1: Conv2d({model_global.layer4[0].conv1.in_channels}, {model_global.layer4[0].conv1.out_channels}, ...)')
Ignoring output layer: fc
Total ignored layers: 1

Global pruning (importance compared across layers):
  layer1[0].conv1: Conv2d(64, 64, ...)
  layer2[0].conv1: Conv2d(64, 128, ...)
  layer3[0].conv1: Conv2d(128, 69, ...)
  layer4[0].conv1: Conv2d(256, 512, ...)

With global pruning, early layers often keep more filters (they’re more important) while later layers with redundant features get pruned more aggressively.

3. Iterative Pruning

For high compression ratios, iterative pruning works better than one-shot. The model gradually adapts to having fewer parameters:

model = resnet18(weights=None)
params_orig = sum(p.numel() for p in model.parameters())

# Iterative pruning: 5 steps to reach 50% pruning
pruner = Pruner(
    model, 
    pruning_ratio=50,
    context='local', 
    criteria=large_final,
    iterative_steps=5  # Spread pruning over 5 steps
)

print('Iterative pruning (5 steps to reach 50%):')
for i in range(5):
    pruner.prune_model()
    params = sum(p.numel() for p in model.parameters())
    print(f'  Step {i+1}: {params:,} params ({100*(1-params/params_orig):.1f}% reduction)')
Ignoring output layer: fc
Total ignored layers: 1
Iterative pruning (5 steps to reach 50%):
  Step 1: 9,481,588 params (18.9% reduction)
  Step 2: 7,534,380 params (35.5% reduction)
  Step 3: 5,820,556 params (50.2% reduction)
  Step 4: 4,318,898 params (63.1% reduction)
  Step 5: 3,055,880 params (73.9% reduction)

In practice: When using PruneCallback during training, iterative pruning happens automatically - the model is pruned a little bit after each batch, allowing it to recover between steps.

4. Per-Layer Pruning Ratios

Different layers have different sensitivity to pruning. You can specify custom ratios using a dictionary:

model = resnet18(weights=None)

# Conservative on early layers, aggressive on later layers
per_layer_ratios = {
    'layer1.0.conv1': 20,  'layer1.0.conv2': 20,  # 20% pruning
    'layer2.0.conv1': 40,  'layer2.0.conv2': 40,  # 40% pruning  
    'layer3.0.conv1': 60,  'layer3.0.conv2': 60,  # 60% pruning
    'layer4.0.conv1': 80,  'layer4.0.conv2': 80,  # 80% pruning
}

pruner = Pruner(model, per_layer_ratios, 'local', large_final)
pruner.prune_model()

print('\nPer-layer pruning results:')
print(f'  layer1.0.conv1: {model.layer1[0].conv1.out_channels} channels (20% pruned from 64)')
print(f'  layer2.0.conv1: {model.layer2[0].conv1.out_channels} channels (40% pruned from 128)')
print(f'  layer3.0.conv1: {model.layer3[0].conv1.out_channels} channels (60% pruned from 256)')
print(f'  layer4.0.conv1: {model.layer4[0].conv1.out_channels} channels (80% pruned from 512)')
Ignoring output layer: fc
Total ignored layers: 1
Using per-layer pruning with 8 layer-specific ratios

Per-layer pruning results:
  layer1.0.conv1: 51 channels (20% pruned from 64)
  layer2.0.conv1: 76 channels (40% pruned from 128)
  layer3.0.conv1: 102 channels (60% pruned from 256)
  layer4.0.conv1: 102 channels (80% pruned from 512)

Tip: Use sensitivity analysis to determine which layers can tolerate more pruning. See the Sensitivity Tutorial for details.

5. Verifying the Pruned Model

After pruning, the model remains fully functional - it just has fewer parameters:

model = resnet18(weights=None)
model.eval()

# Prune 50%
pruner = Pruner(model, 50, 'global', large_final)
pruner.prune_model()

# Verify forward pass works
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(x)

print('\nForward pass verification:')
print(f'  Input shape:  {x.shape}')
print(f'  Output shape: {output.shape}')
print('  Model works correctly after pruning!')
Ignoring output layer: fc
Total ignored layers: 1

Forward pass verification:
  Input shape:  torch.Size([1, 3, 224, 224])
  Output shape: torch.Size([1, 1000])
  Model works correctly after pruning!

6. Importance Criteria

The criteria parameter determines how filter importance is calculated:

Criteria Method Best for
large_final Keep filters with largest L1 norm General use, most common
small_final Keep filters with smallest L1 norm Unusual, for experimentation
random Random selection Baseline comparison
results = {}
for name, criteria in [('large_final', large_final), ('random', random)]:
    model = resnet18(weights=None)
    pruner = Pruner(model, 30, 'local', criteria)
    pruner.prune_model()
    results[name] = sum(p.numel() for p in model.parameters())

print('\nSame pruning ratio, different criteria:')
for name, params in results.items():
    print(f'  {name}: {params:,} parameters')
print('\nNote: Parameter counts are similar, but accuracy differs!')
print('large_final preserves important filters, random does not.')
Ignoring output layer: fc
Total ignored layers: 1
Ignoring output layer: fc
Total ignored layers: 1

Same pruning ratio, different criteria:
  large_final: 5,820,556 parameters
  random: 5,820,556 parameters

Note: Parameter counts are similar, but accuracy differs!
large_final preserves important filters, random does not.

Summary

Feature Description
Structured pruning Removes entire filters, creating genuinely smaller models
Local context Each layer pruned by same percentage
Global context Compare importance across all layers
Iterative pruning Gradual pruning for better accuracy retention
Per-layer ratios Dictionary of custom ratios per layer
Auto dependency Handles layer connections automatically

Typical Workflow

# 1. One-shot pruning for quick experiments
pruner = Pruner(model, 30, 'local', large_final)
pruner.prune_model()

# 2. During training with PruneCallback (recommended)
cb = PruneCallback(pruning_ratio=50, schedule=agp, context='global', criteria=large_final)
learn.fit(10, cbs=[cb])

See Also