diff --git a/onnx_export.py b/onnx_export.py index 7b951b63..20264064 100644 --- a/onnx_export.py +++ b/onnx_export.py @@ -43,6 +43,8 @@ parser.add_argument('-b', '--batch-size', default=1, type=int, metavar='N', help='mini-batch size (default: 1)') parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N', + help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', @@ -82,6 +84,14 @@ def main(): if args.reparam: model = reparameterize_model(model) + if args.input_size is not None: + assert len(args.input_size) == 3, 'input-size should be N H W (channels, height, width)' + input_size = args.input_size + elif args.img_size is not None: + input_size = (3, args.img_size, args.img_size) + else: + input_size = None + onnx_export( model, args.output, @@ -93,7 +103,7 @@ def main(): training=args.training, verbose=args.verbose, use_dynamo=args.dynamo, - input_size=(3, args.img_size, args.img_size), + input_size=input_size, batch_size=args.batch_size, ) diff --git a/timm/utils/onnx.py b/timm/utils/onnx.py index 932bc8f9..5c4fd16f 100644 --- a/timm/utils/onnx.py +++ b/timm/utils/onnx.py @@ -43,7 +43,7 @@ def onnx_export( if example_input is None: if not input_size: - assert hasattr(model, 'default_cfg') + assert hasattr(model, 'default_cfg'), 'Cannot file model default config, input size must be provided' input_size = model.default_cfg.get('input_size') example_input = torch.randn((batch_size,) + input_size, requires_grad=training) @@ -80,7 +80,7 @@ def onnx_export( export_output.save(output_file) torch_out = None else: - torch_out = torch.onnx._export( + torch_out = torch.onnx.export( model, example_input, output_file,