diff --git a/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py b/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py index c177bf9df..f04d55891 100644 --- a/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py +++ b/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py @@ -211,7 +211,7 @@ def multiclass_nms__default(ctx, pre_top_k=pre_top_k, keep_top_k=keep_top_k) else: - return _multiclass_nms( + return ctx.origin_func( boxes, scores, max_output_boxes_per_class=max_output_boxes_per_class, diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index e80ed41d0..b623476f3 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -1,17 +1,54 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable, Dict, List, Optional, Union +from typing import (Any, Callable, Dict, List, MutableSequence, Optional, + Tuple, Union) from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, import_function) -def _set_func(origin_func_path: str, rewrite_func: Callable): +def _replace_all_obj(obj: Any, + new_obj: Any, + ignore_refs: Tuple[Any] = tuple(), + ignore_keys: Tuple[str] = tuple()): + """Replace all object reference with new_object. + + Args: + obj (Any): The object to be replaced. + new_obj (Any): The object to replace obj. + ignore_refs (Tuple[Any]): These refs will be ignored. + ignore_keys (Tuple[str]): object with these keys will be ignored. + """ + import gc + refs = gc.get_referrers(obj) + obj_id = id(obj) + for ref in refs: + if ref in ignore_refs: + continue + elif isinstance(ref, MutableSequence): + for i, v in enumerate(ref): + if id(v) == obj_id: + ref[i] == new_obj + elif isinstance(ref, Dict): + for k, v in ref.items(): + if id(v) == obj_id and k not in ignore_keys: + ref[k] = new_obj + else: + # TODO: check if we can replace tuple + pass + + +def _set_func(origin_func_path: str, + rewrite_func: Callable, + ignore_refs: Tuple[Any] = tuple(), + ignore_keys: Tuple[str] = ('origin_func', )): """Rewrite a function by executing a python statement. Args: origin_func_path (str): The path to origin function. rewrite_func (Callable): The new function instance. + ignore_refs (Tuple[Any]): These refs will be ignored. + ignore_keys (Tuple[str]): object with these keys will be ignored. """ # Import necessary module @@ -22,7 +59,19 @@ def _set_func(origin_func_path: str, rewrite_func: Callable): break except Exception: continue + origin_func = eval(origin_func_path) + method_class = False + if len(split_path) > 1: + module_or_class = eval('.'.join(split_path[:-1])) + if isinstance(module_or_class, type): + method_class = True # Assign function + if not method_class: + _replace_all_obj( + origin_func, + rewrite_func, + ignore_refs=ignore_refs, + ignore_keys=ignore_keys) exec(f'{origin_func_path} = rewrite_func') @@ -37,12 +86,11 @@ def _del_func(path: str): for i in range(len(split_path), 0, -1): try: exec('import {}'.format('.'.join(split_path[:i]))) + exec(f'del {path}') break except Exception: continue - exec(f'del {path}') - class FunctionRewriter: """A function rewriter which maintains rewritten functions. @@ -126,9 +174,10 @@ class FunctionRewriter: if is_addition_function: self._additional_functions.append(function_path) - else: - # Save origin function - self._origin_functions.append((function_path, origin_func)) + + # Save origin function + self._origin_functions.append( + dict(func_path=function_path, origin_func=origin_func)) # Create context_caller rewrite_function = record_dict['_object'] @@ -139,15 +188,20 @@ class FunctionRewriter: **extra_kwargs).get_wrapped_caller() # Cache new the function to avoid homonymic bug - new_functions.append((function_path, context_caller)) + new_functions.append( + dict(func_path=function_path, origin_func=context_caller)) - for function_path, new_function in new_functions: + for func_dict in new_functions: + function_path = func_dict['func_path'] + new_function = func_dict['origin_func'] # Rewrite functions _set_func(function_path, new_function) def exit(self): """Recover the function rewrite.""" - for func_path, func in self._origin_functions: + for func_dict in self._origin_functions: + func_path = func_dict['func_path'] + func = func_dict['origin_func'] _set_func(func_path, func) for func_path in self._additional_functions: _del_func(func_path) diff --git a/tests/test_core/test_function_rewriter.py b/tests/test_core/test_function_rewriter.py index 97a814e92..ca7a681c3 100644 --- a/tests/test_core/test_function_rewriter.py +++ b/tests/test_core/test_function_rewriter.py @@ -89,54 +89,6 @@ def test_rewrite_empty_function(): class TestHomonymicRewriter: - def test_rewrite_homonymic_functions(self): - import package - path1 = 'package.func' - path2 = 'package.module.func' - - assert package.func() == 1 - assert package.module.func() == 1 - - function_rewriter = FunctionRewriter() - - @function_rewriter.register_rewriter(func_name=path1) - def func_2(ctx): - return 2 - - @function_rewriter.register_rewriter( - func_name=path2, backend=Backend.NCNN.value) - def func_3(ctx): - return 3 - - function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) - # This is a feature - assert package.func() == 2 - assert package.module.func() == 3 - function_rewriter.exit() - - assert package.func() == 1 - assert package.module.func() == 1 - - function_rewriter2 = FunctionRewriter() - - @function_rewriter2.register_rewriter( - func_name=path1, backend=Backend.NCNN.value) - def func_4(ctx): - return 4 - - @function_rewriter2.register_rewriter(func_name=path2) - def func_5(ctx): - return 5 - - function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) - # This is a feature - assert package.func() == 4 - assert package.module.func() == 5 - function_rewriter2.exit() - - assert package.func() == 1 - assert package.module.func() == 1 - def test_rewrite_homonymic_methods(self): import package path1 = 'package.C.method'