from fastai.vision.all import *
Sparsify Callback
= untar_data(URLs.PETS)
path = get_image_files(path/"images")
files
def label_func(f): return f[0].isupper()
= ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64)) dls
The most important part of our Callback
happens in before_batch
. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.
= vision_learner(dls, resnet18, metrics=accuracy)
learn learn.unfreeze()
5) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.638129 | 0.776416 | 0.794317 | 00:14 |
1 | 0.382736 | 0.266816 | 0.893775 | 00:04 |
2 | 0.247629 | 0.310128 | 0.901218 | 00:03 |
3 | 0.135562 | 0.175250 | 0.935724 | 00:03 |
4 | 0.076362 | 0.165934 | 0.941813 | 00:03 |
Let’s now try adding some sparsity in our model
= vision_learner(dls, resnet18, metrics=accuracy)
learn learn.unfreeze()
The SparsifyCallback
requires a new argument compared to the Sparsifier
. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.
You can use any scheduling function already available in fastai or come up with your own ! For more information about the pruning schedules, take a look at the Schedules section.
= SparsifyCallback(sparsity=50, granularity='weight', context='local', criteria=large_final, schedule=one_cycle) sp_cb
5, cbs=sp_cb) learn.fit_one_cycle(
Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.672287 | 1.130331 | 0.778078 | 00:05 |
1 | 0.401001 | 0.315872 | 0.860622 | 00:04 |
2 | 0.233937 | 0.191955 | 0.912043 | 00:03 |
3 | 0.130283 | 0.208986 | 0.921516 | 00:03 |
4 | 0.071710 | 0.202378 | 0.930988 | 00:03 |
Sparsity at the end of epoch 0: [1.96]%
Sparsity at the end of epoch 1: [20.07]%
Sparsity at the end of epoch 2: [45.86]%
Sparsity at the end of epoch 3: [49.74]%
Sparsity at the end of epoch 4: [50.0]%
Final Sparsity: [50.0]%
Sparsity Report:
--------------------------------------------------------------------------------
Layer Type Params Zeros Sparsity
--------------------------------------------------------------------------------
Layer 2 Conv2d 9,408 4,704 50.00%
Layer 8 Conv2d 36,864 18,432 50.00%
Layer 11 Conv2d 36,864 18,432 50.00%
Layer 14 Conv2d 36,864 18,432 50.00%
Layer 17 Conv2d 36,864 18,432 50.00%
Layer 21 Conv2d 73,728 36,864 50.00%
Layer 24 Conv2d 147,456 73,727 50.00%
Layer 27 Conv2d 8,192 4,096 50.00%
Layer 30 Conv2d 147,456 73,727 50.00%
Layer 33 Conv2d 147,456 73,727 50.00%
Layer 37 Conv2d 294,912 147,455 50.00%
Layer 40 Conv2d 589,824 294,909 50.00%
Layer 43 Conv2d 32,768 16,384 50.00%
Layer 46 Conv2d 589,824 294,909 50.00%
Layer 49 Conv2d 589,824 294,909 50.00%
Layer 53 Conv2d 1,179,648 589,818 50.00%
Layer 56 Conv2d 2,359,296 1,179,637 50.00%
Layer 59 Conv2d 131,072 65,535 50.00%
Layer 62 Conv2d 2,359,296 1,179,637 50.00%
Layer 65 Conv2d 2,359,296 1,179,637 50.00%
--------------------------------------------------------------------------------
Overall all 11,166,912 5,583,403 50.00%
Surprisingly, our network that is composed of
The SparsifyCallback
also accepts a list of sparsities, corresponding to each layer of layer_type
to be pruned. Below, we show how to prune only the intermediate layers of ResNet-18.
= vision_learner(dls, resnet18, metrics=accuracy)
learn learn.unfreeze()
= [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0] sparsities
= SparsifyCallback(sparsity=sparsities, granularity='weight', context='local', criteria=large_final, schedule=cos) sp_cb
5, cbs=sp_cb) learn.fit_one_cycle(
Pruning of weight until a sparsity of [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]%
Saving Weights at epoch 0
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.700158 | 0.762811 | 0.758457 | 00:04 |
1 | 0.420924 | 0.346230 | 0.873478 | 00:04 |
2 | 0.250773 | 0.207668 | 0.914073 | 00:04 |
3 | 0.141221 | 0.171472 | 0.933018 | 00:04 |
4 | 0.074364 | 0.165237 | 0.935724 | 00:04 |
Sparsity at the end of epoch 0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 1: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 2: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 3: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 4: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Final Sparsity: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity Report:
--------------------------------------------------------------------------------
Layer Type Params Zeros Sparsity
--------------------------------------------------------------------------------
Layer 2 Conv2d 9,408 0 0.00%
Layer 8 Conv2d 36,864 0 0.00%
Layer 11 Conv2d 36,864 0 0.00%
Layer 14 Conv2d 36,864 0 0.00%
Layer 17 Conv2d 36,864 0 0.00%
Layer 21 Conv2d 73,728 0 0.00%
Layer 24 Conv2d 147,456 73,727 50.00%
Layer 27 Conv2d 8,192 4,096 50.00%
Layer 30 Conv2d 147,456 73,727 50.00%
Layer 33 Conv2d 147,456 73,727 50.00%
Layer 37 Conv2d 294,912 147,455 50.00%
Layer 40 Conv2d 589,824 294,909 50.00%
Layer 43 Conv2d 32,768 16,384 50.00%
Layer 46 Conv2d 589,824 294,909 50.00%
Layer 49 Conv2d 589,824 0 0.00%
Layer 53 Conv2d 1,179,648 0 0.00%
Layer 56 Conv2d 2,359,296 0 0.00%
Layer 59 Conv2d 131,072 0 0.00%
Layer 62 Conv2d 2,359,296 0 0.00%
Layer 65 Conv2d 2,359,296 0 0.00%
--------------------------------------------------------------------------------
Overall all 11,166,912 978,934 8.77%
On top of that, the SparsifyCallback
can also take many optionnal arguments:
lth
: whether training using the Lottery Ticket Hypothesis, i.e. reset the weights to their original value at each pruning step (more information in the Lottery Ticket Hypothesis section)rewind_epoch
: the epoch used as a reference for the Lottery Ticket Hypothesis with Rewinding (default to 0)reset_end
: whether you want to reset the weights to their original values after training (pruning masks are still applied)save_tickets
: whether to save intermediate winning tickets.model
: pass a model or a part of the model if you don’t want to apply pruning on the whole model trained.round_to
: if specified, the weights will be pruned to the closest multiple value ofround_to
.layer_type
: specify the type of layer that you want to apply pruning to (default to nn.Conv2d)`
For example, we correctly pruned the convolution layers of our model, but we could imagine pruning the Linear Layers of even only the BatchNorm ones !