Schedules

Make your neural network sparse with fastai

Neural Network Pruning usually follows one of the next 3 schedules:

In fasterai, all those 3 schedules can be applied from the same callback. We’ll cover each below

In the SparsifyCallback, there are several parameters to ‘shape’ our pruning schedule: * start_sparsity: the initial sparsity of our model, generally kept at 0 as after initialization, our weights are generally non-zero. * end_sparsity: the target sparsity at the end of the training * start_epoch: we can decide to start pruning right from the beginning or let it train a bit before removing weights. * sched_func: this is where the general shape of the schedule is specified as it specifies how the sparsity evolves along the training. You can either use a schedule available in fastai our even coming with your own !


path = untar_data(URLs.PETS)

files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)

We will first train a network without any pruning, which will serve as a baseline.

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

learn.fit_one_cycle(10)
epoch train_loss valid_loss accuracy time
0 0.726983 0.501899 0.843031 00:03
1 0.447103 0.355498 0.862652 00:03
2 0.288626 0.217356 0.910690 00:03
3 0.213532 0.294699 0.891746 00:03
4 0.178062 0.243345 0.909337 00:03
5 0.132118 0.236394 0.917456 00:03
6 0.092771 0.194155 0.933018 00:03
7 0.055006 0.248453 0.922192 00:03
8 0.029332 0.199173 0.937754 00:03
9 0.018201 0.205575 0.941813 00:03

One-Shot Pruning

The simplest way to perform pruning is called One-Shot Pruning. It consists of the following three steps:

  1. You first need to train a network
  2. You then need to remove some weights (depending on your criteria, needs,…)
  3. You fine-tune the remaining weights to recover from the loss of parameters.

With fasterai, this is really easy to do. Let’s illustrate it by an example:

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In this case, your network needs to be trained before pruning. This training can be done independently from the pruning callback, or simulated by the start_epoch that will delay the pruning process.

You thus only need to create the Callback with the one_shot schedule and set the start_epoch argument, i.e. how many epochs you want to train your network before pruning it.

sp_cb=SparsifyCallback(sparsity=90, granularity='weight', context='local', criteria=large_final, schedule=one_shot)

Let’s start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

learn.fit(10, cbs=sp_cb)
Pruning of weight until a sparsity of [90]%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.528262 0.357351 0.837618 00:04
1 0.363505 0.513674 0.774019 00:03
2 0.250641 0.268071 0.893775 00:03
3 0.211224 0.268274 0.896482 00:04
4 0.203224 0.264878 0.885656 00:03
5 0.236911 0.248686 0.893099 00:04
6 0.178409 0.239735 0.903924 00:04
7 0.128331 0.258517 0.896482 00:04
8 0.110764 0.218122 0.912720 00:04
9 0.077449 0.279049 0.907307 00:04
Sparsity at the end of epoch 0: [0.0]%
Sparsity at the end of epoch 1: [0.0]%
Sparsity at the end of epoch 2: [0.0]%
Sparsity at the end of epoch 3: [0.0]%
Sparsity at the end of epoch 4: [90.0]%
Sparsity at the end of epoch 5: [90.0]%
Sparsity at the end of epoch 6: [90.0]%
Sparsity at the end of epoch 7: [90.0]%
Sparsity at the end of epoch 8: [90.0]%
Sparsity at the end of epoch 9: [90.0]%
Final Sparsity: [90]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
Layer 2              Conv2d          9,408      8,467         90.00%
Layer 8              Conv2d          36,864     33,177        90.00%
Layer 11             Conv2d          36,864     33,177        90.00%
Layer 14             Conv2d          36,864     33,177        90.00%
Layer 17             Conv2d          36,864     33,177        90.00%
Layer 21             Conv2d          73,728     66,355        90.00%
Layer 24             Conv2d          147,456    132,710       90.00%
Layer 27             Conv2d          8,192      7,372         89.99%
Layer 30             Conv2d          147,456    132,710       90.00%
Layer 33             Conv2d          147,456    132,710       90.00%
Layer 37             Conv2d          294,912    265,420       90.00%
Layer 40             Conv2d          589,824    530,841       90.00%
Layer 43             Conv2d          32,768     29,491        90.00%
Layer 46             Conv2d          589,824    530,841       90.00%
Layer 49             Conv2d          589,824    530,841       90.00%
Layer 53             Conv2d          1,179,648  1,061,683     90.00%
Layer 56             Conv2d          2,359,296  2,123,366     90.00%
Layer 59             Conv2d          131,072    117,964       90.00%
Layer 62             Conv2d          2,359,296  2,123,366     90.00%
Layer 65             Conv2d          2,359,296  2,123,366     90.00%
--------------------------------------------------------------------------------
Overall              all             11,166,912 10,050,211    90.00%

Iterative Pruning

Researchers have come up with a better way to do pruning than pruning all the weigths in once (as in One-Shot Pruning). The idea is to perform several iterations of pruning and fine-tuning and is thus called Iterative Pruning.

  1. You first need to train a network
  2. You then need to remove a part of the weights weights (depending on your criteria, needs,…)
  3. You fine-tune the remaining weights to recover from the loss of parameters.
  4. Back to step 2.
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In this case, your network needs to be trained before pruning.

You only need to create the Callback with the iterative schedule and set the start_epoch argument, i.e. how many epochs you want to train your network before pruning it.

The iterative schedules has a n_stepsparameter, i.e. how many iterations of pruning/fine-tuning you want to perform. To modify its value, we can use the partial function like this:

iterative = partial(iterative, n_steps=5)
sp_cb=SparsifyCallback(sparsity=90, granularity='weight', context='local', criteria=large_final, schedule=iterative)

Let’s start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

learn.fit(10, cbs=sp_cb)
Pruning of weight until a sparsity of [90]%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.610821 0.347474 0.842355 00:03
1 0.398858 0.563407 0.859269 00:03
2 0.302287 0.310804 0.871448 00:04
3 0.222971 0.373315 0.855886 00:04
4 0.175162 0.256725 0.901894 00:04
5 0.140398 0.234549 0.916779 00:04
6 0.100763 0.231317 0.912720 00:04
7 0.252507 0.318962 0.876861 00:04
8 0.186143 0.240475 0.903924 00:04
9 0.151650 0.222071 0.910014 00:04
Sparsity at the end of epoch 0: [0.0]%
Sparsity at the end of epoch 1: [0.0]%
Sparsity at the end of epoch 2: [30.0]%
Sparsity at the end of epoch 3: [30.0]%
Sparsity at the end of epoch 4: [60.0]%
Sparsity at the end of epoch 5: [60.0]%
Sparsity at the end of epoch 6: [60.0]%
Sparsity at the end of epoch 7: [90.0]%
Sparsity at the end of epoch 8: [90.0]%
Sparsity at the end of epoch 9: [90.0]%
Final Sparsity: [90.0]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
Layer 2              Conv2d          9,408      8,467         90.00%
Layer 8              Conv2d          36,864     33,177        90.00%
Layer 11             Conv2d          36,864     33,177        90.00%
Layer 14             Conv2d          36,864     33,177        90.00%
Layer 17             Conv2d          36,864     33,177        90.00%
Layer 21             Conv2d          73,728     66,355        90.00%
Layer 24             Conv2d          147,456    132,710       90.00%
Layer 27             Conv2d          8,192      7,372         89.99%
Layer 30             Conv2d          147,456    132,710       90.00%
Layer 33             Conv2d          147,456    132,710       90.00%
Layer 37             Conv2d          294,912    265,420       90.00%
Layer 40             Conv2d          589,824    530,841       90.00%
Layer 43             Conv2d          32,768     29,491        90.00%
Layer 46             Conv2d          589,824    530,841       90.00%
Layer 49             Conv2d          589,824    530,841       90.00%
Layer 53             Conv2d          1,179,648  1,061,683     90.00%
Layer 56             Conv2d          2,359,296  2,123,366     90.00%
Layer 59             Conv2d          131,072    117,964       90.00%
Layer 62             Conv2d          2,359,296  2,123,366     90.00%
Layer 65             Conv2d          2,359,296  2,123,366     90.00%
--------------------------------------------------------------------------------
Overall              all             11,166,912 10,050,211    90.00%

Gradual Pruning

Here is for example how to implement the Automated Gradual Pruning schedule.

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
sp_cb=SparsifyCallback(sparsity=90, granularity='weight', context='local', criteria=large_final, schedule=agp)

Let’s start pruning after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

learn.fit(10, cbs=sp_cb)
Pruning of weight until a sparsity of [90]%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.606718 0.444957 0.817321 00:03
1 0.404780 0.325938 0.863329 00:03
2 0.308993 0.298605 0.873478 00:04
3 0.232805 0.356629 0.865359 00:04
4 0.184159 0.285140 0.891069 00:04
5 0.176350 0.300211 0.882273 00:04
6 0.163846 0.267521 0.891069 00:04
7 0.158822 0.218856 0.909337 00:04
8 0.115987 0.269380 0.909337 00:04
9 0.079022 0.240360 0.915426 00:04
Sparsity at the end of epoch 0: [0.0]%
Sparsity at the end of epoch 1: [0.0]%
Sparsity at the end of epoch 2: [29.71]%
Sparsity at the end of epoch 3: [52.03]%
Sparsity at the end of epoch 4: [68.03]%
Sparsity at the end of epoch 5: [78.75]%
Sparsity at the end of epoch 6: [85.25]%
Sparsity at the end of epoch 7: [88.59]%
Sparsity at the end of epoch 8: [89.82]%
Sparsity at the end of epoch 9: [90.0]%
Final Sparsity: [90.0]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
Layer 2              Conv2d          9,408      8,467         90.00%
Layer 8              Conv2d          36,864     33,177        90.00%
Layer 11             Conv2d          36,864     33,177        90.00%
Layer 14             Conv2d          36,864     33,177        90.00%
Layer 17             Conv2d          36,864     33,177        90.00%
Layer 21             Conv2d          73,728     66,355        90.00%
Layer 24             Conv2d          147,456    132,710       90.00%
Layer 27             Conv2d          8,192      7,372         89.99%
Layer 30             Conv2d          147,456    132,710       90.00%
Layer 33             Conv2d          147,456    132,710       90.00%
Layer 37             Conv2d          294,912    265,420       90.00%
Layer 40             Conv2d          589,824    530,841       90.00%
Layer 43             Conv2d          32,768     29,491        90.00%
Layer 46             Conv2d          589,824    530,841       90.00%
Layer 49             Conv2d          589,824    530,841       90.00%
Layer 53             Conv2d          1,179,648  1,061,683     90.00%
Layer 56             Conv2d          2,359,296  2,123,366     90.00%
Layer 59             Conv2d          131,072    117,964       90.00%
Layer 62             Conv2d          2,359,296  2,123,366     90.00%
Layer 65             Conv2d          2,359,296  2,123,366     90.00%
--------------------------------------------------------------------------------
Overall              all             11,166,912 10,050,211    90.00%

Even though they are often considered as different pruning methods, those 3 schedules can be captured by the same Callback. Here is how the sparsity in the network evolves for those methods;

Let’s take an example here. Let’s say that we want to train our network for 3 epochs without pruning and then 7 epochs with pruning.

Then this is what our different pruning schedules will look like:

You can also come up with your own pruning schedule !