better register
parent
100ba694ce
commit
5b2459570a
|
@ -1,17 +1,17 @@
|
|||
import logging
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
if importlib.util.find_spec("mmcls"):
|
||||
if importlib.util.find_spec('mmcls'):
|
||||
from .mmcls import * # noqa: F401,F403
|
||||
else:
|
||||
logging.debug('mmcls is not installed.')
|
||||
|
||||
if importlib.util.find_spec("mmdet"):
|
||||
if importlib.util.find_spec('mmdet'):
|
||||
from .mmdet import * # noqa: F401,F403
|
||||
else:
|
||||
logging.debug('mmdet is not installed.')
|
||||
|
||||
if importlib.util.find_spec("mmseg"):
|
||||
if importlib.util.find_spec('mmseg'):
|
||||
from .mmseg import * # noqa: F401,F403
|
||||
else:
|
||||
logging.debug('mmseg is not installed.')
|
||||
|
|
|
@ -52,7 +52,8 @@ def register_rewriter(func_name: str,
|
|||
func_args.update(kwargs)
|
||||
func_caller = type(func_name + '@' + backend, (FuncCaller, ),
|
||||
func_args)
|
||||
return FUNCTION_REWRITERS.register_module()(func_caller)
|
||||
FUNCTION_REWRITERS.register_module()(func_caller)
|
||||
return func
|
||||
|
||||
return wrap
|
||||
|
||||
|
|
|
@ -79,7 +79,8 @@ def register_symbolic(func_name: str,
|
|||
symbolic_args.update(kwargs)
|
||||
wrapper_name = '@'.join([func_name, backend, str(is_pytorch)])
|
||||
wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args)
|
||||
return SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper)
|
||||
SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper)
|
||||
return symbolic_impl
|
||||
|
||||
return wrapper
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@ def test_function_rewriter():
|
|||
x = torch.tensor([1, 2, 3, 4, 5])
|
||||
y = torch.tensor([2, 4, 6, 8, 10])
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter(
|
||||
func_name='torch.mul', backend='tensorrt')
|
||||
@FUNCTION_REWRITERS.register_rewriter(
|
||||
func_name='torch.add', backend='tensorrt')
|
||||
def sub_func(rewriter, x, y):
|
||||
|
@ -21,6 +23,9 @@ def test_function_rewriter():
|
|||
result = torch.add(x, y)
|
||||
# replace add with sub
|
||||
torch.testing.assert_allclose(result, x - y)
|
||||
result = torch.mul(x, y)
|
||||
# replace add with sub
|
||||
torch.testing.assert_allclose(result, x - y)
|
||||
|
||||
result = torch.add(x, y)
|
||||
# recovery origin function
|
||||
|
@ -120,6 +125,7 @@ def test_symbolic_register():
|
|||
mmdeploy.TestFunc = TestFunc
|
||||
test_func = mmdeploy.TestFunc.apply
|
||||
|
||||
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc', backend='ncnn')
|
||||
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc')
|
||||
def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
|
||||
assert hasattr(symbolic_wrapper, 'cfg')
|
||||
|
@ -161,8 +167,18 @@ def test_symbolic_register():
|
|||
assert nodes[1].op_type == 'cummax_default'
|
||||
assert nodes[1].domain == 'mmcv'
|
||||
|
||||
# ncnn
|
||||
register_extra_symbolics(cfg=cfg, backend='ncnn', opset=11)
|
||||
torch.onnx.export(model, x, output_file, opset_version=11)
|
||||
onnx_model = onnx.load(output_file)
|
||||
os.remove(output_file)
|
||||
nodes = onnx_model.graph.node
|
||||
assert nodes[0].op_type == 'symbolic_testfunc_default'
|
||||
assert nodes[0].domain == 'mmcv'
|
||||
assert nodes[1].op_type == 'cummax_default'
|
||||
assert nodes[1].domain == 'mmcv'
|
||||
|
||||
# default
|
||||
cfg = dict()
|
||||
register_extra_symbolics(cfg=cfg, backend='tensorrt', opset=11)
|
||||
torch.onnx.export(model, x, output_file, opset_version=11)
|
||||
onnx_model = onnx.load(output_file)
|
||||
|
|
Loading…
Reference in New Issue