103 lines
4.6 KiB
Python
103 lines
4.6 KiB
Python
""" ONNX export script
|
|
|
|
Export PyTorch models as ONNX graphs.
|
|
|
|
This export script originally started as an adaptation of code snippets found at
|
|
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
|
|
|
|
The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph
|
|
for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible
|
|
with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback
|
|
flags are currently required.
|
|
|
|
Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for
|
|
caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime.
|
|
|
|
Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models.
|
|
Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks.
|
|
|
|
Copyright 2020 Ross Wightman
|
|
"""
|
|
import argparse
|
|
|
|
import timm
|
|
from timm.utils.model import reparameterize_model
|
|
from timm.utils.onnx import onnx_export
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
|
parser.add_argument('output', metavar='ONNX_FILE',
|
|
help='output model filename')
|
|
parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100',
|
|
help='model architecture (default: mobilenetv3_large_100)')
|
|
parser.add_argument('--opset', type=int, default=None,
|
|
help='ONNX opset to use (default: 10)')
|
|
parser.add_argument('--keep-init', action='store_true', default=False,
|
|
help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.')
|
|
parser.add_argument('--aten-fallback', action='store_true', default=False,
|
|
help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.')
|
|
parser.add_argument('--dynamic-size', action='store_true', default=False,
|
|
help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.')
|
|
parser.add_argument('--check-forward', action='store_true', default=False,
|
|
help='Do a full check of torch vs onnx forward after export.')
|
|
parser.add_argument('-b', '--batch-size', default=1, type=int,
|
|
metavar='N', help='mini-batch size (default: 1)')
|
|
parser.add_argument('--img-size', default=None, type=int,
|
|
metavar='N', help='Input image dimension, uses model default if empty')
|
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
|
help='Override mean pixel value of dataset')
|
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
|
help='Override std deviation of of dataset')
|
|
parser.add_argument('--num-classes', type=int, default=1000,
|
|
help='Number classes in dataset')
|
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
|
help='path to checkpoint (default: none)')
|
|
parser.add_argument('--reparam', default=False, action='store_true',
|
|
help='Reparameterize model')
|
|
parser.add_argument('--training', default=False, action='store_true',
|
|
help='Export in training mode (default is eval)')
|
|
parser.add_argument('--verbose', default=False, action='store_true',
|
|
help='Extra stdout output')
|
|
parser.add_argument('--dynamo', default=False, action='store_true',
|
|
help='Use torch dynamo export.')
|
|
|
|
def main():
|
|
args = parser.parse_args()
|
|
|
|
args.pretrained = True
|
|
if args.checkpoint:
|
|
args.pretrained = False
|
|
|
|
print("==> Creating PyTorch {} model".format(args.model))
|
|
# NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers
|
|
# for models using SAME padding
|
|
model = timm.create_model(
|
|
args.model,
|
|
num_classes=args.num_classes,
|
|
in_chans=3,
|
|
pretrained=args.pretrained,
|
|
checkpoint_path=args.checkpoint,
|
|
exportable=True,
|
|
)
|
|
|
|
if args.reparam:
|
|
model = reparameterize_model(model)
|
|
|
|
onnx_export(
|
|
model,
|
|
args.output,
|
|
opset=args.opset,
|
|
dynamic_size=args.dynamic_size,
|
|
aten_fallback=args.aten_fallback,
|
|
keep_initializers=args.keep_init,
|
|
check_forward=args.check_forward,
|
|
training=args.training,
|
|
verbose=args.verbose,
|
|
use_dynamo=args.dynamo,
|
|
input_size=(3, args.img_size, args.img_size),
|
|
batch_size=args.batch_size,
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|