Sparsify Callback
Overview
The SparsifyCallback integrates weight sparsification into the fastai training loop. Unlike pruning (which removes structures), sparsification zeros out individual weights while maintaining the original network shape.
Key Features: - Gradual sparsification according to a schedule - Support for Lottery Ticket Hypothesis (LTH) training - Multiple granularity levels (weight, vector, kernel, filter) - Global or local sparsification context
Found permutation search CUDA kernels [ASP][Info] permutation_search_kernels can be imported. —
SparsifyCallback
def SparsifyCallback(
sparsity:Union[float, list[float]], # Target sparsity level(s)
granularity:str, # Type of pruning granularity (e.g., 'weight', 'filter')
context:str, # Pruning context ('global' or 'local')
criteria:Criteria, # Criteria for determining weights to keep
schedule:Schedule, # Pruning schedule to use
lth:bool=False, # Whether to use Lottery Ticket Hypothesis approach
rewind_epoch:int=0, # Epoch to rewind weights to for LTH
reset_end:bool=False, # Whether to reset weights after pruning
save_tickets:bool=False, # Whether to save pruned models as "winning tickets"
model:Optional[nn.Module]=None, # Model to sparsify (if None, uses learn.model)
round_to:Optional[int]=None, # Round pruning to multiple of this value
nm:bool=False, # Whether to use N:M structured sparsity
layer_type:Type[nn.Module]=<class 'torch.nn.modules.conv.Conv2d'>, # Layer type to apply pruning to
):
Basic class handling tweaks of the training loop by changing a Learner in various events
The most important part of our Callback happens in before_batch. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.
The SparsifyCallback requires a new argument compared to the Sparsifier. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.
You can use any scheduling function already available in fastai or come up with your own ! For more information about the pruning schedules, take a look at the Schedules section.
On top of that, the SparsifyCallbackcan also take many optionnal arguments:
lth: whether training using the Lottery Ticket Hypothesis, i.e. reset the weights to their original value at each pruning step (more information in the Lottery Ticket Hypothesis section)rewind_epoch: the epoch used as a reference for the Lottery Ticket Hypothesis with Rewinding (default to 0)reset_end: whether you want to reset the weights to their original values after training (pruning masks are still applied)save_tickets: whether to save intermediate winning tickets.model: pass a model or a part of the model if you don’t want to apply pruning on the whole model trained.round_to: if specified, the weights will be pruned to the closest multiple value ofround_to.layer_type: specify the type of layer that you want to apply pruning to (default to nn.Conv2d)`
Usage Example
from fasterai.sparse.sparsify_callback import SparsifyCallback
from fasterai.core.schedule import cos
from fasterai.core.criteria import large_final
# Gradually sparsify to 50% using cosine schedule
cb = SparsifyCallback(
sparsity=50,
granularity='weight',
context='global',
criteria=large_final,
schedule=cos
)
learn.fit(10, cbs=[cb])With Lottery Ticket Hypothesis
# Train with LTH - rewind weights to epoch 2 values after each pruning step
cb = SparsifyCallback(
sparsity=90,
granularity='weight',
context='global',
criteria=large_final,
schedule=one_cycle,
lth=True,
rewind_epoch=2
)
learn.fit(20, cbs=[cb])See Also
- Sparsifier - Core sparsification class used by this callback
- Schedules - Control sparsification progression (one_shot, agp, etc.)
- Criteria - Importance measures (large_final, movement, etc.)
- Lottery Ticket Tutorial - Finding winning tickets with sparsification