Add support for dynamo based onnx export

This commit is contained in:
Ross Wightman 2024-03-13 12:05:26 -07:00
parent 2ec2f1aa73
commit ba641e07ae
2 changed files with 35 additions and 18 deletions

View File

@ -57,6 +57,8 @@ parser.add_argument('--training', default=False, action='store_true',
help='Export in training mode (default is eval)')
parser.add_argument('--verbose', default=False, action='store_true',
help='Extra stdout output')
parser.add_argument('--dynamo', default=False, action='store_true',
help='Use torch dynamo export.')
def main():
args = parser.parse_args()
@ -90,6 +92,7 @@ def main():
check_forward=args.check_forward,
training=args.training,
verbose=args.verbose,
use_dynamo=args.dynamo,
input_size=(3, args.img_size, args.img_size),
batch_size=args.batch_size,
)

View File

@ -28,6 +28,7 @@ def onnx_export(
dynamic_size: bool = False,
aten_fallback: bool = False,
keep_initializers: Optional[bool] = None,
use_dynamo: bool = False,
input_names: List[str] = None,
output_names: List[str] = None,
):
@ -53,7 +54,8 @@ def onnx_export(
# Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
# issues in the tracing of the dynamic padding or errors attempting to export the model after jit
# scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
original_out = model(example_input)
with torch.no_grad():
original_out = model(example_input)
input_names = input_names or ["input0"]
output_names = output_names or ["output0"]
@ -68,20 +70,30 @@ def onnx_export(
else:
export_type = torch.onnx.OperatorExportTypes.ONNX
torch_out = torch.onnx._export(
model,
example_input,
output_file,
training=training_mode,
export_params=True,
verbose=verbose,
input_names=input_names,
output_names=output_names,
keep_initializers_as_inputs=keep_initializers,
dynamic_axes=dynamic_axes,
opset_version=opset,
operator_export_type=export_type
)
if use_dynamo:
export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_size)
export_output = torch.onnx.dynamo_export(
model,
example_input,
export_options=export_options,
)
export_output.save(output_file)
torch_out = None
else:
torch_out = torch.onnx._export(
model,
example_input,
output_file,
training=training_mode,
export_params=True,
verbose=verbose,
input_names=input_names,
output_names=output_names,
keep_initializers_as_inputs=keep_initializers,
dynamic_axes=dynamic_axes,
opset_version=opset,
operator_export_type=export_type
)
if check:
onnx_model = onnx.load(output_file)
@ -89,7 +101,9 @@ def onnx_export(
if check_forward and not training:
import numpy as np
onnx_out = onnx_forward(output_file, example_input)
np.testing.assert_almost_equal(torch_out.data.numpy(), onnx_out, decimal=3)
np.testing.assert_almost_equal(original_out.data.numpy(), torch_out.data.numpy(), decimal=5)
if torch_out is not None:
np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3)
np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5)
else:
np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)