better register

pull/12/head
grimoire 2021-06-16 19:25:09 +08:00
parent 100ba694ce
commit 5b2459570a
4 changed files with 25 additions and 7 deletions

View File

@ -1,17 +1,17 @@
import logging
import importlib import importlib
import logging
if importlib.util.find_spec("mmcls"): if importlib.util.find_spec('mmcls'):
from .mmcls import * # noqa: F401,F403 from .mmcls import * # noqa: F401,F403
else: else:
logging.debug('mmcls is not installed.') logging.debug('mmcls is not installed.')
if importlib.util.find_spec("mmdet"): if importlib.util.find_spec('mmdet'):
from .mmdet import * # noqa: F401,F403 from .mmdet import * # noqa: F401,F403
else: else:
logging.debug('mmdet is not installed.') logging.debug('mmdet is not installed.')
if importlib.util.find_spec("mmseg"): if importlib.util.find_spec('mmseg'):
from .mmseg import * # noqa: F401,F403 from .mmseg import * # noqa: F401,F403
else: else:
logging.debug('mmseg is not installed.') logging.debug('mmseg is not installed.')

View File

@ -52,7 +52,8 @@ def register_rewriter(func_name: str,
func_args.update(kwargs) func_args.update(kwargs)
func_caller = type(func_name + '@' + backend, (FuncCaller, ), func_caller = type(func_name + '@' + backend, (FuncCaller, ),
func_args) func_args)
return FUNCTION_REWRITERS.register_module()(func_caller) FUNCTION_REWRITERS.register_module()(func_caller)
return func
return wrap return wrap

View File

@ -79,7 +79,8 @@ def register_symbolic(func_name: str,
symbolic_args.update(kwargs) symbolic_args.update(kwargs)
wrapper_name = '@'.join([func_name, backend, str(is_pytorch)]) wrapper_name = '@'.join([func_name, backend, str(is_pytorch)])
wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args) wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args)
return SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper) SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper)
return symbolic_impl
return wrapper return wrapper

View File

@ -9,6 +9,8 @@ def test_function_rewriter():
x = torch.tensor([1, 2, 3, 4, 5]) x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([2, 4, 6, 8, 10]) y = torch.tensor([2, 4, 6, 8, 10])
@FUNCTION_REWRITERS.register_rewriter(
func_name='torch.mul', backend='tensorrt')
@FUNCTION_REWRITERS.register_rewriter( @FUNCTION_REWRITERS.register_rewriter(
func_name='torch.add', backend='tensorrt') func_name='torch.add', backend='tensorrt')
def sub_func(rewriter, x, y): def sub_func(rewriter, x, y):
@ -21,6 +23,9 @@ def test_function_rewriter():
result = torch.add(x, y) result = torch.add(x, y)
# replace add with sub # replace add with sub
torch.testing.assert_allclose(result, x - y) torch.testing.assert_allclose(result, x - y)
result = torch.mul(x, y)
# replace add with sub
torch.testing.assert_allclose(result, x - y)
result = torch.add(x, y) result = torch.add(x, y)
# recovery origin function # recovery origin function
@ -120,6 +125,7 @@ def test_symbolic_register():
mmdeploy.TestFunc = TestFunc mmdeploy.TestFunc = TestFunc
test_func = mmdeploy.TestFunc.apply test_func = mmdeploy.TestFunc.apply
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc', backend='ncnn')
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc') @SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc')
def symbolic_testfunc_default(symbolic_wrapper, g, x, val): def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
assert hasattr(symbolic_wrapper, 'cfg') assert hasattr(symbolic_wrapper, 'cfg')
@ -161,8 +167,18 @@ def test_symbolic_register():
assert nodes[1].op_type == 'cummax_default' assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv' assert nodes[1].domain == 'mmcv'
# ncnn
register_extra_symbolics(cfg=cfg, backend='ncnn', opset=11)
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)
os.remove(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmcv'
assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv'
# default # default
cfg = dict()
register_extra_symbolics(cfg=cfg, backend='tensorrt', opset=11) register_extra_symbolics(cfg=cfg, backend='tensorrt', opset=11)
torch.onnx.export(model, x, output_file, opset_version=11) torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file) onnx_model = onnx.load(output_file)