QAT + Knowledge Distillation

Combine quantization-aware training with knowledge distillation for best INT8 accuracy

Overview

Quantization-Aware Training (QAT) simulates INT8 quantization during training, letting the model learn to be robust to reduced precision. Adding Knowledge Distillation on top provides a teacher’s soft labels to guide the student through the quantization noise — recovering 80-95% of accuracy loss.

Why combine them?

Approach Accuracy Recovery Training Cost
Post-training quantization (PTQ) 50-80% None
QAT alone 70-90% Full training
QAT + Distillation 80-95% Full training + teacher forward pass

How it works

Both are fastai callbacks — just pass them together:

learn.fit(epochs, cbs=[QuantizeCallback(...), KnowledgeDistillationCallback(...)])

The QAT callback handles fake quantization nodes. The distillation callback blends the teacher’s soft labels into the loss. They don’t interfere with each other.

from fastai.vision.all import *
from fasterai.quantize.quantize_callback import QuantizeCallback
from fasterai.distill.distillation_callback import KnowledgeDistillationCallback
from fasterai.distill.losses import DecoupledKD, SoftTarget
from copy import deepcopy

# Data
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))
# Step 1: Train a teacher model
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit(3)

# Save the teacher (full precision, best accuracy)
teacher = deepcopy(learn.model).eval()

Approach 1: QAT Only (Baseline)

First, let’s see what QAT alone gives us:

# QAT only — no distillation
learn_qat = vision_learner(dls, resnet18, metrics=accuracy)
learn_qat.unfreeze()
learn_qat.fit(3, cbs=[
    QuantizeCallback(backend='x86'),
])

Approach 2: QAT + Knowledge Distillation

Now add the teacher — the distillation loss helps the student navigate the quantization noise:

# QAT + Knowledge Distillation — teacher guides the quantized student
learn_qat_kd = vision_learner(dls, resnet18, metrics=accuracy)
learn_qat_kd.unfreeze()
learn_qat_kd.fit(3, cbs=[
    QuantizeCallback(backend='x86'),
    KnowledgeDistillationCallback(teacher=teacher, loss=DecoupledKD, weight=0.5),
])

Choosing a Loss Function

Different distillation losses work well with QAT:

Loss Type Notes
SoftTarget Output Simple, good default for QAT+KD
DecoupledKD Output Better dark knowledge preservation
DecoupledKD normalize=True Output NKD mode — best for many-class problems
# SoftTarget — simplest option
KnowledgeDistillationCallback(teacher=teacher, loss=SoftTarget, weight=0.5)

# DecoupledKD with NKD — best accuracy
from functools import partial
KnowledgeDistillationCallback(teacher=teacher, loss=partial(DecoupledKD, normalize=True), weight=0.5)

The weight parameter (default 0.5) controls the balance between the standard loss and the distillation loss. Higher values lean more on the teacher.

Tips

  • Use the same architecture for teacher and student for best results — the teacher is the full-precision version of the same model
  • Lower learning rate — QAT benefits from smaller LR since the fake quantization adds noise. Try 1e-4 to 1e-3
  • More epochs — QAT typically needs more epochs than standard training to converge through the quantization noise
  • Teacher stays frozen — the teacher is always in eval() mode, no gradients flow through it
  • Convert after training — after QAT + KD training, convert to real INT8: Quantizer(backend='x86', method='qat').quantize(learn.model, dls.valid)

Summary

Component Role
QuantizeCallback(backend) Inserts fake quantization nodes during training
KnowledgeDistillationCallback(teacher, loss) Blends teacher soft labels into training loss
Both together QAT accuracy + distillation recovery — best INT8 results

The key insight: both are standard fastai callbacks — just pass them together in cbs=[...]. No special wiring needed.


See Also