mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
parent
597350c07b
commit
cc72c00e61
@ -36,11 +36,11 @@ def __build_backend_model(partition_name: str, backend: Backend,
|
||||
|
||||
# Use registry to store models with different partition methods
|
||||
# If a model doesn't need to partition, we don't need this registry
|
||||
__BACKEND_MODEl = mmcv.utils.Registry(
|
||||
__BACKEND_MODEL = mmcv.utils.Registry(
|
||||
'backend_detectors', build_func=__build_backend_model)
|
||||
|
||||
|
||||
@__BACKEND_MODEl.register_module('end2end')
|
||||
@__BACKEND_MODEL.register_module('end2end')
|
||||
class End2EndModel(BaseBackendModel):
|
||||
"""End to end model for inference of detection.
|
||||
|
||||
@ -273,7 +273,7 @@ class End2EndModel(BaseBackendModel):
|
||||
out_file=out_file)
|
||||
|
||||
|
||||
@__BACKEND_MODEl.register_module('single_stage')
|
||||
@__BACKEND_MODEL.register_module('single_stage')
|
||||
class PartitionSingleStageModel(End2EndModel):
|
||||
"""Partitioned single stage detection model.
|
||||
|
||||
@ -352,7 +352,7 @@ class PartitionSingleStageModel(End2EndModel):
|
||||
return self.partition0_postprocess(scores, bboxes)
|
||||
|
||||
|
||||
@__BACKEND_MODEl.register_module('two_stage')
|
||||
@__BACKEND_MODEL.register_module('two_stage')
|
||||
class PartitionTwoStageModel(End2EndModel):
|
||||
"""Partitioned two stage detection model.
|
||||
|
||||
@ -572,7 +572,7 @@ def build_object_detection_model(model_files: Sequence[str],
|
||||
if partition_config is not None:
|
||||
partition_type = partition_config.get('type', None)
|
||||
|
||||
backend_detector = __BACKEND_MODEl.build(
|
||||
backend_detector = __BACKEND_MODEL.build(
|
||||
partition_type,
|
||||
backend=backend,
|
||||
backend_files=model_files,
|
||||
|
@ -9,7 +9,7 @@ from .rewriter_utils import ContextCaller, RewriterRegistry, eval_with_import
|
||||
def _set_func(origin_func_name: str, rewrite_func: Callable):
|
||||
"""Rewrite a function by executing a python statement."""
|
||||
|
||||
# import necessary module
|
||||
# Import necessary module
|
||||
split_path = origin_func_name.split('.')
|
||||
for i in range(len(split_path), 0, -1):
|
||||
try:
|
||||
@ -17,7 +17,7 @@ def _set_func(origin_func_name: str, rewrite_func: Callable):
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
# assign function
|
||||
# Assign function
|
||||
exec(f'{origin_func_name} = rewrite_func')
|
||||
|
||||
|
||||
@ -72,7 +72,8 @@ class FunctionRewriter:
|
||||
functions_records = self._registry.get_records(backend)
|
||||
|
||||
self._origin_functions = list()
|
||||
for function_name, record_dict in functions_records.items():
|
||||
new_functions = list()
|
||||
for function_name, record_dict in functions_records:
|
||||
|
||||
# Check if the origin function exists
|
||||
try:
|
||||
@ -97,8 +98,12 @@ class FunctionRewriter:
|
||||
rewrite_function, origin_func, cfg,
|
||||
**extra_kwargs).get_wrapped_caller()
|
||||
|
||||
# Cache new the function to avoid homonymic bug
|
||||
new_functions.append((function_name, context_caller))
|
||||
|
||||
for function_name, new_function in new_functions:
|
||||
# Rewrite functions
|
||||
_set_func(function_name, context_caller)
|
||||
_set_func(function_name, new_function)
|
||||
|
||||
def exit(self):
|
||||
"""Recover the function rewrite."""
|
||||
|
@ -108,5 +108,5 @@ class ModuleRewriter:
|
||||
"""Collect models in registry."""
|
||||
self._records = {}
|
||||
records = self._registry.get_records(backend)
|
||||
for name, kwargs in records.items():
|
||||
for name, kwargs in records:
|
||||
self._records[eval_with_import(name)] = kwargs
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from mmdeploy.utils.constants import Backend
|
||||
|
||||
@ -57,12 +57,30 @@ class RewriterRegistry:
|
||||
if backend not in self._rewrite_records:
|
||||
self._rewrite_records[backend] = dict()
|
||||
|
||||
def get_records(self, backend: str) -> dict:
|
||||
def get_records(self, backend: str) -> List:
|
||||
"""Get all registered records in record table."""
|
||||
self._check_backend(backend)
|
||||
records = self._rewrite_records[Backend.DEFAULT.value].copy()
|
||||
|
||||
if backend != Backend.DEFAULT.value:
|
||||
records.update(self._rewrite_records[backend])
|
||||
# Update dict A with dict B.
|
||||
# Then convert the result dict to a list, while keeping the order
|
||||
# of A and B: the elements only belong to B should alwarys come
|
||||
# after the elements only belong to A.
|
||||
# The complexity is O(n + m).
|
||||
dict_a = self._rewrite_records[Backend.DEFAULT.value]
|
||||
dict_b = self._rewrite_records[backend]
|
||||
records = []
|
||||
for k, v in dict_a.items():
|
||||
if k in dict_b:
|
||||
records.append((k, dict_b[k]))
|
||||
else:
|
||||
records.append((k, v))
|
||||
for k, v in dict_b.items():
|
||||
if k not in dict_a:
|
||||
records.append((k, v))
|
||||
else:
|
||||
records = list(
|
||||
self._rewrite_records[Backend.DEFAULT.value].items())
|
||||
return records
|
||||
|
||||
def _register(self, name: str, backend: str, **kwargs):
|
||||
|
@ -77,7 +77,8 @@ class SymbolicRewriter:
|
||||
|
||||
self._pytorch_symbolic = list()
|
||||
self._extra_symbolic = list()
|
||||
for function_name, record_dict in symbolic_records.items():
|
||||
new_functions = list()
|
||||
for function_name, record_dict in symbolic_records:
|
||||
|
||||
symbolic_function = record_dict['_object']
|
||||
arg_descriptors = record_dict['arg_descriptors']
|
||||
@ -111,12 +112,18 @@ class SymbolicRewriter:
|
||||
# Only register functions that exist
|
||||
if origin_func is not None:
|
||||
origin_symbolic = getattr(origin_func, 'symbolic', None)
|
||||
context_caller.origin_func = origin_symbolic
|
||||
origin_func.symbolic = context_caller
|
||||
|
||||
# Save origin function
|
||||
self._extra_symbolic.append((origin_func, origin_symbolic))
|
||||
|
||||
# Cache new the function to avoid homonymic bug
|
||||
new_functions.append((origin_func, context_caller))
|
||||
|
||||
for origin_func, new_func in new_functions:
|
||||
origin_symbolic = getattr(origin_func, 'symbolic', None)
|
||||
new_func.origin_func = origin_symbolic
|
||||
origin_func.symbolic = new_func
|
||||
|
||||
def exit(self):
|
||||
"""The implementation of symbolic unregister."""
|
||||
# Unregister pytorch op
|
||||
|
4
tests/test_core/package/__init__.py
Normal file
4
tests/test_core/package/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .module import C, func
|
||||
|
||||
__all__ = ['func', 'C']
|
9
tests/test_core/package/module.py
Normal file
9
tests/test_core/package/module.py
Normal file
@ -0,0 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
def func():
|
||||
return 1
|
||||
|
||||
|
||||
class C:
|
||||
|
||||
def method(self):
|
||||
return 1
|
@ -3,6 +3,7 @@ import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER, RewriterContext
|
||||
from mmdeploy.core.rewriters.function_rewriter import FunctionRewriter
|
||||
from mmdeploy.utils.constants import Backend
|
||||
|
||||
|
||||
def test_function_rewriter():
|
||||
@ -83,3 +84,101 @@ def test_rewrite_empty_function():
|
||||
function_rewriter.enter()
|
||||
assert len(function_rewriter._origin_functions) == 0
|
||||
function_rewriter.exit()
|
||||
|
||||
|
||||
class TestHomonymicRewriter:
|
||||
|
||||
def test_rewrite_homonymic_functions(self):
|
||||
import package
|
||||
path1 = 'package.func'
|
||||
path2 = 'package.module.func'
|
||||
|
||||
assert package.func() == 1
|
||||
assert package.module.func() == 1
|
||||
|
||||
function_rewriter = FunctionRewriter()
|
||||
function_rewriter.add_backend(Backend.NCNN.value)
|
||||
|
||||
@function_rewriter.register_rewriter(func_name=path1)
|
||||
def func_2(ctx):
|
||||
return 2
|
||||
|
||||
@function_rewriter.register_rewriter(
|
||||
func_name=path2, backend=Backend.NCNN.value)
|
||||
def func_3(ctx):
|
||||
return 3
|
||||
|
||||
function_rewriter.enter(backend=Backend.NCNN.value)
|
||||
# This is a feature
|
||||
assert package.func() == 2
|
||||
assert package.module.func() == 3
|
||||
function_rewriter.exit()
|
||||
|
||||
assert package.func() == 1
|
||||
assert package.module.func() == 1
|
||||
|
||||
function_rewriter2 = FunctionRewriter()
|
||||
function_rewriter2.add_backend(Backend.NCNN.value)
|
||||
|
||||
@function_rewriter2.register_rewriter(
|
||||
func_name=path1, backend=Backend.NCNN.value)
|
||||
def func_4(ctx):
|
||||
return 4
|
||||
|
||||
@function_rewriter2.register_rewriter(func_name=path2)
|
||||
def func_5(ctx):
|
||||
return 5
|
||||
|
||||
function_rewriter2.enter(backend=Backend.NCNN.value)
|
||||
# This is a feature
|
||||
assert package.func() == 4
|
||||
assert package.module.func() == 5
|
||||
function_rewriter2.exit()
|
||||
|
||||
assert package.func() == 1
|
||||
assert package.module.func() == 1
|
||||
|
||||
def test_rewrite_homonymic_methods(self):
|
||||
import package
|
||||
path1 = 'package.C.method'
|
||||
path2 = 'package.module.C.method'
|
||||
|
||||
c = package.C()
|
||||
|
||||
function_rewriter = FunctionRewriter()
|
||||
function_rewriter.add_backend(Backend.NCNN.value)
|
||||
|
||||
assert c.method() == 1
|
||||
|
||||
@function_rewriter.register_rewriter(func_name=path1)
|
||||
def func_2(ctx, self):
|
||||
return 2
|
||||
|
||||
@function_rewriter.register_rewriter(
|
||||
func_name=path2, backend=Backend.NCNN.value)
|
||||
def func_3(ctx, self):
|
||||
return 3
|
||||
|
||||
function_rewriter.enter(backend=Backend.NCNN.value)
|
||||
assert c.method() == 3
|
||||
function_rewriter.exit()
|
||||
|
||||
assert c.method() == 1
|
||||
|
||||
function_rewriter2 = FunctionRewriter()
|
||||
function_rewriter2.add_backend(Backend.NCNN.value)
|
||||
|
||||
@function_rewriter2.register_rewriter(
|
||||
func_name=path1, backend=Backend.NCNN.value)
|
||||
def func_4(ctx, self):
|
||||
return 4
|
||||
|
||||
@function_rewriter2.register_rewriter(func_name=path2)
|
||||
def func_5(ctx, self):
|
||||
return 5
|
||||
|
||||
function_rewriter2.enter(backend=Backend.NCNN.value)
|
||||
assert c.method() == 4
|
||||
function_rewriter2.exit()
|
||||
|
||||
assert c.method() == 1
|
||||
|
@ -50,10 +50,10 @@ def test_get_records():
|
||||
def fake_add(a, b):
|
||||
return a * b
|
||||
|
||||
default_records = registry.get_records(Backend.DEFAULT.value)
|
||||
default_records = dict(registry.get_records(Backend.DEFAULT.value))
|
||||
assert default_records['add']['_object'](1, 1) == 2
|
||||
assert default_records['minus']['_object'](1, 1) == 0
|
||||
|
||||
tensorrt_records = registry.get_records(Backend.TENSORRT.value)
|
||||
tensorrt_records = dict(registry.get_records(Backend.TENSORRT.value))
|
||||
assert tensorrt_records['add']['_object'](1, 1) == 1
|
||||
assert tensorrt_records['minus']['_object'](1, 1) == 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user