mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
parent
597350c07b
commit
cc72c00e61
@ -36,11 +36,11 @@ def __build_backend_model(partition_name: str, backend: Backend,
|
|||||||
|
|
||||||
# Use registry to store models with different partition methods
|
# Use registry to store models with different partition methods
|
||||||
# If a model doesn't need to partition, we don't need this registry
|
# If a model doesn't need to partition, we don't need this registry
|
||||||
__BACKEND_MODEl = mmcv.utils.Registry(
|
__BACKEND_MODEL = mmcv.utils.Registry(
|
||||||
'backend_detectors', build_func=__build_backend_model)
|
'backend_detectors', build_func=__build_backend_model)
|
||||||
|
|
||||||
|
|
||||||
@__BACKEND_MODEl.register_module('end2end')
|
@__BACKEND_MODEL.register_module('end2end')
|
||||||
class End2EndModel(BaseBackendModel):
|
class End2EndModel(BaseBackendModel):
|
||||||
"""End to end model for inference of detection.
|
"""End to end model for inference of detection.
|
||||||
|
|
||||||
@ -273,7 +273,7 @@ class End2EndModel(BaseBackendModel):
|
|||||||
out_file=out_file)
|
out_file=out_file)
|
||||||
|
|
||||||
|
|
||||||
@__BACKEND_MODEl.register_module('single_stage')
|
@__BACKEND_MODEL.register_module('single_stage')
|
||||||
class PartitionSingleStageModel(End2EndModel):
|
class PartitionSingleStageModel(End2EndModel):
|
||||||
"""Partitioned single stage detection model.
|
"""Partitioned single stage detection model.
|
||||||
|
|
||||||
@ -352,7 +352,7 @@ class PartitionSingleStageModel(End2EndModel):
|
|||||||
return self.partition0_postprocess(scores, bboxes)
|
return self.partition0_postprocess(scores, bboxes)
|
||||||
|
|
||||||
|
|
||||||
@__BACKEND_MODEl.register_module('two_stage')
|
@__BACKEND_MODEL.register_module('two_stage')
|
||||||
class PartitionTwoStageModel(End2EndModel):
|
class PartitionTwoStageModel(End2EndModel):
|
||||||
"""Partitioned two stage detection model.
|
"""Partitioned two stage detection model.
|
||||||
|
|
||||||
@ -572,7 +572,7 @@ def build_object_detection_model(model_files: Sequence[str],
|
|||||||
if partition_config is not None:
|
if partition_config is not None:
|
||||||
partition_type = partition_config.get('type', None)
|
partition_type = partition_config.get('type', None)
|
||||||
|
|
||||||
backend_detector = __BACKEND_MODEl.build(
|
backend_detector = __BACKEND_MODEL.build(
|
||||||
partition_type,
|
partition_type,
|
||||||
backend=backend,
|
backend=backend,
|
||||||
backend_files=model_files,
|
backend_files=model_files,
|
||||||
|
@ -9,7 +9,7 @@ from .rewriter_utils import ContextCaller, RewriterRegistry, eval_with_import
|
|||||||
def _set_func(origin_func_name: str, rewrite_func: Callable):
|
def _set_func(origin_func_name: str, rewrite_func: Callable):
|
||||||
"""Rewrite a function by executing a python statement."""
|
"""Rewrite a function by executing a python statement."""
|
||||||
|
|
||||||
# import necessary module
|
# Import necessary module
|
||||||
split_path = origin_func_name.split('.')
|
split_path = origin_func_name.split('.')
|
||||||
for i in range(len(split_path), 0, -1):
|
for i in range(len(split_path), 0, -1):
|
||||||
try:
|
try:
|
||||||
@ -17,7 +17,7 @@ def _set_func(origin_func_name: str, rewrite_func: Callable):
|
|||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
# assign function
|
# Assign function
|
||||||
exec(f'{origin_func_name} = rewrite_func')
|
exec(f'{origin_func_name} = rewrite_func')
|
||||||
|
|
||||||
|
|
||||||
@ -72,7 +72,8 @@ class FunctionRewriter:
|
|||||||
functions_records = self._registry.get_records(backend)
|
functions_records = self._registry.get_records(backend)
|
||||||
|
|
||||||
self._origin_functions = list()
|
self._origin_functions = list()
|
||||||
for function_name, record_dict in functions_records.items():
|
new_functions = list()
|
||||||
|
for function_name, record_dict in functions_records:
|
||||||
|
|
||||||
# Check if the origin function exists
|
# Check if the origin function exists
|
||||||
try:
|
try:
|
||||||
@ -97,8 +98,12 @@ class FunctionRewriter:
|
|||||||
rewrite_function, origin_func, cfg,
|
rewrite_function, origin_func, cfg,
|
||||||
**extra_kwargs).get_wrapped_caller()
|
**extra_kwargs).get_wrapped_caller()
|
||||||
|
|
||||||
# Rewrite functions
|
# Cache new the function to avoid homonymic bug
|
||||||
_set_func(function_name, context_caller)
|
new_functions.append((function_name, context_caller))
|
||||||
|
|
||||||
|
for function_name, new_function in new_functions:
|
||||||
|
# Rewrite functions
|
||||||
|
_set_func(function_name, new_function)
|
||||||
|
|
||||||
def exit(self):
|
def exit(self):
|
||||||
"""Recover the function rewrite."""
|
"""Recover the function rewrite."""
|
||||||
|
@ -108,5 +108,5 @@ class ModuleRewriter:
|
|||||||
"""Collect models in registry."""
|
"""Collect models in registry."""
|
||||||
self._records = {}
|
self._records = {}
|
||||||
records = self._registry.get_records(backend)
|
records = self._registry.get_records(backend)
|
||||||
for name, kwargs in records.items():
|
for name, kwargs in records:
|
||||||
self._records[eval_with_import(name)] = kwargs
|
self._records[eval_with_import(name)] = kwargs
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
from mmdeploy.utils.constants import Backend
|
from mmdeploy.utils.constants import Backend
|
||||||
|
|
||||||
@ -57,12 +57,30 @@ class RewriterRegistry:
|
|||||||
if backend not in self._rewrite_records:
|
if backend not in self._rewrite_records:
|
||||||
self._rewrite_records[backend] = dict()
|
self._rewrite_records[backend] = dict()
|
||||||
|
|
||||||
def get_records(self, backend: str) -> dict:
|
def get_records(self, backend: str) -> List:
|
||||||
"""Get all registered records in record table."""
|
"""Get all registered records in record table."""
|
||||||
self._check_backend(backend)
|
self._check_backend(backend)
|
||||||
records = self._rewrite_records[Backend.DEFAULT.value].copy()
|
|
||||||
if backend != Backend.DEFAULT.value:
|
if backend != Backend.DEFAULT.value:
|
||||||
records.update(self._rewrite_records[backend])
|
# Update dict A with dict B.
|
||||||
|
# Then convert the result dict to a list, while keeping the order
|
||||||
|
# of A and B: the elements only belong to B should alwarys come
|
||||||
|
# after the elements only belong to A.
|
||||||
|
# The complexity is O(n + m).
|
||||||
|
dict_a = self._rewrite_records[Backend.DEFAULT.value]
|
||||||
|
dict_b = self._rewrite_records[backend]
|
||||||
|
records = []
|
||||||
|
for k, v in dict_a.items():
|
||||||
|
if k in dict_b:
|
||||||
|
records.append((k, dict_b[k]))
|
||||||
|
else:
|
||||||
|
records.append((k, v))
|
||||||
|
for k, v in dict_b.items():
|
||||||
|
if k not in dict_a:
|
||||||
|
records.append((k, v))
|
||||||
|
else:
|
||||||
|
records = list(
|
||||||
|
self._rewrite_records[Backend.DEFAULT.value].items())
|
||||||
return records
|
return records
|
||||||
|
|
||||||
def _register(self, name: str, backend: str, **kwargs):
|
def _register(self, name: str, backend: str, **kwargs):
|
||||||
|
@ -77,7 +77,8 @@ class SymbolicRewriter:
|
|||||||
|
|
||||||
self._pytorch_symbolic = list()
|
self._pytorch_symbolic = list()
|
||||||
self._extra_symbolic = list()
|
self._extra_symbolic = list()
|
||||||
for function_name, record_dict in symbolic_records.items():
|
new_functions = list()
|
||||||
|
for function_name, record_dict in symbolic_records:
|
||||||
|
|
||||||
symbolic_function = record_dict['_object']
|
symbolic_function = record_dict['_object']
|
||||||
arg_descriptors = record_dict['arg_descriptors']
|
arg_descriptors = record_dict['arg_descriptors']
|
||||||
@ -111,12 +112,18 @@ class SymbolicRewriter:
|
|||||||
# Only register functions that exist
|
# Only register functions that exist
|
||||||
if origin_func is not None:
|
if origin_func is not None:
|
||||||
origin_symbolic = getattr(origin_func, 'symbolic', None)
|
origin_symbolic = getattr(origin_func, 'symbolic', None)
|
||||||
context_caller.origin_func = origin_symbolic
|
|
||||||
origin_func.symbolic = context_caller
|
|
||||||
|
|
||||||
# Save origin function
|
# Save origin function
|
||||||
self._extra_symbolic.append((origin_func, origin_symbolic))
|
self._extra_symbolic.append((origin_func, origin_symbolic))
|
||||||
|
|
||||||
|
# Cache new the function to avoid homonymic bug
|
||||||
|
new_functions.append((origin_func, context_caller))
|
||||||
|
|
||||||
|
for origin_func, new_func in new_functions:
|
||||||
|
origin_symbolic = getattr(origin_func, 'symbolic', None)
|
||||||
|
new_func.origin_func = origin_symbolic
|
||||||
|
origin_func.symbolic = new_func
|
||||||
|
|
||||||
def exit(self):
|
def exit(self):
|
||||||
"""The implementation of symbolic unregister."""
|
"""The implementation of symbolic unregister."""
|
||||||
# Unregister pytorch op
|
# Unregister pytorch op
|
||||||
|
4
tests/test_core/package/__init__.py
Normal file
4
tests/test_core/package/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .module import C, func
|
||||||
|
|
||||||
|
__all__ = ['func', 'C']
|
9
tests/test_core/package/module.py
Normal file
9
tests/test_core/package/module.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
def func():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
class C:
|
||||||
|
|
||||||
|
def method(self):
|
||||||
|
return 1
|
@ -3,6 +3,7 @@ import torch
|
|||||||
|
|
||||||
from mmdeploy.core import FUNCTION_REWRITER, RewriterContext
|
from mmdeploy.core import FUNCTION_REWRITER, RewriterContext
|
||||||
from mmdeploy.core.rewriters.function_rewriter import FunctionRewriter
|
from mmdeploy.core.rewriters.function_rewriter import FunctionRewriter
|
||||||
|
from mmdeploy.utils.constants import Backend
|
||||||
|
|
||||||
|
|
||||||
def test_function_rewriter():
|
def test_function_rewriter():
|
||||||
@ -83,3 +84,101 @@ def test_rewrite_empty_function():
|
|||||||
function_rewriter.enter()
|
function_rewriter.enter()
|
||||||
assert len(function_rewriter._origin_functions) == 0
|
assert len(function_rewriter._origin_functions) == 0
|
||||||
function_rewriter.exit()
|
function_rewriter.exit()
|
||||||
|
|
||||||
|
|
||||||
|
class TestHomonymicRewriter:
|
||||||
|
|
||||||
|
def test_rewrite_homonymic_functions(self):
|
||||||
|
import package
|
||||||
|
path1 = 'package.func'
|
||||||
|
path2 = 'package.module.func'
|
||||||
|
|
||||||
|
assert package.func() == 1
|
||||||
|
assert package.module.func() == 1
|
||||||
|
|
||||||
|
function_rewriter = FunctionRewriter()
|
||||||
|
function_rewriter.add_backend(Backend.NCNN.value)
|
||||||
|
|
||||||
|
@function_rewriter.register_rewriter(func_name=path1)
|
||||||
|
def func_2(ctx):
|
||||||
|
return 2
|
||||||
|
|
||||||
|
@function_rewriter.register_rewriter(
|
||||||
|
func_name=path2, backend=Backend.NCNN.value)
|
||||||
|
def func_3(ctx):
|
||||||
|
return 3
|
||||||
|
|
||||||
|
function_rewriter.enter(backend=Backend.NCNN.value)
|
||||||
|
# This is a feature
|
||||||
|
assert package.func() == 2
|
||||||
|
assert package.module.func() == 3
|
||||||
|
function_rewriter.exit()
|
||||||
|
|
||||||
|
assert package.func() == 1
|
||||||
|
assert package.module.func() == 1
|
||||||
|
|
||||||
|
function_rewriter2 = FunctionRewriter()
|
||||||
|
function_rewriter2.add_backend(Backend.NCNN.value)
|
||||||
|
|
||||||
|
@function_rewriter2.register_rewriter(
|
||||||
|
func_name=path1, backend=Backend.NCNN.value)
|
||||||
|
def func_4(ctx):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
@function_rewriter2.register_rewriter(func_name=path2)
|
||||||
|
def func_5(ctx):
|
||||||
|
return 5
|
||||||
|
|
||||||
|
function_rewriter2.enter(backend=Backend.NCNN.value)
|
||||||
|
# This is a feature
|
||||||
|
assert package.func() == 4
|
||||||
|
assert package.module.func() == 5
|
||||||
|
function_rewriter2.exit()
|
||||||
|
|
||||||
|
assert package.func() == 1
|
||||||
|
assert package.module.func() == 1
|
||||||
|
|
||||||
|
def test_rewrite_homonymic_methods(self):
|
||||||
|
import package
|
||||||
|
path1 = 'package.C.method'
|
||||||
|
path2 = 'package.module.C.method'
|
||||||
|
|
||||||
|
c = package.C()
|
||||||
|
|
||||||
|
function_rewriter = FunctionRewriter()
|
||||||
|
function_rewriter.add_backend(Backend.NCNN.value)
|
||||||
|
|
||||||
|
assert c.method() == 1
|
||||||
|
|
||||||
|
@function_rewriter.register_rewriter(func_name=path1)
|
||||||
|
def func_2(ctx, self):
|
||||||
|
return 2
|
||||||
|
|
||||||
|
@function_rewriter.register_rewriter(
|
||||||
|
func_name=path2, backend=Backend.NCNN.value)
|
||||||
|
def func_3(ctx, self):
|
||||||
|
return 3
|
||||||
|
|
||||||
|
function_rewriter.enter(backend=Backend.NCNN.value)
|
||||||
|
assert c.method() == 3
|
||||||
|
function_rewriter.exit()
|
||||||
|
|
||||||
|
assert c.method() == 1
|
||||||
|
|
||||||
|
function_rewriter2 = FunctionRewriter()
|
||||||
|
function_rewriter2.add_backend(Backend.NCNN.value)
|
||||||
|
|
||||||
|
@function_rewriter2.register_rewriter(
|
||||||
|
func_name=path1, backend=Backend.NCNN.value)
|
||||||
|
def func_4(ctx, self):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
@function_rewriter2.register_rewriter(func_name=path2)
|
||||||
|
def func_5(ctx, self):
|
||||||
|
return 5
|
||||||
|
|
||||||
|
function_rewriter2.enter(backend=Backend.NCNN.value)
|
||||||
|
assert c.method() == 4
|
||||||
|
function_rewriter2.exit()
|
||||||
|
|
||||||
|
assert c.method() == 1
|
||||||
|
@ -50,10 +50,10 @@ def test_get_records():
|
|||||||
def fake_add(a, b):
|
def fake_add(a, b):
|
||||||
return a * b
|
return a * b
|
||||||
|
|
||||||
default_records = registry.get_records(Backend.DEFAULT.value)
|
default_records = dict(registry.get_records(Backend.DEFAULT.value))
|
||||||
assert default_records['add']['_object'](1, 1) == 2
|
assert default_records['add']['_object'](1, 1) == 2
|
||||||
assert default_records['minus']['_object'](1, 1) == 0
|
assert default_records['minus']['_object'](1, 1) == 0
|
||||||
|
|
||||||
tensorrt_records = registry.get_records(Backend.TENSORRT.value)
|
tensorrt_records = dict(registry.get_records(Backend.TENSORRT.value))
|
||||||
assert tensorrt_records['add']['_object'](1, 1) == 1
|
assert tensorrt_records['add']['_object'](1, 1) == 1
|
||||||
assert tensorrt_records['minus']['_object'](1, 1) == 0
|
assert tensorrt_records['minus']['_object'](1, 1) == 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user