Sparsify Callback

Use the sparsifier in fastai Callback system
from fastai.vision.all import *
/home/nathan/miniconda3/envs/dev/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Overview

Sparsification sets individual weights to zero during training, creating sparse networks that can be more efficient for inference. Unlike structured pruning (which removes entire filters), sparsification maintains the original architecture while introducing zeros.

Why Use Sparsification?

Approach What’s Removed Architecture Hardware Support
Sparsification Individual weights Unchanged Sparse accelerators
Structured Pruning Entire filters/channels Changed Standard hardware

Key Benefits

  • Gradual sparsity - Weights are progressively zeroed during training
  • Maintained accuracy - Network adapts to sparsity during training
  • Flexible targeting - Choose which layers and how much to sparsify
  • Schedule control - Use one-cycle, cosine, or custom schedules

1. Setup and Data

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

2. Baseline: Dense Model

First, let’s train a standard dense model to establish baseline accuracy:

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 0.651048 2.844684 0.715156 00:02
1 0.385961 0.352643 0.885656 00:02
2 0.252964 0.289159 0.889716 00:01
3 0.140969 0.208180 0.922869 00:02
4 0.078336 0.197539 0.923545 00:01

3. Training with SparsifyCallback

Now let’s train with 50% sparsity. The SparsifyCallback gradually introduces zeros during training according to the specified schedule:

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

The callback requires a schedule parameter that controls how sparsity increases over training. You can use any fastai annealing function or define your own.

sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='local', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(5, cbs=sp_cb)
Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.703576 1.049896 0.809202 00:02
1 0.389609 0.374904 0.873478 00:02
2 0.252554 0.238450 0.900541 00:02
3 0.143394 0.197716 0.920162 00:02
4 0.081234 0.182868 0.932341 00:02
Sparsity at the end of epoch 0: [1.96]%
Sparsity at the end of epoch 1: [20.07]%
Sparsity at the end of epoch 2: [45.86]%
Sparsity at the end of epoch 3: [49.74]%
Sparsity at the end of epoch 4: [50.0]%
Final Sparsity: [50.0]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
Layer 0              Conv2d          9,408      4,704         50.00%
Layer 1              Conv2d          36,864     18,432        50.00%
Layer 2              Conv2d          36,864     18,432        50.00%
Layer 3              Conv2d          36,864     18,432        50.00%
Layer 4              Conv2d          36,864     18,432        50.00%
Layer 5              Conv2d          73,728     36,864        50.00%
Layer 6              Conv2d          147,456    73,727        50.00%
Layer 7              Conv2d          8,192      4,096         50.00%
Layer 8              Conv2d          147,456    73,727        50.00%
Layer 9              Conv2d          147,456    73,727        50.00%
Layer 10             Conv2d          294,912    147,455       50.00%
Layer 11             Conv2d          589,824    294,909       50.00%
Layer 12             Conv2d          32,768     16,384        50.00%
Layer 13             Conv2d          589,824    294,909       50.00%
Layer 14             Conv2d          589,824    294,909       50.00%
Layer 15             Conv2d          1,179,648  589,818       50.00%
Layer 16             Conv2d          2,359,296  1,179,637     50.00%
Layer 17             Conv2d          131,072    65,535        50.00%
Layer 18             Conv2d          2,359,296  1,179,637     50.00%
Layer 19             Conv2d          2,359,296  1,179,637     50.00%
--------------------------------------------------------------------------------
Overall              all             11,166,912 5,583,403     50.00%

Despite having 50% of weights set to zero, the sparse model performs comparably to the dense baseline!

4. Per-Layer Sparsity

You can specify different sparsity levels for each layer by passing a list. This is useful for preserving capacity in critical layers (like early/late layers) while sparsifying intermediate layers more aggressively:

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
sparsities = [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]
sp_cb = SparsifyCallback(sparsity=sparsities, granularity='weight', context='local', criteria=large_final, schedule=cos)
learn.fit_one_cycle(5, cbs=sp_cb)
Pruning of weight until a sparsity of [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.692132 2.065659 0.738836 00:02
1 0.400833 0.257213 0.891746 00:02
2 0.240173 0.262235 0.893099 00:02
3 0.136817 0.186124 0.927605 00:02
4 0.073321 0.186889 0.934371 00:02
Sparsity at the end of epoch 0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 1: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 2: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 3: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 4: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Final Sparsity: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
Layer 0              Conv2d          9,408      0              0.00%
Layer 1              Conv2d          36,864     0              0.00%
Layer 2              Conv2d          36,864     0              0.00%
Layer 3              Conv2d          36,864     0              0.00%
Layer 4              Conv2d          36,864     0              0.00%
Layer 5              Conv2d          73,728     0              0.00%
Layer 6              Conv2d          147,456    73,727        50.00%
Layer 7              Conv2d          8,192      4,096         50.00%
Layer 8              Conv2d          147,456    73,727        50.00%
Layer 9              Conv2d          147,456    73,727        50.00%
Layer 10             Conv2d          294,912    147,454       50.00%
Layer 11             Conv2d          589,824    294,909       50.00%
Layer 12             Conv2d          32,768     16,384        50.00%
Layer 13             Conv2d          589,824    294,909       50.00%
Layer 14             Conv2d          589,824    0              0.00%
Layer 15             Conv2d          1,179,648  0              0.00%
Layer 16             Conv2d          2,359,296  0              0.00%
Layer 17             Conv2d          131,072    0              0.00%
Layer 18             Conv2d          2,359,296  0              0.00%
Layer 19             Conv2d          2,359,296  0              0.00%
--------------------------------------------------------------------------------
Overall              all             11,166,912 978,933        8.77%

5. Parameter Reference

Core Parameters

Parameter Description Example
sparsity Target sparsity % (single value or list per layer) 50 or [0, 50, 50, 0]
granularity Level of sparsification 'weight', 'vector', 'kernel', 'filter'
context How to compute importance 'local' (per-layer) or 'global' (whole model)
criteria Importance measure large_final, small_final, magnitude
schedule How sparsity increases over training one_cycle, cos, linear

Advanced Parameters

Parameter Description
lth Enable Lottery Ticket Hypothesis (reset weights after pruning)
rewind_epoch Epoch to rewind weights to (for LTH)
reset_end Reset weights to original values after training
save_tickets Save intermediate winning tickets
model Apply to specific submodule instead of whole model
round_to Round sparsity to nearest multiple
layer_type Type of layers to sparsify (default: nn.Conv2d)

Summary

Concept Description
Sparsification Setting individual weights to zero while maintaining architecture
SparsifyCallback fastai callback for gradual sparsification during training
Schedule Controls how sparsity increases over training (one_cycle, cos, etc.)
Per-layer sparsity Different sparsity targets for different layers
Typical result 50%+ sparsity with minimal accuracy loss

See Also