Quantization-Aware Training (QAT) simulates low-precision inference during training, allowing the model to adapt to quantization effects. This produces more accurate quantized models than post-training quantization alone.
Why Use QAT?
Approach
Accuracy
Speed
When to Use
Post-Training Quantization
Lower
Fast (no training)
Quick deployment, accuracy tolerant
Quantization-Aware Training
Higher
Slower (requires training)
Production models, accuracy critical
Key Benefits
4x smaller models - FP32 → INT8 reduces model size by ~75%
Faster inference - Integer operations are faster on most hardware
Hardware compatibility - INT8 is widely supported (CPU, mobile, edge)
1. Setup and Data
First, let’s load a dataset and create a model. We’ll use a pretrained ResNet-34 from timm.
2. Training with QuantizeCallback
The QuantizeCallback automatically: 1. Prepares the model for QAT by inserting fake quantization modules 2. Trains with simulated INT8 precision 3. Converts the model to actual INT8 after training
/home/nathan/miniconda3/envs/dev/lib/python3.12/site-packages/torch/ao/quantization/observer.py:246: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
warnings.warn(
epoch
train_loss
valid_loss
accuracy
time
0
0.497659
0.410839
0.805819
00:03
1
0.306967
0.280245
0.870771
00:03
2
0.204825
0.260572
0.887010
00:03
3. Evaluating the Quantized Model
Let’s compare the original and quantized models in terms of size and accuracy.
from tqdm import tqdmdef get_model_size(model): torch.save(model.state_dict(), "temp.p") size = os.path.getsize("temp.p") /1e6# Size in MB os.remove("temp.p")return sizedef compute_validation_accuracy(model, valid_dataloader, device=None):# Set the model to evaluation mode model.eval()# Use the model's device if no device is specified device = torch.device('cpu')# Move model to the specified device model = model.to(device)# Tracking correct predictions and total samples total_correct =0 total_samples =0# Disable gradient computation for efficiencywith torch.no_grad():for batch in tqdm(valid_dataloader):# Assuming batch is a tuple of (inputs, labels)# Adjust this if your dataloader returns a different format inputs, labels = batch# Move inputs and labels to the same device as the model inputs = torch.Tensor(inputs).to(device) labels = labels.to(device)# Forward pass outputs = model(inputs)# Get predictions (for classification tasks)# Use argmax along the class dimension _, predicted = torch.max(outputs, 1)# Update counters total_samples += labels.size(0) total_correct += (predicted == labels).sum().item()# Compute accuracy as a percentage accuracy = (total_correct / total_samples) *100return accuracy
Create an original (non-quantized) model for comparison:
print(f'Size of the original model: {get_model_size(learn_original.model):.2f} MB')print(f'Size of the quantized model: {get_model_size(learn.model):.2f} MB')
Size of the original model: 85.27 MB
Size of the quantized model: 21.51 MB
Accuracy Verification
Despite the 4x size reduction, the quantized model maintains good accuracy: