path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))Prune Callback
Let’s try our PruneCallback on the Pets dataset
We’ll train a vanilla ResNet18 for 5 epochs to have an idea of the expected performance
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(1)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.612992 | 0.329872 | 0.860622 | 00:02 |
base_macs, base_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))Let’s now try adding to remove some filters in our model
We’ll set the sparsity to 50 (i.e. remove 50% of filters), the context to global (i.e. we remove filters from anywhere in the network), the criteria to large_final (i.e. keep the highest value filters and the schedule to one_cycle (i.e. follow the One-Cycle schedule to remove filters along training).
pruner = Pruner(
learn.model,
criteria=large_final,
pruning_ratio=40,
context='global',
iterative_steps=,
schedule=one_cycle._scheduler,
)pr_cb = PruneCallback(pruning_ratio=40, context='global', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(10, cbs=pr_cb)920
Ignoring output layer: Linear(in_features=512, out_features=2, bias=False)
Total ignored layers: 1
| epoch | train_loss | valid_loss | accuracy | time |
|---|
KeyboardInterrupt
KeyboardInterrupt
pruned_macs, pruned_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))We observe that our network has lost less than 1% of accuracy. But how much parameters have we removed and how much compute does that save ?
print(f'The pruned model has {pruned_macs/base_macs:.2f} the compute of original model')The pruned model has 0.63 the compute of original model
print(f'The pruned model has {pruned_params/base_params:.2f} the parameters of original model')The pruned model has 0.18 the parameters of original model
So at the price of a slight decrease in accuracy, we now have a model that is 5x smaller and requires 1.5x fewer compute.