From cc72c00e619bbbbd29ddc78b3c30dbbcefa6b9cc Mon Sep 17 00:00:00 2001 From: Yifan Zhou Date: Fri, 3 Dec 2021 18:31:58 +0800 Subject: [PATCH] [Fix]: Fix homonymic rewriters bugs (#242) * Fix bug * license --- .../mmdet/deploy/object_detection_model.py | 10 +- mmdeploy/core/rewriters/function_rewriter.py | 15 ++- mmdeploy/core/rewriters/module_rewriter.py | 2 +- mmdeploy/core/rewriters/rewriter_utils.py | 26 ++++- mmdeploy/core/rewriters/symbolic_rewriter.py | 13 ++- tests/test_core/package/__init__.py | 4 + tests/test_core/package/module.py | 9 ++ tests/test_core/test_function_rewriter.py | 99 +++++++++++++++++++ tests/test_core/test_rewriter_registry.py | 4 +- 9 files changed, 162 insertions(+), 20 deletions(-) create mode 100644 tests/test_core/package/__init__.py create mode 100644 tests/test_core/package/module.py diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index 9386466cb..8a8259f18 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -36,11 +36,11 @@ def __build_backend_model(partition_name: str, backend: Backend, # Use registry to store models with different partition methods # If a model doesn't need to partition, we don't need this registry -__BACKEND_MODEl = mmcv.utils.Registry( +__BACKEND_MODEL = mmcv.utils.Registry( 'backend_detectors', build_func=__build_backend_model) -@__BACKEND_MODEl.register_module('end2end') +@__BACKEND_MODEL.register_module('end2end') class End2EndModel(BaseBackendModel): """End to end model for inference of detection. @@ -273,7 +273,7 @@ class End2EndModel(BaseBackendModel): out_file=out_file) -@__BACKEND_MODEl.register_module('single_stage') +@__BACKEND_MODEL.register_module('single_stage') class PartitionSingleStageModel(End2EndModel): """Partitioned single stage detection model. @@ -352,7 +352,7 @@ class PartitionSingleStageModel(End2EndModel): return self.partition0_postprocess(scores, bboxes) -@__BACKEND_MODEl.register_module('two_stage') +@__BACKEND_MODEL.register_module('two_stage') class PartitionTwoStageModel(End2EndModel): """Partitioned two stage detection model. @@ -572,7 +572,7 @@ def build_object_detection_model(model_files: Sequence[str], if partition_config is not None: partition_type = partition_config.get('type', None) - backend_detector = __BACKEND_MODEl.build( + backend_detector = __BACKEND_MODEL.build( partition_type, backend=backend, backend_files=model_files, diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index 6911ad2e9..50491117d 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -9,7 +9,7 @@ from .rewriter_utils import ContextCaller, RewriterRegistry, eval_with_import def _set_func(origin_func_name: str, rewrite_func: Callable): """Rewrite a function by executing a python statement.""" - # import necessary module + # Import necessary module split_path = origin_func_name.split('.') for i in range(len(split_path), 0, -1): try: @@ -17,7 +17,7 @@ def _set_func(origin_func_name: str, rewrite_func: Callable): break except Exception: continue - # assign function + # Assign function exec(f'{origin_func_name} = rewrite_func') @@ -72,7 +72,8 @@ class FunctionRewriter: functions_records = self._registry.get_records(backend) self._origin_functions = list() - for function_name, record_dict in functions_records.items(): + new_functions = list() + for function_name, record_dict in functions_records: # Check if the origin function exists try: @@ -97,8 +98,12 @@ class FunctionRewriter: rewrite_function, origin_func, cfg, **extra_kwargs).get_wrapped_caller() - # Rewrite functions - _set_func(function_name, context_caller) + # Cache new the function to avoid homonymic bug + new_functions.append((function_name, context_caller)) + + for function_name, new_function in new_functions: + # Rewrite functions + _set_func(function_name, new_function) def exit(self): """Recover the function rewrite.""" diff --git a/mmdeploy/core/rewriters/module_rewriter.py b/mmdeploy/core/rewriters/module_rewriter.py index fae278900..c729b3bf5 100644 --- a/mmdeploy/core/rewriters/module_rewriter.py +++ b/mmdeploy/core/rewriters/module_rewriter.py @@ -108,5 +108,5 @@ class ModuleRewriter: """Collect models in registry.""" self._records = {} records = self._registry.get_records(backend) - for name, kwargs in records.items(): + for name, kwargs in records: self._records[eval_with_import(name)] = kwargs diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index 0612607f7..1f2792dc7 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, List from mmdeploy.utils.constants import Backend @@ -57,12 +57,30 @@ class RewriterRegistry: if backend not in self._rewrite_records: self._rewrite_records[backend] = dict() - def get_records(self, backend: str) -> dict: + def get_records(self, backend: str) -> List: """Get all registered records in record table.""" self._check_backend(backend) - records = self._rewrite_records[Backend.DEFAULT.value].copy() + if backend != Backend.DEFAULT.value: - records.update(self._rewrite_records[backend]) + # Update dict A with dict B. + # Then convert the result dict to a list, while keeping the order + # of A and B: the elements only belong to B should alwarys come + # after the elements only belong to A. + # The complexity is O(n + m). + dict_a = self._rewrite_records[Backend.DEFAULT.value] + dict_b = self._rewrite_records[backend] + records = [] + for k, v in dict_a.items(): + if k in dict_b: + records.append((k, dict_b[k])) + else: + records.append((k, v)) + for k, v in dict_b.items(): + if k not in dict_a: + records.append((k, v)) + else: + records = list( + self._rewrite_records[Backend.DEFAULT.value].items()) return records def _register(self, name: str, backend: str, **kwargs): diff --git a/mmdeploy/core/rewriters/symbolic_rewriter.py b/mmdeploy/core/rewriters/symbolic_rewriter.py index 6d83a2901..3f4343146 100644 --- a/mmdeploy/core/rewriters/symbolic_rewriter.py +++ b/mmdeploy/core/rewriters/symbolic_rewriter.py @@ -77,7 +77,8 @@ class SymbolicRewriter: self._pytorch_symbolic = list() self._extra_symbolic = list() - for function_name, record_dict in symbolic_records.items(): + new_functions = list() + for function_name, record_dict in symbolic_records: symbolic_function = record_dict['_object'] arg_descriptors = record_dict['arg_descriptors'] @@ -111,12 +112,18 @@ class SymbolicRewriter: # Only register functions that exist if origin_func is not None: origin_symbolic = getattr(origin_func, 'symbolic', None) - context_caller.origin_func = origin_symbolic - origin_func.symbolic = context_caller # Save origin function self._extra_symbolic.append((origin_func, origin_symbolic)) + # Cache new the function to avoid homonymic bug + new_functions.append((origin_func, context_caller)) + + for origin_func, new_func in new_functions: + origin_symbolic = getattr(origin_func, 'symbolic', None) + new_func.origin_func = origin_symbolic + origin_func.symbolic = new_func + def exit(self): """The implementation of symbolic unregister.""" # Unregister pytorch op diff --git a/tests/test_core/package/__init__.py b/tests/test_core/package/__init__.py new file mode 100644 index 000000000..2e2785af7 --- /dev/null +++ b/tests/test_core/package/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .module import C, func + +__all__ = ['func', 'C'] diff --git a/tests/test_core/package/module.py b/tests/test_core/package/module.py new file mode 100644 index 000000000..3e00c4dff --- /dev/null +++ b/tests/test_core/package/module.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def func(): + return 1 + + +class C: + + def method(self): + return 1 diff --git a/tests/test_core/test_function_rewriter.py b/tests/test_core/test_function_rewriter.py index 0a97a6280..222e29b35 100644 --- a/tests/test_core/test_function_rewriter.py +++ b/tests/test_core/test_function_rewriter.py @@ -3,6 +3,7 @@ import torch from mmdeploy.core import FUNCTION_REWRITER, RewriterContext from mmdeploy.core.rewriters.function_rewriter import FunctionRewriter +from mmdeploy.utils.constants import Backend def test_function_rewriter(): @@ -83,3 +84,101 @@ def test_rewrite_empty_function(): function_rewriter.enter() assert len(function_rewriter._origin_functions) == 0 function_rewriter.exit() + + +class TestHomonymicRewriter: + + def test_rewrite_homonymic_functions(self): + import package + path1 = 'package.func' + path2 = 'package.module.func' + + assert package.func() == 1 + assert package.module.func() == 1 + + function_rewriter = FunctionRewriter() + function_rewriter.add_backend(Backend.NCNN.value) + + @function_rewriter.register_rewriter(func_name=path1) + def func_2(ctx): + return 2 + + @function_rewriter.register_rewriter( + func_name=path2, backend=Backend.NCNN.value) + def func_3(ctx): + return 3 + + function_rewriter.enter(backend=Backend.NCNN.value) + # This is a feature + assert package.func() == 2 + assert package.module.func() == 3 + function_rewriter.exit() + + assert package.func() == 1 + assert package.module.func() == 1 + + function_rewriter2 = FunctionRewriter() + function_rewriter2.add_backend(Backend.NCNN.value) + + @function_rewriter2.register_rewriter( + func_name=path1, backend=Backend.NCNN.value) + def func_4(ctx): + return 4 + + @function_rewriter2.register_rewriter(func_name=path2) + def func_5(ctx): + return 5 + + function_rewriter2.enter(backend=Backend.NCNN.value) + # This is a feature + assert package.func() == 4 + assert package.module.func() == 5 + function_rewriter2.exit() + + assert package.func() == 1 + assert package.module.func() == 1 + + def test_rewrite_homonymic_methods(self): + import package + path1 = 'package.C.method' + path2 = 'package.module.C.method' + + c = package.C() + + function_rewriter = FunctionRewriter() + function_rewriter.add_backend(Backend.NCNN.value) + + assert c.method() == 1 + + @function_rewriter.register_rewriter(func_name=path1) + def func_2(ctx, self): + return 2 + + @function_rewriter.register_rewriter( + func_name=path2, backend=Backend.NCNN.value) + def func_3(ctx, self): + return 3 + + function_rewriter.enter(backend=Backend.NCNN.value) + assert c.method() == 3 + function_rewriter.exit() + + assert c.method() == 1 + + function_rewriter2 = FunctionRewriter() + function_rewriter2.add_backend(Backend.NCNN.value) + + @function_rewriter2.register_rewriter( + func_name=path1, backend=Backend.NCNN.value) + def func_4(ctx, self): + return 4 + + @function_rewriter2.register_rewriter(func_name=path2) + def func_5(ctx, self): + return 5 + + function_rewriter2.enter(backend=Backend.NCNN.value) + assert c.method() == 4 + function_rewriter2.exit() + + assert c.method() == 1 diff --git a/tests/test_core/test_rewriter_registry.py b/tests/test_core/test_rewriter_registry.py index 22c0d7d8c..b577d0262 100644 --- a/tests/test_core/test_rewriter_registry.py +++ b/tests/test_core/test_rewriter_registry.py @@ -50,10 +50,10 @@ def test_get_records(): def fake_add(a, b): return a * b - default_records = registry.get_records(Backend.DEFAULT.value) + default_records = dict(registry.get_records(Backend.DEFAULT.value)) assert default_records['add']['_object'](1, 1) == 2 assert default_records['minus']['_object'](1, 1) == 0 - tensorrt_records = registry.get_records(Backend.TENSORRT.value) + tensorrt_records = dict(registry.get_records(Backend.TENSORRT.value)) assert tensorrt_records['add']['_object'](1, 1) == 1 assert tensorrt_records['minus']['_object'](1, 1) == 0