better register
parent
100ba694ce
commit
5b2459570a
|
@ -1,17 +1,17 @@
|
||||||
import logging
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
|
|
||||||
if importlib.util.find_spec("mmcls"):
|
if importlib.util.find_spec('mmcls'):
|
||||||
from .mmcls import * # noqa: F401,F403
|
from .mmcls import * # noqa: F401,F403
|
||||||
else:
|
else:
|
||||||
logging.debug('mmcls is not installed.')
|
logging.debug('mmcls is not installed.')
|
||||||
|
|
||||||
if importlib.util.find_spec("mmdet"):
|
if importlib.util.find_spec('mmdet'):
|
||||||
from .mmdet import * # noqa: F401,F403
|
from .mmdet import * # noqa: F401,F403
|
||||||
else:
|
else:
|
||||||
logging.debug('mmdet is not installed.')
|
logging.debug('mmdet is not installed.')
|
||||||
|
|
||||||
if importlib.util.find_spec("mmseg"):
|
if importlib.util.find_spec('mmseg'):
|
||||||
from .mmseg import * # noqa: F401,F403
|
from .mmseg import * # noqa: F401,F403
|
||||||
else:
|
else:
|
||||||
logging.debug('mmseg is not installed.')
|
logging.debug('mmseg is not installed.')
|
||||||
|
|
|
@ -52,7 +52,8 @@ def register_rewriter(func_name: str,
|
||||||
func_args.update(kwargs)
|
func_args.update(kwargs)
|
||||||
func_caller = type(func_name + '@' + backend, (FuncCaller, ),
|
func_caller = type(func_name + '@' + backend, (FuncCaller, ),
|
||||||
func_args)
|
func_args)
|
||||||
return FUNCTION_REWRITERS.register_module()(func_caller)
|
FUNCTION_REWRITERS.register_module()(func_caller)
|
||||||
|
return func
|
||||||
|
|
||||||
return wrap
|
return wrap
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,8 @@ def register_symbolic(func_name: str,
|
||||||
symbolic_args.update(kwargs)
|
symbolic_args.update(kwargs)
|
||||||
wrapper_name = '@'.join([func_name, backend, str(is_pytorch)])
|
wrapper_name = '@'.join([func_name, backend, str(is_pytorch)])
|
||||||
wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args)
|
wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args)
|
||||||
return SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper)
|
SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper)
|
||||||
|
return symbolic_impl
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,8 @@ def test_function_rewriter():
|
||||||
x = torch.tensor([1, 2, 3, 4, 5])
|
x = torch.tensor([1, 2, 3, 4, 5])
|
||||||
y = torch.tensor([2, 4, 6, 8, 10])
|
y = torch.tensor([2, 4, 6, 8, 10])
|
||||||
|
|
||||||
|
@FUNCTION_REWRITERS.register_rewriter(
|
||||||
|
func_name='torch.mul', backend='tensorrt')
|
||||||
@FUNCTION_REWRITERS.register_rewriter(
|
@FUNCTION_REWRITERS.register_rewriter(
|
||||||
func_name='torch.add', backend='tensorrt')
|
func_name='torch.add', backend='tensorrt')
|
||||||
def sub_func(rewriter, x, y):
|
def sub_func(rewriter, x, y):
|
||||||
|
@ -21,6 +23,9 @@ def test_function_rewriter():
|
||||||
result = torch.add(x, y)
|
result = torch.add(x, y)
|
||||||
# replace add with sub
|
# replace add with sub
|
||||||
torch.testing.assert_allclose(result, x - y)
|
torch.testing.assert_allclose(result, x - y)
|
||||||
|
result = torch.mul(x, y)
|
||||||
|
# replace add with sub
|
||||||
|
torch.testing.assert_allclose(result, x - y)
|
||||||
|
|
||||||
result = torch.add(x, y)
|
result = torch.add(x, y)
|
||||||
# recovery origin function
|
# recovery origin function
|
||||||
|
@ -120,6 +125,7 @@ def test_symbolic_register():
|
||||||
mmdeploy.TestFunc = TestFunc
|
mmdeploy.TestFunc = TestFunc
|
||||||
test_func = mmdeploy.TestFunc.apply
|
test_func = mmdeploy.TestFunc.apply
|
||||||
|
|
||||||
|
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc', backend='ncnn')
|
||||||
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc')
|
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc')
|
||||||
def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
|
def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
|
||||||
assert hasattr(symbolic_wrapper, 'cfg')
|
assert hasattr(symbolic_wrapper, 'cfg')
|
||||||
|
@ -161,8 +167,18 @@ def test_symbolic_register():
|
||||||
assert nodes[1].op_type == 'cummax_default'
|
assert nodes[1].op_type == 'cummax_default'
|
||||||
assert nodes[1].domain == 'mmcv'
|
assert nodes[1].domain == 'mmcv'
|
||||||
|
|
||||||
|
# ncnn
|
||||||
|
register_extra_symbolics(cfg=cfg, backend='ncnn', opset=11)
|
||||||
|
torch.onnx.export(model, x, output_file, opset_version=11)
|
||||||
|
onnx_model = onnx.load(output_file)
|
||||||
|
os.remove(output_file)
|
||||||
|
nodes = onnx_model.graph.node
|
||||||
|
assert nodes[0].op_type == 'symbolic_testfunc_default'
|
||||||
|
assert nodes[0].domain == 'mmcv'
|
||||||
|
assert nodes[1].op_type == 'cummax_default'
|
||||||
|
assert nodes[1].domain == 'mmcv'
|
||||||
|
|
||||||
# default
|
# default
|
||||||
cfg = dict()
|
|
||||||
register_extra_symbolics(cfg=cfg, backend='tensorrt', opset=11)
|
register_extra_symbolics(cfg=cfg, backend='tensorrt', opset=11)
|
||||||
torch.onnx.export(model, x, output_file, opset_version=11)
|
torch.onnx.export(model, x, output_file, opset_version=11)
|
||||||
onnx_model = onnx.load(output_file)
|
onnx_model = onnx.load(output_file)
|
||||||
|
|
Loading…
Reference in New Issue