from fasterai.core.criteria import *
from fasterai.core.schedule import *
from fasterai.regularize.all import *
from fastai.vision.all import *
Regularize Callback
Perform Group Regularization in fastai Callback system
Get your data
= untar_data(URLs.PETS)
path = get_image_files(path/"images")
files
def label_func(f): return f[0].isupper()
= ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64)) dls
Train a model without Regularization as a baseline
= vision_learner(dls, resnet18, metrics=accuracy)
learn
learn.unfreeze()
5) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.683390 | 0.504752 | 0.850474 | 00:03 |
1 | 0.398581 | 0.278983 | 0.891746 | 00:03 |
2 | 0.227765 | 0.227970 | 0.907984 | 00:03 |
3 | 0.126593 | 0.196543 | 0.924899 | 00:03 |
4 | 0.067882 | 0.171512 | 0.940460 | 00:03 |
Create the RegularizeCallback
= RegularizeCallback(squared_final, 'weight', 3e-5, schedule=one_cycle) reg_cb
= vision_learner(dls, resnet18, metrics=accuracy)
learn learn.unfreeze()
5, cbs=reg_cb) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.645172 | 0.562196 | 0.835589 | 00:04 |
1 | 0.436420 | 0.302934 | 0.905954 | 00:04 |
2 | 0.336652 | 0.379853 | 0.900541 | 00:04 |
3 | 0.285935 | 0.322683 | 0.930988 | 00:04 |
4 | 0.225295 | 0.317049 | 0.935724 | 00:04 |