Prune Callback

Use the pruner in fastai Callback system

Let’s try our PruneCallback on the Pets dataset

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

We’ll train a vanilla ResNet18 for 5 epochs to have an idea of the expected performance

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(1)
epoch train_loss valid_loss accuracy time
0 0.612992 0.329872 0.860622 00:02
base_macs, base_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))

Let’s now try adding to remove some filters in our model

We’ll set the sparsity to 50 (i.e. remove 50% of filters), the context to global (i.e. we remove filters from anywhere in the network), the criteria to large_final (i.e. keep the highest value filters and the schedule to one_cycle (i.e. follow the One-Cycle schedule to remove filters along training).

pruner = Pruner(
learn.model,
criteria=large_final,
pruning_ratio=40, 
context='global',
iterative_steps=, 
schedule=one_cycle._scheduler,
)
pr_cb = PruneCallback(pruning_ratio=40, context='global', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(10, cbs=pr_cb)
920
Ignoring output layer: Linear(in_features=512, out_features=2, bias=False)
Total ignored layers: 1
0.00% [0/10 00:00<?]
epoch train_loss valid_loss accuracy time

0.00% [0/24 00:00<?]

KeyboardInterrupt


KeyboardInterrupt
pruned_macs, pruned_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))

We observe that our network has lost less than 1% of accuracy. But how much parameters have we removed and how much compute does that save ?

print(f'The pruned model has {pruned_macs/base_macs:.2f} the compute of original model')
The pruned model has 0.63 the compute of original model
print(f'The pruned model has {pruned_params/base_params:.2f} the parameters of original model')
The pruned model has 0.18 the parameters of original model

So at the price of a slight decrease in accuracy, we now have a model that is 5x smaller and requires 1.5x fewer compute.