from fasterai.core.criteria import *
from fasterai.core.schedule import *
from fasterai.regularize.all import *
from fastai.vision.all import *Regularize Callback
Perform Group Regularization in fastai Callback system
Get your data
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))Train a model without Regularization as a baseline
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(5)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.683390 | 0.504752 | 0.850474 | 00:03 |
| 1 | 0.398581 | 0.278983 | 0.891746 | 00:03 |
| 2 | 0.227765 | 0.227970 | 0.907984 | 00:03 |
| 3 | 0.126593 | 0.196543 | 0.924899 | 00:03 |
| 4 | 0.067882 | 0.171512 | 0.940460 | 00:03 |
Create the RegularizeCallback
reg_cb = RegularizeCallback(squared_final, 'weight', 3e-5, schedule=one_cycle)learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()learn.fit_one_cycle(5, cbs=reg_cb)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.645172 | 0.562196 | 0.835589 | 00:04 |
| 1 | 0.436420 | 0.302934 | 0.905954 | 00:04 |
| 2 | 0.336652 | 0.379853 | 0.900541 | 00:04 |
| 3 | 0.285935 | 0.322683 | 0.930988 | 00:04 |
| 4 | 0.225295 | 0.317049 | 0.935724 | 00:04 |