from fasterai.distill.all import *
from fasterai.core.schedule import cos
from fastai.vision.all import *Distillation Losses Guide
Overview
fasterai ships 8 distillation losses in two categories:
- Output losses compare final predictions. Simple, always work.
- Intermediate losses compare internal feature maps. More powerful, need layer matching.
Quick version: start with SoftTarget. If you need more, add Attention. For best results, use DecoupledKD.
Setup
# Data
path = untar_data(URLs.PETS)
files = get_image_files(path/'images')
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224))
# Train teacher (ResNet-34)
teacher = vision_learner(dls, resnet34, metrics=accuracy)
teacher.fit_one_cycle(5, 1e-3)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.200494 | 0.011118 | 0.995264 | 00:03 |
| 1 | 0.074000 | 0.005578 | 0.997294 | 00:03 |
| 2 | 0.038066 | 0.005455 | 0.997294 | 00:05 |
| 3 | 0.019297 | 0.003446 | 0.998647 | 00:05 |
| 4 | 0.016898 | 0.003194 | 0.998647 | 00:05 |
Benchmark - ResNet18
bench = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
bench.fit_one_cycle(10, 1e-3)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.595749 | 1.426755 | 0.688769 | 00:05 |
| 1 | 0.564285 | 1.309221 | 0.705007 | 00:04 |
| 2 | 0.531350 | 0.531994 | 0.740189 | 00:04 |
| 3 | 0.481564 | 0.459301 | 0.779432 | 00:05 |
| 4 | 0.425005 | 0.649794 | 0.657645 | 00:04 |
| 5 | 0.363458 | 0.357183 | 0.838295 | 00:04 |
| 6 | 0.300653 | 0.280660 | 0.883627 | 00:04 |
| 7 | 0.217996 | 0.221633 | 0.919486 | 00:04 |
| 8 | 0.175396 | 0.212536 | 0.917456 | 00:04 |
| 9 | 0.143873 | 0.211442 | 0.918809 | 00:04 |
Output Losses
Compare final predictions only — no layer matching needed.
SoftTarget — the classic (Hinton 2015)
Softens distributions with temperature T, matches via KL divergence. Higher T reveals more “dark knowledge.”
When: Default choice for classification. Always start here.
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, SoftTarget, weight=0.5)
student.fit_one_cycle(10, 1e-3, cbs=kd)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 5.166677 | 4.747869 | 0.720568 | 00:06 |
| 1 | 4.856295 | 4.530585 | 0.713802 | 00:08 |
| 2 | 4.549835 | 4.248987 | 0.756428 | 00:07 |
| 3 | 4.254271 | 6.520922 | 0.720568 | 00:07 |
| 4 | 3.742702 | 3.190195 | 0.843031 | 00:07 |
| 5 | 3.186871 | 3.485327 | 0.805142 | 00:07 |
| 6 | 2.652603 | 2.963988 | 0.850474 | 00:07 |
| 7 | 2.135208 | 2.323179 | 0.880244 | 00:07 |
| 8 | 1.654864 | 1.901101 | 0.902571 | 00:07 |
| 9 | 1.457664 | 1.843484 | 0.907984 | 00:08 |
DecoupledKD — best output loss (Zhao, CVPR 2022)
Separates into target-class (TCKD) and non-target-class (NCKD) components. The insight: the most valuable dark knowledge is in the non-target classes (“a dog is more like a cat than a plane”).
With normalize=True, becomes Normalized KD (Yang, ICCV 2023).
When: Best output loss for classification. Worth the switch from SoftTarget.
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, DecoupledKD, weight=0.5)
student.fit_one_cycle(10, 1e-3, cbs=kd)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 3.904424 | 4.550154 | 0.692152 | 00:08 |
| 1 | 3.671331 | 3.912555 | 0.725304 | 00:07 |
| 2 | 3.499841 | 3.510440 | 0.716509 | 00:05 |
| 3 | 3.198684 | 2.982338 | 0.791610 | 00:05 |
| 4 | 2.776469 | 3.369417 | 0.738160 | 00:05 |
| 5 | 2.336425 | 2.103316 | 0.856563 | 00:05 |
| 6 | 1.933325 | 1.867857 | 0.862652 | 00:05 |
| 7 | 1.490620 | 2.113095 | 0.855210 | 00:05 |
| 8 | 1.161858 | 1.303830 | 0.918809 | 00:05 |
| 9 | 0.935724 | 1.287248 | 0.916103 | 00:05 |
Logits and Mutual
| Loss | What | When |
|---|---|---|
Logits |
MSE on raw logits | Regression tasks |
Mutual |
KL without temperature (T=1) | When temperature tuning hurts |
Intermediate Losses
Compare internal feature maps. More powerful but need match_feature_layers to pair layers:
import torch
student_model = resnet18(num_classes=2)
# Auto-match layers by spatial resolution
try:
from fasterai.distill.distillation_callback import match_feature_layers
pairs = match_feature_layers(student_model, teacher.model, torch.randn(1, 3, 224, 224))
except ImportError:
# Fallback: manual layer names (ResNet family)
pairs = {'student': ['layer1', 'layer2', 'layer3', 'layer4'],
'teacher': ['layer1', 'layer2', 'layer3', 'layer4']}
print(f'{"Student":<16} Teacher')
for s, t in zip(pairs['student'], pairs['teacher']):
print(f'{s:<12} <-> {t}')Student Teacher
conv1 <-> 0.0
layer1 <-> 0.4
layer2 <-> 0.5
layer3 <-> 0.6
layer4 <-> 0
Attention — most robust (Zagoruyko 2017)
Transfers where the teacher looks by matching spatial attention maps. Channel-agnostic — works across different architectures.
When: Go-to intermediate loss. Combine with any output loss.
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(
teacher.model, Attention,
pairs['student'], pairs['teacher'],
weight=0.9
)
student.fit_one_cycle(10, 1e-3, cbs=kd)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.069697 | 0.068317 | 0.652909 | 00:05 |
| 1 | 0.062390 | 0.063302 | 0.736807 | 00:05 |
| 2 | 0.054792 | 0.050560 | 0.778755 | 00:05 |
| 3 | 0.047434 | 0.044820 | 0.815968 | 00:05 |
| 4 | 0.039665 | 0.048785 | 0.780108 | 00:06 |
| 5 | 0.034083 | 0.029331 | 0.896482 | 00:06 |
| 6 | 0.025838 | 0.025366 | 0.912720 | 00:06 |
| 7 | 0.019359 | 0.022271 | 0.925575 | 00:06 |
| 8 | 0.014520 | 0.019081 | 0.937754 | 00:07 |
| 9 | 0.011471 | 0.018635 | 0.940460 | 00:08 |
Similarity — architecture-agnostic (Tung & Mori 2019)
Matches sample-sample relationships (Gram matrices) rather than individual features. Computes G = F · Fᵀ for each batch, then matches student and teacher Gram matrices.
Key advantage: only compares relationships between samples — works even when student and teacher have completely different channel counts and architectures (e.g., CNN → MLP).
When: Student and teacher are very different architectures.
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(
teacher.model, Similarity,
pairs['student'], pairs['teacher'],
weight=0.5
)
student.fit_one_cycle(10, 1e-3, cbs=kd)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.310352 | 0.333232 | 0.696211 | 00:07 |
| 1 | 0.286330 | 0.298598 | 0.676590 | 00:08 |
| 2 | 0.262785 | 0.329347 | 0.629229 | 00:08 |
| 3 | 0.238726 | 0.209839 | 0.811908 | 00:07 |
| 4 | 0.212011 | 0.239112 | 0.790934 | 00:08 |
| 5 | 0.176889 | 0.206506 | 0.836265 | 00:08 |
| 6 | 0.141853 | 0.143187 | 0.883627 | 00:08 |
| 7 | 0.102757 | 0.115481 | 0.906631 | 00:08 |
| 8 | 0.077785 | 0.092364 | 0.929635 | 00:07 |
| 9 | 0.062214 | 0.089491 | 0.936401 | 00:08 |
FitNet — direct feature matching (Romero 2015)
MSE between raw feature maps at each paired layer. The simplest intermediate loss.
Caveat: requires matching channel counts between student and teacher at each paired layer. This limits it to same-family architectures (e.g., ResNet-18 ↔︎ ResNet-34). For cross-architecture distillation, use Attention or Similarity instead.
When: Student and teacher are from the same family.
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(
teacher.model, FitNet,
pairs['student'], pairs['teacher'],
weight=0.5
)
student.fit_one_cycle(10, 1e-3, cbs=kd)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 3.888415 | 3.896719 | 0.654939 | 00:08 |
| 1 | 2.619282 | 2.271353 | 0.705683 | 00:07 |
| 2 | 2.044442 | 1.881444 | 0.704330 | 00:05 |
| 3 | 1.780878 | 1.767146 | 0.751015 | 00:05 |
| 4 | 1.633668 | 1.653581 | 0.802436 | 00:05 |
| 5 | 1.499802 | 1.550966 | 0.842355 | 00:05 |
| 6 | 1.399436 | 1.357530 | 0.903248 | 00:05 |
| 7 | 1.320239 | 1.296222 | 0.937754 | 00:05 |
| 8 | 1.267885 | 1.275629 | 0.937754 | 00:05 |
| 9 | 1.229082 | 1.250488 | 0.937077 | 00:05 |
ActivationBoundaries — decision boundaries (Heo 2019)
Focuses on where neurons switch between active and inactive. The student learns to match the teacher’s activation boundaries with a margin m. Penalizes the student when its activations disagree with the teacher’s sign.
When: You care about replicating the teacher’s decision boundaries, not just feature magnitudes.
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(
teacher.model, ActivationBoundaries,
pairs['student'], pairs['teacher'],
weight=0.5
)
student.fit_one_cycle(10, 1e-3, cbs=kd)| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 9.760076 | 9.343024 | 0.689445 | 00:05 |
| 1 | 8.321721 | 7.996875 | 0.693505 | 00:05 |
| 2 | 7.585203 | 7.483111 | 0.736807 | 00:05 |
| 3 | 7.177260 | 7.063045 | 0.819350 | 00:05 |
| 4 | 6.904568 | 6.820758 | 0.856563 | 00:05 |
| 5 | 6.702225 | 6.870900 | 0.815968 | 00:05 |
| 6 | 6.566173 | 6.602059 | 0.860622 | 00:05 |
| 7 | 6.458624 | 6.439599 | 0.947226 | 00:05 |
| 8 | 6.392623 | 6.408597 | 0.949932 | 00:05 |
| 9 | 6.361909 | 6.387673 | 0.956022 | 00:05 |
Results Comparison
ResNet-34 → ResNet-18 on Oxford Pets:
| Loss | Type | Accuracy | vs Baseline | Needs Layers? |
|---|---|---|---|---|
| No distillation | — | 89.5% | — | — |
SoftTarget |
Output | 92.2% | +2.7% | No |
DecoupledKD |
Output | 93.2% | +3.7% | No |
Attention |
Intermediate | 92.8% | +3.3% | Yes |
Similarity |
Intermediate | 92.1% | +2.6% | Yes |
FitNet |
Intermediate | 91.5% | +2.0% | Yes (same channels) |
DecoupledKD is the best single loss. Attention is the best intermediate loss and can be combined with any output loss for further gains.
Quick Decision
Just starting? → SoftTarget
Want best accuracy? → DecoupledKD
Cross-architecture? → Attention (or Similarity for very different)
Regression task? → Logits
See Also
- KD Callback Tutorial — End-to-end distillation workflow
- Losses API — Full API reference
- QAT + Distillation — Combine with quantization