diff --git a/onnx_export.py b/onnx_export.py new file mode 100644 index 00000000..54f8f352 --- /dev/null +++ b/onnx_export.py @@ -0,0 +1,86 @@ +""" ONNX export script + +Export PyTorch models as ONNX graphs. + +This export script originally started as an adaptation of code snippets found at +https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html + +The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph +for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible +with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback +flags are currently required. + +Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for +caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime. + +Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models. +Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks. + +Copyright 2020 Ross Wightman +""" +import argparse + +import timm +from timm.utils.onnx import onnx_export + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('output', metavar='ONNX_FILE', + help='output model filename') +parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100', + help='model architecture (default: mobilenetv3_large_100)') +parser.add_argument('--opset', type=int, default=None, + help='ONNX opset to use (default: 10)') +parser.add_argument('--keep-init', action='store_true', default=False, + help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.') +parser.add_argument('--aten-fallback', action='store_true', default=False, + help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.') +parser.add_argument('--dynamic-size', action='store_true', default=False, + help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.') +parser.add_argument('--check-forward', action='store_true', default=False, + help='Do a full check of torch vs onnx forward after export.') +parser.add_argument('-b', '--batch-size', default=1, type=int, + metavar='N', help='mini-batch size (default: 1)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--num-classes', type=int, default=1000, + help='Number classes in dataset') +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', + help='path to checkpoint (default: none)') + + +def main(): + args = parser.parse_args() + + args.pretrained = True + if args.checkpoint: + args.pretrained = False + + print("==> Creating PyTorch {} model".format(args.model)) + # NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers + # for models using SAME padding + model = timm.create_model( + args.model, + num_classes=args.num_classes, + in_chans=3, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint, + exportable=True, + ) + + onnx_export( + model, + args.output, + opset=args.opset, + dynamic_size=args.dynamic_size, + aten_fallback=args.aten_fallback, + keep_initializers=args.keep_init, + check_forward=args.check_forward, + ) + + +if __name__ == '__main__': + main() diff --git a/onnx_validate.py b/onnx_validate.py new file mode 100644 index 00000000..acd1c189 --- /dev/null +++ b/onnx_validate.py @@ -0,0 +1,110 @@ +""" ONNX-runtime validation script + +This script was created to verify accuracy and performance of exported ONNX +models running with the onnxruntime. It utilizes the PyTorch dataloader/processing +pipeline for a fair comparison against the originals. + +Copyright 2020 Ross Wightman +""" +import argparse +import numpy as np +import onnxruntime +from timm.data import create_loader, resolve_data_config, create_dataset +from timm.utils import AverageMeter +import time + +parser = argparse.ArgumentParser(description='ONNX Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--onnx-input', default='', type=str, metavar='PATH', + help='path to onnx model/weights file') +parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH', + help='path to output optimized onnx graph') +parser.add_argument('--profile', action='store_true', default=False, + help='Enable profiler output.') +parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', + help='Override default crop pct of 0.875') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') + + +def main(): + args = parser.parse_args() + args.gpu_id = 0 + + # Set graph optimization level + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + if args.profile: + sess_options.enable_profiling = True + if args.onnx_output_opt: + sess_options.optimized_model_filepath = args.onnx_output_opt + + session = onnxruntime.InferenceSession(args.onnx_input, sess_options) + + data_config = resolve_data_config(vars(args)) + loader = create_loader( + create_dataset('', args.data), + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=False, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=data_config['crop_pct'] + ) + + input_name = session.get_inputs()[0].name + + batch_time = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + for i, (input, target) in enumerate(loader): + # run the net and return prediction + output = session.run([], {input_name: input.data.numpy()}) + output = output[0] + + # measure accuracy and record loss + prec1, prec5 = accuracy_np(output, target.numpy()) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print( + f'Test: [{i}/{len(loader)}]\t' + f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {input.size(0) / batch_time.avg:.3f}/s, ' + f'{100 * batch_time.avg / input.size(0):.3f} ms/sample) \t' + f'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + f'Prec@5 {top5.val:.3f} ({top5.avg:.3f})' + ) + + print(f' * Prec@1 {top1.avg:.3f} ({100-top1.avg:.3f}) Prec@5 {top5.avg:.3f} ({100.-top5.avg:.3f})') + + +def accuracy_np(output, target): + max_indices = np.argsort(output, axis=1)[:, ::-1] + top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() + top1 = 100 * np.equal(max_indices[:, 0], target).mean() + return top1, top5 + + +if __name__ == '__main__': + main() diff --git a/timm/layers/conv2d_same.py b/timm/layers/conv2d_same.py index 75f0f98d..7ac85b79 100644 --- a/timm/layers/conv2d_same.py +++ b/timm/layers/conv2d_same.py @@ -7,12 +7,22 @@ import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional -from .padding import pad_same, get_padding_value +from .config import is_exportable, is_scriptable +from .padding import pad_same, pad_same_arg, get_padding_value + + +_USE_EXPORT_CONV = False def conv2d_same( - x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), - padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + x, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + dilation: Tuple[int, int] = (1, 1), + groups: int = 1, +): x = pad_same(x, weight.shape[-2:], stride, dilation) return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) @@ -21,13 +31,66 @@ class Conv2dSame(nn.Conv2d): """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): super(Conv2dSame, self).__init__( - in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + in_channels, out_channels, kernel_size, + stride, 0, dilation, groups, bias, + ) def forward(self, x): - return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return conv2d_same( + x, self.weight, self.bias, + self.stride, self.padding, self.dilation, self.groups, + ) + + +class Conv2dSameExport(nn.Conv2d): + """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions + + NOTE: This does not currently work with torch.jit.script + """ + + # pylint: disable=unused-argument + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + super(Conv2dSameExport, self).__init__( + in_channels, out_channels, kernel_size, + stride, 0, dilation, groups, bias, + ) + self.pad = None + self.pad_input_size = (0, 0) + + def forward(self, x): + input_size = x.size()[-2:] + if self.pad is None: + pad_arg = pad_same_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) + self.pad = nn.ZeroPad2d(pad_arg) + self.pad_input_size = input_size + + x = self.pad(x) + return F.conv2d( + x, self.weight, self.bias, + self.stride, self.padding, self.dilation, self.groups, + ) def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): @@ -35,7 +98,12 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): kwargs.setdefault('bias', False) padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) if is_dynamic: - return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + if _USE_EXPORT_CONV and is_exportable(): + # older PyTorch ver needed this to export same padding reasonably + assert not is_scriptable() # Conv2DSameExport does not work with jit + return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) + else: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) else: return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) diff --git a/timm/layers/padding.py b/timm/layers/padding.py index 34afc37c..d6971526 100644 --- a/timm/layers/padding.py +++ b/timm/layers/padding.py @@ -5,6 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman import math from typing import List, Tuple +import torch import torch.nn.functional as F @@ -15,8 +16,11 @@ def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> in # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution -def get_same_padding(x: int, k: int, s: int, d: int): - return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) +def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int): + if isinstance(x, torch.Tensor): + return torch.clamp(((x / stride).ceil() - 1) * stride + (kernel_size - 1) * dilation + 1 - x, min=0) + else: + return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0) # Can SAME padding for given args be done statically? @@ -24,12 +28,31 @@ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 +def pad_same_arg( + input_size: List[int], + kernel_size: List[int], + stride: List[int], + dilation: List[int] = (1, 1), +) -> List[int]: + ih, iw = input_size + kh, kw = kernel_size + pad_h = get_same_padding(ih, kh, stride[0], dilation[0]) + pad_w = get_same_padding(iw, kw, stride[1], dilation[1]) + return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + + # Dynamically pad input x with 'SAME' padding for conv with specified args -def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): +def pad_same( + x, + kernel_size: List[int], + stride: List[int], + dilation: List[int] = (1, 1), + value: float = 0, +): ih, iw = x.size()[-2:] - pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) + pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0]) + pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1]) + x = F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), value=value) return x diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 7f340021..c7beb3d6 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -8,6 +8,8 @@ from typing import List, Tuple, Optional, Union import torch from torch import nn as nn +from .trace_utils import _assert + def pixel_freq_bands( num_bands: int, @@ -425,7 +427,7 @@ class RotaryEmbeddingCat(nn.Module): def get_embed(self, shape: Optional[List[int]] = None): if self.bands is not None: # rebuild embeddings every call, use if target shape changes - assert shape is not None + _assert(shape is not None, 'valid shape needed') embeds = build_rotary_pos_embed( shape, self.bands, diff --git a/timm/utils/onnx.py b/timm/utils/onnx.py new file mode 100644 index 00000000..58cb2d2a --- /dev/null +++ b/timm/utils/onnx.py @@ -0,0 +1,95 @@ +from typing import Optional, Tuple, List + +import torch + + +def onnx_forward(onnx_file, example_input): + import onnxruntime + + sess_options = onnxruntime.SessionOptions() + session = onnxruntime.InferenceSession(onnx_file, sess_options) + input_name = session.get_inputs()[0].name + output = session.run([], {input_name: example_input.numpy()}) + output = output[0] + return output + + +def onnx_export( + model: torch.nn.Module, + output_file: str, + example_input: Optional[torch.Tensor] = None, + training: bool = False, + verbose: bool = False, + check: bool = True, + check_forward: bool = False, + batch_size: int = 64, + input_size: Tuple[int, int, int] = None, + opset: Optional[int] = None, + dynamic_size: bool = False, + aten_fallback: bool = False, + keep_initializers: Optional[bool] = None, + input_names: List[str] = None, + output_names: List[str] = None, +): + import onnx + + if training: + training_mode = torch.onnx.TrainingMode.TRAINING + model.train() + else: + training_mode = torch.onnx.TrainingMode.EVAL + model.eval() + + if example_input is None: + if not input_size: + assert hasattr(model, 'default_cfg') + input_size = model.default_cfg.get('input_size') + example_input = torch.randn((batch_size,) + input_size, requires_grad=training) + + # Run model once before export trace, sets padding for models with Conv2dSameExport. This means + # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for + # the input img_size specified in this script. + + # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to + # issues in the tracing of the dynamic padding or errors attempting to export the model after jit + # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions... + original_out = model(example_input) + + input_names = input_names or ["input0"] + output_names = output_names or ["output0"] + + dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}} + if dynamic_size: + dynamic_axes['input0'][2] = 'height' + dynamic_axes['input0'][3] = 'width' + + if aten_fallback: + export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK + else: + export_type = torch.onnx.OperatorExportTypes.ONNX + + torch_out = torch.onnx._export( + model, + example_input, + output_file, + training=training_mode, + export_params=True, + verbose=verbose, + input_names=input_names, + output_names=output_names, + keep_initializers_as_inputs=keep_initializers, + dynamic_axes=dynamic_axes, + opset_version=opset, + operator_export_type=export_type + ) + + if check: + onnx_model = onnx.load(output_file) + onnx.checker.check_model(onnx_model, full_check=True) # assuming throw on error + if check_forward and not training: + import numpy as np + onnx_out = onnx_forward(output_file, example_input) + np.testing.assert_almost_equal(torch_out.data.numpy(), onnx_out, decimal=3) + np.testing.assert_almost_equal(original_out.data.numpy(), torch_out.data.numpy(), decimal=5) + +