diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..77ca0d7c8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,121 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +data/ +data +.vscode +.idea +.DS_Store + +# custom +*.pkl +*.pkl.json +*.log.json +work_dirs/ + +# Pytorch +*.pth +*.py~ +*.sh~ diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 000000000..7a34bf769 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,2 @@ +[settings] +known_third_party = diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..695ed1922 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,47 @@ +exclude: ^tests/data/ +repos: + - repo: https://gitlab.com/pycqa/flake8.git + rev: 3.8.3 + hooks: + - id: flake8 + - repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + - repo: https://github.com/timothycrosley/isort + rev: 4.3.21 + hooks: + - id: isort + - repo: https://github.com/pre-commit/mirrors-yapf + rev: v0.30.0 + hooks: + - id: yapf + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.1.0 + hooks: + - id: trailing-whitespace + - id: check-yaml + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: double-quote-string-fixer + - id: check-merge-conflict + - id: fix-encoding-pragma + args: ["--remove"] + - id: mixed-line-ending + args: ["--fix=lf"] + - repo: https://github.com/jumanjihouse/pre-commit-hooks + rev: 2.1.4 + hooks: + - id: markdownlint + args: + [ + "-r", + "~MD002,~MD013,~MD029,~MD033,~MD034", + "-t", + "allow_different_nesting", + ] + - repo: https://github.com/myint/docformatter + rev: v1.3.1 + hooks: + - id: docformatter + args: ["--in-place", "--wrap-descriptions", "79"] diff --git a/README.md b/README.md new file mode 100644 index 000000000..aeba4aab3 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# MMDeployment + +WIP diff --git a/mmdeploy/__init__.py b/mmdeploy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mmdeploy/utils/__init__.py b/mmdeploy/utils/__init__.py new file mode 100644 index 000000000..7135100a0 --- /dev/null +++ b/mmdeploy/utils/__init__.py @@ -0,0 +1,8 @@ +from .function_rewriter import FUNCTION_REWRITERS, RewriterContext +from .module_rewriter import MODULE_REWRITERS, patch_model +from .symbolic_register import SYMBOLICS_REGISTER, register_extra_symbolics + +__all__ = [ + 'RewriterContext', 'FUNCTION_REWRITERS', 'MODULE_REWRITERS', 'patch_model', + 'SYMBOLICS_REGISTER', 'register_extra_symbolics' +] diff --git a/mmdeploy/utils/function_rewriter.py b/mmdeploy/utils/function_rewriter.py new file mode 100644 index 000000000..0e8f305ef --- /dev/null +++ b/mmdeploy/utils/function_rewriter.py @@ -0,0 +1,121 @@ +import logging + +from mmcv.utils import Registry + +from .register_utils import eval_with_import + + +# builder of register +def build_caller(func_name, backend, cfg, registry, **kwargs): + # func_caller = registry.get(func_name + '@' + backend) + func_caller = registry.module_dict[func_name + '@' + backend] + assert func_caller is not None, '{} with {} not exist.'.format( + func_name, backend) + return func_caller(cfg, **kwargs) + + +# create register +FUNCTION_REWRITERS = Registry('func_rewriters', build_func=build_caller) + + +# caller wrapper +class FuncCaller(object): + func_name = None + backend = None + func = None + + def __init__(self, cfg, **kwargs): + self.cfg = cfg + [setattr(self, k, v) for k, v in kwargs.items()] + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + +# caller decorator +def register_rewriter(func_name, backend='default', **kwargs): + + def wrap(func): + func_args = dict(func_name=func_name, backend=backend, func=func) + func_args.update(kwargs) + func_caller = type(func_name + '@' + backend, (FuncCaller, ), + func_args) + return FUNCTION_REWRITERS.register_module()(func_caller) + + return wrap + + +FUNCTION_REWRITERS.register_rewriter = register_rewriter + + +def apply_rewriter(regist_func): + + def wrapper(*args, **kwargs): + return regist_func(*args, **kwargs) + + return wrapper + + +class RewriterHook(object): + + def __init__(self, regist_name, cfg, **kwargs): + func_name, backend = regist_name.split('@') + self.func_name = func_name + self.backend = backend + self.regist_func = FUNCTION_REWRITERS.build( + func_name, backend=self.backend, cfg=cfg, **kwargs) + try: + self.origin_func = eval_with_import(self.func_name) + except Exception: + self.origin_func = None + logging.warning( + 'Can not found {}, function rewrite will not be applied'. + format(self.func_name)) + + def _set_func(self, rewrite_func): + if self.origin_func is not None: + # import necessary module + split_path = self.func_name.split('.') + for i in range(len(split_path), 0, -1): + try: + exec('import {}'.format('.'.join(split_path[:i]))) + break + except Exception: + continue + # assign function + exec('{} = rewrite_func'.format(self.func_name)) + + def __enter__(self): + self._set_func(apply_rewriter(self.regist_func)) + + def __exit__(self, type, val, tb): + self._set_func(self.origin_func) + + +class RewriterContext(object): + + def __init__(self, cfg, backend='default', **kwargs): + self.cfg = cfg + func_backend_dict = {} + for regist_name in FUNCTION_REWRITERS.module_dict: + regist_func, regist_backend = regist_name.split('@') + # only build `backend` or `default` + if regist_backend not in [backend, 'default']: + continue + if regist_func not in func_backend_dict or func_backend_dict[ + regist_func] == 'default': + func_backend_dict[regist_func] = regist_backend + + self.hooks = [ + RewriterHook(k + '@' + v, cfg, **kwargs) + for k, v in func_backend_dict.items() + ] + + def __enter__(self): + for hook in self.hooks: + hook.__enter__() + return self + + def __exit__(self, type, val, tb): + for hook in self.hooks: + hook.__exit__(type, val, tb) diff --git a/mmdeploy/utils/module_rewriter.py b/mmdeploy/utils/module_rewriter.py new file mode 100644 index 000000000..72c507a58 --- /dev/null +++ b/mmdeploy/utils/module_rewriter.py @@ -0,0 +1,69 @@ +from mmcv.utils import Registry +from .register_utils import eval_with_import +from copy import deepcopy + + +def build_rewrite_module(module, cfg, backend, registry, **kwargs): + + backend_dict = registry.module_eval_dict.get(type(module), None) + if backend_dict is None: + return module + + RewriteModuleClass = None + for backend in [backend, 'default']: + RewriteModuleClass = backend_dict.get(backend, None) + if RewriteModuleClass is not None: + break + + if RewriteModuleClass is None: + return module + + return RewriteModuleClass(module, cfg, **kwargs) + + +class RewriteModuleRegistry(Registry): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._module_eval_dict = dict() + + def register_rewrite_module(self, + module_type, + backend='default', + **kwargs): + register_name = module_type + '@' + backend + return self.register_module(register_name) + + @property + def module_eval_dict(self): + return self._module_eval_dict + + def _register_module(self, module_class, module_name=None, force=False): + super()._register_module(module_class, module_name, force) + + module_type, backend = module_name.split('@') + module_type_cls = eval_with_import(module_type) + if module_type_cls not in self._module_eval_dict: + self._module_eval_dict[module_type_cls] = dict() + + assert (backend not in self._module_eval_dict[module_type_cls] + ), '{} with backend:{} has already been registed.'.format( + module_type, backend) + self._module_eval_dict[module_type_cls][backend] = self.module_dict[ + module_name] + + +# create register +MODULE_REWRITERS = RewriteModuleRegistry( + 'module_rewriters', build_func=build_rewrite_module, scope='.') + + +def patch_model(model, cfg, backend='default', **kwargs): + + def _patch_impl(model, cfg, **kwargs): + for name, module in model.named_children(): + model._modules[name] = _patch_impl(module, cfg, **kwargs) + return MODULE_REWRITERS.build( + module=model, cfg=cfg, backend=backend, **kwargs) + + return _patch_impl(deepcopy(model), cfg, **kwargs) diff --git a/mmdeploy/utils/register_utils.py b/mmdeploy/utils/register_utils.py new file mode 100644 index 000000000..a711d186c --- /dev/null +++ b/mmdeploy/utils/register_utils.py @@ -0,0 +1,9 @@ +def eval_with_import(path): + split_path = path.split('.') + for i in range(len(split_path), 0, -1): + try: + exec('import {}'.format('.'.join(split_path[:i]))) + break + except Exception: + continue + return eval(path) diff --git a/mmdeploy/utils/symbolic_register.py b/mmdeploy/utils/symbolic_register.py new file mode 100644 index 000000000..11d47828d --- /dev/null +++ b/mmdeploy/utils/symbolic_register.py @@ -0,0 +1,86 @@ +import logging + +from mmcv.utils import Registry +from torch.autograd import Function +from torch.onnx.symbolic_helper import parse_args +from torch.onnx.symbolic_registry import register_op + +from .register_utils import eval_with_import + + +def set_symbolic(cfg, registry, backend='default', opset=11, **kwargs): + + # find valid symbolic + valid_symbolic_dict = {} + for module_name, symbolic_impl in registry.module_dict.items(): + func_name, symbolic_backend, is_pytorch = module_name.split('@') + if symbolic_backend == backend or (symbolic_backend == 'default' + and func_name + not in valid_symbolic_dict): + valid_symbolic_dict[func_name] = (symbolic_impl, + is_pytorch == 'True') + + # build symbolic + for func_name in valid_symbolic_dict: + symbolic_impl, is_pytorch = valid_symbolic_dict[func_name] + arg_descriptors = symbolic_impl.arg_descriptors + symbolic_impl = symbolic_impl(cfg=cfg, **kwargs) + if arg_descriptors is not None and len(arg_descriptors) > 0: + symbolic_impl = parse_args(*arg_descriptors)(symbolic_impl) + if is_pytorch: + register_op(func_name, symbolic_impl, '', opset) + else: + try: + func = eval_with_import(func_name) + assert issubclass( + func, + Function), '{} is not an torch.autograd.Function'.format( + func_name) + func.symbolic = symbolic_impl + except Exception: + logging.warning( + 'Can not add symbolic for `{}`'.format(func_name)) + + +SYMBOLICS_REGISTER = Registry('symbolics', build_func=set_symbolic, scope=None) + + +class SymbolicWrapper: + func_name = '' + backend = '' + is_pytorch = False + symbolic = None + arg_descriptors = None + + def __init__(self, cfg, **kwargs): + self.cfg = cfg + + def __call__(self, *args, **kwargs): + return self.symbolic(*args, **kwargs) + + +def register_symbolic(func_name, + backend='default', + is_pytorch=False, + arg_descriptors=None, + **kwargs): + + def wrapper(symbolic_impl): + symbolic_args = dict( + func_name=func_name, + backend=backend, + symbolic=symbolic_impl, + arg_descriptors=arg_descriptors) + symbolic_args.update(kwargs) + wrapper_name = '@'.join([func_name, backend, str(is_pytorch)]) + wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args) + return SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper) + + return wrapper + + +SYMBOLICS_REGISTER.register_symbolic = register_symbolic + + +def register_extra_symbolics(cfg, backend='default', opset=11): + SYMBOLICS_REGISTER.build(cfg=cfg, backend=backend) diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..2068abdc8 --- /dev/null +++ b/setup.py @@ -0,0 +1,4 @@ +from setuptools import setup + +if __name__ == '__main__': + setup(name='mmdeploy', version=0.1) diff --git a/tests/test_utils/test_register.py b/tests/test_utils/test_register.py new file mode 100644 index 000000000..dda60599e --- /dev/null +++ b/tests/test_utils/test_register.py @@ -0,0 +1,158 @@ +import torch +import os + + +def test_function_rewriter(): + from mmdeploy.utils import FUNCTION_REWRITERS, RewriterContext + + x = torch.tensor([1, 2, 3, 4, 5]) + y = torch.tensor([2, 4, 6, 8, 10]) + + @FUNCTION_REWRITERS.register_rewriter( + func_name='torch.add', backend='tensorrt') + def sub_func(rewriter, x, y): + return x - y + + cfg = dict() + with RewriterContext(cfg, backend='tensorrt'): + result = torch.add(x, y) + # replace add with sub + torch.testing.assert_allclose(result, x - y) + + result = torch.add(x, y) + # recovery origin function + torch.testing.assert_allclose(result, x + y) + + with RewriterContext(cfg): + result = torch.add(x, y) + # replace should not happen with wrong backend + torch.testing.assert_allclose(result, x + y) + + @FUNCTION_REWRITERS.register_rewriter( + func_name='torch.Tensor.add', backend='default') + def mul_func_class(rewriter, x, y): + return x * y + + with RewriterContext(cfg, backend='tensorrt'): + result = x.add(y) + # replace add with multi + torch.testing.assert_allclose(result, x * y) + + result = x.add(y) + # recovery origin function + torch.testing.assert_allclose(result, x + y) + + with RewriterContext(cfg): + result = x.add(y) + # replace add with multi + torch.testing.assert_allclose(result, x * y) + + +def test_module_rewriter(): + from mmdeploy.utils import MODULE_REWRITERS, patch_model + from torchvision.models.resnet import resnet50 + + @MODULE_REWRITERS.register_rewrite_module( + module_type='torchvision.models.resnet.Bottleneck', backend='tensorrt') + class BottleneckWrapper(torch.nn.Module): + + def __init__(self, module, cfg, **kwargs): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) * 2 + + x = torch.rand(1, 64, 32, 32) + model = resnet50().eval() + bottle_neck = model.layer1[0] + result = bottle_neck(x) + + # rewrite module + cfg = dict() + + rewrited_model = patch_model(model, cfg=cfg, backend='tensorrt') + rewrited_bottle_nect = rewrited_model.layer1[0] + rewrited_result = rewrited_bottle_nect(x) + torch.testing.assert_allclose(rewrited_result, result * 2) + + # wrong backend should not be rewrited + + rewrited_model = patch_model(model, cfg=cfg) + rewrited_bottle_nect = rewrited_model.layer1[0] + rewrited_result = rewrited_bottle_nect(x) + torch.testing.assert_allclose(rewrited_result, result) + + +def test_symbolic_register(): + import mmdeploy + from mmdeploy.utils import SYMBOLICS_REGISTER, register_extra_symbolics + from torch.autograd import Function + import onnx + + class TestFunc(Function): + + @staticmethod + def symbolic(g, x, val): + return g.op('mmcv::symbolic_old', x, val_i=val) + + @staticmethod + def forward(ctx, x, val): + return x + val + + # put TestFunc in an module so we can found it + # could be any module + mmdeploy.TestFunc = TestFunc + test_func = mmdeploy.TestFunc.apply + + @SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc') + def symbolic_testfunc_default(symbolic_wrapper, g, x, val): + return g.op('mmcv::symbolic_testfunc_default', x, val_i=val) + + @SYMBOLICS_REGISTER.register_symbolic( + 'mmdeploy.TestFunc', backend='tensorrt') + def symbolic_testfunc_tensorrt(symbolic_wrapper, g, x, val): + return g.op('mmcv::symbolic_testfunc_tensorrt', x, val_i=val) + + @SYMBOLICS_REGISTER.register_symbolic( + 'cummax', is_pytorch=True, arg_descriptors=['v', 'i']) + def symbolic_cummax(symbolic_wrapper, g, input, dim): + return g.op('mmcv::cummax_default', input, dim_i=dim, outputs=2) + + class TestModel(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.cummax(test_func(x, 5), dim=1) + + model = TestModel().eval() + + # dummy input + x = torch.rand(2, 3, 4) + output_file = 'demo.onnx' + + # default + cfg = dict() + register_extra_symbolics(cfg=cfg, opset=11) + torch.onnx.export(model, x, output_file, opset_version=11) + onnx_model = onnx.load(output_file) + os.remove(output_file) + nodes = onnx_model.graph.node + assert nodes[0].op_type == 'symbolic_testfunc_default' + assert nodes[0].domain == 'mmcv' + assert nodes[1].op_type == 'cummax_default' + assert nodes[1].domain == 'mmcv' + + # default + cfg = dict() + register_extra_symbolics(cfg=cfg, backend='tensorrt', opset=11) + torch.onnx.export(model, x, output_file, opset_version=11) + onnx_model = onnx.load(output_file) + os.remove(output_file) + nodes = onnx_model.graph.node + assert nodes[0].op_type == 'symbolic_testfunc_tensorrt' + assert nodes[0].domain == 'mmcv' + assert nodes[1].op_type == 'cummax_default' + assert nodes[1].domain == 'mmcv'