[Fix]: Fix homonymic rewriters bugs (#242)

* Fix bug

* license
This commit is contained in:
Yifan Zhou 2021-12-03 18:31:58 +08:00 committed by GitHub
parent 597350c07b
commit cc72c00e61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 162 additions and 20 deletions

View File

@ -36,11 +36,11 @@ def __build_backend_model(partition_name: str, backend: Backend,
# Use registry to store models with different partition methods
# If a model doesn't need to partition, we don't need this registry
__BACKEND_MODEl = mmcv.utils.Registry(
__BACKEND_MODEL = mmcv.utils.Registry(
'backend_detectors', build_func=__build_backend_model)
@__BACKEND_MODEl.register_module('end2end')
@__BACKEND_MODEL.register_module('end2end')
class End2EndModel(BaseBackendModel):
"""End to end model for inference of detection.
@ -273,7 +273,7 @@ class End2EndModel(BaseBackendModel):
out_file=out_file)
@__BACKEND_MODEl.register_module('single_stage')
@__BACKEND_MODEL.register_module('single_stage')
class PartitionSingleStageModel(End2EndModel):
"""Partitioned single stage detection model.
@ -352,7 +352,7 @@ class PartitionSingleStageModel(End2EndModel):
return self.partition0_postprocess(scores, bboxes)
@__BACKEND_MODEl.register_module('two_stage')
@__BACKEND_MODEL.register_module('two_stage')
class PartitionTwoStageModel(End2EndModel):
"""Partitioned two stage detection model.
@ -572,7 +572,7 @@ def build_object_detection_model(model_files: Sequence[str],
if partition_config is not None:
partition_type = partition_config.get('type', None)
backend_detector = __BACKEND_MODEl.build(
backend_detector = __BACKEND_MODEL.build(
partition_type,
backend=backend,
backend_files=model_files,

View File

@ -9,7 +9,7 @@ from .rewriter_utils import ContextCaller, RewriterRegistry, eval_with_import
def _set_func(origin_func_name: str, rewrite_func: Callable):
"""Rewrite a function by executing a python statement."""
# import necessary module
# Import necessary module
split_path = origin_func_name.split('.')
for i in range(len(split_path), 0, -1):
try:
@ -17,7 +17,7 @@ def _set_func(origin_func_name: str, rewrite_func: Callable):
break
except Exception:
continue
# assign function
# Assign function
exec(f'{origin_func_name} = rewrite_func')
@ -72,7 +72,8 @@ class FunctionRewriter:
functions_records = self._registry.get_records(backend)
self._origin_functions = list()
for function_name, record_dict in functions_records.items():
new_functions = list()
for function_name, record_dict in functions_records:
# Check if the origin function exists
try:
@ -97,8 +98,12 @@ class FunctionRewriter:
rewrite_function, origin_func, cfg,
**extra_kwargs).get_wrapped_caller()
# Rewrite functions
_set_func(function_name, context_caller)
# Cache new the function to avoid homonymic bug
new_functions.append((function_name, context_caller))
for function_name, new_function in new_functions:
# Rewrite functions
_set_func(function_name, new_function)
def exit(self):
"""Recover the function rewrite."""

View File

@ -108,5 +108,5 @@ class ModuleRewriter:
"""Collect models in registry."""
self._records = {}
records = self._registry.get_records(backend)
for name, kwargs in records.items():
for name, kwargs in records:
self._records[eval_with_import(name)] = kwargs

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, List
from mmdeploy.utils.constants import Backend
@ -57,12 +57,30 @@ class RewriterRegistry:
if backend not in self._rewrite_records:
self._rewrite_records[backend] = dict()
def get_records(self, backend: str) -> dict:
def get_records(self, backend: str) -> List:
"""Get all registered records in record table."""
self._check_backend(backend)
records = self._rewrite_records[Backend.DEFAULT.value].copy()
if backend != Backend.DEFAULT.value:
records.update(self._rewrite_records[backend])
# Update dict A with dict B.
# Then convert the result dict to a list, while keeping the order
# of A and B: the elements only belong to B should alwarys come
# after the elements only belong to A.
# The complexity is O(n + m).
dict_a = self._rewrite_records[Backend.DEFAULT.value]
dict_b = self._rewrite_records[backend]
records = []
for k, v in dict_a.items():
if k in dict_b:
records.append((k, dict_b[k]))
else:
records.append((k, v))
for k, v in dict_b.items():
if k not in dict_a:
records.append((k, v))
else:
records = list(
self._rewrite_records[Backend.DEFAULT.value].items())
return records
def _register(self, name: str, backend: str, **kwargs):

View File

@ -77,7 +77,8 @@ class SymbolicRewriter:
self._pytorch_symbolic = list()
self._extra_symbolic = list()
for function_name, record_dict in symbolic_records.items():
new_functions = list()
for function_name, record_dict in symbolic_records:
symbolic_function = record_dict['_object']
arg_descriptors = record_dict['arg_descriptors']
@ -111,12 +112,18 @@ class SymbolicRewriter:
# Only register functions that exist
if origin_func is not None:
origin_symbolic = getattr(origin_func, 'symbolic', None)
context_caller.origin_func = origin_symbolic
origin_func.symbolic = context_caller
# Save origin function
self._extra_symbolic.append((origin_func, origin_symbolic))
# Cache new the function to avoid homonymic bug
new_functions.append((origin_func, context_caller))
for origin_func, new_func in new_functions:
origin_symbolic = getattr(origin_func, 'symbolic', None)
new_func.origin_func = origin_symbolic
origin_func.symbolic = new_func
def exit(self):
"""The implementation of symbolic unregister."""
# Unregister pytorch op

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .module import C, func
__all__ = ['func', 'C']

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
def func():
return 1
class C:
def method(self):
return 1

View File

@ -3,6 +3,7 @@ import torch
from mmdeploy.core import FUNCTION_REWRITER, RewriterContext
from mmdeploy.core.rewriters.function_rewriter import FunctionRewriter
from mmdeploy.utils.constants import Backend
def test_function_rewriter():
@ -83,3 +84,101 @@ def test_rewrite_empty_function():
function_rewriter.enter()
assert len(function_rewriter._origin_functions) == 0
function_rewriter.exit()
class TestHomonymicRewriter:
def test_rewrite_homonymic_functions(self):
import package
path1 = 'package.func'
path2 = 'package.module.func'
assert package.func() == 1
assert package.module.func() == 1
function_rewriter = FunctionRewriter()
function_rewriter.add_backend(Backend.NCNN.value)
@function_rewriter.register_rewriter(func_name=path1)
def func_2(ctx):
return 2
@function_rewriter.register_rewriter(
func_name=path2, backend=Backend.NCNN.value)
def func_3(ctx):
return 3
function_rewriter.enter(backend=Backend.NCNN.value)
# This is a feature
assert package.func() == 2
assert package.module.func() == 3
function_rewriter.exit()
assert package.func() == 1
assert package.module.func() == 1
function_rewriter2 = FunctionRewriter()
function_rewriter2.add_backend(Backend.NCNN.value)
@function_rewriter2.register_rewriter(
func_name=path1, backend=Backend.NCNN.value)
def func_4(ctx):
return 4
@function_rewriter2.register_rewriter(func_name=path2)
def func_5(ctx):
return 5
function_rewriter2.enter(backend=Backend.NCNN.value)
# This is a feature
assert package.func() == 4
assert package.module.func() == 5
function_rewriter2.exit()
assert package.func() == 1
assert package.module.func() == 1
def test_rewrite_homonymic_methods(self):
import package
path1 = 'package.C.method'
path2 = 'package.module.C.method'
c = package.C()
function_rewriter = FunctionRewriter()
function_rewriter.add_backend(Backend.NCNN.value)
assert c.method() == 1
@function_rewriter.register_rewriter(func_name=path1)
def func_2(ctx, self):
return 2
@function_rewriter.register_rewriter(
func_name=path2, backend=Backend.NCNN.value)
def func_3(ctx, self):
return 3
function_rewriter.enter(backend=Backend.NCNN.value)
assert c.method() == 3
function_rewriter.exit()
assert c.method() == 1
function_rewriter2 = FunctionRewriter()
function_rewriter2.add_backend(Backend.NCNN.value)
@function_rewriter2.register_rewriter(
func_name=path1, backend=Backend.NCNN.value)
def func_4(ctx, self):
return 4
@function_rewriter2.register_rewriter(func_name=path2)
def func_5(ctx, self):
return 5
function_rewriter2.enter(backend=Backend.NCNN.value)
assert c.method() == 4
function_rewriter2.exit()
assert c.method() == 1

View File

@ -50,10 +50,10 @@ def test_get_records():
def fake_add(a, b):
return a * b
default_records = registry.get_records(Backend.DEFAULT.value)
default_records = dict(registry.get_records(Backend.DEFAULT.value))
assert default_records['add']['_object'](1, 1) == 2
assert default_records['minus']['_object'](1, 1) == 0
tensorrt_records = registry.get_records(Backend.TENSORRT.value)
tensorrt_records = dict(registry.get_records(Backend.TENSORRT.value))
assert tensorrt_records['add']['_object'](1, 1) == 1
assert tensorrt_records['minus']['_object'](1, 1) == 0