diff --git a/mmdeploy/__init__.py b/mmdeploy/__init__.py index e6064085a..4dcdfd926 100644 --- a/mmdeploy/__init__.py +++ b/mmdeploy/__init__.py @@ -1,17 +1,17 @@ -import logging import importlib +import logging -if importlib.util.find_spec("mmcls"): +if importlib.util.find_spec('mmcls'): from .mmcls import * # noqa: F401,F403 else: logging.debug('mmcls is not installed.') -if importlib.util.find_spec("mmdet"): +if importlib.util.find_spec('mmdet'): from .mmdet import * # noqa: F401,F403 else: logging.debug('mmdet is not installed.') -if importlib.util.find_spec("mmseg"): +if importlib.util.find_spec('mmseg'): from .mmseg import * # noqa: F401,F403 else: logging.debug('mmseg is not installed.') diff --git a/mmdeploy/utils/function_rewriter.py b/mmdeploy/utils/function_rewriter.py index a7dce49a9..6f6d2b9fd 100644 --- a/mmdeploy/utils/function_rewriter.py +++ b/mmdeploy/utils/function_rewriter.py @@ -52,7 +52,8 @@ def register_rewriter(func_name: str, func_args.update(kwargs) func_caller = type(func_name + '@' + backend, (FuncCaller, ), func_args) - return FUNCTION_REWRITERS.register_module()(func_caller) + FUNCTION_REWRITERS.register_module()(func_caller) + return func return wrap diff --git a/mmdeploy/utils/symbolic_register.py b/mmdeploy/utils/symbolic_register.py index e51d9852e..bb6e1584f 100644 --- a/mmdeploy/utils/symbolic_register.py +++ b/mmdeploy/utils/symbolic_register.py @@ -79,7 +79,8 @@ def register_symbolic(func_name: str, symbolic_args.update(kwargs) wrapper_name = '@'.join([func_name, backend, str(is_pytorch)]) wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args) - return SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper) + SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper) + return symbolic_impl return wrapper diff --git a/tests/test_utils/test_register.py b/tests/test_utils/test_register.py index 4663ba085..bdfd2c55e 100644 --- a/tests/test_utils/test_register.py +++ b/tests/test_utils/test_register.py @@ -9,6 +9,8 @@ def test_function_rewriter(): x = torch.tensor([1, 2, 3, 4, 5]) y = torch.tensor([2, 4, 6, 8, 10]) + @FUNCTION_REWRITERS.register_rewriter( + func_name='torch.mul', backend='tensorrt') @FUNCTION_REWRITERS.register_rewriter( func_name='torch.add', backend='tensorrt') def sub_func(rewriter, x, y): @@ -21,6 +23,9 @@ def test_function_rewriter(): result = torch.add(x, y) # replace add with sub torch.testing.assert_allclose(result, x - y) + result = torch.mul(x, y) + # replace add with sub + torch.testing.assert_allclose(result, x - y) result = torch.add(x, y) # recovery origin function @@ -120,6 +125,7 @@ def test_symbolic_register(): mmdeploy.TestFunc = TestFunc test_func = mmdeploy.TestFunc.apply + @SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc', backend='ncnn') @SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc') def symbolic_testfunc_default(symbolic_wrapper, g, x, val): assert hasattr(symbolic_wrapper, 'cfg') @@ -161,8 +167,18 @@ def test_symbolic_register(): assert nodes[1].op_type == 'cummax_default' assert nodes[1].domain == 'mmcv' + # ncnn + register_extra_symbolics(cfg=cfg, backend='ncnn', opset=11) + torch.onnx.export(model, x, output_file, opset_version=11) + onnx_model = onnx.load(output_file) + os.remove(output_file) + nodes = onnx_model.graph.node + assert nodes[0].op_type == 'symbolic_testfunc_default' + assert nodes[0].domain == 'mmcv' + assert nodes[1].op_type == 'cummax_default' + assert nodes[1].domain == 'mmcv' + # default - cfg = dict() register_extra_symbolics(cfg=cfg, backend='tensorrt', opset=11) torch.onnx.export(model, x, output_file, opset_version=11) onnx_model = onnx.load(output_file)