better register

pull/12/head
grimoire 2021-06-16 19:25:09 +08:00
parent 100ba694ce
commit 5b2459570a
4 changed files with 25 additions and 7 deletions

View File

@ -1,17 +1,17 @@
import logging
import importlib
import logging
if importlib.util.find_spec("mmcls"):
if importlib.util.find_spec('mmcls'):
from .mmcls import * # noqa: F401,F403
else:
logging.debug('mmcls is not installed.')
if importlib.util.find_spec("mmdet"):
if importlib.util.find_spec('mmdet'):
from .mmdet import * # noqa: F401,F403
else:
logging.debug('mmdet is not installed.')
if importlib.util.find_spec("mmseg"):
if importlib.util.find_spec('mmseg'):
from .mmseg import * # noqa: F401,F403
else:
logging.debug('mmseg is not installed.')

View File

@ -52,7 +52,8 @@ def register_rewriter(func_name: str,
func_args.update(kwargs)
func_caller = type(func_name + '@' + backend, (FuncCaller, ),
func_args)
return FUNCTION_REWRITERS.register_module()(func_caller)
FUNCTION_REWRITERS.register_module()(func_caller)
return func
return wrap

View File

@ -79,7 +79,8 @@ def register_symbolic(func_name: str,
symbolic_args.update(kwargs)
wrapper_name = '@'.join([func_name, backend, str(is_pytorch)])
wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args)
return SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper)
SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper)
return symbolic_impl
return wrapper

View File

@ -9,6 +9,8 @@ def test_function_rewriter():
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([2, 4, 6, 8, 10])
@FUNCTION_REWRITERS.register_rewriter(
func_name='torch.mul', backend='tensorrt')
@FUNCTION_REWRITERS.register_rewriter(
func_name='torch.add', backend='tensorrt')
def sub_func(rewriter, x, y):
@ -21,6 +23,9 @@ def test_function_rewriter():
result = torch.add(x, y)
# replace add with sub
torch.testing.assert_allclose(result, x - y)
result = torch.mul(x, y)
# replace add with sub
torch.testing.assert_allclose(result, x - y)
result = torch.add(x, y)
# recovery origin function
@ -120,6 +125,7 @@ def test_symbolic_register():
mmdeploy.TestFunc = TestFunc
test_func = mmdeploy.TestFunc.apply
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc', backend='ncnn')
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc')
def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
assert hasattr(symbolic_wrapper, 'cfg')
@ -161,8 +167,18 @@ def test_symbolic_register():
assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv'
# ncnn
register_extra_symbolics(cfg=cfg, backend='ncnn', opset=11)
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)
os.remove(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmcv'
assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv'
# default
cfg = dict()
register_extra_symbolics(cfg=cfg, backend='tensorrt', opset=11)
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)