add register
parent
b6d8d48d7c
commit
f90ebf8c2c
|
@ -0,0 +1,121 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
data/
|
||||
data
|
||||
.vscode
|
||||
.idea
|
||||
.DS_Store
|
||||
|
||||
# custom
|
||||
*.pkl
|
||||
*.pkl.json
|
||||
*.log.json
|
||||
work_dirs/
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
*.py~
|
||||
*.sh~
|
|
@ -0,0 +1,2 @@
|
|||
[settings]
|
||||
known_third_party =
|
|
@ -0,0 +1,47 @@
|
|||
exclude: ^tests/data/
|
||||
repos:
|
||||
- repo: https://gitlab.com/pycqa/flake8.git
|
||||
rev: 3.8.3
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/asottile/seed-isort-config
|
||||
rev: v2.2.0
|
||||
hooks:
|
||||
- id: seed-isort-config
|
||||
- repo: https://github.com/timothycrosley/isort
|
||||
rev: 4.3.21
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: v0.30.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: requirements-txt-fixer
|
||||
- id: double-quote-string-fixer
|
||||
- id: check-merge-conflict
|
||||
- id: fix-encoding-pragma
|
||||
args: ["--remove"]
|
||||
- id: mixed-line-ending
|
||||
args: ["--fix=lf"]
|
||||
- repo: https://github.com/jumanjihouse/pre-commit-hooks
|
||||
rev: 2.1.4
|
||||
hooks:
|
||||
- id: markdownlint
|
||||
args:
|
||||
[
|
||||
"-r",
|
||||
"~MD002,~MD013,~MD029,~MD033,~MD034",
|
||||
"-t",
|
||||
"allow_different_nesting",
|
||||
]
|
||||
- repo: https://github.com/myint/docformatter
|
||||
rev: v1.3.1
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
|
@ -0,0 +1,8 @@
|
|||
from .function_rewriter import FUNCTION_REWRITERS, RewriterContext
|
||||
from .module_rewriter import MODULE_REWRITERS, patch_model
|
||||
from .symbolic_register import SYMBOLICS_REGISTER, register_extra_symbolics
|
||||
|
||||
__all__ = [
|
||||
'RewriterContext', 'FUNCTION_REWRITERS', 'MODULE_REWRITERS', 'patch_model',
|
||||
'SYMBOLICS_REGISTER', 'register_extra_symbolics'
|
||||
]
|
|
@ -0,0 +1,121 @@
|
|||
import logging
|
||||
|
||||
from mmcv.utils import Registry
|
||||
|
||||
from .register_utils import eval_with_import
|
||||
|
||||
|
||||
# builder of register
|
||||
def build_caller(func_name, backend, cfg, registry, **kwargs):
|
||||
# func_caller = registry.get(func_name + '@' + backend)
|
||||
func_caller = registry.module_dict[func_name + '@' + backend]
|
||||
assert func_caller is not None, '{} with {} not exist.'.format(
|
||||
func_name, backend)
|
||||
return func_caller(cfg, **kwargs)
|
||||
|
||||
|
||||
# create register
|
||||
FUNCTION_REWRITERS = Registry('func_rewriters', build_func=build_caller)
|
||||
|
||||
|
||||
# caller wrapper
|
||||
class FuncCaller(object):
|
||||
func_name = None
|
||||
backend = None
|
||||
func = None
|
||||
|
||||
def __init__(self, cfg, **kwargs):
|
||||
self.cfg = cfg
|
||||
[setattr(self, k, v) for k, v in kwargs.items()]
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
# caller decorator
|
||||
def register_rewriter(func_name, backend='default', **kwargs):
|
||||
|
||||
def wrap(func):
|
||||
func_args = dict(func_name=func_name, backend=backend, func=func)
|
||||
func_args.update(kwargs)
|
||||
func_caller = type(func_name + '@' + backend, (FuncCaller, ),
|
||||
func_args)
|
||||
return FUNCTION_REWRITERS.register_module()(func_caller)
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
FUNCTION_REWRITERS.register_rewriter = register_rewriter
|
||||
|
||||
|
||||
def apply_rewriter(regist_func):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
return regist_func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RewriterHook(object):
|
||||
|
||||
def __init__(self, regist_name, cfg, **kwargs):
|
||||
func_name, backend = regist_name.split('@')
|
||||
self.func_name = func_name
|
||||
self.backend = backend
|
||||
self.regist_func = FUNCTION_REWRITERS.build(
|
||||
func_name, backend=self.backend, cfg=cfg, **kwargs)
|
||||
try:
|
||||
self.origin_func = eval_with_import(self.func_name)
|
||||
except Exception:
|
||||
self.origin_func = None
|
||||
logging.warning(
|
||||
'Can not found {}, function rewrite will not be applied'.
|
||||
format(self.func_name))
|
||||
|
||||
def _set_func(self, rewrite_func):
|
||||
if self.origin_func is not None:
|
||||
# import necessary module
|
||||
split_path = self.func_name.split('.')
|
||||
for i in range(len(split_path), 0, -1):
|
||||
try:
|
||||
exec('import {}'.format('.'.join(split_path[:i])))
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
# assign function
|
||||
exec('{} = rewrite_func'.format(self.func_name))
|
||||
|
||||
def __enter__(self):
|
||||
self._set_func(apply_rewriter(self.regist_func))
|
||||
|
||||
def __exit__(self, type, val, tb):
|
||||
self._set_func(self.origin_func)
|
||||
|
||||
|
||||
class RewriterContext(object):
|
||||
|
||||
def __init__(self, cfg, backend='default', **kwargs):
|
||||
self.cfg = cfg
|
||||
func_backend_dict = {}
|
||||
for regist_name in FUNCTION_REWRITERS.module_dict:
|
||||
regist_func, regist_backend = regist_name.split('@')
|
||||
# only build `backend` or `default`
|
||||
if regist_backend not in [backend, 'default']:
|
||||
continue
|
||||
if regist_func not in func_backend_dict or func_backend_dict[
|
||||
regist_func] == 'default':
|
||||
func_backend_dict[regist_func] = regist_backend
|
||||
|
||||
self.hooks = [
|
||||
RewriterHook(k + '@' + v, cfg, **kwargs)
|
||||
for k, v in func_backend_dict.items()
|
||||
]
|
||||
|
||||
def __enter__(self):
|
||||
for hook in self.hooks:
|
||||
hook.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, val, tb):
|
||||
for hook in self.hooks:
|
||||
hook.__exit__(type, val, tb)
|
|
@ -0,0 +1,69 @@
|
|||
from mmcv.utils import Registry
|
||||
from .register_utils import eval_with_import
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def build_rewrite_module(module, cfg, backend, registry, **kwargs):
|
||||
|
||||
backend_dict = registry.module_eval_dict.get(type(module), None)
|
||||
if backend_dict is None:
|
||||
return module
|
||||
|
||||
RewriteModuleClass = None
|
||||
for backend in [backend, 'default']:
|
||||
RewriteModuleClass = backend_dict.get(backend, None)
|
||||
if RewriteModuleClass is not None:
|
||||
break
|
||||
|
||||
if RewriteModuleClass is None:
|
||||
return module
|
||||
|
||||
return RewriteModuleClass(module, cfg, **kwargs)
|
||||
|
||||
|
||||
class RewriteModuleRegistry(Registry):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._module_eval_dict = dict()
|
||||
|
||||
def register_rewrite_module(self,
|
||||
module_type,
|
||||
backend='default',
|
||||
**kwargs):
|
||||
register_name = module_type + '@' + backend
|
||||
return self.register_module(register_name)
|
||||
|
||||
@property
|
||||
def module_eval_dict(self):
|
||||
return self._module_eval_dict
|
||||
|
||||
def _register_module(self, module_class, module_name=None, force=False):
|
||||
super()._register_module(module_class, module_name, force)
|
||||
|
||||
module_type, backend = module_name.split('@')
|
||||
module_type_cls = eval_with_import(module_type)
|
||||
if module_type_cls not in self._module_eval_dict:
|
||||
self._module_eval_dict[module_type_cls] = dict()
|
||||
|
||||
assert (backend not in self._module_eval_dict[module_type_cls]
|
||||
), '{} with backend:{} has already been registed.'.format(
|
||||
module_type, backend)
|
||||
self._module_eval_dict[module_type_cls][backend] = self.module_dict[
|
||||
module_name]
|
||||
|
||||
|
||||
# create register
|
||||
MODULE_REWRITERS = RewriteModuleRegistry(
|
||||
'module_rewriters', build_func=build_rewrite_module, scope='.')
|
||||
|
||||
|
||||
def patch_model(model, cfg, backend='default', **kwargs):
|
||||
|
||||
def _patch_impl(model, cfg, **kwargs):
|
||||
for name, module in model.named_children():
|
||||
model._modules[name] = _patch_impl(module, cfg, **kwargs)
|
||||
return MODULE_REWRITERS.build(
|
||||
module=model, cfg=cfg, backend=backend, **kwargs)
|
||||
|
||||
return _patch_impl(deepcopy(model), cfg, **kwargs)
|
|
@ -0,0 +1,9 @@
|
|||
def eval_with_import(path):
|
||||
split_path = path.split('.')
|
||||
for i in range(len(split_path), 0, -1):
|
||||
try:
|
||||
exec('import {}'.format('.'.join(split_path[:i])))
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
return eval(path)
|
|
@ -0,0 +1,86 @@
|
|||
import logging
|
||||
|
||||
from mmcv.utils import Registry
|
||||
from torch.autograd import Function
|
||||
from torch.onnx.symbolic_helper import parse_args
|
||||
from torch.onnx.symbolic_registry import register_op
|
||||
|
||||
from .register_utils import eval_with_import
|
||||
|
||||
|
||||
def set_symbolic(cfg, registry, backend='default', opset=11, **kwargs):
|
||||
|
||||
# find valid symbolic
|
||||
valid_symbolic_dict = {}
|
||||
for module_name, symbolic_impl in registry.module_dict.items():
|
||||
func_name, symbolic_backend, is_pytorch = module_name.split('@')
|
||||
if symbolic_backend == backend or (symbolic_backend == 'default'
|
||||
and func_name
|
||||
not in valid_symbolic_dict):
|
||||
valid_symbolic_dict[func_name] = (symbolic_impl,
|
||||
is_pytorch == 'True')
|
||||
|
||||
# build symbolic
|
||||
for func_name in valid_symbolic_dict:
|
||||
symbolic_impl, is_pytorch = valid_symbolic_dict[func_name]
|
||||
arg_descriptors = symbolic_impl.arg_descriptors
|
||||
symbolic_impl = symbolic_impl(cfg=cfg, **kwargs)
|
||||
if arg_descriptors is not None and len(arg_descriptors) > 0:
|
||||
symbolic_impl = parse_args(*arg_descriptors)(symbolic_impl)
|
||||
if is_pytorch:
|
||||
register_op(func_name, symbolic_impl, '', opset)
|
||||
else:
|
||||
try:
|
||||
func = eval_with_import(func_name)
|
||||
assert issubclass(
|
||||
func,
|
||||
Function), '{} is not an torch.autograd.Function'.format(
|
||||
func_name)
|
||||
func.symbolic = symbolic_impl
|
||||
except Exception:
|
||||
logging.warning(
|
||||
'Can not add symbolic for `{}`'.format(func_name))
|
||||
|
||||
|
||||
SYMBOLICS_REGISTER = Registry('symbolics', build_func=set_symbolic, scope=None)
|
||||
|
||||
|
||||
class SymbolicWrapper:
|
||||
func_name = ''
|
||||
backend = ''
|
||||
is_pytorch = False
|
||||
symbolic = None
|
||||
arg_descriptors = None
|
||||
|
||||
def __init__(self, cfg, **kwargs):
|
||||
self.cfg = cfg
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.symbolic(*args, **kwargs)
|
||||
|
||||
|
||||
def register_symbolic(func_name,
|
||||
backend='default',
|
||||
is_pytorch=False,
|
||||
arg_descriptors=None,
|
||||
**kwargs):
|
||||
|
||||
def wrapper(symbolic_impl):
|
||||
symbolic_args = dict(
|
||||
func_name=func_name,
|
||||
backend=backend,
|
||||
symbolic=symbolic_impl,
|
||||
arg_descriptors=arg_descriptors)
|
||||
symbolic_args.update(kwargs)
|
||||
wrapper_name = '@'.join([func_name, backend, str(is_pytorch)])
|
||||
wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args)
|
||||
return SYMBOLICS_REGISTER.register_module(wrapper_name)(wrapper)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
SYMBOLICS_REGISTER.register_symbolic = register_symbolic
|
||||
|
||||
|
||||
def register_extra_symbolics(cfg, backend='default', opset=11):
|
||||
SYMBOLICS_REGISTER.build(cfg=cfg, backend=backend)
|
|
@ -0,0 +1,4 @@
|
|||
from setuptools import setup
|
||||
|
||||
if __name__ == '__main__':
|
||||
setup(name='mmdeploy', version=0.1)
|
|
@ -0,0 +1,158 @@
|
|||
import torch
|
||||
import os
|
||||
|
||||
|
||||
def test_function_rewriter():
|
||||
from mmdeploy.utils import FUNCTION_REWRITERS, RewriterContext
|
||||
|
||||
x = torch.tensor([1, 2, 3, 4, 5])
|
||||
y = torch.tensor([2, 4, 6, 8, 10])
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter(
|
||||
func_name='torch.add', backend='tensorrt')
|
||||
def sub_func(rewriter, x, y):
|
||||
return x - y
|
||||
|
||||
cfg = dict()
|
||||
with RewriterContext(cfg, backend='tensorrt'):
|
||||
result = torch.add(x, y)
|
||||
# replace add with sub
|
||||
torch.testing.assert_allclose(result, x - y)
|
||||
|
||||
result = torch.add(x, y)
|
||||
# recovery origin function
|
||||
torch.testing.assert_allclose(result, x + y)
|
||||
|
||||
with RewriterContext(cfg):
|
||||
result = torch.add(x, y)
|
||||
# replace should not happen with wrong backend
|
||||
torch.testing.assert_allclose(result, x + y)
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter(
|
||||
func_name='torch.Tensor.add', backend='default')
|
||||
def mul_func_class(rewriter, x, y):
|
||||
return x * y
|
||||
|
||||
with RewriterContext(cfg, backend='tensorrt'):
|
||||
result = x.add(y)
|
||||
# replace add with multi
|
||||
torch.testing.assert_allclose(result, x * y)
|
||||
|
||||
result = x.add(y)
|
||||
# recovery origin function
|
||||
torch.testing.assert_allclose(result, x + y)
|
||||
|
||||
with RewriterContext(cfg):
|
||||
result = x.add(y)
|
||||
# replace add with multi
|
||||
torch.testing.assert_allclose(result, x * y)
|
||||
|
||||
|
||||
def test_module_rewriter():
|
||||
from mmdeploy.utils import MODULE_REWRITERS, patch_model
|
||||
from torchvision.models.resnet import resnet50
|
||||
|
||||
@MODULE_REWRITERS.register_rewrite_module(
|
||||
module_type='torchvision.models.resnet.Bottleneck', backend='tensorrt')
|
||||
class BottleneckWrapper(torch.nn.Module):
|
||||
|
||||
def __init__(self, module, cfg, **kwargs):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs) * 2
|
||||
|
||||
x = torch.rand(1, 64, 32, 32)
|
||||
model = resnet50().eval()
|
||||
bottle_neck = model.layer1[0]
|
||||
result = bottle_neck(x)
|
||||
|
||||
# rewrite module
|
||||
cfg = dict()
|
||||
|
||||
rewrited_model = patch_model(model, cfg=cfg, backend='tensorrt')
|
||||
rewrited_bottle_nect = rewrited_model.layer1[0]
|
||||
rewrited_result = rewrited_bottle_nect(x)
|
||||
torch.testing.assert_allclose(rewrited_result, result * 2)
|
||||
|
||||
# wrong backend should not be rewrited
|
||||
|
||||
rewrited_model = patch_model(model, cfg=cfg)
|
||||
rewrited_bottle_nect = rewrited_model.layer1[0]
|
||||
rewrited_result = rewrited_bottle_nect(x)
|
||||
torch.testing.assert_allclose(rewrited_result, result)
|
||||
|
||||
|
||||
def test_symbolic_register():
|
||||
import mmdeploy
|
||||
from mmdeploy.utils import SYMBOLICS_REGISTER, register_extra_symbolics
|
||||
from torch.autograd import Function
|
||||
import onnx
|
||||
|
||||
class TestFunc(Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, x, val):
|
||||
return g.op('mmcv::symbolic_old', x, val_i=val)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, val):
|
||||
return x + val
|
||||
|
||||
# put TestFunc in an module so we can found it
|
||||
# could be any module
|
||||
mmdeploy.TestFunc = TestFunc
|
||||
test_func = mmdeploy.TestFunc.apply
|
||||
|
||||
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc')
|
||||
def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
|
||||
return g.op('mmcv::symbolic_testfunc_default', x, val_i=val)
|
||||
|
||||
@SYMBOLICS_REGISTER.register_symbolic(
|
||||
'mmdeploy.TestFunc', backend='tensorrt')
|
||||
def symbolic_testfunc_tensorrt(symbolic_wrapper, g, x, val):
|
||||
return g.op('mmcv::symbolic_testfunc_tensorrt', x, val_i=val)
|
||||
|
||||
@SYMBOLICS_REGISTER.register_symbolic(
|
||||
'cummax', is_pytorch=True, arg_descriptors=['v', 'i'])
|
||||
def symbolic_cummax(symbolic_wrapper, g, input, dim):
|
||||
return g.op('mmcv::cummax_default', input, dim_i=dim, outputs=2)
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cummax(test_func(x, 5), dim=1)
|
||||
|
||||
model = TestModel().eval()
|
||||
|
||||
# dummy input
|
||||
x = torch.rand(2, 3, 4)
|
||||
output_file = 'demo.onnx'
|
||||
|
||||
# default
|
||||
cfg = dict()
|
||||
register_extra_symbolics(cfg=cfg, opset=11)
|
||||
torch.onnx.export(model, x, output_file, opset_version=11)
|
||||
onnx_model = onnx.load(output_file)
|
||||
os.remove(output_file)
|
||||
nodes = onnx_model.graph.node
|
||||
assert nodes[0].op_type == 'symbolic_testfunc_default'
|
||||
assert nodes[0].domain == 'mmcv'
|
||||
assert nodes[1].op_type == 'cummax_default'
|
||||
assert nodes[1].domain == 'mmcv'
|
||||
|
||||
# default
|
||||
cfg = dict()
|
||||
register_extra_symbolics(cfg=cfg, backend='tensorrt', opset=11)
|
||||
torch.onnx.export(model, x, output_file, opset_version=11)
|
||||
onnx_model = onnx.load(output_file)
|
||||
os.remove(output_file)
|
||||
nodes = onnx_model.graph.node
|
||||
assert nodes[0].op_type == 'symbolic_testfunc_tensorrt'
|
||||
assert nodes[0].domain == 'mmcv'
|
||||
assert nodes[1].op_type == 'cummax_default'
|
||||
assert nodes[1].domain == 'mmcv'
|
Loading…
Reference in New Issue