add register

pull/12/head
grimoire 2021-06-11 13:26:05 +08:00
parent b6d8d48d7c
commit f90ebf8c2c
12 changed files with 628 additions and 0 deletions

121
.gitignore vendored 100644
View File

@ -0,0 +1,121 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
data/
data
.vscode
.idea
.DS_Store
# custom
*.pkl
*.pkl.json
*.log.json
work_dirs/
# Pytorch
*.pth
*.py~
*.sh~

2
.isort.cfg 100644
View File

@ -0,0 +1,2 @@
[settings]
known_third_party =

View File

@ -0,0 +1,47 @@
exclude: ^tests/data/
repos:
- repo: https://gitlab.com/pycqa/flake8.git
rev: 3.8.3
hooks:
- id: flake8
- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config
- repo: https://github.com/timothycrosley/isort
rev: 4.3.21
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.30.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.1.0
hooks:
- id: trailing-whitespace
- id: check-yaml
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: double-quote-string-fixer
- id: check-merge-conflict
- id: fix-encoding-pragma
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/jumanjihouse/pre-commit-hooks
rev: 2.1.4
hooks:
- id: markdownlint
args:
[
"-r",
"~MD002,~MD013,~MD029,~MD033,~MD034",
"-t",
"allow_different_nesting",
]
- repo: https://github.com/myint/docformatter
rev: v1.3.1
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]

3
README.md 100644
View File

@ -0,0 +1,3 @@
# MMDeployment
WIP

View File

View File

@ -0,0 +1,8 @@
from .function_rewriter import FUNCTION_REWRITERS, RewriterContext
from .module_rewriter import MODULE_REWRITERS, patch_model
from .symbolic_register import SYMBOLICS_REGISTER, register_extra_symbolics
__all__ = [
'RewriterContext', 'FUNCTION_REWRITERS', 'MODULE_REWRITERS', 'patch_model',
'SYMBOLICS_REGISTER', 'register_extra_symbolics'
]

View File

@ -0,0 +1,121 @@
import logging
from mmcv.utils import Registry
from .register_utils import eval_with_import
# builder of register
def build_caller(func_name, backend, cfg, registry, **kwargs):
# func_caller = registry.get(func_name + '@' + backend)
func_caller = registry.module_dict[func_name + '@' + backend]
assert func_caller is not None, '{} with {} not exist.'.format(
func_name, backend)
return func_caller(cfg, **kwargs)
# create register
FUNCTION_REWRITERS = Registry('func_rewriters', build_func=build_caller)
# caller wrapper
class FuncCaller(object):
func_name = None
backend = None
func = None
def __init__(self, cfg, **kwargs):
self.cfg = cfg
[setattr(self, k, v) for k, v in kwargs.items()]
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
# caller decorator
def register_rewriter(func_name, backend='default', **kwargs):
def wrap(func):
func_args = dict(func_name=func_name, backend=backend, func=func)
func_args.update(kwargs)
func_caller = type(func_name + '@' + backend, (FuncCaller, ),
func_args)
return FUNCTION_REWRITERS.register_module()(func_caller)
return wrap
FUNCTION_REWRITERS.register_rewriter = register_rewriter
def apply_rewriter(regist_func):
def wrapper(*args, **kwargs):
return regist_func(*args, **kwargs)
return wrapper
class RewriterHook(object):
def __init__(self, regist_name, cfg, **kwargs):
func_name, backend = regist_name.split('@')
self.func_name = func_name
self.backend = backend
self.regist_func = FUNCTION_REWRITERS.build(
func_name, backend=self.backend, cfg=cfg, **kwargs)
try:
self.origin_func = eval_with_import(self.func_name)
except Exception:
self.origin_func = None
logging.warning(
'Can not found {}, function rewrite will not be applied'.
format(self.func_name))
def _set_func(self, rewrite_func):
if self.origin_func is not None:
# import necessary module
split_path = self.func_name.split('.')
for i in range(len(split_path), 0, -1):
try:
exec('import {}'.format('.'.join(split_path[:i])))
break
except Exception:
continue
# assign function
exec('{} = rewrite_func'.format(self.func_name))
def __enter__(self):
self._set_func(apply_rewriter(self.regist_func))
def __exit__(self, type, val, tb):
self._set_func(self.origin_func)
class RewriterContext(object):
def __init__(self, cfg, backend='default', **kwargs):
self.cfg = cfg
func_backend_dict = {}
for regist_name in FUNCTION_REWRITERS.module_dict:
regist_func, regist_backend = regist_name.split('@')
# only build `backend` or `default`
if regist_backend not in [backend, 'default']:
continue
if regist_func not in func_backend_dict or func_backend_dict[
regist_func] == 'default':
func_backend_dict[regist_func] = regist_backend
self.hooks = [
RewriterHook(k + '@' + v, cfg, **kwargs)
for k, v in func_backend_dict.items()
]
def __enter__(self):
for hook in self.hooks:
hook.__enter__()
return self
def __exit__(self, type, val, tb):
for hook in self.hooks:
hook.__exit__(type, val, tb)

View File

@ -0,0 +1,69 @@
from mmcv.utils import Registry
from .register_utils import eval_with_import
from copy import deepcopy
def build_rewrite_module(module, cfg, backend, registry, **kwargs):
backend_dict = registry.module_eval_dict.get(type(module), None)
if backend_dict is None:
return module
RewriteModuleClass = None
for backend in [backend, 'default']:
RewriteModuleClass = backend_dict.get(backend, None)
if RewriteModuleClass is not None:
break
if RewriteModuleClass is None:
return module
return RewriteModuleClass(module, cfg, **kwargs)
class RewriteModuleRegistry(Registry):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._module_eval_dict = dict()
def register_rewrite_module(self,
module_type,
backend='default',
**kwargs):
register_name = module_type + '@' + backend
return self.register_module(register_name)
@property
def module_eval_dict(self):
return self._module_eval_dict
def _register_module(self, module_class, module_name=None, force=False):
super()._register_module(module_class, module_name, force)
module_type, backend = module_name.split('@')
module_type_cls = eval_with_import(module_type)
if module_type_cls not in self._module_eval_dict:
self._module_eval_dict[module_type_cls] = dict()
assert (backend not in self._module_eval_dict[module_type_cls]
), '{} with backend:{} has already been registed.'.format(
module_type, backend)
self._module_eval_dict[module_type_cls][backend] = self.module_dict[
module_name]
# create register
MODULE_REWRITERS = RewriteModuleRegistry(
'module_rewriters', build_func=build_rewrite_module, scope='.')
def patch_model(model, cfg, backend='default', **kwargs):
def _patch_impl(model, cfg, **kwargs):
for name, module in model.named_children():
model._modules[name] = _patch_impl(module, cfg, **kwargs)
return MODULE_REWRITERS.build(
module=model, cfg=cfg, backend=backend, **kwargs)
return _patch_impl(deepcopy(model), cfg, **kwargs)

View File

@ -0,0 +1,9 @@
def eval_with_import(path):
split_path = path.split('.')
for i in range(len(split_path), 0, -1):
try:
exec('import {}'.format('.'.join(split_path[:i])))
break
except Exception:
continue
return eval(path)

View File

@ -0,0 +1,86 @@
import logging
from mmcv.utils import Registry
from torch.autograd import Function
from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_registry import register_op
from .register_utils import eval_with_import
def set_symbolic(cfg, registry, backend='default', opset=11, **kwargs):
# find valid symbolic
valid_symbolic_dict = {}
for module_name, symbolic_impl in registry.module_dict.items():
func_name, symbolic_backend, is_pytorch = module_name.split('@')
if symbolic_backend == backend or (symbolic_backend == 'default'
and func_name
not in valid_symbolic_dict):
valid_symbolic_dict[func_name] = (symbolic_impl,
is_pytorch == 'True')
# build symbolic
for func_name in valid_symbolic_dict:
symbolic_impl, is_pytorch = valid_symbolic_dict[func_name]
arg_descriptors = symbolic_impl.arg_descriptors
symbolic_impl = symbolic_impl(cfg=cfg, **kwargs)
if arg_descriptors is not None and len(arg_descriptors) > 0:
symbolic_impl = parse_args(*arg_descriptors)(symbolic_impl)
if is_pytorch:
register_op(func_name, symbolic_impl, '', opset)
else:
try:
func = eval_with_import(func_name)
assert issubclass(
func,
Function), '{} is not an torch.autograd.Function'.format(
func_name)
func.symbolic = symbolic_impl
except Exception:
logging.warning(
'Can not add symbolic for `{}`'.format(func_name))
SYMBOLICS_REGISTER = Registry('symbolics', build_func=set_symbolic, scope=None)
class SymbolicWrapper:
func_name = ''
backend = ''
is_pytorch = False
symbolic = None
arg_descriptors = None
def __init__(self, cfg, **kwargs):
self.cfg = cfg
def __call__(self, *args, **kwargs):
return self.symbolic(*args, **kwargs)
def register_symbolic(func_name,
backend='default',
is_pytorch=False,
arg_descriptors=None,
**kwargs):
def wrapper(symbolic_impl):
symbolic_args = dict(
func_name=func_name,
backend=backend,
symbolic=symbolic_impl,
arg_descriptors=arg_descriptors)
symbolic_args.update(kwargs)
wrapper_name = '@'.join([func_name, backend, str(is_pytorch)])
wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args)
return SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper)
return wrapper
SYMBOLICS_REGISTER.register_symbolic = register_symbolic
def register_extra_symbolics(cfg, backend='default', opset=11):
SYMBOLICS_REGISTER.build(cfg=cfg, backend=backend)

4
setup.py 100644
View File

@ -0,0 +1,4 @@
from setuptools import setup
if __name__ == '__main__':
setup(name='mmdeploy', version=0.1)

View File

@ -0,0 +1,158 @@
import torch
import os
def test_function_rewriter():
from mmdeploy.utils import FUNCTION_REWRITERS, RewriterContext
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([2, 4, 6, 8, 10])
@FUNCTION_REWRITERS.register_rewriter(
func_name='torch.add', backend='tensorrt')
def sub_func(rewriter, x, y):
return x - y
cfg = dict()
with RewriterContext(cfg, backend='tensorrt'):
result = torch.add(x, y)
# replace add with sub
torch.testing.assert_allclose(result, x - y)
result = torch.add(x, y)
# recovery origin function
torch.testing.assert_allclose(result, x + y)
with RewriterContext(cfg):
result = torch.add(x, y)
# replace should not happen with wrong backend
torch.testing.assert_allclose(result, x + y)
@FUNCTION_REWRITERS.register_rewriter(
func_name='torch.Tensor.add', backend='default')
def mul_func_class(rewriter, x, y):
return x * y
with RewriterContext(cfg, backend='tensorrt'):
result = x.add(y)
# replace add with multi
torch.testing.assert_allclose(result, x * y)
result = x.add(y)
# recovery origin function
torch.testing.assert_allclose(result, x + y)
with RewriterContext(cfg):
result = x.add(y)
# replace add with multi
torch.testing.assert_allclose(result, x * y)
def test_module_rewriter():
from mmdeploy.utils import MODULE_REWRITERS, patch_model
from torchvision.models.resnet import resnet50
@MODULE_REWRITERS.register_rewrite_module(
module_type='torchvision.models.resnet.Bottleneck', backend='tensorrt')
class BottleneckWrapper(torch.nn.Module):
def __init__(self, module, cfg, **kwargs):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs) * 2
x = torch.rand(1, 64, 32, 32)
model = resnet50().eval()
bottle_neck = model.layer1[0]
result = bottle_neck(x)
# rewrite module
cfg = dict()
rewrited_model = patch_model(model, cfg=cfg, backend='tensorrt')
rewrited_bottle_nect = rewrited_model.layer1[0]
rewrited_result = rewrited_bottle_nect(x)
torch.testing.assert_allclose(rewrited_result, result * 2)
# wrong backend should not be rewrited
rewrited_model = patch_model(model, cfg=cfg)
rewrited_bottle_nect = rewrited_model.layer1[0]
rewrited_result = rewrited_bottle_nect(x)
torch.testing.assert_allclose(rewrited_result, result)
def test_symbolic_register():
import mmdeploy
from mmdeploy.utils import SYMBOLICS_REGISTER, register_extra_symbolics
from torch.autograd import Function
import onnx
class TestFunc(Function):
@staticmethod
def symbolic(g, x, val):
return g.op('mmcv::symbolic_old', x, val_i=val)
@staticmethod
def forward(ctx, x, val):
return x + val
# put TestFunc in an module so we can found it
# could be any module
mmdeploy.TestFunc = TestFunc
test_func = mmdeploy.TestFunc.apply
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc')
def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
return g.op('mmcv::symbolic_testfunc_default', x, val_i=val)
@SYMBOLICS_REGISTER.register_symbolic(
'mmdeploy.TestFunc', backend='tensorrt')
def symbolic_testfunc_tensorrt(symbolic_wrapper, g, x, val):
return g.op('mmcv::symbolic_testfunc_tensorrt', x, val_i=val)
@SYMBOLICS_REGISTER.register_symbolic(
'cummax', is_pytorch=True, arg_descriptors=['v', 'i'])
def symbolic_cummax(symbolic_wrapper, g, input, dim):
return g.op('mmcv::cummax_default', input, dim_i=dim, outputs=2)
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.cummax(test_func(x, 5), dim=1)
model = TestModel().eval()
# dummy input
x = torch.rand(2, 3, 4)
output_file = 'demo.onnx'
# default
cfg = dict()
register_extra_symbolics(cfg=cfg, opset=11)
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)
os.remove(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmcv'
assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv'
# default
cfg = dict()
register_extra_symbolics(cfg=cfg, backend='tensorrt', opset=11)
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)
os.remove(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_tensorrt'
assert nodes[0].domain == 'mmcv'
assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmcv'