diff --git a/onnx_export.py b/onnx_export.py index 554d314b..7b951b63 100644 --- a/onnx_export.py +++ b/onnx_export.py @@ -57,6 +57,8 @@ parser.add_argument('--training', default=False, action='store_true', help='Export in training mode (default is eval)') parser.add_argument('--verbose', default=False, action='store_true', help='Extra stdout output') +parser.add_argument('--dynamo', default=False, action='store_true', + help='Use torch dynamo export.') def main(): args = parser.parse_args() @@ -90,6 +92,7 @@ def main(): check_forward=args.check_forward, training=args.training, verbose=args.verbose, + use_dynamo=args.dynamo, input_size=(3, args.img_size, args.img_size), batch_size=args.batch_size, ) diff --git a/timm/utils/onnx.py b/timm/utils/onnx.py index 58cb2d2a..932bc8f9 100644 --- a/timm/utils/onnx.py +++ b/timm/utils/onnx.py @@ -28,6 +28,7 @@ def onnx_export( dynamic_size: bool = False, aten_fallback: bool = False, keep_initializers: Optional[bool] = None, + use_dynamo: bool = False, input_names: List[str] = None, output_names: List[str] = None, ): @@ -53,7 +54,8 @@ def onnx_export( # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to # issues in the tracing of the dynamic padding or errors attempting to export the model after jit # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions... - original_out = model(example_input) + with torch.no_grad(): + original_out = model(example_input) input_names = input_names or ["input0"] output_names = output_names or ["output0"] @@ -68,20 +70,30 @@ def onnx_export( else: export_type = torch.onnx.OperatorExportTypes.ONNX - torch_out = torch.onnx._export( - model, - example_input, - output_file, - training=training_mode, - export_params=True, - verbose=verbose, - input_names=input_names, - output_names=output_names, - keep_initializers_as_inputs=keep_initializers, - dynamic_axes=dynamic_axes, - opset_version=opset, - operator_export_type=export_type - ) + if use_dynamo: + export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_size) + export_output = torch.onnx.dynamo_export( + model, + example_input, + export_options=export_options, + ) + export_output.save(output_file) + torch_out = None + else: + torch_out = torch.onnx._export( + model, + example_input, + output_file, + training=training_mode, + export_params=True, + verbose=verbose, + input_names=input_names, + output_names=output_names, + keep_initializers_as_inputs=keep_initializers, + dynamic_axes=dynamic_axes, + opset_version=opset, + operator_export_type=export_type + ) if check: onnx_model = onnx.load(output_file) @@ -89,7 +101,9 @@ def onnx_export( if check_forward and not training: import numpy as np onnx_out = onnx_forward(output_file, example_input) - np.testing.assert_almost_equal(torch_out.data.numpy(), onnx_out, decimal=3) - np.testing.assert_almost_equal(original_out.data.numpy(), torch_out.data.numpy(), decimal=5) - + if torch_out is not None: + np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3) + np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5) + else: + np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)