Quantize Callback

Quantization-Aware Callback

Overview

Quantization-Aware Training (QAT) simulates low-precision inference during training, allowing the model to adapt to quantization effects. This produces more accurate quantized models than post-training quantization alone.

Why Use QAT?

Approach Accuracy Speed When to Use
Post-Training Quantization Lower Fast (no training) Quick deployment, accuracy tolerant
Quantization-Aware Training Higher Slower (requires training) Production models, accuracy critical

Key Benefits

  • 4x smaller models - FP32 → INT8 reduces model size by ~75%
  • Faster inference - Integer operations are faster on most hardware
  • Maintained accuracy - QAT minimizes accuracy degradation
  • Hardware compatibility - INT8 is widely supported (CPU, mobile, edge)

1. Setup and Data

First, let’s load a dataset and create a model. We’ll use a pretrained ResNet-34 from timm.

2. Training with QuantizeCallback

The QuantizeCallback automatically: 1. Prepares the model for QAT by inserting fake quantization modules 2. Trains with simulated INT8 precision 3. Converts the model to actual INT8 after training

Simply add the callback to your training loop:

pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
learn = Learner(dls, pretrained_resnet_34, metrics=accuracy)
learn.model.fc = nn.Linear(512, 2)
learn.fit_one_cycle(3, 1e-3, cbs=QuantizeCallback())
/home/nathan/miniconda3/envs/dev/lib/python3.12/site-packages/torch/ao/quantization/observer.py:246: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
epoch train_loss valid_loss accuracy time
0 0.497659 0.410839 0.805819 00:03
1 0.306967 0.280245 0.870771 00:03
2 0.204825 0.260572 0.887010 00:03

3. Evaluating the Quantized Model

Let’s compare the original and quantized models in terms of size and accuracy.

from tqdm import tqdm

def get_model_size(model):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p") / 1e6  # Size in MB
    os.remove("temp.p")
    return size
    
def compute_validation_accuracy(model, valid_dataloader, device=None):
    # Set the model to evaluation mode
    model.eval()
    
    # Use the model's device if no device is specified
    
    device = torch.device('cpu')
    
    # Move model to the specified device
    model = model.to(device)
    
    # Tracking correct predictions and total samples
    total_correct = 0
    total_samples = 0
    
    # Disable gradient computation for efficiency
    with torch.no_grad():
        for batch in tqdm(valid_dataloader):
            # Assuming batch is a tuple of (inputs, labels)
            # Adjust this if your dataloader returns a different format
            inputs, labels = batch
            
            # Move inputs and labels to the same device as the model
            inputs = torch.Tensor(inputs).to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            
            # Get predictions (for classification tasks)
            # Use argmax along the class dimension
            _, predicted = torch.max(outputs, 1)
            
            # Update counters
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()
    
    # Compute accuracy as a percentage
    accuracy = (total_correct / total_samples) * 100
    
    return accuracy
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
learn_original = Learner(dls, pretrained_resnet_34, metrics=accuracy)
learn_original.model.fc = nn.Linear(512, 2)

Size Comparison

Create an original (non-quantized) model for comparison:

print(f'Size of the original model: {get_model_size(learn_original.model):.2f} MB')
print(f'Size of the quantized model: {get_model_size(learn.model):.2f} MB')
Size of the original model: 85.27 MB
Size of the quantized model: 21.51 MB

Accuracy Verification

Despite the 4x size reduction, the quantized model maintains good accuracy:

compute_validation_accuracy(learn.model, dls.valid)
100%|████████████████████████████████████████████████████| 24/24 [00:06<00:00,  3.99it/s]
88.70094722598105

4. Parameter Guide

QuantizeCallback Parameters

Parameter Default Description
backend 'x86' Quantization backend: 'x86' (Intel/AMD), 'qnnpack' (ARM/mobile), 'fbgemm' (server)
qconfig None Custom quantization config. If None, uses backend default

Backend Selection Guide

Backend Best For Hardware
'x86' Desktop/server CPUs Intel, AMD processors
'qnnpack' Mobile deployment ARM processors, Android, iOS
'fbgemm' Server inference Facebook’s optimized backend

Tips for Best Results

  1. Train longer - QAT benefits from more epochs to adapt to quantization noise
  2. Lower learning rate - Use 1/10th the normal LR for fine-tuning
  3. Calibrate batch norm - Run a few batches through the model before final conversion
  4. Test on target hardware - Quantization benefits vary by platform

Summary

Concept Description
Quantization-Aware Training Training with simulated low-precision to prepare model for INT8 inference
QuantizeCallback fastai callback that handles QAT preparation, training, and conversion
Size Reduction ~4x smaller model (FP32 → INT8)
Backend Target hardware platform ('x86', 'qnnpack', 'fbgemm')
Typical Use Production deployment where model size and inference speed matter

See Also

  • Quantizer - Lower-level quantization API with more control
  • Sparsifier - Combine with sparsification for even smaller models
  • BN Folding - Fold batch norm layers before quantization for best results
  • PyTorch Quantization Docs - Official PyTorch quantization documentation