Sparsify Callback

Use the sparsifier in fastai Callback system
from fastai.vision.all import *

Overview

Sparsification sets individual weights to zero during training, creating sparse networks that can be more efficient for inference. Unlike structured pruning (which removes entire filters), sparsification maintains the original architecture while introducing zeros.

Why Use Sparsification?

Approach What’s Removed Architecture Hardware Support
Sparsification Individual weights Unchanged Sparse accelerators
Structured Pruning Entire filters/channels Changed Standard hardware

Key Benefits

  • Gradual sparsity - Weights are progressively zeroed during training
  • Maintained accuracy - Network adapts to sparsity during training
  • Flexible targeting - Choose which layers and how much to sparsify
  • Schedule control - Use one-cycle, cosine, or custom schedules

1. Setup and Data

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

2. Baseline: Dense Model

First, let’s train a standard dense model to establish baseline accuracy:

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 0.681180 0.583755 0.838295 00:02
1 0.376350 0.247775 0.894452 00:02
2 0.280325 0.302185 0.876184 00:02
3 0.170461 0.214075 0.919486 00:02
4 0.089716 0.188505 0.923545 00:02

3. Training with SparsifyCallback

Now let’s train with 50% sparsity. The SparsifyCallback gradually introduces zeros during training according to the specified schedule:

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

The callback requires a schedule parameter that controls how sparsity increases over training. You can use any fastai annealing function or define your own.

sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='local', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(5, cbs=sp_cb)
Sparsifying weight until a sparsity of 50%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.668898 0.542240 0.828146 00:04
1 0.385380 0.240949 0.903248 00:05
2 0.228002 0.221298 0.910014 00:05
3 0.138280 0.177941 0.928281 00:05
4 0.076438 0.194101 0.929635 00:05
Sparsity at the end of epoch 0: 1.96%
Sparsity at the end of epoch 1: 20.07%
Sparsity at the end of epoch 2: 45.86%
Sparsity at the end of epoch 3: 49.74%
Sparsity at the end of epoch 4: 50.00%
Final Sparsity: 50.00%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                          Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
0.0                            Conv2d          9,408      4,702         49.98%
0.4.0.conv1                    Conv2d          36,864     18,430        49.99%
0.4.0.conv2                    Conv2d          36,864     18,430        49.99%
0.4.1.conv1                    Conv2d          36,864     18,430        49.99%
0.4.1.conv2                    Conv2d          36,864     18,430        49.99%
0.5.0.conv1                    Conv2d          73,728     36,862        50.00%
0.5.0.conv2                    Conv2d          147,456    73,726        50.00%
0.5.0.downsample.0             Conv2d          8,192      4,094         49.98%
0.5.1.conv1                    Conv2d          147,456    73,726        50.00%
0.5.1.conv2                    Conv2d          147,456    73,726        50.00%
0.6.0.conv1                    Conv2d          294,912    147,453       50.00%
0.6.0.conv2                    Conv2d          589,824    294,908       50.00%
0.6.0.downsample.0             Conv2d          32,768     16,382        49.99%
0.6.1.conv1                    Conv2d          589,824    294,908       50.00%
0.6.1.conv2                    Conv2d          589,824    294,908       50.00%
0.7.0.conv1                    Conv2d          1,179,648  589,817       50.00%
0.7.0.conv2                    Conv2d          2,359,296  1,179,635     50.00%
0.7.0.downsample.0             Conv2d          131,072    65,534        50.00%
0.7.1.conv1                    Conv2d          2,359,296  1,179,635     50.00%
0.7.1.conv2                    Conv2d          2,359,296  1,179,634     50.00%
--------------------------------------------------------------------------------
Overall                        all             11,166,912 5,583,370     50.00%

Despite having 50% of weights set to zero, the sparse model performs comparably to the dense baseline!

3b. Per-Layer Sparsity

Different layers have different sensitivities to sparsification. Early layers often need more weights to preserve low-level features, while deeper layers can tolerate higher sparsity. You can specify per-layer targets using a dictionary:

# Define different sparsity targets for different layers
per_layer_sparsity = {
    '0.4.0.conv1': 30,   # Early layers: lower sparsity (more sensitive)
    '0.4.0.conv2': 30,
    '0.4.1.conv1': 30,
    '0.4.1.conv2': 30,
    '0.5.0.conv1': 50,   # Middle layers: medium sparsity
    '0.5.0.conv2': 50,
    '0.5.1.conv1': 50,
    '0.5.1.conv2': 50,
    '0.6.0.conv1': 70,   # Deeper layers: higher sparsity (more redundant)
    '0.6.0.conv2': 70,
    '0.6.1.conv1': 70,
    '0.6.1.conv2': 70,
    '0.7.0.conv1': 80,   # Deepest layers: highest sparsity
    '0.7.0.conv2': 80,
    '0.7.1.conv1': 80,
    '0.7.1.conv2': 80,
}
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

# Use dict for per-layer sparsity - requires 'local' context
sp_cb = SparsifyCallback(
    sparsity=per_layer_sparsity, 
    granularity='weight', 
    context='local',  # Required for per-layer sparsity
    criteria=large_final, 
    schedule=cos
)

learn.fit_one_cycle(5, cbs=sp_cb)
Sparsifying weight until a sparsity of {'0.4.0.conv1': 30, '0.4.0.conv2': 30, '0.4.1.conv1': 30, '0.4.1.conv2': 30, '0.5.0.conv1': 50, '0.5.0.conv2': 50, '0.5.1.conv1': 50, '0.5.1.conv2': 50, '0.6.0.conv1': 70, '0.6.0.conv2': 70, '0.6.1.conv1': 70, '0.6.1.conv2': 70, '0.7.0.conv1': 80, '0.7.0.conv2': 80, '0.7.1.conv1': 80, '0.7.1.conv2': 80}%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.728993 0.559671 0.801759 00:04
1 0.379864 0.328270 0.880920 00:05
2 0.226777 0.212459 0.914750 00:05
3 0.135246 0.216588 0.918133 00:05
4 0.077181 0.210505 0.926252 00:05
Sparsity at the end of epoch 0: avg=5.49%
Sparsity at the end of epoch 1: avg=19.87%
Sparsity at the end of epoch 2: avg=37.63%
Sparsity at the end of epoch 3: avg=52.01%
Sparsity at the end of epoch 4: avg=57.50%
Final Sparsity: {'0.4.0.conv1': 30.0, '0.4.0.conv2': 30.0, '0.4.1.conv1': 30.0, '0.4.1.conv2': 30.0, '0.5.0.conv1': 50.0, '0.5.0.conv2': 50.0, '0.5.1.conv1': 50.0, '0.5.1.conv2': 50.0, '0.6.0.conv1': 70.0, '0.6.0.conv2': 70.0, '0.6.1.conv1': 70.0, '0.6.1.conv2': 70.0, '0.7.0.conv1': 80.0, '0.7.0.conv2': 80.0, '0.7.1.conv1': 80.0, '0.7.1.conv2': 80.0}

Sparsity Report:
--------------------------------------------------------------------------------
Layer                          Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
0.0                            Conv2d          9,408      0              0.00%
0.4.0.conv1                    Conv2d          36,864     11,058        30.00%
0.4.0.conv2                    Conv2d          36,864     11,058        30.00%
0.4.1.conv1                    Conv2d          36,864     11,058        30.00%
0.4.1.conv2                    Conv2d          36,864     11,058        30.00%
0.5.0.conv1                    Conv2d          73,728     36,862        50.00%
0.5.0.conv2                    Conv2d          147,456    73,726        50.00%
0.5.0.downsample.0             Conv2d          8,192      0              0.00%
0.5.1.conv1                    Conv2d          147,456    73,726        50.00%
0.5.1.conv2                    Conv2d          147,456    73,726        50.00%
0.6.0.conv1                    Conv2d          294,912    206,435       70.00%
0.6.0.conv2                    Conv2d          589,824    412,871       70.00%
0.6.0.downsample.0             Conv2d          32,768     0              0.00%
0.6.1.conv1                    Conv2d          589,824    412,871       70.00%
0.6.1.conv2                    Conv2d          589,824    412,871       70.00%
0.7.0.conv1                    Conv2d          1,179,648  943,708       80.00%
0.7.0.conv2                    Conv2d          2,359,296  1,887,417     80.00%
0.7.0.downsample.0             Conv2d          131,072    0              0.00%
0.7.1.conv1                    Conv2d          2,359,296  1,887,417     80.00%
0.7.1.conv2                    Conv2d          2,359,296  1,887,417     80.00%
--------------------------------------------------------------------------------
Overall                        all             11,166,912 8,353,279     74.80%

Key points about per-layer sparsity:

  • Use a dict mapping layer names to sparsity percentages
  • Requires context='local' (global context doesn’t support non-uniform sparsity)
  • Layer names match those shown in the Sparsity Report (e.g., '0.4.0.conv1')
  • Layers not in the dict are left dense (0% sparsity)
  • The schedule applies uniformly - all layers progress from 0% to their target together

Tip: Use learn.model to explore layer names, or run a uniform sparsity first to see the Sparsity Report with all layer names.

4. Parameter Reference

Core Parameters

Parameter Description Example
sparsity Target sparsity % (float or dict for per-layer) 50 or {'layer1': 30, 'layer2': 70}
granularity Level of sparsification 'weight', 'vector', 'kernel', 'filter'
context How to compute importance 'local' (per-layer) or 'global' (whole model)
criteria Importance measure large_final, small_final, magnitude
schedule How sparsity increases over training one_cycle, cos, lin

Advanced Parameters

Parameter Description
lth Enable Lottery Ticket Hypothesis (reset weights after pruning)
rewind_epoch Epoch to rewind weights to (for LTH)
reset_end Reset weights to original values after training
save_tickets Save intermediate winning tickets
model Apply to specific submodule instead of whole model
round_to Round sparsity to nearest multiple
layer_type Type of layers to sparsify (default: nn.Conv2d)

Summary

Concept Description
Sparsification Setting individual weights to zero while maintaining architecture
SparsifyCallback fastai callback for gradual sparsification during training
Schedule Controls how sparsity increases over training (one_cycle, cos, etc.)
Per-layer sparsity Different sparsity targets for different layers
Typical result 50%+ sparsity with minimal accuracy loss

See Also