from fastai.vision.all import *
from fasterai.quantize.quantize_callback import QuantizeCallback
from fasterai.distill.distillation_callback import KnowledgeDistillationCallback
from fasterai.distill.losses import DecoupledKD, SoftTarget
from copy import deepcopy
# Data
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))QAT + Knowledge Distillation
Overview
Quantization-Aware Training (QAT) simulates INT8 quantization during training, letting the model learn to be robust to reduced precision. Adding Knowledge Distillation on top provides a teacher’s soft labels to guide the student through the quantization noise — recovering 80-95% of accuracy loss.
Why combine them?
| Approach | Accuracy Recovery | Training Cost |
|---|---|---|
| Post-training quantization (PTQ) | 50-80% | None |
| QAT alone | 70-90% | Full training |
| QAT + Distillation | 80-95% | Full training + teacher forward pass |
How it works
Both are fastai callbacks — just pass them together:
learn.fit(epochs, cbs=[QuantizeCallback(...), KnowledgeDistillationCallback(...)])
The QAT callback handles fake quantization nodes. The distillation callback blends the teacher’s soft labels into the loss. They don’t interfere with each other.
# Step 1: Train a teacher model
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit(3)
# Save the teacher (full precision, best accuracy)
teacher = deepcopy(learn.model).eval()Approach 1: QAT Only (Baseline)
First, let’s see what QAT alone gives us:
# QAT only — no distillation
learn_qat = vision_learner(dls, resnet18, metrics=accuracy)
learn_qat.unfreeze()
learn_qat.fit(3, cbs=[
QuantizeCallback(backend='x86'),
])Approach 2: QAT + Knowledge Distillation
Now add the teacher — the distillation loss helps the student navigate the quantization noise:
# QAT + Knowledge Distillation — teacher guides the quantized student
learn_qat_kd = vision_learner(dls, resnet18, metrics=accuracy)
learn_qat_kd.unfreeze()
learn_qat_kd.fit(3, cbs=[
QuantizeCallback(backend='x86'),
KnowledgeDistillationCallback(teacher=teacher, loss=DecoupledKD, weight=0.5),
])Choosing a Loss Function
Different distillation losses work well with QAT:
| Loss | Type | Notes |
|---|---|---|
| SoftTarget | Output | Simple, good default for QAT+KD |
| DecoupledKD | Output | Better dark knowledge preservation |
DecoupledKD normalize=True |
Output | NKD mode — best for many-class problems |
# SoftTarget — simplest option
KnowledgeDistillationCallback(teacher=teacher, loss=SoftTarget, weight=0.5)
# DecoupledKD with NKD — best accuracy
from functools import partial
KnowledgeDistillationCallback(teacher=teacher, loss=partial(DecoupledKD, normalize=True), weight=0.5)The weight parameter (default 0.5) controls the balance between the standard loss and the distillation loss. Higher values lean more on the teacher.
Tips
- Use the same architecture for teacher and student for best results — the teacher is the full-precision version of the same model
- Lower learning rate — QAT benefits from smaller LR since the fake quantization adds noise. Try
1e-4to1e-3 - More epochs — QAT typically needs more epochs than standard training to converge through the quantization noise
- Teacher stays frozen — the teacher is always in
eval()mode, no gradients flow through it - Convert after training — after QAT + KD training, convert to real INT8:
Quantizer(backend='x86', method='qat').quantize(learn.model, dls.valid)
Summary
| Component | Role |
|---|---|
QuantizeCallback(backend) |
Inserts fake quantization nodes during training |
KnowledgeDistillationCallback(teacher, loss) |
Blends teacher soft labels into training loss |
| Both together | QAT accuracy + distillation recovery — best INT8 results |
The key insight: both are standard fastai callbacks — just pass them together in cbs=[...]. No special wiring needed.
See Also
- Quantization Tutorial - PTQ and torchao quantization
- QAT Callback - QAT details
- Distillation Tutorial - Distillation losses and setup