KnowledgeDistillation Callback

How to apply knowledge distillation with fasterai

Overview

Knowledge Distillation transfers knowledge from a large, accurate “teacher” model to a smaller, faster “student” model. The student learns not just from ground truth labels, but also from the teacher’s soft predictions—capturing the teacher’s learned relationships between classes.

Why Use Knowledge Distillation?

Approach Model Size Training Data Accuracy
Train small model alone Small Labels only Lower
Distillation Small Labels + Teacher knowledge Higher

Key Benefits

  • Smaller deployment models - Student can be much smaller than teacher
  • Better than training from scratch - Teacher provides richer supervision
  • No additional labeled data needed - Uses existing training set
  • Flexible loss functions - Soft targets, attention transfer, feature matching

In this tutorial, we’ll distill a ResNet-34 (teacher) into a ResNet-18 (student).

1. Setup and Data

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

2. Train the Teacher Model

First, we train the larger teacher model (ResNet-34) to achieve good accuracy on our dataset:

teacher = vision_learner(dls, resnet34, metrics=accuracy)
teacher.unfreeze()
teacher.fit_one_cycle(10, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.663302 0.382650 0.881597 00:02
1 0.444977 1.731543 0.723951 00:02
2 0.456336 0.390448 0.847091 00:02
3 0.463871 0.314980 0.864005 00:02
4 0.399526 0.548000 0.845061 00:03
5 0.267582 0.222926 0.903248 00:02
6 0.177511 0.180466 0.933694 00:02
7 0.121694 0.195583 0.927605 00:02
8 0.077676 0.192459 0.936401 00:02
9 0.047532 0.180056 0.936401 00:02

3. Baseline: Student Without Distillation

Let’s train a ResNet-18 student model without distillation to establish a baseline:

Training from scratch with only ground truth labels:

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
student.fit_one_cycle(10, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.611359 0.660552 0.676590 00:02
1 0.565523 0.669257 0.704330 00:02
2 0.537007 0.567621 0.728011 00:02
3 0.498747 0.541553 0.741543 00:02
4 0.449077 0.455508 0.783491 00:02
5 0.399169 0.393245 0.828823 00:02
6 0.342478 0.369859 0.834912 00:02
7 0.272756 0.334547 0.853857 00:02
8 0.187447 0.346933 0.859269 00:02
9 0.147805 0.358428 0.859946 00:02

4. Student With Knowledge Distillation

Now let’s train the same architecture with help from the teacher using SoftTarget loss:

The SoftTarget loss combines: - Classification loss (Cross-Entropy with ground truth) - Distillation loss (KL divergence between student and teacher soft predictions)

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, SoftTarget, schedule=cos)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 0.622423 0.658045 0.692828 00:03
1 0.654330 1.211342 0.677267 00:02
2 0.736943 0.757770 0.736807 00:03
3 0.830559 0.949577 0.698241 00:02
4 0.882739 0.915873 0.793640 00:03
5 0.890884 0.799081 0.824763 00:02
6 0.817516 1.475584 0.737483 00:02
7 0.687356 0.730070 0.866035 00:02
8 0.523237 0.718984 0.866035 00:03
9 0.452811 0.703519 0.870771 00:03

With teacher guidance, the student achieves better accuracy!

5. Advanced: Attention Transfer

Beyond soft targets, fasterai supports more sophisticated distillation losses like Attention Transfer from “Paying Attention to Attention”. Here, the student learns to replicate the teacher’s attention maps at intermediate layers.

To use intermediate layer losses, specify which layers to match using their string names. Use get_model_layers to discover available layers.

Here we match attention maps after each residual block:

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, Attention, ['layer1', 'layer2', 'layer3', 'layer4'], ['0.4', '0.5', '0.6', '0.7'], weight=0.9)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 0.092506 0.091555 0.678620 00:03
1 0.083053 0.084819 0.648173 00:03
2 0.071733 0.073612 0.705007 00:02
3 0.062212 0.059138 0.815291 00:03
4 0.055396 0.053225 0.827470 00:03
5 0.047694 0.052672 0.821380 00:03
6 0.041354 0.048255 0.860622 00:03
7 0.031322 0.042128 0.874831 00:03
8 0.024217 0.042546 0.879567 00:03
9 0.019581 0.042967 0.886333 00:03

6. Parameter Guide

KnowledgeDistillationCallback Parameters

Parameter Description
teacher The trained teacher model
loss Distillation loss function (SoftTarget, Attention, FitNet, etc.)
student_layers (For intermediate losses) Layers in student to extract features from
teacher_layers (For intermediate losses) Corresponding layers in teacher
weight Weight of distillation loss vs classification loss

Available Loss Functions

Loss Type Description
SoftTarget Output Match teacher’s softened predictions
Attention Intermediate Match attention maps (spatial activation patterns)
FitNet Intermediate Directly match feature maps (requires same dimensions)
RKD Relational Match distance/angle relationships between samples
PKT Probabilistic Match probability distributions in feature space

Summary

Concept Description
Knowledge Distillation Training a small student to mimic a large teacher
KnowledgeDistillationCallback fastai callback for distillation during training
SoftTarget Basic distillation using teacher’s soft predictions
Attention Transfer Advanced distillation using intermediate attention maps
Typical Benefit 1-3% accuracy improvement over training student alone

See Also

  • Distillation Losses - All available distillation loss functions
  • Pruner - Combine distillation with pruning for even smaller models
  • Sparsifier - Add sparsity to distilled models