Sparsify Callback

Use the sparsifier in fastai Callback system
from fastai.vision.all import *
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

The most important part of our Callback happens in before_batch. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 0.638129 0.776416 0.794317 00:14
1 0.382736 0.266816 0.893775 00:04
2 0.247629 0.310128 0.901218 00:03
3 0.135562 0.175250 0.935724 00:03
4 0.076362 0.165934 0.941813 00:03

Let’s now try adding some sparsity in our model

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

The SparsifyCallback requires a new argument compared to the Sparsifier. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.

You can use any scheduling function already available in fastai or come up with your own ! For more information about the pruning schedules, take a look at the Schedules section.

sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='local', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(5, cbs=sp_cb)
Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.672287 1.130331 0.778078 00:05
1 0.401001 0.315872 0.860622 00:04
2 0.233937 0.191955 0.912043 00:03
3 0.130283 0.208986 0.921516 00:03
4 0.071710 0.202378 0.930988 00:03
Sparsity at the end of epoch 0: [1.96]%
Sparsity at the end of epoch 1: [20.07]%
Sparsity at the end of epoch 2: [45.86]%
Sparsity at the end of epoch 3: [49.74]%
Sparsity at the end of epoch 4: [50.0]%
Final Sparsity: [50.0]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
Layer 2              Conv2d          9,408      4,704         50.00%
Layer 8              Conv2d          36,864     18,432        50.00%
Layer 11             Conv2d          36,864     18,432        50.00%
Layer 14             Conv2d          36,864     18,432        50.00%
Layer 17             Conv2d          36,864     18,432        50.00%
Layer 21             Conv2d          73,728     36,864        50.00%
Layer 24             Conv2d          147,456    73,727        50.00%
Layer 27             Conv2d          8,192      4,096         50.00%
Layer 30             Conv2d          147,456    73,727        50.00%
Layer 33             Conv2d          147,456    73,727        50.00%
Layer 37             Conv2d          294,912    147,455       50.00%
Layer 40             Conv2d          589,824    294,909       50.00%
Layer 43             Conv2d          32,768     16,384        50.00%
Layer 46             Conv2d          589,824    294,909       50.00%
Layer 49             Conv2d          589,824    294,909       50.00%
Layer 53             Conv2d          1,179,648  589,818       50.00%
Layer 56             Conv2d          2,359,296  1,179,637     50.00%
Layer 59             Conv2d          131,072    65,535        50.00%
Layer 62             Conv2d          2,359,296  1,179,637     50.00%
Layer 65             Conv2d          2,359,296  1,179,637     50.00%
--------------------------------------------------------------------------------
Overall              all             11,166,912 5,583,403     50.00%

Surprisingly, our network that is composed of 50% of zeroes performs reasonnably well when compared to our plain and dense network.

The SparsifyCallback also accepts a list of sparsities, corresponding to each layer of layer_type to be pruned. Below, we show how to prune only the intermediate layers of ResNet-18.

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
sparsities = [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]
sp_cb = SparsifyCallback(sparsity=sparsities, granularity='weight', context='local', criteria=large_final, schedule=cos)
learn.fit_one_cycle(5, cbs=sp_cb)
Pruning of weight until a sparsity of [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.700158 0.762811 0.758457 00:04
1 0.420924 0.346230 0.873478 00:04
2 0.250773 0.207668 0.914073 00:04
3 0.141221 0.171472 0.933018 00:04
4 0.074364 0.165237 0.935724 00:04
Sparsity at the end of epoch 0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 1: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 2: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 3: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 4: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Final Sparsity: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
Layer 2              Conv2d          9,408      0              0.00%
Layer 8              Conv2d          36,864     0              0.00%
Layer 11             Conv2d          36,864     0              0.00%
Layer 14             Conv2d          36,864     0              0.00%
Layer 17             Conv2d          36,864     0              0.00%
Layer 21             Conv2d          73,728     0              0.00%
Layer 24             Conv2d          147,456    73,727        50.00%
Layer 27             Conv2d          8,192      4,096         50.00%
Layer 30             Conv2d          147,456    73,727        50.00%
Layer 33             Conv2d          147,456    73,727        50.00%
Layer 37             Conv2d          294,912    147,455       50.00%
Layer 40             Conv2d          589,824    294,909       50.00%
Layer 43             Conv2d          32,768     16,384        50.00%
Layer 46             Conv2d          589,824    294,909       50.00%
Layer 49             Conv2d          589,824    0              0.00%
Layer 53             Conv2d          1,179,648  0              0.00%
Layer 56             Conv2d          2,359,296  0              0.00%
Layer 59             Conv2d          131,072    0              0.00%
Layer 62             Conv2d          2,359,296  0              0.00%
Layer 65             Conv2d          2,359,296  0              0.00%
--------------------------------------------------------------------------------
Overall              all             11,166,912 978,934        8.77%

On top of that, the SparsifyCallbackcan also take many optionnal arguments:

For example, we correctly pruned the convolution layers of our model, but we could imagine pruning the Linear Layers of even only the BatchNorm ones !