Merge pull request #1766 from huggingface/onnx_export

Add long missing onnx utils and export code
This commit is contained in:
Ross Wightman 2023-04-12 09:06:16 -07:00 committed by GitHub
commit 3bd35c7004
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 399 additions and 15 deletions

86
onnx_export.py Normal file
View File

@ -0,0 +1,86 @@
""" ONNX export script
Export PyTorch models as ONNX graphs.
This export script originally started as an adaptation of code snippets found at
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph
for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible
with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback
flags are currently required.
Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for
caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime.
Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models.
Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks.
Copyright 2020 Ross Wightman
"""
import argparse
import timm
from timm.utils.onnx import onnx_export
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('output', metavar='ONNX_FILE',
help='output model filename')
parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100',
help='model architecture (default: mobilenetv3_large_100)')
parser.add_argument('--opset', type=int, default=None,
help='ONNX opset to use (default: 10)')
parser.add_argument('--keep-init', action='store_true', default=False,
help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.')
parser.add_argument('--aten-fallback', action='store_true', default=False,
help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.')
parser.add_argument('--dynamic-size', action='store_true', default=False,
help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.')
parser.add_argument('--check-forward', action='store_true', default=False,
help='Do a full check of torch vs onnx forward after export.')
parser.add_argument('-b', '--batch-size', default=1, type=int,
metavar='N', help='mini-batch size (default: 1)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to checkpoint (default: none)')
def main():
args = parser.parse_args()
args.pretrained = True
if args.checkpoint:
args.pretrained = False
print("==> Creating PyTorch {} model".format(args.model))
# NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers
# for models using SAME padding
model = timm.create_model(
args.model,
num_classes=args.num_classes,
in_chans=3,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint,
exportable=True,
)
onnx_export(
model,
args.output,
opset=args.opset,
dynamic_size=args.dynamic_size,
aten_fallback=args.aten_fallback,
keep_initializers=args.keep_init,
check_forward=args.check_forward,
)
if __name__ == '__main__':
main()

110
onnx_validate.py Normal file
View File

@ -0,0 +1,110 @@
""" ONNX-runtime validation script
This script was created to verify accuracy and performance of exported ONNX
models running with the onnxruntime. It utilizes the PyTorch dataloader/processing
pipeline for a fair comparison against the originals.
Copyright 2020 Ross Wightman
"""
import argparse
import numpy as np
import onnxruntime
from timm.data import create_loader, resolve_data_config, create_dataset
from timm.utils import AverageMeter
import time
parser = argparse.ArgumentParser(description='ONNX Validation')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--onnx-input', default='', type=str, metavar='PATH',
help='path to onnx model/weights file')
parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH',
help='path to output optimized onnx graph')
parser.add_argument('--profile', action='store_true', default=False,
help='Enable profiler output.')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
help='Override default crop pct of 0.875')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
def main():
args = parser.parse_args()
args.gpu_id = 0
# Set graph optimization level
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
if args.profile:
sess_options.enable_profiling = True
if args.onnx_output_opt:
sess_options.optimized_model_filepath = args.onnx_output_opt
session = onnxruntime.InferenceSession(args.onnx_input, sess_options)
data_config = resolve_data_config(vars(args))
loader = create_loader(
create_dataset('', args.data),
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=False,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=data_config['crop_pct']
)
input_name = session.get_inputs()[0].name
batch_time = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
for i, (input, target) in enumerate(loader):
# run the net and return prediction
output = session.run([], {input_name: input.data.numpy()})
output = output[0]
# measure accuracy and record loss
prec1, prec5 = accuracy_np(output, target.numpy())
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print(
f'Test: [{i}/{len(loader)}]\t'
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {input.size(0) / batch_time.avg:.3f}/s, '
f'{100 * batch_time.avg / input.size(0):.3f} ms/sample) \t'
f'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
f'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
)
print(f' * Prec@1 {top1.avg:.3f} ({100-top1.avg:.3f}) Prec@5 {top5.avg:.3f} ({100.-top5.avg:.3f})')
def accuracy_np(output, target):
max_indices = np.argsort(output, axis=1)[:, ::-1]
top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
top1 = 100 * np.equal(max_indices[:, 0], target).mean()
return top1, top5
if __name__ == '__main__':
main()

View File

@ -7,12 +7,22 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
from .padding import pad_same, get_padding_value
from .config import is_exportable, is_scriptable
from .padding import pad_same, pad_same_arg, get_padding_value
_USE_EXPORT_CONV = False
def conv2d_same(
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
groups: int = 1,
):
x = pad_same(x, weight.shape[-2:], stride, dilation)
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
@ -21,13 +31,66 @@ class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
in_channels, out_channels, kernel_size,
stride, 0, dilation, groups, bias,
)
def forward(self, x):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return conv2d_same(
x, self.weight, self.bias,
self.stride, self.padding, self.dilation, self.groups,
)
class Conv2dSameExport(nn.Conv2d):
""" ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
NOTE: This does not currently work with torch.jit.script
"""
# pylint: disable=unused-argument
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
super(Conv2dSameExport, self).__init__(
in_channels, out_channels, kernel_size,
stride, 0, dilation, groups, bias,
)
self.pad = None
self.pad_input_size = (0, 0)
def forward(self, x):
input_size = x.size()[-2:]
if self.pad is None:
pad_arg = pad_same_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
self.pad = nn.ZeroPad2d(pad_arg)
self.pad_input_size = input_size
x = self.pad(x)
return F.conv2d(
x, self.weight, self.bias,
self.stride, self.padding, self.dilation, self.groups,
)
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
@ -35,7 +98,12 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
kwargs.setdefault('bias', False)
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
if is_dynamic:
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
if _USE_EXPORT_CONV and is_exportable():
# older PyTorch ver needed this to export same padding reasonably
assert not is_scriptable() # Conv2DSameExport does not work with jit
return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
else:
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)

View File

@ -5,6 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import math
from typing import List, Tuple
import torch
import torch.nn.functional as F
@ -15,8 +16,11 @@ def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> in
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
def get_same_padding(x: int, k: int, s: int, d: int):
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int):
if isinstance(x, torch.Tensor):
return torch.clamp(((x / stride).ceil() - 1) * stride + (kernel_size - 1) * dilation + 1 - x, min=0)
else:
return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0)
# Can SAME padding for given args be done statically?
@ -24,12 +28,31 @@ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
def pad_same_arg(
input_size: List[int],
kernel_size: List[int],
stride: List[int],
dilation: List[int] = (1, 1),
) -> List[int]:
ih, iw = input_size
kh, kw = kernel_size
pad_h = get_same_padding(ih, kh, stride[0], dilation[0])
pad_w = get_same_padding(iw, kw, stride[1], dilation[1])
return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
# Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
def pad_same(
x,
kernel_size: List[int],
stride: List[int],
dilation: List[int] = (1, 1),
value: float = 0,
):
ih, iw = x.size()[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
x = F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), value=value)
return x

View File

@ -8,6 +8,8 @@ from typing import List, Tuple, Optional, Union
import torch
from torch import nn as nn
from .trace_utils import _assert
def pixel_freq_bands(
num_bands: int,
@ -425,7 +427,7 @@ class RotaryEmbeddingCat(nn.Module):
def get_embed(self, shape: Optional[List[int]] = None):
if self.bands is not None:
# rebuild embeddings every call, use if target shape changes
assert shape is not None
_assert(shape is not None, 'valid shape needed')
embeds = build_rotary_pos_embed(
shape,
self.bands,

95
timm/utils/onnx.py Normal file
View File

@ -0,0 +1,95 @@
from typing import Optional, Tuple, List
import torch
def onnx_forward(onnx_file, example_input):
import onnxruntime
sess_options = onnxruntime.SessionOptions()
session = onnxruntime.InferenceSession(onnx_file, sess_options)
input_name = session.get_inputs()[0].name
output = session.run([], {input_name: example_input.numpy()})
output = output[0]
return output
def onnx_export(
model: torch.nn.Module,
output_file: str,
example_input: Optional[torch.Tensor] = None,
training: bool = False,
verbose: bool = False,
check: bool = True,
check_forward: bool = False,
batch_size: int = 64,
input_size: Tuple[int, int, int] = None,
opset: Optional[int] = None,
dynamic_size: bool = False,
aten_fallback: bool = False,
keep_initializers: Optional[bool] = None,
input_names: List[str] = None,
output_names: List[str] = None,
):
import onnx
if training:
training_mode = torch.onnx.TrainingMode.TRAINING
model.train()
else:
training_mode = torch.onnx.TrainingMode.EVAL
model.eval()
if example_input is None:
if not input_size:
assert hasattr(model, 'default_cfg')
input_size = model.default_cfg.get('input_size')
example_input = torch.randn((batch_size,) + input_size, requires_grad=training)
# Run model once before export trace, sets padding for models with Conv2dSameExport. This means
# that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for
# the input img_size specified in this script.
# Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
# issues in the tracing of the dynamic padding or errors attempting to export the model after jit
# scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
original_out = model(example_input)
input_names = input_names or ["input0"]
output_names = output_names or ["output0"]
dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}}
if dynamic_size:
dynamic_axes['input0'][2] = 'height'
dynamic_axes['input0'][3] = 'width'
if aten_fallback:
export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
else:
export_type = torch.onnx.OperatorExportTypes.ONNX
torch_out = torch.onnx._export(
model,
example_input,
output_file,
training=training_mode,
export_params=True,
verbose=verbose,
input_names=input_names,
output_names=output_names,
keep_initializers_as_inputs=keep_initializers,
dynamic_axes=dynamic_axes,
opset_version=opset,
operator_export_type=export_type
)
if check:
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model, full_check=True) # assuming throw on error
if check_forward and not training:
import numpy as np
onnx_out = onnx_forward(output_file, example_input)
np.testing.assert_almost_equal(torch_out.data.numpy(), onnx_out, decimal=3)
np.testing.assert_almost_equal(original_out.data.numpy(), torch_out.data.numpy(), decimal=5)