update symbolic rewriter 1.x (#1404)

* update symbolic rewriter 1.x

* typo fix
This commit is contained in:
q.yao 2022-11-23 10:18:25 +08:00 committed by GitHub
parent f6ea5d3315
commit cb37b092bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 9 deletions

View File

@ -1,10 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, List, Optional, Sequence, Union
import torch
from torch.autograd import Function
from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_registry import _registry as pytorch_registry
from torch.onnx.symbolic_registry import register_op
from mmdeploy.utils import IR, Backend, get_root_logger
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
@ -95,7 +94,9 @@ class SymbolicRewriter:
is_pytorch = record_dict['is_pytorch']
if is_pytorch:
register_op(function_name, context_caller, '', opset)
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic(f'::{function_name}',
context_caller, opset)
# Save domain and version
self._pytorch_symbolic.append((function_name, '', opset))
@ -132,11 +133,18 @@ class SymbolicRewriter:
def exit(self):
"""The implementation of symbolic unregister."""
# Unregister pytorch op
for function_name, domain, version in self._pytorch_symbolic:
# Same to ungister_op() in torch 1.9.0+
del pytorch_registry[(domain, version)][function_name]
if not pytorch_registry[(domain, version)]:
del pytorch_registry[(domain, version)]
if hasattr(torch.onnx, 'unregister_custom_op_symbolic'):
from torch.onnx import unregister_custom_op_symbolic
for function_name, domain, version in self._pytorch_symbolic:
unregister_custom_op_symbolic(f'::{function_name}', version)
else:
from torch.onnx.symbolic_registry import \
_registry as pytorch_registry
for function_name, domain, version in self._pytorch_symbolic:
# Same to unregister_op() in torch 1.9.0+
del pytorch_registry[(domain, version)][function_name]
if not pytorch_registry[(domain, version)]:
del pytorch_registry[(domain, version)]
# Unregister custom op
for origin_func, origin_symbolic in self._extra_symbolic:

View File

@ -137,7 +137,7 @@ def test_unregister():
assert nodes[0].op_type == 'cummax_default'
assert nodes[0].domain == 'mmdeploy'
with pytest.raises(RuntimeError):
with pytest.raises((ValueError, RuntimeError)):
torch.onnx.export(model, x, output_file, opset_version=11)
model = TestModel2().eval()