mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add support for dynamo based onnx export
This commit is contained in:
parent
2ec2f1aa73
commit
ba641e07ae
@ -57,6 +57,8 @@ parser.add_argument('--training', default=False, action='store_true',
|
|||||||
help='Export in training mode (default is eval)')
|
help='Export in training mode (default is eval)')
|
||||||
parser.add_argument('--verbose', default=False, action='store_true',
|
parser.add_argument('--verbose', default=False, action='store_true',
|
||||||
help='Extra stdout output')
|
help='Extra stdout output')
|
||||||
|
parser.add_argument('--dynamo', default=False, action='store_true',
|
||||||
|
help='Use torch dynamo export.')
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -90,6 +92,7 @@ def main():
|
|||||||
check_forward=args.check_forward,
|
check_forward=args.check_forward,
|
||||||
training=args.training,
|
training=args.training,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
|
use_dynamo=args.dynamo,
|
||||||
input_size=(3, args.img_size, args.img_size),
|
input_size=(3, args.img_size, args.img_size),
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
)
|
)
|
||||||
|
@ -28,6 +28,7 @@ def onnx_export(
|
|||||||
dynamic_size: bool = False,
|
dynamic_size: bool = False,
|
||||||
aten_fallback: bool = False,
|
aten_fallback: bool = False,
|
||||||
keep_initializers: Optional[bool] = None,
|
keep_initializers: Optional[bool] = None,
|
||||||
|
use_dynamo: bool = False,
|
||||||
input_names: List[str] = None,
|
input_names: List[str] = None,
|
||||||
output_names: List[str] = None,
|
output_names: List[str] = None,
|
||||||
):
|
):
|
||||||
@ -53,7 +54,8 @@ def onnx_export(
|
|||||||
# Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
|
# Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
|
||||||
# issues in the tracing of the dynamic padding or errors attempting to export the model after jit
|
# issues in the tracing of the dynamic padding or errors attempting to export the model after jit
|
||||||
# scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
|
# scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
|
||||||
original_out = model(example_input)
|
with torch.no_grad():
|
||||||
|
original_out = model(example_input)
|
||||||
|
|
||||||
input_names = input_names or ["input0"]
|
input_names = input_names or ["input0"]
|
||||||
output_names = output_names or ["output0"]
|
output_names = output_names or ["output0"]
|
||||||
@ -68,20 +70,30 @@ def onnx_export(
|
|||||||
else:
|
else:
|
||||||
export_type = torch.onnx.OperatorExportTypes.ONNX
|
export_type = torch.onnx.OperatorExportTypes.ONNX
|
||||||
|
|
||||||
torch_out = torch.onnx._export(
|
if use_dynamo:
|
||||||
model,
|
export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_size)
|
||||||
example_input,
|
export_output = torch.onnx.dynamo_export(
|
||||||
output_file,
|
model,
|
||||||
training=training_mode,
|
example_input,
|
||||||
export_params=True,
|
export_options=export_options,
|
||||||
verbose=verbose,
|
)
|
||||||
input_names=input_names,
|
export_output.save(output_file)
|
||||||
output_names=output_names,
|
torch_out = None
|
||||||
keep_initializers_as_inputs=keep_initializers,
|
else:
|
||||||
dynamic_axes=dynamic_axes,
|
torch_out = torch.onnx._export(
|
||||||
opset_version=opset,
|
model,
|
||||||
operator_export_type=export_type
|
example_input,
|
||||||
)
|
output_file,
|
||||||
|
training=training_mode,
|
||||||
|
export_params=True,
|
||||||
|
verbose=verbose,
|
||||||
|
input_names=input_names,
|
||||||
|
output_names=output_names,
|
||||||
|
keep_initializers_as_inputs=keep_initializers,
|
||||||
|
dynamic_axes=dynamic_axes,
|
||||||
|
opset_version=opset,
|
||||||
|
operator_export_type=export_type
|
||||||
|
)
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
onnx_model = onnx.load(output_file)
|
onnx_model = onnx.load(output_file)
|
||||||
@ -89,7 +101,9 @@ def onnx_export(
|
|||||||
if check_forward and not training:
|
if check_forward and not training:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
onnx_out = onnx_forward(output_file, example_input)
|
onnx_out = onnx_forward(output_file, example_input)
|
||||||
np.testing.assert_almost_equal(torch_out.data.numpy(), onnx_out, decimal=3)
|
if torch_out is not None:
|
||||||
np.testing.assert_almost_equal(original_out.data.numpy(), torch_out.data.numpy(), decimal=5)
|
np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3)
|
||||||
|
np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5)
|
||||||
|
else:
|
||||||
|
np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user