ONNX Export Tutorial

Export compressed models to ONNX for deployment

Overview

After compressing a model with fasterai, you’ll want to deploy it. ONNX (Open Neural Network Exchange) is the standard format for deploying models across different platforms and runtimes.

Why Export to ONNX?

Benefit Description
Portability Run on any platform: servers, mobile, edge devices, browsers
Performance ONNX Runtime is highly optimized for inference
Quantization Apply additional INT8 quantization during export
No Python needed Deploy without Python dependencies

The Deployment Pipeline

Train → Compress (prune/sparsify/quantize) → Fold BN → Export ONNX → Deploy

This tutorial walks through the complete pipeline.

1. Setup and Training

First, let’s train a model that we’ll later compress and export.

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))
Could not do one pass in your dataloader, there is something wrong in it. Please see the stack trace below:
---------------------------------------------------------------------------
AcceleratorError                          Traceback (most recent call last)
Cell In[2], line 6
      2 files = get_image_files(path/"images")
      4 def label_func(f): return f[0].isupper()
----> 6 dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/vision/data.py:150, in ImageDataLoaders.from_name_func(cls, path, fnames, label_func, **kwargs)
    148     raise ValueError("label_func couldn't be lambda function on Windows")
    149 f = using_attr(label_func, 'name')
--> 150 return cls.from_path_func(path, fnames, f, **kwargs)

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/vision/data.py:136, in ImageDataLoaders.from_path_func(cls, path, fnames, label_func, valid_pct, seed, item_tfms, batch_tfms, img_cls, **kwargs)
    130 "Create from list of `fnames` in `path`s with `label_func`"
    131 dblock = DataBlock(blocks=(ImageBlock(img_cls), CategoryBlock),
    132                    splitter=RandomSplitter(valid_pct, seed=seed),
    133                    get_y=label_func,
    134                    item_tfms=item_tfms,
    135                    batch_tfms=batch_tfms)
--> 136 return cls.from_dblock(dblock, fnames, path=path, **kwargs)

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/data/core.py:280, in DataLoaders.from_dblock(cls, dblock, source, path, bs, val_bs, shuffle, device, **kwargs)
    269 @classmethod
    270 def from_dblock(cls, 
    271     dblock, # `DataBlock` object
   (...)    278     **kwargs
    279 ):
--> 280     return dblock.dataloaders(source, path=path, bs=bs, val_bs=val_bs, shuffle=shuffle, device=device, **kwargs)

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/data/block.py:159, in DataBlock.dataloaders(self, source, path, verbose, **kwargs)
    157 dsets = self.datasets(source, verbose=verbose)
    158 kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}
--> 159 return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/data/core.py:333, in FilteredBase.dataloaders(self, bs, shuffle_train, shuffle, val_shuffle, n, path, dl_type, dl_kwargs, device, drop_last, val_bs, **kwargs)
    331 dl = dl_type(self.subset(0), **merge(kwargs,def_kwargs, dl_kwargs[0]))
    332 def_kwargs = {'bs':bs if val_bs is None else val_bs,'shuffle':val_shuffle,'n':None,'drop_last':False}
--> 333 dls = [dl] + [dl.new(self.subset(i), **merge(kwargs,def_kwargs,val_kwargs,dl_kwargs[i]))
    334               for i in range(1, self.n_subsets)]
    335 return self._dbunch_type(*dls, path=path, device=device)

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/data/core.py:104, in TfmdDL.new(self, dataset, cls, **kwargs)
    102 if not hasattr(self, '_n_inp') or not hasattr(self, '_types'):
    103     try:
--> 104         self._one_pass()
    105         res._n_inp,res._types = self._n_inp,self._types
    106     except Exception as e: 

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/data/core.py:86, in TfmdDL._one_pass(self)
     84 def _one_pass(self):
     85     b = self.do_batch([self.do_item(None)])
---> 86     if self.device is not None: b = to_device(b, self.device)
     87     its = self.after_batch(b)
     88     self._n_inp = 1 if not isinstance(its, (list,tuple)) or len(its)==1 else len(its)-1

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/torch_core.py:287, in to_device(b, device, non_blocking)
    285     if isinstance(o,Tensor): return o.to(device, non_blocking=non_blocking)
    286     return o
--> 287 return apply(_inner, b)

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/torch_core.py:224, in apply(func, x, *args, **kwargs)
    222 def apply(func, x, *args, **kwargs):
    223     "Apply `func` recursively to `x`, passing on args"
--> 224     if is_listy(x): return type(x)([apply(func, o, *args, **kwargs) for o in x])
    225     if isinstance(x,(dict,MutableMapping)): return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}
    226     res = func(x, *args, **kwargs)

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/torch_core.py:226, in apply(func, x, *args, **kwargs)
    224 if is_listy(x): return type(x)([apply(func, o, *args, **kwargs) for o in x])
    225 if isinstance(x,(dict,MutableMapping)): return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}
--> 226 res = func(x, *args, **kwargs)
    227 return res if x is None else retain_type(res, x)

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/torch_core.py:285, in to_device.<locals>._inner(o)
    283 def _inner(o):
    284     # ToDo: add TensorDict when released
--> 285     if isinstance(o,Tensor): return o.to(device, non_blocking=non_blocking)
    286     return o

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/torch_core.py:384, in TensorBase.__torch_function__(cls, func, types, args, kwargs)
    382 if cls.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)
    383 if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)
--> 384 res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
    385 dict_objs = _find_args(args) if args else _find_args(list(kwargs.values()))
    386 if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_copy=True)

File ~/miniconda3/envs/dev/lib/python3.12/site-packages/torch/_tensor.py:1654, in Tensor.__torch_function__(cls, func, types, args, kwargs)
   1651     return NotImplemented
   1653 with _C.DisableTorchFunctionSubclass():
-> 1654     ret = func(*args, **kwargs)
   1655     if func in get_default_nowrap_functions():
   1656         return ret

AcceleratorError: CUDA error: out of memory
Search for `cudaErrorMemoryAllocation' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(3)

2. Compress the Model

Apply sparsification to reduce model size. You could also use pruning, quantization, or any combination.

sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='local', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(2, cbs=sp_cb)

3. Fold BatchNorm Layers

Before export, fold batch normalization layers into convolutions for faster inference:

bn_folder = BN_Folder()
model = bn_folder.fold(learn.model)
model.eval();

4. Export to ONNX

Now export the optimized model to ONNX format:

# Create example input (batch_size=1, channels=3, height=64, width=64)
sample = torch.randn(1, 3, 64, 64)

# Export to ONNX
onnx_path = export_onnx(model.cpu(), sample, "model.onnx")
print(f"Exported to: {onnx_path}")

Verify the Export

Always verify that the ONNX model produces the same outputs as the PyTorch model:

is_valid = verify_onnx(model, onnx_path, sample)
print(f"Verification {'passed' if is_valid else 'FAILED'}: ONNX outputs {'match' if is_valid else 'do not match'} PyTorch!")

5. Export with INT8 Quantization

For even smaller models and faster inference, apply INT8 quantization during export:

# Dynamic quantization (no calibration data needed)
quantized_path = export_onnx(
    model.cpu(), sample, "model_int8.onnx",
    quantize=True,
    quantize_mode="dynamic"
)
print(f"Exported quantized model to: {quantized_path}")

For better accuracy, use static quantization with calibration data:

# Static quantization with calibration
quantized_path = export_onnx(
    model, sample, "model_int8_static.onnx",
    quantize=True,
    quantize_mode="static",
    calibration_data=dls.train  # Use training data for calibration
)

6. Compare Model Sizes

import os

def get_size_mb(path):
    return os.path.getsize(path) / 1e6

# Save PyTorch model for comparison
torch.save(model.state_dict(), "model.pt")

pt_size = get_size_mb("model.pt")
onnx_size = get_size_mb("model.onnx")
int8_size = get_size_mb(quantized_path)

print(f"PyTorch model:    {pt_size:.2f} MB")
print(f"ONNX model:       {onnx_size:.2f} MB")
print(f"ONNX INT8 model:  {int8_size:.2f} MB ({pt_size/int8_size:.1f}x smaller)")

7. Running Inference with ONNX Runtime

Use the ONNXModel wrapper for easy inference:

# Load the ONNX model
onnx_model = ONNXModel("model.onnx", device="cpu")

# Run inference
test_input = torch.randn(1, 3, 64, 64)
output = onnx_model(test_input)

print(f"Output shape: {output.shape}")
print(f"Predictions: {output}")

Benchmark Inference Speed

import time

def benchmark(fn, input_tensor, warmup=10, runs=100):
    # Warmup
    for _ in range(warmup):
        fn(input_tensor)
    
    # Benchmark
    start = time.perf_counter()
    for _ in range(runs):
        fn(input_tensor)
    elapsed = (time.perf_counter() - start) / runs * 1000
    return elapsed

test_input = torch.randn(1, 3, 64, 64)

# PyTorch
model.eval()
with torch.no_grad():
    pt_time = benchmark(model, test_input)

# ONNX
onnx_model = ONNXModel("model.onnx")
onnx_time = benchmark(onnx_model, test_input)

# ONNX INT8
onnx_int8 = ONNXModel(quantized_path)
int8_time = benchmark(onnx_int8, test_input)

print(f"PyTorch inference: {pt_time:.2f} ms")
print(f"ONNX inference:    {onnx_time:.2f} ms ({pt_time/onnx_time:.1f}x faster)")
print(f"ONNX INT8:         {int8_time:.2f} ms ({pt_time/int8_time:.1f}x faster)")

8. Parameter Reference

export_onnx Parameters

Parameter Default Description
model Required PyTorch model to export
sample Required Example input tensor (with batch dimension)
output_path Required Output .onnx file path
opset_version 18 ONNX opset version
quantize False Apply INT8 quantization after export
quantize_mode "dynamic" "dynamic" (no calibration) or "static"
calibration_data None DataLoader for static quantization
optimize True Run ONNX graph optimizer
dynamic_batch True Allow variable batch size at runtime

Quantization Mode Comparison

Mode Calibration Accuracy Speed Use Case
dynamic Not needed Good Fast export Quick deployment
static Required Better Slower export Production models

Summary

Step Tool Purpose
Compress SparsifyCallback, PruneCallback, etc. Reduce model complexity
Fold BN BN_Folder Eliminate batch norm overhead
Export export_onnx Convert to deployment format
Verify verify_onnx Ensure correctness
Quantize quantize=True Further reduce size (4x)
Deploy ONNXModel Run inference

Complete Pipeline Example

from fasterai.sparse.all import *
from fasterai.misc.all import *
from fasterai.export.all import *

# 1. Train with compression
sp_cb = SparsifyCallback(sparsity=50, granularity='weight', ...)
learn.fit_one_cycle(5, cbs=sp_cb)

# 2. Fold batch norm
model = BN_Folder().fold(learn.model)

# 3. Export with quantization
sample = torch.randn(1, 3, 224, 224)
path = export_onnx(model, sample, "model_int8.onnx", quantize=True)

# 4. Verify
assert verify_onnx(model, path, sample)

# 5. Deploy
onnx_model = ONNXModel(path)
output = onnx_model(input_tensor)

See Also