Distillation Losses

Knowledge distillation loss functions

Overview

This module provides loss functions for knowledge distillation. These losses enable training a smaller “student” network to mimic a larger “teacher” network.

Loss Categories: - Output-based: SoftTarget, Logits, Mutual - compare final predictions - Feature-based: Attention, FitNet, Similarity, ActivationBoundaries - compare intermediate representations

Output-Based Losses

These losses compare the final output predictions between student and teacher networks.

Found permutation search CUDA kernels [ASP][Info] permutation_search_kernels can be imported. —

source

SoftTarget


def SoftTarget(
    pred:torch.Tensor, # Student predictions
    teacher_pred:torch.Tensor, # Teacher predictions
    T:float=5, # Temperature for softening
    kwargs:VAR_KEYWORD
)->torch.Tensor:

Knowledge distillation with softened distributions (Hinton et al.)


source

Logits


def Logits(
    pred:torch.Tensor, # Student predictions
    teacher_pred:torch.Tensor, # Teacher predictions
    kwargs:VAR_KEYWORD
)->torch.Tensor:

Direct logit matching between student and teacher


source

Mutual


def Mutual(
    pred:torch.Tensor, # Student predictions
    teacher_pred:torch.Tensor, # Teacher predictions
    kwargs:VAR_KEYWORD
)->torch.Tensor:

KL divergence between student and teacher


Feature-Based Losses

These losses compare intermediate feature representations, enabling the student to learn internal representations similar to the teacher.


source

Attention


def Attention(
    fm_s:dict[str, torch.Tensor], # Student feature maps {name: tensor}
    fm_t:dict[str, torch.Tensor], # Teacher feature maps {name: tensor}
    p:int=2, # Power for attention computation
    kwargs:VAR_KEYWORD
)->torch.Tensor:

Attention transfer loss (Zagoruyko & Komodakis)


source

ActivationBoundaries


def ActivationBoundaries(
    fm_s:dict[str, torch.Tensor], # Student feature maps
    fm_t:dict[str, torch.Tensor], # Teacher feature maps
    m:float=2, # Boundary margin
    kwargs:VAR_KEYWORD
)->torch.Tensor:

Boundary-based knowledge distillation (Heo et al.)


source

FitNet


def FitNet(
    fm_s:dict[str, torch.Tensor], # Student feature maps
    fm_t:dict[str, torch.Tensor], # Teacher feature maps
    kwargs:VAR_KEYWORD
)->torch.Tensor:

FitNets: direct feature map matching (Romero et al.)


source

Similarity


def Similarity(
    fm_s:dict[str, torch.Tensor], # Student feature maps
    fm_t:dict[str, torch.Tensor], # Teacher feature maps
    pred:torch.Tensor, # Student predictions (unused, for API consistency)
    p:int=2, # Normalization power
    kwargs:VAR_KEYWORD
)->torch.Tensor:

Similarity-preserving knowledge distillation (Tung & Mori)


See Also

Loss Selection Guide

Loss Best For Complexity
SoftTarget General distillation, logit matching Low
Attention When attention patterns matter Low
FitNet Intermediate feature matching Medium
PKT Probability distribution matching Medium
RKD Relational knowledge transfer High