diff --git a/mmdeploy/core/rewriters/symbolic_rewriter.py b/mmdeploy/core/rewriters/symbolic_rewriter.py index dd47cd8d5..dfcc28f76 100644 --- a/mmdeploy/core/rewriters/symbolic_rewriter.py +++ b/mmdeploy/core/rewriters/symbolic_rewriter.py @@ -1,10 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Callable, Dict, List, Optional, Sequence, Union +import torch from torch.autograd import Function from torch.onnx.symbolic_helper import parse_args -from torch.onnx.symbolic_registry import _registry as pytorch_registry -from torch.onnx.symbolic_registry import register_op from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, @@ -95,7 +94,9 @@ class SymbolicRewriter: is_pytorch = record_dict['is_pytorch'] if is_pytorch: - register_op(function_name, context_caller, '', opset) + from torch.onnx import register_custom_op_symbolic + register_custom_op_symbolic(f'::{function_name}', + context_caller, opset) # Save domain and version self._pytorch_symbolic.append((function_name, '', opset)) @@ -132,11 +133,18 @@ class SymbolicRewriter: def exit(self): """The implementation of symbolic unregister.""" # Unregister pytorch op - for function_name, domain, version in self._pytorch_symbolic: - # Same to ungister_op() in torch 1.9.0+ - del pytorch_registry[(domain, version)][function_name] - if not pytorch_registry[(domain, version)]: - del pytorch_registry[(domain, version)] + if hasattr(torch.onnx, 'unregister_custom_op_symbolic'): + from torch.onnx import unregister_custom_op_symbolic + for function_name, domain, version in self._pytorch_symbolic: + unregister_custom_op_symbolic(f'::{function_name}', version) + else: + from torch.onnx.symbolic_registry import \ + _registry as pytorch_registry + for function_name, domain, version in self._pytorch_symbolic: + # Same to unregister_op() in torch 1.9.0+ + del pytorch_registry[(domain, version)][function_name] + if not pytorch_registry[(domain, version)]: + del pytorch_registry[(domain, version)] # Unregister custom op for origin_func, origin_symbolic in self._extra_symbolic: diff --git a/tests/test_core/test_symbolic_register.py b/tests/test_core/test_symbolic_register.py index 7b71320a2..b012f6a8b 100644 --- a/tests/test_core/test_symbolic_register.py +++ b/tests/test_core/test_symbolic_register.py @@ -137,7 +137,7 @@ def test_unregister(): assert nodes[0].op_type == 'cummax_default' assert nodes[0].domain == 'mmdeploy' - with pytest.raises(RuntimeError): + with pytest.raises((ValueError, RuntimeError)): torch.onnx.export(model, x, output_file, opset_version=11) model = TestModel2().eval()