update symbolic rewriter 1.x (#1404)

* update symbolic rewriter 1.x

* typo fix
This commit is contained in:
q.yao 2022-11-23 10:18:25 +08:00 committed by GitHub
parent f6ea5d3315
commit cb37b092bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 9 deletions

View File

@ -1,10 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, List, Optional, Sequence, Union from typing import Callable, Dict, List, Optional, Sequence, Union
import torch
from torch.autograd import Function from torch.autograd import Function
from torch.onnx.symbolic_helper import parse_args from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_registry import _registry as pytorch_registry
from torch.onnx.symbolic_registry import register_op
from mmdeploy.utils import IR, Backend, get_root_logger from mmdeploy.utils import IR, Backend, get_root_logger
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
@ -95,7 +94,9 @@ class SymbolicRewriter:
is_pytorch = record_dict['is_pytorch'] is_pytorch = record_dict['is_pytorch']
if is_pytorch: if is_pytorch:
register_op(function_name, context_caller, '', opset) from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic(f'::{function_name}',
context_caller, opset)
# Save domain and version # Save domain and version
self._pytorch_symbolic.append((function_name, '', opset)) self._pytorch_symbolic.append((function_name, '', opset))
@ -132,11 +133,18 @@ class SymbolicRewriter:
def exit(self): def exit(self):
"""The implementation of symbolic unregister.""" """The implementation of symbolic unregister."""
# Unregister pytorch op # Unregister pytorch op
for function_name, domain, version in self._pytorch_symbolic: if hasattr(torch.onnx, 'unregister_custom_op_symbolic'):
# Same to ungister_op() in torch 1.9.0+ from torch.onnx import unregister_custom_op_symbolic
del pytorch_registry[(domain, version)][function_name] for function_name, domain, version in self._pytorch_symbolic:
if not pytorch_registry[(domain, version)]: unregister_custom_op_symbolic(f'::{function_name}', version)
del pytorch_registry[(domain, version)] else:
from torch.onnx.symbolic_registry import \
_registry as pytorch_registry
for function_name, domain, version in self._pytorch_symbolic:
# Same to unregister_op() in torch 1.9.0+
del pytorch_registry[(domain, version)][function_name]
if not pytorch_registry[(domain, version)]:
del pytorch_registry[(domain, version)]
# Unregister custom op # Unregister custom op
for origin_func, origin_symbolic in self._extra_symbolic: for origin_func, origin_symbolic in self._extra_symbolic:

View File

@ -137,7 +137,7 @@ def test_unregister():
assert nodes[0].op_type == 'cummax_default' assert nodes[0].op_type == 'cummax_default'
assert nodes[0].domain == 'mmdeploy' assert nodes[0].domain == 'mmdeploy'
with pytest.raises(RuntimeError): with pytest.raises((ValueError, RuntimeError)):
torch.onnx.export(model, x, output_file, opset_version=11) torch.onnx.export(model, x, output_file, opset_version=11)
model = TestModel2().eval() model = TestModel2().eval()