from fasterai.core.criteria import *
from fasterai.core.schedule import *
from fasterai.regularize.all import *
from fastai.vision.all import *Regularize Callback
Overview
Group Regularization is a technique that encourages structured sparsity in neural networks during training. Unlike standard L2 regularization (weight decay) which penalizes individual weights, group regularization penalizes groups of weights together—such as entire filters, kernels, or channels.
Why Use Group Regularization?
When preparing a model for structured pruning, you want entire structures (filters, channels) to become unimportant, not just individual weights. Group regularization pushes these structures toward zero during training, making subsequent pruning:
- More effective - Pruned structures are already near-zero, minimizing accuracy loss
- Cleaner - Clear separation between important and unimportant structures
- Hardware-friendly - Structured sparsity maps well to GPU/CPU acceleration
The RegularizeCallback
The RegularizeCallback adds a regularization term to the loss function:
\[\mathcal{L}_{total} = \mathcal{L}_{task} + \lambda \sum_{g \in \text{groups}} \|W_g\|_p\]
Where \(\lambda\) is the regularization weight and \(W_g\) are weight groups at your chosen granularity.
1. Setup and Data
Let’s start by loading a dataset and establishing a baseline model without regularization.
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))2. Baseline Training (No Regularization)
First, we train a model without any regularization to establish a baseline accuracy.
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(5)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.695834 | 0.946244 | 0.740866 | 00:03 |
| 1 | 0.383204 | 0.314028 | 0.847091 | 00:03 |
| 2 | 0.222566 | 0.240774 | 0.899865 | 00:04 |
| 3 | 0.123860 | 0.220786 | 0.914073 | 00:03 |
| 4 | 0.070046 | 0.203204 | 0.920162 | 00:03 |
3. Training with Group Regularization
Now let’s train with RegularizeCallback. We’ll configure it with:
criteria=squared_final: Uses squared weight magnitudes for regularizationgranularity='weight': Regularizes at individual weight level (try'filter'for structured pruning prep)weight=3e-5: Regularization strength (higher = more aggressive)schedule=one_cycle: Varies regularization strength during training
reg_cb = RegularizeCallback(squared_final, 'weight', 1e-3, schedule=one_cycle)learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()learn.fit_one_cycle(5, cbs=reg_cb)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.776885 | 1.002960 | 0.771989 | 00:03 |
| 1 | 1.337281 | 2.333170 | 0.866712 | 00:03 |
| 2 | 3.507138 | 4.474680 | 0.876861 | 00:03 |
| 3 | 4.233951 | 4.420142 | 0.905277 | 00:03 |
| 4 | 4.203538 | 4.318278 | 0.926928 | 00:03 |
4. Comparing Results
After training, you should observe: - Similar or slightly lower accuracy (regularization adds a constraint) - Weights that are more concentrated around zero - Cleaner weight distribution for subsequent pruning
To visualize the effect, you can plot weight histograms:
import matplotlib.pyplot as plt
# Get all conv weights
weights = torch.cat([m.weight.data.flatten() for m in learn.model.modules()
if isinstance(m, nn.Conv2d)])
plt.hist(weights.cpu().numpy(), bins=100, alpha=0.7)
plt.xlabel('Weight Value')
plt.ylabel('Count')
plt.title('Weight Distribution After Group Regularization')
plt.show()5. Parameter Guide
Choosing Granularity
| Granularity | Effect | Best For |
|---|---|---|
'weight' |
Regularizes individual weights | Unstructured pruning, general sparsity |
'filter' |
Regularizes entire Conv2d filters | Structured pruning (recommended) |
'kernel' |
Regularizes 2D kernels within filters | Moderate structure |
'channel' |
Regularizes input channels | Channel pruning |
Choosing Regularization Weight
| Weight Range | Effect |
|---|---|
1e-6 - 1e-5 |
Very light regularization, minimal accuracy impact |
1e-5 - 1e-4 |
Moderate regularization, good balance |
1e-4 - 1e-3 |
Strong regularization, may reduce accuracy |
> 1e-3 |
Very aggressive, use with caution |
Tip: Start with 1e-5 and increase if weights don’t concentrate toward zero.
Recommended Workflow
# 1. Train with filter-level regularization
reg_cb = RegularizeCallback(
criteria=large_final, # or squared_final
granularity='filter', # for structured pruning
weight=1e-4,
schedule=one_cycle,
verbose=True
)
learn.fit(epochs, cbs=[reg_cb])
# 2. Prune the regularized model
from fasterai.prune.all import *
pruner = Pruner(learn.model, sparsity=0.3, context='local', criteria=large_final)
pruner.prune_model()
# 3. Fine-tune
learn.fit(fine_tune_epochs)Summary
| Concept | Description |
|---|---|
| Group Regularization | Penalizes groups of weights to encourage structured sparsity |
| RegularizeCallback | fastai callback that adds regularization term to loss |
| Granularity | Level at which to group weights ('weight', 'filter', 'kernel') |
| Schedule | Varies regularization strength during training |
| Typical Use | Pre-pruning preparation to make structured pruning more effective |
See Also
- Criteria - Importance measures used for regularization
- Schedules - Control regularization strength over training
- Pruner - Apply structured pruning after regularization
- Sparsifier - Apply unstructured sparsification