Sparsify Callback
Use the sparsifier in fastai Callback system
SparsifyCallback
SparsifyCallback (sparsity:Union[float,List[float]], granularity:str, context:str, criteria:fasterai.core.criteria.Criteria, schedule:fasterai.core.schedule.Schedule, lth:bool=False, rewind_epoch:int=0, reset_end:bool=False, save_tickets:bool=False, model:Optional[torch.nn.modules.module.Module]=None, round_to:Optional[int]=None, nm:bool=False, layer_type:Type[torch.nn.modules.module.Module]=<class 'torch.nn.modules.conv.Conv2d'>)
Basic class handling tweaks of the training loop by changing a Learner
in various events
Type | Default | Details | |
---|---|---|---|
sparsity | Union | Target sparsity level(s) | |
granularity | str | Type of pruning granularity (e.g., ‘weight’, ‘filter’) | |
context | str | Pruning context (‘global’ or ‘local’) | |
criteria | Criteria | Criteria for determining weights to keep | |
schedule | Schedule | Pruning schedule to use | |
lth | bool | False | Whether to use Lottery Ticket Hypothesis approach |
rewind_epoch | int | 0 | Epoch to rewind weights to for LTH |
reset_end | bool | False | Whether to reset weights after pruning |
save_tickets | bool | False | Whether to save pruned models as “winning tickets” |
model | Optional | None | Model to sparsify (if None, uses learn.model) |
round_to | Optional | None | Round pruning to multiple of this value |
nm | bool | False | Whether to use N:M structured sparsity |
layer_type | Type | Conv2d | Layer type to apply pruning to |
The most important part of our Callback
happens in before_batch
. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.
The SparsifyCallback
requires a new argument compared to the Sparsifier
. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.
You can use any scheduling function already available in fastai or come up with your own ! For more information about the pruning schedules, take a look at the Schedules section.
On top of that, the SparsifyCallback
can also take many optionnal arguments:
lth
: whether training using the Lottery Ticket Hypothesis, i.e. reset the weights to their original value at each pruning step (more information in the Lottery Ticket Hypothesis section)rewind_epoch
: the epoch used as a reference for the Lottery Ticket Hypothesis with Rewinding (default to 0)reset_end
: whether you want to reset the weights to their original values after training (pruning masks are still applied)save_tickets
: whether to save intermediate winning tickets.model
: pass a model or a part of the model if you don’t want to apply pruning on the whole model trained.round_to
: if specified, the weights will be pruned to the closest multiple value ofround_to
.layer_type
: specify the type of layer that you want to apply pruning to (default to nn.Conv2d)`