Fix #2472, torch.onnx.export_ (with return output) finally removed :(

This commit is contained in:
Ross Wightman 2025-04-15 12:03:03 -07:00 committed by Ross Wightman
parent 681be882e8
commit 0cae8a4cd8
2 changed files with 13 additions and 3 deletions

View File

@ -43,6 +43,8 @@ parser.add_argument('-b', '--batch-size', default=1, type=int,
metavar='N', help='mini-batch size (default: 1)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N',
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@ -82,6 +84,14 @@ def main():
if args.reparam:
model = reparameterize_model(model)
if args.input_size is not None:
assert len(args.input_size) == 3, 'input-size should be N H W (channels, height, width)'
input_size = args.input_size
elif args.img_size is not None:
input_size = (3, args.img_size, args.img_size)
else:
input_size = None
onnx_export(
model,
args.output,
@ -93,7 +103,7 @@ def main():
training=args.training,
verbose=args.verbose,
use_dynamo=args.dynamo,
input_size=(3, args.img_size, args.img_size),
input_size=input_size,
batch_size=args.batch_size,
)

View File

@ -43,7 +43,7 @@ def onnx_export(
if example_input is None:
if not input_size:
assert hasattr(model, 'default_cfg')
assert hasattr(model, 'default_cfg'), 'Cannot file model default config, input size must be provided'
input_size = model.default_cfg.get('input_size')
example_input = torch.randn((batch_size,) + input_size, requires_grad=training)
@ -80,7 +80,7 @@ def onnx_export(
export_output.save(output_file)
torch_out = None
else:
torch_out = torch.onnx._export(
torch_out = torch.onnx.export(
model,
example_input,
output_file,