Sparsify Callback

Use the sparsifier in fastai Callback system
from fastai.vision.all import *

Overview

Sparsification sets individual weights to zero during training, creating sparse networks that can be more efficient for inference. Unlike structured pruning (which removes entire filters), sparsification maintains the original architecture while introducing zeros.

Why Use Sparsification?

Approach What’s Removed Architecture Hardware Support
Sparsification Individual weights Unchanged Sparse accelerators
Structured Pruning Entire filters/channels Changed Standard hardware

Key Benefits

  • Gradual sparsity - Weights are progressively zeroed during training
  • Maintained accuracy - Network adapts to sparsity during training
  • Flexible targeting - Choose which layers and how much to sparsify
  • Schedule control - Use one-cycle, cosine, or custom schedules

1. Setup and Data

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

2. Baseline: Dense Model

First, let’s train a standard dense model to establish baseline accuracy:

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 0.732612 0.397222 0.839648 00:04
1 0.394582 0.260210 0.887686 00:04
2 0.218636 0.235590 0.907307 00:04
3 0.118740 0.200626 0.922869 00:04
4 0.078772 0.187712 0.922869 00:04

3. Training with SparsifyCallback

Now let’s train with 50% sparsity. The SparsifyCallback gradually introduces zeros during training according to the specified schedule:

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

The callback requires a schedule parameter that controls how sparsity increases over training. You can use any fastai annealing function or define your own.

sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='local', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(5, cbs=sp_cb)
Pruning of weight until a sparsity of 50%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.662926 1.296763 0.810555 00:07
1 0.376402 0.278251 0.883627 00:06
2 0.243227 0.213432 0.911367 00:07
3 0.130433 0.186261 0.930311 00:07
4 0.079553 0.165558 0.934371 00:06
Sparsity at the end of epoch 0: 1.96%
Sparsity at the end of epoch 1: 20.07%
Sparsity at the end of epoch 2: 45.86%
Sparsity at the end of epoch 3: 49.74%
Sparsity at the end of epoch 4: 50.00%
Final Sparsity: 50.00%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                          Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
0.0                            Conv2d          9,408      4,704         50.00%
0.4.0.conv1                    Conv2d          36,864     18,432        50.00%
0.4.0.conv2                    Conv2d          36,864     18,432        50.00%
0.4.1.conv1                    Conv2d          36,864     18,432        50.00%
0.4.1.conv2                    Conv2d          36,864     18,432        50.00%
0.5.0.conv1                    Conv2d          73,728     36,864        50.00%
0.5.0.conv2                    Conv2d          147,456    73,727        50.00%
0.5.0.downsample.0             Conv2d          8,192      4,096         50.00%
0.5.1.conv1                    Conv2d          147,456    73,727        50.00%
0.5.1.conv2                    Conv2d          147,456    73,727        50.00%
0.6.0.conv1                    Conv2d          294,912    147,455       50.00%
0.6.0.conv2                    Conv2d          589,824    294,909       50.00%
0.6.0.downsample.0             Conv2d          32,768     16,384        50.00%
0.6.1.conv1                    Conv2d          589,824    294,909       50.00%
0.6.1.conv2                    Conv2d          589,824    294,909       50.00%
0.7.0.conv1                    Conv2d          1,179,648  589,818       50.00%
0.7.0.conv2                    Conv2d          2,359,296  1,179,637     50.00%
0.7.0.downsample.0             Conv2d          131,072    65,535        50.00%
0.7.1.conv1                    Conv2d          2,359,296  1,179,637     50.00%
0.7.1.conv2                    Conv2d          2,359,296  1,179,637     50.00%
--------------------------------------------------------------------------------
Overall                        all             11,166,912 5,583,403     50.00%

Despite having 50% of weights set to zero, the sparse model performs comparably to the dense baseline!

3b. Per-Layer Sparsity

Different layers have different sensitivities to sparsification. Early layers often need more weights to preserve low-level features, while deeper layers can tolerate higher sparsity. You can specify per-layer targets using a dictionary:

# Define different sparsity targets for different layers
per_layer_sparsity = {
    '0.4.0.conv1': 30,   # Early layers: lower sparsity (more sensitive)
    '0.4.0.conv2': 30,
    '0.4.1.conv1': 30,
    '0.4.1.conv2': 30,
    '0.5.0.conv1': 50,   # Middle layers: medium sparsity
    '0.5.0.conv2': 50,
    '0.5.1.conv1': 50,
    '0.5.1.conv2': 50,
    '0.6.0.conv1': 70,   # Deeper layers: higher sparsity (more redundant)
    '0.6.0.conv2': 70,
    '0.6.1.conv1': 70,
    '0.6.1.conv2': 70,
    '0.7.0.conv1': 80,   # Deepest layers: highest sparsity
    '0.7.0.conv2': 80,
    '0.7.1.conv1': 80,
    '0.7.1.conv2': 80,
}
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

# Use dict for per-layer sparsity - requires 'local' context
sp_cb = SparsifyCallback(
    sparsity=per_layer_sparsity, 
    granularity='weight', 
    context='local',  # Required for per-layer sparsity
    criteria=large_final, 
    schedule=cos
)

learn.fit_one_cycle(5, cbs=sp_cb)
Pruning of weight until a sparsity of {'0.4.0.conv1': 30, '0.4.0.conv2': 30, '0.4.1.conv1': 30, '0.4.1.conv2': 30, '0.5.0.conv1': 50, '0.5.0.conv2': 50, '0.5.1.conv1': 50, '0.5.1.conv2': 50, '0.6.0.conv1': 70, '0.6.0.conv2': 70, '0.6.1.conv1': 70, '0.6.1.conv2': 70, '0.7.0.conv1': 80, '0.7.0.conv2': 80, '0.7.1.conv1': 80, '0.7.1.conv2': 80}%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.702893 0.432825 0.829499 00:06
1 0.395077 0.314297 0.887010 00:06
2 0.229694 0.263221 0.892422 00:05
3 0.132596 0.182942 0.930311 00:06
4 0.077698 0.172972 0.935724 00:07
Sparsity at the end of epoch 0: avg=5.49%
Sparsity at the end of epoch 1: avg=19.87%
Sparsity at the end of epoch 2: avg=37.63%
Sparsity at the end of epoch 3: avg=52.01%
Sparsity at the end of epoch 4: avg=57.50%
Final Sparsity: {'0.4.0.conv1': 30.0, '0.4.0.conv2': 30.0, '0.4.1.conv1': 30.0, '0.4.1.conv2': 30.0, '0.5.0.conv1': 50.0, '0.5.0.conv2': 50.0, '0.5.1.conv1': 50.0, '0.5.1.conv2': 50.0, '0.6.0.conv1': 70.0, '0.6.0.conv2': 70.0, '0.6.1.conv1': 70.0, '0.6.1.conv2': 70.0, '0.7.0.conv1': 80.0, '0.7.0.conv2': 80.0, '0.7.1.conv1': 80.0, '0.7.1.conv2': 80.0}

Sparsity Report:
--------------------------------------------------------------------------------
Layer                          Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
0.0                            Conv2d          9,408      0              0.00%
0.4.0.conv1                    Conv2d          36,864     11,059        30.00%
0.4.0.conv2                    Conv2d          36,864     11,059        30.00%
0.4.1.conv1                    Conv2d          36,864     11,059        30.00%
0.4.1.conv2                    Conv2d          36,864     11,059        30.00%
0.5.0.conv1                    Conv2d          73,728     36,864        50.00%
0.5.0.conv2                    Conv2d          147,456    73,727        50.00%
0.5.0.downsample.0             Conv2d          8,192      0              0.00%
0.5.1.conv1                    Conv2d          147,456    73,727        50.00%
0.5.1.conv2                    Conv2d          147,456    73,727        50.00%
0.6.0.conv1                    Conv2d          294,912    206,436       70.00%
0.6.0.conv2                    Conv2d          589,824    412,872       70.00%
0.6.0.downsample.0             Conv2d          32,768     0              0.00%
0.6.1.conv1                    Conv2d          589,824    412,872       70.00%
0.6.1.conv2                    Conv2d          589,824    412,872       70.00%
0.7.0.conv1                    Conv2d          1,179,648  943,709       80.00%
0.7.0.conv2                    Conv2d          2,359,296  1,887,418     80.00%
0.7.0.downsample.0             Conv2d          131,072    0              0.00%
0.7.1.conv1                    Conv2d          2,359,296  1,887,418     80.00%
0.7.1.conv2                    Conv2d          2,359,296  1,887,418     80.00%
--------------------------------------------------------------------------------
Overall                        all             11,166,912 8,353,296     74.80%

Key points about per-layer sparsity:

  • Use a dict mapping layer names to sparsity percentages
  • Requires context='local' (global context doesn’t support non-uniform sparsity)
  • Layer names match those shown in the Sparsity Report (e.g., '0.4.0.conv1')
  • Layers not in the dict are left dense (0% sparsity)
  • The schedule applies uniformly - all layers progress from 0% to their target together

Tip: Use learn.model to explore layer names, or run a uniform sparsity first to see the Sparsity Report with all layer names.

4. Parameter Reference

Core Parameters

Parameter Description Example
sparsity Target sparsity % (float or dict for per-layer) 50 or {'layer1': 30, 'layer2': 70}
granularity Level of sparsification 'weight', 'vector', 'kernel', 'filter'
context How to compute importance 'local' (per-layer) or 'global' (whole model)
criteria Importance measure large_final, small_final, magnitude
schedule How sparsity increases over training one_cycle, cos, lin

Advanced Parameters

Parameter Description
lth Enable Lottery Ticket Hypothesis (reset weights after pruning)
rewind_epoch Epoch to rewind weights to (for LTH)
reset_end Reset weights to original values after training
save_tickets Save intermediate winning tickets
model Apply to specific submodule instead of whole model
round_to Round sparsity to nearest multiple
layer_type Type of layers to sparsify (default: nn.Conv2d)

Summary

Concept Description
Sparsification Setting individual weights to zero while maintaining architecture
SparsifyCallback fastai callback for gradual sparsification during training
Schedule Controls how sparsity increases over training (one_cycle, cos, etc.)
Per-layer sparsity Different sparsity targets for different layers
Typical result 50%+ sparsity with minimal accuracy loss

See Also