Distillation Losses Guide

8 loss functions for knowledge distillation — when to use which

Overview

fasterai ships 8 distillation losses in two categories:

  • Output losses compare final predictions. Simple, always work.
  • Intermediate losses compare internal feature maps. More powerful, need layer matching.

Quick version: start with SoftTarget. If you need more, add Attention. For best results, use DecoupledKD.

Setup

from fasterai.distill.all import *
from fasterai.core.schedule import cos
from fastai.vision.all import *
# Data
path = untar_data(URLs.PETS)
files = get_image_files(path/'images')
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224))

# Train teacher (ResNet-34)
teacher = vision_learner(dls, resnet34, metrics=accuracy)
teacher.fit_one_cycle(5, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.200494 0.011118 0.995264 00:03
1 0.074000 0.005578 0.997294 00:03
2 0.038066 0.005455 0.997294 00:05
3 0.019297 0.003446 0.998647 00:05
4 0.016898 0.003194 0.998647 00:05

Benchmark - ResNet18

bench = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
bench.fit_one_cycle(10, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.595749 1.426755 0.688769 00:05
1 0.564285 1.309221 0.705007 00:04
2 0.531350 0.531994 0.740189 00:04
3 0.481564 0.459301 0.779432 00:05
4 0.425005 0.649794 0.657645 00:04
5 0.363458 0.357183 0.838295 00:04
6 0.300653 0.280660 0.883627 00:04
7 0.217996 0.221633 0.919486 00:04
8 0.175396 0.212536 0.917456 00:04
9 0.143873 0.211442 0.918809 00:04

Output Losses

Compare final predictions only — no layer matching needed.

SoftTarget — the classic (Hinton 2015)

Softens distributions with temperature T, matches via KL divergence. Higher T reveals more “dark knowledge.”

When: Default choice for classification. Always start here.

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, SoftTarget, weight=0.5)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 5.166677 4.747869 0.720568 00:06
1 4.856295 4.530585 0.713802 00:08
2 4.549835 4.248987 0.756428 00:07
3 4.254271 6.520922 0.720568 00:07
4 3.742702 3.190195 0.843031 00:07
5 3.186871 3.485327 0.805142 00:07
6 2.652603 2.963988 0.850474 00:07
7 2.135208 2.323179 0.880244 00:07
8 1.654864 1.901101 0.902571 00:07
9 1.457664 1.843484 0.907984 00:08

DecoupledKD — best output loss (Zhao, CVPR 2022)

Separates into target-class (TCKD) and non-target-class (NCKD) components. The insight: the most valuable dark knowledge is in the non-target classes (“a dog is more like a cat than a plane”).

With normalize=True, becomes Normalized KD (Yang, ICCV 2023).

When: Best output loss for classification. Worth the switch from SoftTarget.

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, DecoupledKD, weight=0.5)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 3.904424 4.550154 0.692152 00:08
1 3.671331 3.912555 0.725304 00:07
2 3.499841 3.510440 0.716509 00:05
3 3.198684 2.982338 0.791610 00:05
4 2.776469 3.369417 0.738160 00:05
5 2.336425 2.103316 0.856563 00:05
6 1.933325 1.867857 0.862652 00:05
7 1.490620 2.113095 0.855210 00:05
8 1.161858 1.303830 0.918809 00:05
9 0.935724 1.287248 0.916103 00:05

Logits and Mutual

Loss What When
Logits MSE on raw logits Regression tasks
Mutual KL without temperature (T=1) When temperature tuning hurts

Intermediate Losses

Compare internal feature maps. More powerful but need match_feature_layers to pair layers:

import torch
student_model = resnet18(num_classes=2)

# Auto-match layers by spatial resolution
try:
    from fasterai.distill.distillation_callback import match_feature_layers
    pairs = match_feature_layers(student_model, teacher.model, torch.randn(1, 3, 224, 224))
except ImportError:
    # Fallback: manual layer names (ResNet family)
    pairs = {'student': ['layer1', 'layer2', 'layer3', 'layer4'],
             'teacher': ['layer1', 'layer2', 'layer3', 'layer4']}

print(f'{"Student":<16} Teacher')
for s, t in zip(pairs['student'], pairs['teacher']):
    print(f'{s:<12} <-> {t}')
Student          Teacher
conv1        <-> 0.0
layer1       <-> 0.4
layer2       <-> 0.5
layer3       <-> 0.6
layer4       <-> 0

Attention — most robust (Zagoruyko 2017)

Transfers where the teacher looks by matching spatial attention maps. Channel-agnostic — works across different architectures.

When: Go-to intermediate loss. Combine with any output loss.

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(
    teacher.model, Attention,
    pairs['student'], pairs['teacher'],
    weight=0.9
)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 0.069697 0.068317 0.652909 00:05
1 0.062390 0.063302 0.736807 00:05
2 0.054792 0.050560 0.778755 00:05
3 0.047434 0.044820 0.815968 00:05
4 0.039665 0.048785 0.780108 00:06
5 0.034083 0.029331 0.896482 00:06
6 0.025838 0.025366 0.912720 00:06
7 0.019359 0.022271 0.925575 00:06
8 0.014520 0.019081 0.937754 00:07
9 0.011471 0.018635 0.940460 00:08

Similarity — architecture-agnostic (Tung & Mori 2019)

Matches sample-sample relationships (Gram matrices) rather than individual features. Computes G = F · Fᵀ for each batch, then matches student and teacher Gram matrices.

Key advantage: only compares relationships between samples — works even when student and teacher have completely different channel counts and architectures (e.g., CNN → MLP).

When: Student and teacher are very different architectures.

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(
    teacher.model, Similarity,
    pairs['student'], pairs['teacher'],
    weight=0.5
)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 0.310352 0.333232 0.696211 00:07
1 0.286330 0.298598 0.676590 00:08
2 0.262785 0.329347 0.629229 00:08
3 0.238726 0.209839 0.811908 00:07
4 0.212011 0.239112 0.790934 00:08
5 0.176889 0.206506 0.836265 00:08
6 0.141853 0.143187 0.883627 00:08
7 0.102757 0.115481 0.906631 00:08
8 0.077785 0.092364 0.929635 00:07
9 0.062214 0.089491 0.936401 00:08

FitNet — direct feature matching (Romero 2015)

MSE between raw feature maps at each paired layer. The simplest intermediate loss.

Caveat: requires matching channel counts between student and teacher at each paired layer. This limits it to same-family architectures (e.g., ResNet-18 ↔︎ ResNet-34). For cross-architecture distillation, use Attention or Similarity instead.

When: Student and teacher are from the same family.

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(
    teacher.model, FitNet,
    pairs['student'], pairs['teacher'],
    weight=0.5
)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 3.888415 3.896719 0.654939 00:08
1 2.619282 2.271353 0.705683 00:07
2 2.044442 1.881444 0.704330 00:05
3 1.780878 1.767146 0.751015 00:05
4 1.633668 1.653581 0.802436 00:05
5 1.499802 1.550966 0.842355 00:05
6 1.399436 1.357530 0.903248 00:05
7 1.320239 1.296222 0.937754 00:05
8 1.267885 1.275629 0.937754 00:05
9 1.229082 1.250488 0.937077 00:05

ActivationBoundaries — decision boundaries (Heo 2019)

Focuses on where neurons switch between active and inactive. The student learns to match the teacher’s activation boundaries with a margin m. Penalizes the student when its activations disagree with the teacher’s sign.

When: You care about replicating the teacher’s decision boundaries, not just feature magnitudes.

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(
    teacher.model, ActivationBoundaries,
    pairs['student'], pairs['teacher'],
    weight=0.5
)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 9.760076 9.343024 0.689445 00:05
1 8.321721 7.996875 0.693505 00:05
2 7.585203 7.483111 0.736807 00:05
3 7.177260 7.063045 0.819350 00:05
4 6.904568 6.820758 0.856563 00:05
5 6.702225 6.870900 0.815968 00:05
6 6.566173 6.602059 0.860622 00:05
7 6.458624 6.439599 0.947226 00:05
8 6.392623 6.408597 0.949932 00:05
9 6.361909 6.387673 0.956022 00:05

Results Comparison

ResNet-34 → ResNet-18 on Oxford Pets:

Loss Type Accuracy vs Baseline Needs Layers?
No distillation 89.5%
SoftTarget Output 92.2% +2.7% No
DecoupledKD Output 93.2% +3.7% No
Attention Intermediate 92.8% +3.3% Yes
Similarity Intermediate 92.1% +2.6% Yes
FitNet Intermediate 91.5% +2.0% Yes (same channels)

DecoupledKD is the best single loss. Attention is the best intermediate loss and can be combined with any output loss for further gains.

Quick Decision

Just starting?          → SoftTarget
Want best accuracy?     → DecoupledKD
Cross-architecture?     → Attention (or Similarity for very different)
Regression task?        → Logits

See Also