from fasterai.sparse.all import *
Lottery Ticket Hypothesis
The Lottery Ticket Hypothesis
The Lottery Ticket Hypothesis is a really intriguing discovery made in 2019 by Frankle & Carbin. It states that:
A randomly-initialized, dense neural network contains a subnetwork that is initialised such that — when trained in isolation — it can match the test accuracy of the original network after training for at most the same number of iterations.
Meaning that, once we find that subnetwork. Every other parameter in the network becomes useless.
The way authors propose to find those subnetwork is as follows:
- Initialize the neural network
- Train it to convergence
- Prune the smallest magnitude weights by creating a mask \(m\)
- Reinitialize the weights to their original value; i.e at iteration \(0\).
- Repeat from step 2 until reaching the desired level of sparsity.
= untar_data(URLs.PETS)
path = get_image_files(path/"images")
files
def label_func(f): return f[0].isupper()
= 'cuda:0' if torch.cuda.is_available() else 'cpu'
device
= ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device) dls
What we are trying to prove is that: in a neural network A, there exists a subnetwork B able to get an accuracy \(a_B > a_A\), in a training time \(t_B < t_A\).
Let’s get the baseline for network A:
= Learner(dls, resnet18(num_classes=2), metrics=accuracy) learn
Let’s save original weights
= deepcopy(learn.model.state_dict()) initial_weights
5, 1e-3) learn.fit(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.581151 | 1.603907 | 0.666441 | 00:03 |
1 | 0.547510 | 0.717316 | 0.691475 | 00:03 |
2 | 0.517336 | 0.621597 | 0.628552 | 00:03 |
3 | 0.477595 | 1.084812 | 0.438430 | 00:03 |
4 | 0.446734 | 0.736970 | 0.602842 | 00:03 |
We now have our accuracy \(a_A\) of \(79\%\) and our training time \(t_A\) of \(5\) epochs
To find the lottery ticket, we will perform iterative pruning but, at each pruning step we will re-initialize the remaining weights to their original values (i.e. before training).
We will restart from the same initialization to be sure to not get lucky.
= Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn learn.model.load_state_dict(initial_weights)
<All keys matched successfully>
We can pass the parameters lth=True
to make the weights of the network reset to their original value after each pruning step, i.e. step 4) of the LTH. To empirically validate the LTH, we need to retrain the found “lottery ticket” after the pruning phase. Lottery tickets are usually found following an iterative pruning schedule. We set the start_epoch
parameter to \(5\) to begin the pruning process after \(5\) epochs.
= Schedule(sched_iterative, start_pct=0.25) schedule
= SparsifyCallback(50, 'weight', 'local', large_final, schedule, lth=True) sp_cb
As our iterative
schedule makes \(3\) pruning steps by default, it means that we have to train our network for start_epoch
+ \(3*t_B\), so \(20\) epochs in order to get our LTH. After each step, the remaining weights will be reinitialized to their original value
20, 1e-3, cbs=sp_cb) learn.fit(
Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.587076 | 0.864903 | 0.665765 | 00:03 |
1 | 0.545250 | 0.605309 | 0.709743 | 00:03 |
2 | 0.507854 | 0.609370 | 0.651556 | 00:03 |
3 | 0.472215 | 0.516507 | 0.736807 | 00:03 |
4 | 0.438759 | 0.501989 | 0.757104 | 00:03 |
5 | 0.543971 | 0.591650 | 0.699594 | 00:03 |
6 | 0.505276 | 0.576168 | 0.680650 | 00:03 |
7 | 0.459986 | 0.495335 | 0.753045 | 00:03 |
8 | 0.424159 | 0.707351 | 0.736807 | 00:03 |
9 | 0.398294 | 0.436202 | 0.817997 | 00:03 |
10 | 0.509464 | 0.726151 | 0.547361 | 00:03 |
11 | 0.437742 | 0.879882 | 0.706360 | 00:03 |
12 | 0.388987 | 0.436665 | 0.794993 | 00:03 |
13 | 0.343624 | 0.385158 | 0.836942 | 00:03 |
14 | 0.324372 | 0.526213 | 0.796346 | 00:03 |
15 | 0.392040 | 0.495435 | 0.766576 | 00:03 |
16 | 0.358236 | 0.450453 | 0.775372 | 00:03 |
17 | 0.320370 | 0.423908 | 0.795670 | 00:03 |
18 | 0.277983 | 0.469020 | 0.774019 | 00:03 |
19 | 0.262043 | 0.533755 | 0.774019 | 00:03 |
Sparsity at the end of epoch 0: [0.0]%
Sparsity at the end of epoch 1: [0.0]%
Sparsity at the end of epoch 2: [0.0]%
Sparsity at the end of epoch 3: [0.0]%
Sparsity at the end of epoch 4: [0.0]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 5: [16.67]%
Sparsity at the end of epoch 6: [16.67]%
Sparsity at the end of epoch 7: [16.67]%
Sparsity at the end of epoch 8: [16.67]%
Sparsity at the end of epoch 9: [16.67]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 10: [33.33]%
Sparsity at the end of epoch 11: [33.33]%
Sparsity at the end of epoch 12: [33.33]%
Sparsity at the end of epoch 13: [33.33]%
Sparsity at the end of epoch 14: [33.33]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 15: [50.0]%
Sparsity at the end of epoch 16: [50.0]%
Sparsity at the end of epoch 17: [50.0]%
Sparsity at the end of epoch 18: [50.0]%
Sparsity at the end of epoch 19: [50.0]%
Final Sparsity: [50.0]%
Sparsity Report:
--------------------------------------------------------------------------------
Layer Type Params Zeros Sparsity
--------------------------------------------------------------------------------
Layer 1 Conv2d 9,408 4,704 50.00%
Layer 7 Conv2d 36,864 18,432 50.00%
Layer 10 Conv2d 36,864 18,432 50.00%
Layer 13 Conv2d 36,864 18,432 50.00%
Layer 16 Conv2d 36,864 18,432 50.00%
Layer 20 Conv2d 73,728 36,864 50.00%
Layer 23 Conv2d 147,456 73,728 50.00%
Layer 26 Conv2d 8,192 4,096 50.00%
Layer 29 Conv2d 147,456 73,728 50.00%
Layer 32 Conv2d 147,456 73,728 50.00%
Layer 36 Conv2d 294,912 147,456 50.00%
Layer 39 Conv2d 589,824 294,912 50.00%
Layer 42 Conv2d 32,768 16,384 50.00%
Layer 45 Conv2d 589,824 294,912 50.00%
Layer 48 Conv2d 589,824 294,912 50.00%
Layer 52 Conv2d 1,179,648 589,824 50.00%
Layer 55 Conv2d 2,359,296 1,179,648 50.00%
Layer 58 Conv2d 131,072 65,536 50.00%
Layer 61 Conv2d 2,359,296 1,179,648 50.00%
Layer 64 Conv2d 2,359,296 1,179,648 50.00%
--------------------------------------------------------------------------------
Overall all 11,166,912 5,583,456 50.00%
We indeed have a network B, whose accuracy \(a_B > a_A\) in the same training time.
Lottery Ticket Hypothesis with Rewinding
In some case, LTH fails for deeper networks, author then propose a solution, which is to rewind the weights to a more advanced iteration instead of the initialization value.
= Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn learn.model.load_state_dict(initial_weights)
<All keys matched successfully>
This can be done in fasterai by passing the rewind_epoch
parameter, that will save the weights at that epoch, then resetting the weights accordingly.
= SparsifyCallback(50, 'weight', 'local', large_final, schedule, lth=True, rewind_epoch=1) sp_cb
20, 1e-3, cbs=sp_cb) learn.fit(
Pruning of weight until a sparsity of [50]%
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.584980 | 0.575644 | 0.702977 | 00:03 |
1 | 0.542651 | 0.749559 | 0.667118 | 00:03 |
2 | 0.499945 | 0.610964 | 0.661705 | 00:03 |
3 | 0.458487 | 0.528082 | 0.746955 | 00:03 |
4 | 0.417654 | 0.488568 | 0.797700 | 00:03 |
5 | 0.500780 | 0.706954 | 0.545332 | 00:03 |
6 | 0.459245 | 0.925210 | 0.429635 | 00:03 |
7 | 0.412464 | 0.509823 | 0.731394 | 00:03 |
8 | 0.377349 | 0.603450 | 0.762517 | 00:03 |
9 | 0.346884 | 0.829551 | 0.717185 | 00:03 |
10 | 0.441250 | 0.669327 | 0.713802 | 00:03 |
11 | 0.382227 | 0.491025 | 0.779432 | 00:03 |
12 | 0.342169 | 0.627792 | 0.749662 | 00:03 |
13 | 0.309651 | 0.433117 | 0.792963 | 00:03 |
14 | 0.264981 | 0.480634 | 0.800406 | 00:03 |
15 | 0.339205 | 0.479244 | 0.776049 | 00:03 |
16 | 0.301984 | 0.743302 | 0.671854 | 00:03 |
17 | 0.265180 | 0.590050 | 0.793640 | 00:03 |
18 | 0.239833 | 0.550390 | 0.742896 | 00:04 |
19 | 0.208132 | 0.569378 | 0.775372 | 00:03 |
Sparsity at the end of epoch 0: [0.0]%
Saving Weights at epoch 1
Sparsity at the end of epoch 1: [0.0]%
Sparsity at the end of epoch 2: [0.0]%
Sparsity at the end of epoch 3: [0.0]%
Sparsity at the end of epoch 4: [0.0]%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 5: [16.67]%
Sparsity at the end of epoch 6: [16.67]%
Sparsity at the end of epoch 7: [16.67]%
Sparsity at the end of epoch 8: [16.67]%
Sparsity at the end of epoch 9: [16.67]%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 10: [33.33]%
Sparsity at the end of epoch 11: [33.33]%
Sparsity at the end of epoch 12: [33.33]%
Sparsity at the end of epoch 13: [33.33]%
Sparsity at the end of epoch 14: [33.33]%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 15: [50.0]%
Sparsity at the end of epoch 16: [50.0]%
Sparsity at the end of epoch 17: [50.0]%
Sparsity at the end of epoch 18: [50.0]%
Sparsity at the end of epoch 19: [50.0]%
Final Sparsity: [50.0]%
Sparsity Report:
--------------------------------------------------------------------------------
Layer Type Params Zeros Sparsity
--------------------------------------------------------------------------------
Layer 1 Conv2d 9,408 4,704 50.00%
Layer 7 Conv2d 36,864 18,432 50.00%
Layer 10 Conv2d 36,864 18,432 50.00%
Layer 13 Conv2d 36,864 18,432 50.00%
Layer 16 Conv2d 36,864 18,432 50.00%
Layer 20 Conv2d 73,728 36,864 50.00%
Layer 23 Conv2d 147,456 73,728 50.00%
Layer 26 Conv2d 8,192 4,096 50.00%
Layer 29 Conv2d 147,456 73,728 50.00%
Layer 32 Conv2d 147,456 73,728 50.00%
Layer 36 Conv2d 294,912 147,456 50.00%
Layer 39 Conv2d 589,824 294,912 50.00%
Layer 42 Conv2d 32,768 16,384 50.00%
Layer 45 Conv2d 589,824 294,912 50.00%
Layer 48 Conv2d 589,824 294,912 50.00%
Layer 52 Conv2d 1,179,648 589,824 50.00%
Layer 55 Conv2d 2,359,296 1,179,648 50.00%
Layer 58 Conv2d 131,072 65,536 50.00%
Layer 61 Conv2d 2,359,296 1,179,648 50.00%
Layer 64 Conv2d 2,359,296 1,179,648 50.00%
--------------------------------------------------------------------------------
Overall all 11,166,912 5,583,456 50.00%
Super-Masks
Researchers from Uber AI investigated the LTH and found the existence of what they call “Super-Masks”, i.e. masks that, applied on a untrained neural network, allows to reach better-than-random results.
= Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn learn.model.load_state_dict(initial_weights)
<All keys matched successfully>
To find supermasks, authors perform the LTH method then apply the mask on the original, untrained network. In fasterai, you can pass the parameter reset_end=True
, which will reset the weights to their original value at the end of the training, but keeping the pruned weights (i.e. the mask) unchanged.
= SparsifyCallback(50, 'weight', 'local', large_final, schedule, lth=True, reset_end=True) sp_cb
10, 1e-3, cbs=sp_cb) learn.fit(
Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.596297 | 0.816344 | 0.671177 | 00:04 |
1 | 0.548523 | 0.587115 | 0.707713 | 00:04 |
2 | 0.571606 | 0.630654 | 0.640054 | 00:04 |
3 | 0.537812 | 0.528453 | 0.756428 | 00:04 |
4 | 0.513824 | 0.774966 | 0.520298 | 00:04 |
5 | 0.538326 | 0.601889 | 0.649526 | 00:04 |
6 | 0.481332 | 0.565981 | 0.725981 | 00:04 |
7 | 0.482080 | 0.843460 | 0.364005 | 00:03 |
8 | 0.442359 | 0.516618 | 0.784844 | 00:03 |
9 | 0.401376 | 0.530230 | 0.769283 | 00:04 |
Sparsity at the end of epoch 0: [0.0]%
Sparsity at the end of epoch 1: [0.0]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 2: [16.67]%
Sparsity at the end of epoch 3: [16.67]%
Sparsity at the end of epoch 4: [16.67]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 5: [33.33]%
Sparsity at the end of epoch 6: [33.33]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 7: [50.0]%
Sparsity at the end of epoch 8: [50.0]%
Sparsity at the end of epoch 9: [50.0]%
Final Sparsity: [50.0]%
Sparsity Report:
--------------------------------------------------------------------------------
Layer Type Params Zeros Sparsity
--------------------------------------------------------------------------------
Layer 1 Conv2d 9,408 4,704 50.00%
Layer 7 Conv2d 36,864 18,432 50.00%
Layer 10 Conv2d 36,864 18,432 50.00%
Layer 13 Conv2d 36,864 18,432 50.00%
Layer 16 Conv2d 36,864 18,432 50.00%
Layer 20 Conv2d 73,728 36,864 50.00%
Layer 23 Conv2d 147,456 73,728 50.00%
Layer 26 Conv2d 8,192 4,096 50.00%
Layer 29 Conv2d 147,456 73,728 50.00%
Layer 32 Conv2d 147,456 73,728 50.00%
Layer 36 Conv2d 294,912 147,456 50.00%
Layer 39 Conv2d 589,824 294,912 50.00%
Layer 42 Conv2d 32,768 16,384 50.00%
Layer 45 Conv2d 589,824 294,912 50.00%
Layer 48 Conv2d 589,824 294,912 50.00%
Layer 52 Conv2d 1,179,648 589,824 50.00%
Layer 55 Conv2d 2,359,296 1,179,648 50.00%
Layer 58 Conv2d 131,072 65,536 50.00%
Layer 61 Conv2d 2,359,296 1,179,648 50.00%
Layer 64 Conv2d 2,359,296 1,179,648 50.00%
--------------------------------------------------------------------------------
Overall all 11,166,912 5,583,456 50.00%
learn.model.conv1.weight
l
learn.validate()
(#2) [0.643438994884491,0.6576454639434814]