KnowledgeDistillation Callback

How to apply knowledge distillation with fasterai

Overview

Knowledge Distillation transfers knowledge from a large, accurate “teacher” model to a smaller, faster “student” model. The student learns not just from ground truth labels, but also from the teacher’s soft predictions—capturing the teacher’s learned relationships between classes.

Why Use Knowledge Distillation?

Approach Model Size Training Data Accuracy
Train small model alone Small Labels only Lower
Distillation Small Labels + Teacher knowledge Higher

Key Benefits

  • Smaller deployment models - Student can be much smaller than teacher
  • Better than training from scratch - Teacher provides richer supervision
  • No additional labeled data needed - Uses existing training set
  • Flexible loss functions - Soft targets, attention transfer, feature matching

In this tutorial, we’ll distill a ResNet-34 (teacher) into a ResNet-18 (student).

1. Setup and Data

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

2. Train the Teacher Model

First, we train the larger teacher model (ResNet-34) to achieve good accuracy on our dataset:

teacher = vision_learner(dls, resnet34, metrics=accuracy)
teacher.unfreeze()
teacher.fit_one_cycle(10, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.693071 0.702811 0.826793 00:04
1 0.431739 1.089537 0.816644 00:03
2 0.466526 0.410398 0.823410 00:03
3 0.517755 0.434329 0.790934 00:04
4 0.573171 0.647482 0.763193 00:04
5 0.429925 0.355466 0.832882 00:03
6 0.304983 0.311687 0.857916 00:03
7 0.236466 0.258148 0.887686 00:03
8 0.176511 0.247277 0.891069 00:03
9 0.139511 0.241842 0.903248 00:04

3. Baseline: Student Without Distillation

Let’s train a ResNet-18 student model without distillation to establish a baseline:

Training from scratch with only ground truth labels:

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
student.fit_one_cycle(10, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.631313 0.593756 0.686062 00:03
1 0.564619 0.690962 0.677943 00:03
2 0.540129 0.512200 0.731394 00:03
3 0.500723 0.532254 0.736807 00:03
4 0.463844 0.638763 0.728011 00:03
5 0.415164 0.430835 0.806495 00:03
6 0.358220 0.398086 0.818674 00:03
7 0.282297 0.369941 0.841678 00:03
8 0.211329 0.379604 0.838295 00:03
9 0.165122 0.372304 0.850474 00:02

4. Student With Knowledge Distillation

Now let’s train the same architecture with help from the teacher using SoftTarget loss:

The SoftTarget loss combines: - Classification loss (Cross-Entropy with ground truth) - Distillation loss (KL divergence between student and teacher soft predictions)

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, SoftTarget, schedule=cos)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 0.607547 0.760900 0.557510 00:04
1 0.623009 0.695733 0.633288 00:04
2 0.638243 0.723522 0.673207 00:04
3 0.641051 0.616777 0.774696 00:03
4 0.635822 0.602102 0.800406 00:04
5 0.570080 0.609991 0.820027 00:04
6 0.478322 0.569599 0.825440 00:04
7 0.388707 0.471958 0.854533 00:04
8 0.317506 0.449543 0.859946 00:04
9 0.278542 0.444559 0.858593 00:04

With teacher guidance, the student achieves better accuracy!

5. Advanced: Attention Transfer

Beyond soft targets, fasterai supports more sophisticated distillation losses like Attention Transfer from “Paying Attention to Attention”. Here, the student learns to replicate the teacher’s attention maps at intermediate layers.

Use match_feature_layers to automatically find compatible layer pairs between student and teacher — no manual naming needed:

The function matches layers by spatial resolution — it runs one forward pass through each model, groups layers by output (H, W), and picks the best match at each resolution:

student_model = resnet18(num_classes=2)

# Auto-match layers by spatial resolution
pairs = match_feature_layers(student_model, teacher.model, torch.randn(1, 3, 64, 64))
print(pairs)
# {'student': ['conv1', 'layer1', 'layer2', 'layer3', 'layer4'],
#  'teacher': ['conv1', 'layer1', 'layer2', 'layer3', 'layer4']}

student = Learner(dls, student_model, metrics=accuracy)
kd = KnowledgeDistillationCallback(
    teacher.model, Attention,
    pairs['student'], pairs['teacher'],
    weight=0.9
)
student.fit_one_cycle(10, 1e-3, cbs=kd)
{'student': ['conv1', 'layer1', 'layer2', 'layer3', 'layer4'], 'teacher': ['0.0', '0.4', '0.5', '0.6', '0']}
epoch train_loss valid_loss accuracy time
0 0.094315 0.090826 0.682679 00:03
1 0.082324 0.135551 0.663058 00:04
2 0.070702 0.098358 0.700947 00:04
3 0.061594 0.065866 0.762517 00:04
4 0.054387 0.068833 0.782138 00:04
5 0.046002 0.084695 0.777402 00:04
6 0.039880 0.057967 0.815968 00:04
7 0.032698 0.046938 0.857916 00:04
8 0.025579 0.046442 0.861976 00:04
9 0.022248 0.046658 0.865359 00:04

6. Parameter Guide

KnowledgeDistillationCallback Parameters

Parameter Description
teacher The trained teacher model
loss Distillation loss function (SoftTarget, Attention, FitNet, etc.)
activations_student (For intermediate losses) Layers in student to extract features from
activations_teacher (For intermediate losses) Corresponding layers in teacher
weight Weight of distillation loss vs classification loss
schedule Optional schedule for weight progression (e.g. cos)

Available Loss Functions

Loss Type Description
SoftTarget Output Match teacher’s softened predictions
Attention Intermediate Match attention maps (spatial activation patterns)
FitNet Intermediate Directly match feature maps (requires same dimensions)
RKD Relational Match distance/angle relationships between samples
PKT Probabilistic Match probability distributions in feature space

Summary

Tool / Function Description
KnowledgeDistillationCallback fastai callback for distillation during training
match_feature_layers(student, teacher, x) Auto-match layers by spatial resolution
SoftTarget Basic distillation using soft predictions
Attention Attention transfer using intermediate layers

See Also

  • Distillation Losses - All available distillation loss functions
  • Pruner - Combine distillation with pruning for even smaller models
  • Sparsifier - Add sparsity to distilled models