compatible with torch 1.11 (#272)
parent
3fa2cc8845
commit
e01056cc0a
|
@ -32,6 +32,8 @@ def torch2onnx_impl(model: torch.nn.Module, input: torch.Tensor,
|
|||
output_names = onnx_cfg['output_names']
|
||||
axis_names = input_names + output_names
|
||||
dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names)
|
||||
verbose = not onnx_cfg.get('strip_doc_string', True) or onnx_cfg.get(
|
||||
'verbose', False)
|
||||
|
||||
# patch model
|
||||
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
|
||||
|
@ -50,7 +52,7 @@ def torch2onnx_impl(model: torch.nn.Module, input: torch.Tensor,
|
|||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=onnx_cfg[
|
||||
'keep_initializers_as_inputs'],
|
||||
strip_doc_string=onnx_cfg.get('strip_doc_string', True))
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def torch2onnx(img: Any,
|
||||
|
|
Loading…
Reference in New Issue