Distillation Losses
Overview
This module provides loss functions for knowledge distillation. These losses enable training a smaller “student” network to mimic a larger “teacher” network.
Loss Categories: - Output-based: SoftTarget, Logits, Mutual - compare final predictions - Feature-based: Attention, FitNet, Similarity, ActivationBoundaries - compare intermediate representations
Output-Based Losses
These losses compare the final output predictions between student and teacher networks.
Found permutation search CUDA kernels [ASP][Info] permutation_search_kernels can be imported. —
SoftTarget
def SoftTarget(
pred:torch.Tensor, # Student predictions
teacher_pred:torch.Tensor, # Teacher predictions
T:float=5, # Temperature for softening
kwargs:VAR_KEYWORD
)->torch.Tensor:
Knowledge distillation with softened distributions (Hinton et al.)
Logits
def Logits(
pred:torch.Tensor, # Student predictions
teacher_pred:torch.Tensor, # Teacher predictions
kwargs:VAR_KEYWORD
)->torch.Tensor:
Direct logit matching between student and teacher
Mutual
def Mutual(
pred:torch.Tensor, # Student predictions
teacher_pred:torch.Tensor, # Teacher predictions
kwargs:VAR_KEYWORD
)->torch.Tensor:
KL divergence between student and teacher
Feature-Based Losses
These losses compare intermediate feature representations, enabling the student to learn internal representations similar to the teacher.
Attention
def Attention(
fm_s:dict[str, torch.Tensor], # Student feature maps {name: tensor}
fm_t:dict[str, torch.Tensor], # Teacher feature maps {name: tensor}
p:int=2, # Power for attention computation
kwargs:VAR_KEYWORD
)->torch.Tensor:
Attention transfer loss (Zagoruyko & Komodakis)
ActivationBoundaries
def ActivationBoundaries(
fm_s:dict[str, torch.Tensor], # Student feature maps
fm_t:dict[str, torch.Tensor], # Teacher feature maps
m:float=2, # Boundary margin
kwargs:VAR_KEYWORD
)->torch.Tensor:
Boundary-based knowledge distillation (Heo et al.)
FitNet
def FitNet(
fm_s:dict[str, torch.Tensor], # Student feature maps
fm_t:dict[str, torch.Tensor], # Teacher feature maps
kwargs:VAR_KEYWORD
)->torch.Tensor:
FitNets: direct feature map matching (Romero et al.)
Similarity
def Similarity(
fm_s:dict[str, torch.Tensor], # Student feature maps
fm_t:dict[str, torch.Tensor], # Teacher feature maps
pred:torch.Tensor, # Student predictions (unused, for API consistency)
p:int=2, # Normalization power
kwargs:VAR_KEYWORD
)->torch.Tensor:
Similarity-preserving knowledge distillation (Tung & Mori)
See Also
- KnowledgeDistillationCallback - Apply these losses during training
- Distillation Tutorial - Practical examples with different losses
Loss Selection Guide
| Loss | Best For | Complexity |
|---|---|---|
| SoftTarget | General distillation, logit matching | Low |
| Attention | When attention patterns matter | Low |
| FitNet | Intermediate feature matching | Medium |
| PKT | Probability distribution matching | Medium |
| RKD | Relational knowledge transfer | High |