Regularize Callback

Perform Group Regularization in fastai Callback system
from fasterai.core.criteria import *
from fasterai.core.schedule import *
from fasterai.regularize.all import *
from fastai.vision.all import *

Overview

Group Regularization is a technique that encourages structured sparsity in neural networks during training. Unlike standard L2 regularization (weight decay) which penalizes individual weights, group regularization penalizes groups of weights together—such as entire filters, kernels, or channels.

Why Use Group Regularization?

When preparing a model for structured pruning, you want entire structures (filters, channels) to become unimportant, not just individual weights. Group regularization pushes these structures toward zero during training, making subsequent pruning:

  1. More effective - Pruned structures are already near-zero, minimizing accuracy loss
  2. Cleaner - Clear separation between important and unimportant structures
  3. Hardware-friendly - Structured sparsity maps well to GPU/CPU acceleration

The RegularizeCallback

The RegularizeCallback adds a regularization term to the loss function:

\[\mathcal{L}_{total} = \mathcal{L}_{task} + \lambda \sum_{g \in \text{groups}} \|W_g\|_p\]

Where \(\lambda\) is the regularization weight and \(W_g\) are weight groups at your chosen granularity.

1. Setup and Data

Let’s start by loading a dataset and establishing a baseline model without regularization.

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

2. Baseline Training (No Regularization)

First, we train a model without any regularization to establish a baseline accuracy.

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

learn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 0.695834 0.946244 0.740866 00:03
1 0.383204 0.314028 0.847091 00:03
2 0.222566 0.240774 0.899865 00:04
3 0.123860 0.220786 0.914073 00:03
4 0.070046 0.203204 0.920162 00:03

3. Training with Group Regularization

Now let’s train with RegularizeCallback. We’ll configure it with:

  • criteria=squared_final: Uses squared weight magnitudes for regularization
  • granularity='weight': Regularizes at individual weight level (try 'filter' for structured pruning prep)
  • weight=3e-5: Regularization strength (higher = more aggressive)
  • schedule=one_cycle: Varies regularization strength during training
reg_cb = RegularizeCallback(squared_final, 'weight', 1e-3, schedule=one_cycle)
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(5, cbs=reg_cb)
epoch train_loss valid_loss accuracy time
0 0.776885 1.002960 0.771989 00:03
1 1.337281 2.333170 0.866712 00:03
2 3.507138 4.474680 0.876861 00:03
3 4.233951 4.420142 0.905277 00:03
4 4.203538 4.318278 0.926928 00:03

4. Comparing Results

After training, you should observe: - Similar or slightly lower accuracy (regularization adds a constraint) - Weights that are more concentrated around zero - Cleaner weight distribution for subsequent pruning

To visualize the effect, you can plot weight histograms:

import matplotlib.pyplot as plt

# Get all conv weights
weights = torch.cat([m.weight.data.flatten() for m in learn.model.modules() 
                     if isinstance(m, nn.Conv2d)])

plt.hist(weights.cpu().numpy(), bins=100, alpha=0.7)
plt.xlabel('Weight Value')
plt.ylabel('Count')
plt.title('Weight Distribution After Group Regularization')
plt.show()

5. Parameter Guide

Choosing Granularity

Granularity Effect Best For
'weight' Regularizes individual weights Unstructured pruning, general sparsity
'filter' Regularizes entire Conv2d filters Structured pruning (recommended)
'kernel' Regularizes 2D kernels within filters Moderate structure
'channel' Regularizes input channels Channel pruning

Choosing Regularization Weight

Weight Range Effect
1e-6 - 1e-5 Very light regularization, minimal accuracy impact
1e-5 - 1e-4 Moderate regularization, good balance
1e-4 - 1e-3 Strong regularization, may reduce accuracy
> 1e-3 Very aggressive, use with caution

Tip: Start with 1e-5 and increase if weights don’t concentrate toward zero.

Summary

Concept Description
Group Regularization Penalizes groups of weights to encourage structured sparsity
RegularizeCallback fastai callback that adds regularization term to loss
Granularity Level at which to group weights ('weight', 'filter', 'kernel')
Schedule Varies regularization strength during training
Typical Use Pre-pruning preparation to make structured pruning more effective

See Also

  • Criteria - Importance measures used for regularization
  • Schedules - Control regularization strength over training
  • Pruner - Apply structured pruning after regularization
  • Sparsifier - Apply unstructured sparsification