mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
update symbolic rewriter 1.x (#1404)
* update symbolic rewriter 1.x * typo fix
This commit is contained in:
parent
f6ea5d3315
commit
cb37b092bd
@ -1,10 +1,9 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.autograd import Function
|
from torch.autograd import Function
|
||||||
from torch.onnx.symbolic_helper import parse_args
|
from torch.onnx.symbolic_helper import parse_args
|
||||||
from torch.onnx.symbolic_registry import _registry as pytorch_registry
|
|
||||||
from torch.onnx.symbolic_registry import register_op
|
|
||||||
|
|
||||||
from mmdeploy.utils import IR, Backend, get_root_logger
|
from mmdeploy.utils import IR, Backend, get_root_logger
|
||||||
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
|
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
|
||||||
@ -95,7 +94,9 @@ class SymbolicRewriter:
|
|||||||
|
|
||||||
is_pytorch = record_dict['is_pytorch']
|
is_pytorch = record_dict['is_pytorch']
|
||||||
if is_pytorch:
|
if is_pytorch:
|
||||||
register_op(function_name, context_caller, '', opset)
|
from torch.onnx import register_custom_op_symbolic
|
||||||
|
register_custom_op_symbolic(f'::{function_name}',
|
||||||
|
context_caller, opset)
|
||||||
|
|
||||||
# Save domain and version
|
# Save domain and version
|
||||||
self._pytorch_symbolic.append((function_name, '', opset))
|
self._pytorch_symbolic.append((function_name, '', opset))
|
||||||
@ -132,11 +133,18 @@ class SymbolicRewriter:
|
|||||||
def exit(self):
|
def exit(self):
|
||||||
"""The implementation of symbolic unregister."""
|
"""The implementation of symbolic unregister."""
|
||||||
# Unregister pytorch op
|
# Unregister pytorch op
|
||||||
for function_name, domain, version in self._pytorch_symbolic:
|
if hasattr(torch.onnx, 'unregister_custom_op_symbolic'):
|
||||||
# Same to ungister_op() in torch 1.9.0+
|
from torch.onnx import unregister_custom_op_symbolic
|
||||||
del pytorch_registry[(domain, version)][function_name]
|
for function_name, domain, version in self._pytorch_symbolic:
|
||||||
if not pytorch_registry[(domain, version)]:
|
unregister_custom_op_symbolic(f'::{function_name}', version)
|
||||||
del pytorch_registry[(domain, version)]
|
else:
|
||||||
|
from torch.onnx.symbolic_registry import \
|
||||||
|
_registry as pytorch_registry
|
||||||
|
for function_name, domain, version in self._pytorch_symbolic:
|
||||||
|
# Same to unregister_op() in torch 1.9.0+
|
||||||
|
del pytorch_registry[(domain, version)][function_name]
|
||||||
|
if not pytorch_registry[(domain, version)]:
|
||||||
|
del pytorch_registry[(domain, version)]
|
||||||
|
|
||||||
# Unregister custom op
|
# Unregister custom op
|
||||||
for origin_func, origin_symbolic in self._extra_symbolic:
|
for origin_func, origin_symbolic in self._extra_symbolic:
|
||||||
|
@ -137,7 +137,7 @@ def test_unregister():
|
|||||||
assert nodes[0].op_type == 'cummax_default'
|
assert nodes[0].op_type == 'cummax_default'
|
||||||
assert nodes[0].domain == 'mmdeploy'
|
assert nodes[0].domain == 'mmdeploy'
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises((ValueError, RuntimeError)):
|
||||||
torch.onnx.export(model, x, output_file, opset_version=11)
|
torch.onnx.export(model, x, output_file, opset_version=11)
|
||||||
|
|
||||||
model = TestModel2().eval()
|
model = TestModel2().eval()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user