mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix #2472, torch.onnx.export_ (with return output) finally removed :(
This commit is contained in:
parent
681be882e8
commit
0cae8a4cd8
@ -43,6 +43,8 @@ parser.add_argument('-b', '--batch-size', default=1, type=int,
|
||||
metavar='N', help='mini-batch size (default: 1)')
|
||||
parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N',
|
||||
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
@ -82,6 +84,14 @@ def main():
|
||||
if args.reparam:
|
||||
model = reparameterize_model(model)
|
||||
|
||||
if args.input_size is not None:
|
||||
assert len(args.input_size) == 3, 'input-size should be N H W (channels, height, width)'
|
||||
input_size = args.input_size
|
||||
elif args.img_size is not None:
|
||||
input_size = (3, args.img_size, args.img_size)
|
||||
else:
|
||||
input_size = None
|
||||
|
||||
onnx_export(
|
||||
model,
|
||||
args.output,
|
||||
@ -93,7 +103,7 @@ def main():
|
||||
training=args.training,
|
||||
verbose=args.verbose,
|
||||
use_dynamo=args.dynamo,
|
||||
input_size=(3, args.img_size, args.img_size),
|
||||
input_size=input_size,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
|
@ -43,7 +43,7 @@ def onnx_export(
|
||||
|
||||
if example_input is None:
|
||||
if not input_size:
|
||||
assert hasattr(model, 'default_cfg')
|
||||
assert hasattr(model, 'default_cfg'), 'Cannot file model default config, input size must be provided'
|
||||
input_size = model.default_cfg.get('input_size')
|
||||
example_input = torch.randn((batch_size,) + input_size, requires_grad=training)
|
||||
|
||||
@ -80,7 +80,7 @@ def onnx_export(
|
||||
export_output.save(output_file)
|
||||
torch_out = None
|
||||
else:
|
||||
torch_out = torch.onnx._export(
|
||||
torch_out = torch.onnx.export(
|
||||
model,
|
||||
example_input,
|
||||
output_file,
|
||||
|
Loading…
x
Reference in New Issue
Block a user