path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))KnowledgeDistillation Callback
Overview
Knowledge Distillation transfers knowledge from a large, accurate “teacher” model to a smaller, faster “student” model. The student learns not just from ground truth labels, but also from the teacher’s soft predictions—capturing the teacher’s learned relationships between classes.
Why Use Knowledge Distillation?
| Approach | Model Size | Training Data | Accuracy |
|---|---|---|---|
| Train small model alone | Small | Labels only | Lower |
| Distillation | Small | Labels + Teacher knowledge | Higher |
Key Benefits
- Smaller deployment models - Student can be much smaller than teacher
- Better than training from scratch - Teacher provides richer supervision
- No additional labeled data needed - Uses existing training set
- Flexible loss functions - Soft targets, attention transfer, feature matching
In this tutorial, we’ll distill a ResNet-34 (teacher) into a ResNet-18 (student).
1. Setup and Data
2. Train the Teacher Model
First, we train the larger teacher model (ResNet-34) to achieve good accuracy on our dataset:
teacher = vision_learner(dls, resnet34, metrics=accuracy)
teacher.unfreeze()
teacher.fit_one_cycle(10, 1e-3)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.694079 | 0.833391 | 0.835589 | 00:02 |
| 1 | 0.440978 | 15.979666 | 0.633965 | 00:02 |
| 2 | 0.335077 | 0.274627 | 0.884980 | 00:02 |
| 3 | 0.244045 | 0.368534 | 0.869418 | 00:02 |
| 4 | 0.175783 | 0.329631 | 0.861976 | 00:02 |
| 5 | 0.135747 | 0.273052 | 0.890392 | 00:02 |
| 6 | 0.132307 | 0.195349 | 0.926252 | 00:02 |
| 7 | 0.070904 | 0.154353 | 0.942490 | 00:02 |
| 8 | 0.041256 | 0.151464 | 0.949256 | 00:02 |
| 9 | 0.026115 | 0.155519 | 0.949256 | 00:02 |
3. Baseline: Student Without Distillation
Let’s train a ResNet-18 student model without distillation to establish a baseline:
Training from scratch with only ground truth labels:
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
student.fit_one_cycle(10, 1e-3)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.602561 | 0.585610 | 0.683356 | 00:02 |
| 1 | 0.576955 | 0.557465 | 0.714479 | 00:02 |
| 2 | 0.547720 | 0.502581 | 0.743572 | 00:02 |
| 3 | 0.514453 | 0.472438 | 0.771989 | 00:02 |
| 4 | 0.472709 | 0.419733 | 0.807848 | 00:02 |
| 5 | 0.426662 | 0.438521 | 0.779432 | 00:02 |
| 6 | 0.373338 | 0.449561 | 0.793640 | 00:02 |
| 7 | 0.301837 | 0.373818 | 0.833559 | 00:02 |
| 8 | 0.231140 | 0.327809 | 0.866035 | 00:02 |
| 9 | 0.176535 | 0.329182 | 0.860622 | 00:02 |
4. Student With Knowledge Distillation
Now let’s train the same architecture with help from the teacher using SoftTarget loss:
The SoftTarget loss combines: - Classification loss (Cross-Entropy with ground truth) - Distillation loss (KL divergence between student and teacher soft predictions)
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, SoftTarget)
student.fit_one_cycle(10, 1e-3, cbs=kd)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 2.634963 | 2.308056 | 0.709743 | 00:02 |
| 1 | 2.549637 | 2.177948 | 0.707037 | 00:02 |
| 2 | 2.393359 | 2.461530 | 0.623139 | 00:02 |
| 3 | 2.162401 | 2.211867 | 0.746955 | 00:02 |
| 4 | 1.914652 | 1.680310 | 0.770636 | 00:02 |
| 5 | 1.637346 | 1.246407 | 0.838972 | 00:02 |
| 6 | 1.354778 | 1.168832 | 0.851827 | 00:02 |
| 7 | 1.039897 | 1.065769 | 0.864005 | 00:02 |
| 8 | 0.818041 | 0.926388 | 0.878890 | 00:02 |
| 9 | 0.662967 | 0.921144 | 0.879567 | 00:02 |
With teacher guidance, the student achieves better accuracy!
5. Advanced: Attention Transfer
Beyond soft targets, fasterai supports more sophisticated distillation losses like Attention Transfer from “Paying Attention to Attention”. Here, the student learns to replicate the teacher’s attention maps at intermediate layers.
To use intermediate layer losses, specify which layers to match using their string names. Use get_model_layers to discover available layers.
Here we match attention maps after each residual block:
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, Attention, ['layer1', 'layer2', 'layer3', 'layer4'], ['0.4', '0.5', '0.6', '0.7'], weight=0.9)
student.fit_one_cycle(10, 1e-3, cbs=kd)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.090912 | 0.082237 | 0.702977 | 00:02 |
| 1 | 0.085223 | 0.082258 | 0.669147 | 00:02 |
| 2 | 0.072878 | 0.064446 | 0.777402 | 00:02 |
| 3 | 0.064771 | 0.057709 | 0.813938 | 00:03 |
| 4 | 0.057675 | 0.054276 | 0.814614 | 00:03 |
| 5 | 0.048829 | 0.062366 | 0.777402 | 00:02 |
| 6 | 0.042154 | 0.044733 | 0.870771 | 00:02 |
| 7 | 0.032746 | 0.043048 | 0.875507 | 00:02 |
| 8 | 0.026087 | 0.041130 | 0.885656 | 00:02 |
| 9 | 0.021022 | 0.041199 | 0.884980 | 00:02 |
6. Parameter Guide
KnowledgeDistillationCallback Parameters
| Parameter | Description |
|---|---|
teacher |
The trained teacher model |
loss |
Distillation loss function (SoftTarget, Attention, FitNet, etc.) |
student_layers |
(For intermediate losses) Layers in student to extract features from |
teacher_layers |
(For intermediate losses) Corresponding layers in teacher |
weight |
Weight of distillation loss vs classification loss |
Available Loss Functions
| Loss | Type | Description |
|---|---|---|
SoftTarget |
Output | Match teacher’s softened predictions |
Attention |
Intermediate | Match attention maps (spatial activation patterns) |
FitNet |
Intermediate | Directly match feature maps (requires same dimensions) |
RKD |
Relational | Match distance/angle relationships between samples |
PKT |
Probabilistic | Match probability distributions in feature space |
Summary
| Concept | Description |
|---|---|
| Knowledge Distillation | Training a small student to mimic a large teacher |
| KnowledgeDistillationCallback | fastai callback for distillation during training |
| SoftTarget | Basic distillation using teacher’s soft predictions |
| Attention Transfer | Advanced distillation using intermediate attention maps |
| Typical Benefit | 1-3% accuracy improvement over training student alone |
See Also
- Distillation Losses - All available distillation loss functions
- Pruner - Combine distillation with pruning for even smaller models
- Sparsifier - Add sparsity to distilled models