= untar_data(URLs.PETS)
path = get_image_files(path/"images")
files
def label_func(f): return f[0].isupper()
= ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64)) dls
Prune Callback
Let’s try our PruneCallback
on the Pets
dataset
We’ll train a vanilla ResNet18 for 5 epochs to have an idea of the expected performance
= vision_learner(dls, resnet18, metrics=accuracy)
learn
learn.unfreeze()1) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.612992 | 0.329872 | 0.860622 | 00:02 |
= tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device())) base_macs, base_params
Let’s now try adding to remove some filters in our model
We’ll set the sparsity
to 50 (i.e. remove 50% of filters), the context
to global (i.e. we remove filters from anywhere in the network), the criteria
to large_final (i.e. keep the highest value filters and the schedule
to one_cycle (i.e. follow the One-Cycle schedule to remove filters along training).
= Pruner(
pruner
learn.model,=large_final,
criteria=40,
pruning_ratio='global',
context=,
iterative_steps=one_cycle._scheduler,
schedule )
= PruneCallback(pruning_ratio=40, context='global', criteria=large_final, schedule=one_cycle)
pr_cb 10, cbs=pr_cb) learn.fit_one_cycle(
920
Ignoring output layer: Linear(in_features=512, out_features=2, bias=False)
Total ignored layers: 1
epoch | train_loss | valid_loss | accuracy | time |
---|
KeyboardInterrupt
KeyboardInterrupt
= tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device())) pruned_macs, pruned_params
We observe that our network has lost less than 1% of accuracy. But how much parameters have we removed and how much compute does that save ?
print(f'The pruned model has {pruned_macs/base_macs:.2f} the compute of original model')
The pruned model has 0.63 the compute of original model
print(f'The pruned model has {pruned_params/base_params:.2f} the parameters of original model')
The pruned model has 0.18 the parameters of original model
So at the price of a slight decrease in accuracy, we now have a model that is 5x smaller and requires 1.5x fewer compute.