KnowledgeDistillation Callback

How to apply knowledge distillation with fasterai

Overview

Knowledge Distillation transfers knowledge from a large, accurate “teacher” model to a smaller, faster “student” model. The student learns not just from ground truth labels, but also from the teacher’s soft predictions—capturing the teacher’s learned relationships between classes.

Why Use Knowledge Distillation?

Approach Model Size Training Data Accuracy
Train small model alone Small Labels only Lower
Distillation Small Labels + Teacher knowledge Higher

Key Benefits

  • Smaller deployment models - Student can be much smaller than teacher
  • Better than training from scratch - Teacher provides richer supervision
  • No additional labeled data needed - Uses existing training set
  • Flexible loss functions - Soft targets, attention transfer, feature matching

In this tutorial, we’ll distill a ResNet-34 (teacher) into a ResNet-18 (student).

1. Setup and Data

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

2. Train the Teacher Model

First, we train the larger teacher model (ResNet-34) to achieve good accuracy on our dataset:

teacher = vision_learner(dls, resnet34, metrics=accuracy)
teacher.unfreeze()
teacher.fit_one_cycle(10, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.694079 0.833391 0.835589 00:02
1 0.440978 15.979666 0.633965 00:02
2 0.335077 0.274627 0.884980 00:02
3 0.244045 0.368534 0.869418 00:02
4 0.175783 0.329631 0.861976 00:02
5 0.135747 0.273052 0.890392 00:02
6 0.132307 0.195349 0.926252 00:02
7 0.070904 0.154353 0.942490 00:02
8 0.041256 0.151464 0.949256 00:02
9 0.026115 0.155519 0.949256 00:02

3. Baseline: Student Without Distillation

Let’s train a ResNet-18 student model without distillation to establish a baseline:

Training from scratch with only ground truth labels:

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
student.fit_one_cycle(10, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.602561 0.585610 0.683356 00:02
1 0.576955 0.557465 0.714479 00:02
2 0.547720 0.502581 0.743572 00:02
3 0.514453 0.472438 0.771989 00:02
4 0.472709 0.419733 0.807848 00:02
5 0.426662 0.438521 0.779432 00:02
6 0.373338 0.449561 0.793640 00:02
7 0.301837 0.373818 0.833559 00:02
8 0.231140 0.327809 0.866035 00:02
9 0.176535 0.329182 0.860622 00:02

4. Student With Knowledge Distillation

Now let’s train the same architecture with help from the teacher using SoftTarget loss:

The SoftTarget loss combines: - Classification loss (Cross-Entropy with ground truth) - Distillation loss (KL divergence between student and teacher soft predictions)

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, SoftTarget)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 2.634963 2.308056 0.709743 00:02
1 2.549637 2.177948 0.707037 00:02
2 2.393359 2.461530 0.623139 00:02
3 2.162401 2.211867 0.746955 00:02
4 1.914652 1.680310 0.770636 00:02
5 1.637346 1.246407 0.838972 00:02
6 1.354778 1.168832 0.851827 00:02
7 1.039897 1.065769 0.864005 00:02
8 0.818041 0.926388 0.878890 00:02
9 0.662967 0.921144 0.879567 00:02

With teacher guidance, the student achieves better accuracy!

5. Advanced: Attention Transfer

Beyond soft targets, fasterai supports more sophisticated distillation losses like Attention Transfer from “Paying Attention to Attention”. Here, the student learns to replicate the teacher’s attention maps at intermediate layers.

To use intermediate layer losses, specify which layers to match using their string names. Use get_model_layers to discover available layers.

Here we match attention maps after each residual block:

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, Attention, ['layer1', 'layer2', 'layer3', 'layer4'], ['0.4', '0.5', '0.6', '0.7'], weight=0.9)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 0.090912 0.082237 0.702977 00:02
1 0.085223 0.082258 0.669147 00:02
2 0.072878 0.064446 0.777402 00:02
3 0.064771 0.057709 0.813938 00:03
4 0.057675 0.054276 0.814614 00:03
5 0.048829 0.062366 0.777402 00:02
6 0.042154 0.044733 0.870771 00:02
7 0.032746 0.043048 0.875507 00:02
8 0.026087 0.041130 0.885656 00:02
9 0.021022 0.041199 0.884980 00:02

6. Parameter Guide

KnowledgeDistillationCallback Parameters

Parameter Description
teacher The trained teacher model
loss Distillation loss function (SoftTarget, Attention, FitNet, etc.)
student_layers (For intermediate losses) Layers in student to extract features from
teacher_layers (For intermediate losses) Corresponding layers in teacher
weight Weight of distillation loss vs classification loss

Available Loss Functions

Loss Type Description
SoftTarget Output Match teacher’s softened predictions
Attention Intermediate Match attention maps (spatial activation patterns)
FitNet Intermediate Directly match feature maps (requires same dimensions)
RKD Relational Match distance/angle relationships between samples
PKT Probabilistic Match probability distributions in feature space

Summary

Concept Description
Knowledge Distillation Training a small student to mimic a large teacher
KnowledgeDistillationCallback fastai callback for distillation during training
SoftTarget Basic distillation using teacher’s soft predictions
Attention Transfer Advanced distillation using intermediate attention maps
Typical Benefit 1-3% accuracy improvement over training student alone

See Also

  • Distillation Losses - All available distillation loss functions
  • Pruner - Combine distillation with pruning for even smaller models
  • Sparsifier - Add sparsity to distilled models