Lottery Ticket Hypothesis

How to find winning tickets with fastai

The Lottery Ticket Hypothesis

The Lottery Ticket Hypothesis is a really intriguing discovery made in 2019 by Frankle & Carbin. It states that:

A randomly-initialized, dense neural network contains a subnetwork that is initialised such that — when trained in isolation — it can match the test accuracy of the original network after training for at most the same number of iterations.

Meaning that, once we find that subnetwork. Every other parameter in the network becomes useless.

The way authors propose to find those subnetwork is as follows:

  1. Initialize the neural network
  2. Train it to convergence
  3. Prune the smallest magnitude weights by creating a mask \(m\)
  4. Reinitialize the weights to their original value; i.e at iteration \(0\).
  5. Repeat from step 2 until reaching the desired level of sparsity.
from fasterai.sparse.all import *
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)

What we are trying to prove is that: in a neural network A, there exists a subnetwork B able to get an accuracy \(a_B > a_A\), in a training time \(t_B < t_A\).

Let’s get the baseline for network A:

learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)

Let’s save original weights

initial_weights = deepcopy(learn.model.state_dict())
learn.fit(5, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.581151 1.603907 0.666441 00:03
1 0.547510 0.717316 0.691475 00:03
2 0.517336 0.621597 0.628552 00:03
3 0.477595 1.084812 0.438430 00:03
4 0.446734 0.736970 0.602842 00:03

We now have our accuracy \(a_A\) of \(79\%\) and our training time \(t_A\) of \(5\) epochs

To find the lottery ticket, we will perform iterative pruning but, at each pruning step we will re-initialize the remaining weights to their original values (i.e. before training).

We will restart from the same initialization to be sure to not get lucky.

learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn.model.load_state_dict(initial_weights)
<All keys matched successfully>

We can pass the parameters lth=True to make the weights of the network reset to their original value after each pruning step, i.e. step 4) of the LTH. To empirically validate the LTH, we need to retrain the found “lottery ticket” after the pruning phase. Lottery tickets are usually found following an iterative pruning schedule. We set the start_epoch parameter to \(5\) to begin the pruning process after \(5\) epochs.

schedule = Schedule(sched_iterative, start_pct=0.25)
sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, schedule, lth=True)

As our iterative schedule makes \(3\) pruning steps by default, it means that we have to train our network for start_epoch + \(3*t_B\), so \(20\) epochs in order to get our LTH. After each step, the remaining weights will be reinitialized to their original value

learn.fit(20, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.587076 0.864903 0.665765 00:03
1 0.545250 0.605309 0.709743 00:03
2 0.507854 0.609370 0.651556 00:03
3 0.472215 0.516507 0.736807 00:03
4 0.438759 0.501989 0.757104 00:03
5 0.543971 0.591650 0.699594 00:03
6 0.505276 0.576168 0.680650 00:03
7 0.459986 0.495335 0.753045 00:03
8 0.424159 0.707351 0.736807 00:03
9 0.398294 0.436202 0.817997 00:03
10 0.509464 0.726151 0.547361 00:03
11 0.437742 0.879882 0.706360 00:03
12 0.388987 0.436665 0.794993 00:03
13 0.343624 0.385158 0.836942 00:03
14 0.324372 0.526213 0.796346 00:03
15 0.392040 0.495435 0.766576 00:03
16 0.358236 0.450453 0.775372 00:03
17 0.320370 0.423908 0.795670 00:03
18 0.277983 0.469020 0.774019 00:03
19 0.262043 0.533755 0.774019 00:03
Sparsity at the end of epoch 0: [0.0]%
Sparsity at the end of epoch 1: [0.0]%
Sparsity at the end of epoch 2: [0.0]%
Sparsity at the end of epoch 3: [0.0]%
Sparsity at the end of epoch 4: [0.0]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 5: [16.67]%
Sparsity at the end of epoch 6: [16.67]%
Sparsity at the end of epoch 7: [16.67]%
Sparsity at the end of epoch 8: [16.67]%
Sparsity at the end of epoch 9: [16.67]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 10: [33.33]%
Sparsity at the end of epoch 11: [33.33]%
Sparsity at the end of epoch 12: [33.33]%
Sparsity at the end of epoch 13: [33.33]%
Sparsity at the end of epoch 14: [33.33]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 15: [50.0]%
Sparsity at the end of epoch 16: [50.0]%
Sparsity at the end of epoch 17: [50.0]%
Sparsity at the end of epoch 18: [50.0]%
Sparsity at the end of epoch 19: [50.0]%
Final Sparsity: [50.0]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
Layer 1              Conv2d          9,408      4,704         50.00%
Layer 7              Conv2d          36,864     18,432        50.00%
Layer 10             Conv2d          36,864     18,432        50.00%
Layer 13             Conv2d          36,864     18,432        50.00%
Layer 16             Conv2d          36,864     18,432        50.00%
Layer 20             Conv2d          73,728     36,864        50.00%
Layer 23             Conv2d          147,456    73,728        50.00%
Layer 26             Conv2d          8,192      4,096         50.00%
Layer 29             Conv2d          147,456    73,728        50.00%
Layer 32             Conv2d          147,456    73,728        50.00%
Layer 36             Conv2d          294,912    147,456       50.00%
Layer 39             Conv2d          589,824    294,912       50.00%
Layer 42             Conv2d          32,768     16,384        50.00%
Layer 45             Conv2d          589,824    294,912       50.00%
Layer 48             Conv2d          589,824    294,912       50.00%
Layer 52             Conv2d          1,179,648  589,824       50.00%
Layer 55             Conv2d          2,359,296  1,179,648     50.00%
Layer 58             Conv2d          131,072    65,536        50.00%
Layer 61             Conv2d          2,359,296  1,179,648     50.00%
Layer 64             Conv2d          2,359,296  1,179,648     50.00%
--------------------------------------------------------------------------------
Overall              all             11,166,912 5,583,456     50.00%

We indeed have a network B, whose accuracy \(a_B > a_A\) in the same training time.

Lottery Ticket Hypothesis with Rewinding

In some case, LTH fails for deeper networks, author then propose a solution, which is to rewind the weights to a more advanced iteration instead of the initialization value.

learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn.model.load_state_dict(initial_weights)
<All keys matched successfully>

This can be done in fasterai by passing the rewind_epoch parameter, that will save the weights at that epoch, then resetting the weights accordingly.

sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, schedule, lth=True, rewind_epoch=1)
learn.fit(20, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of [50]%
epoch train_loss valid_loss accuracy time
0 0.584980 0.575644 0.702977 00:03
1 0.542651 0.749559 0.667118 00:03
2 0.499945 0.610964 0.661705 00:03
3 0.458487 0.528082 0.746955 00:03
4 0.417654 0.488568 0.797700 00:03
5 0.500780 0.706954 0.545332 00:03
6 0.459245 0.925210 0.429635 00:03
7 0.412464 0.509823 0.731394 00:03
8 0.377349 0.603450 0.762517 00:03
9 0.346884 0.829551 0.717185 00:03
10 0.441250 0.669327 0.713802 00:03
11 0.382227 0.491025 0.779432 00:03
12 0.342169 0.627792 0.749662 00:03
13 0.309651 0.433117 0.792963 00:03
14 0.264981 0.480634 0.800406 00:03
15 0.339205 0.479244 0.776049 00:03
16 0.301984 0.743302 0.671854 00:03
17 0.265180 0.590050 0.793640 00:03
18 0.239833 0.550390 0.742896 00:04
19 0.208132 0.569378 0.775372 00:03
Sparsity at the end of epoch 0: [0.0]%
Saving Weights at epoch 1
Sparsity at the end of epoch 1: [0.0]%
Sparsity at the end of epoch 2: [0.0]%
Sparsity at the end of epoch 3: [0.0]%
Sparsity at the end of epoch 4: [0.0]%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 5: [16.67]%
Sparsity at the end of epoch 6: [16.67]%
Sparsity at the end of epoch 7: [16.67]%
Sparsity at the end of epoch 8: [16.67]%
Sparsity at the end of epoch 9: [16.67]%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 10: [33.33]%
Sparsity at the end of epoch 11: [33.33]%
Sparsity at the end of epoch 12: [33.33]%
Sparsity at the end of epoch 13: [33.33]%
Sparsity at the end of epoch 14: [33.33]%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 15: [50.0]%
Sparsity at the end of epoch 16: [50.0]%
Sparsity at the end of epoch 17: [50.0]%
Sparsity at the end of epoch 18: [50.0]%
Sparsity at the end of epoch 19: [50.0]%
Final Sparsity: [50.0]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
Layer 1              Conv2d          9,408      4,704         50.00%
Layer 7              Conv2d          36,864     18,432        50.00%
Layer 10             Conv2d          36,864     18,432        50.00%
Layer 13             Conv2d          36,864     18,432        50.00%
Layer 16             Conv2d          36,864     18,432        50.00%
Layer 20             Conv2d          73,728     36,864        50.00%
Layer 23             Conv2d          147,456    73,728        50.00%
Layer 26             Conv2d          8,192      4,096         50.00%
Layer 29             Conv2d          147,456    73,728        50.00%
Layer 32             Conv2d          147,456    73,728        50.00%
Layer 36             Conv2d          294,912    147,456       50.00%
Layer 39             Conv2d          589,824    294,912       50.00%
Layer 42             Conv2d          32,768     16,384        50.00%
Layer 45             Conv2d          589,824    294,912       50.00%
Layer 48             Conv2d          589,824    294,912       50.00%
Layer 52             Conv2d          1,179,648  589,824       50.00%
Layer 55             Conv2d          2,359,296  1,179,648     50.00%
Layer 58             Conv2d          131,072    65,536        50.00%
Layer 61             Conv2d          2,359,296  1,179,648     50.00%
Layer 64             Conv2d          2,359,296  1,179,648     50.00%
--------------------------------------------------------------------------------
Overall              all             11,166,912 5,583,456     50.00%

Super-Masks

Researchers from Uber AI investigated the LTH and found the existence of what they call “Super-Masks”, i.e. masks that, applied on a untrained neural network, allows to reach better-than-random results.

learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn.model.load_state_dict(initial_weights)
<All keys matched successfully>

To find supermasks, authors perform the LTH method then apply the mask on the original, untrained network. In fasterai, you can pass the parameter reset_end=True, which will reset the weights to their original value at the end of the training, but keeping the pruned weights (i.e. the mask) unchanged.

sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, schedule, lth=True, reset_end=True)
learn.fit(10, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.596297 0.816344 0.671177 00:04
1 0.548523 0.587115 0.707713 00:04
2 0.571606 0.630654 0.640054 00:04
3 0.537812 0.528453 0.756428 00:04
4 0.513824 0.774966 0.520298 00:04
5 0.538326 0.601889 0.649526 00:04
6 0.481332 0.565981 0.725981 00:04
7 0.482080 0.843460 0.364005 00:03
8 0.442359 0.516618 0.784844 00:03
9 0.401376 0.530230 0.769283 00:04
Sparsity at the end of epoch 0: [0.0]%
Sparsity at the end of epoch 1: [0.0]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 2: [16.67]%
Sparsity at the end of epoch 3: [16.67]%
Sparsity at the end of epoch 4: [16.67]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 5: [33.33]%
Sparsity at the end of epoch 6: [33.33]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 7: [50.0]%
Sparsity at the end of epoch 8: [50.0]%
Sparsity at the end of epoch 9: [50.0]%
Final Sparsity: [50.0]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
Layer 1              Conv2d          9,408      4,704         50.00%
Layer 7              Conv2d          36,864     18,432        50.00%
Layer 10             Conv2d          36,864     18,432        50.00%
Layer 13             Conv2d          36,864     18,432        50.00%
Layer 16             Conv2d          36,864     18,432        50.00%
Layer 20             Conv2d          73,728     36,864        50.00%
Layer 23             Conv2d          147,456    73,728        50.00%
Layer 26             Conv2d          8,192      4,096         50.00%
Layer 29             Conv2d          147,456    73,728        50.00%
Layer 32             Conv2d          147,456    73,728        50.00%
Layer 36             Conv2d          294,912    147,456       50.00%
Layer 39             Conv2d          589,824    294,912       50.00%
Layer 42             Conv2d          32,768     16,384        50.00%
Layer 45             Conv2d          589,824    294,912       50.00%
Layer 48             Conv2d          589,824    294,912       50.00%
Layer 52             Conv2d          1,179,648  589,824       50.00%
Layer 55             Conv2d          2,359,296  1,179,648     50.00%
Layer 58             Conv2d          131,072    65,536        50.00%
Layer 61             Conv2d          2,359,296  1,179,648     50.00%
Layer 64             Conv2d          2,359,296  1,179,648     50.00%
--------------------------------------------------------------------------------
Overall              all             11,166,912 5,583,456     50.00%
learn.model.conv1.weight
l
learn.validate()
(#2) [0.643438994884491,0.6576454639434814]