mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
update symbolic rewriter 1.x (#1404)
* update symbolic rewriter 1.x * typo fix
This commit is contained in:
parent
f6ea5d3315
commit
cb37b092bd
@ -1,10 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
from torch.onnx.symbolic_helper import parse_args
|
||||
from torch.onnx.symbolic_registry import _registry as pytorch_registry
|
||||
from torch.onnx.symbolic_registry import register_op
|
||||
|
||||
from mmdeploy.utils import IR, Backend, get_root_logger
|
||||
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
|
||||
@ -95,7 +94,9 @@ class SymbolicRewriter:
|
||||
|
||||
is_pytorch = record_dict['is_pytorch']
|
||||
if is_pytorch:
|
||||
register_op(function_name, context_caller, '', opset)
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
register_custom_op_symbolic(f'::{function_name}',
|
||||
context_caller, opset)
|
||||
|
||||
# Save domain and version
|
||||
self._pytorch_symbolic.append((function_name, '', opset))
|
||||
@ -132,11 +133,18 @@ class SymbolicRewriter:
|
||||
def exit(self):
|
||||
"""The implementation of symbolic unregister."""
|
||||
# Unregister pytorch op
|
||||
for function_name, domain, version in self._pytorch_symbolic:
|
||||
# Same to ungister_op() in torch 1.9.0+
|
||||
del pytorch_registry[(domain, version)][function_name]
|
||||
if not pytorch_registry[(domain, version)]:
|
||||
del pytorch_registry[(domain, version)]
|
||||
if hasattr(torch.onnx, 'unregister_custom_op_symbolic'):
|
||||
from torch.onnx import unregister_custom_op_symbolic
|
||||
for function_name, domain, version in self._pytorch_symbolic:
|
||||
unregister_custom_op_symbolic(f'::{function_name}', version)
|
||||
else:
|
||||
from torch.onnx.symbolic_registry import \
|
||||
_registry as pytorch_registry
|
||||
for function_name, domain, version in self._pytorch_symbolic:
|
||||
# Same to unregister_op() in torch 1.9.0+
|
||||
del pytorch_registry[(domain, version)][function_name]
|
||||
if not pytorch_registry[(domain, version)]:
|
||||
del pytorch_registry[(domain, version)]
|
||||
|
||||
# Unregister custom op
|
||||
for origin_func, origin_symbolic in self._extra_symbolic:
|
||||
|
@ -137,7 +137,7 @@ def test_unregister():
|
||||
assert nodes[0].op_type == 'cummax_default'
|
||||
assert nodes[0].domain == 'mmdeploy'
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises((ValueError, RuntimeError)):
|
||||
torch.onnx.export(model, x, output_file, opset_version=11)
|
||||
|
||||
model = TestModel2().eval()
|
||||
|
Loading…
x
Reference in New Issue
Block a user