mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix]: Fix rewriter conflict when processing derived class (#289)
* Fix rewriter * lint * rename function and update docstring * use is class * Update docstring
This commit is contained in:
parent
0dea300714
commit
78b37bbd32
@ -3,14 +3,19 @@ import logging
|
||||
from typing import Callable, Dict
|
||||
|
||||
from mmdeploy.utils.constants import Backend
|
||||
from .rewriter_utils import ContextCaller, RewriterRegistry, eval_with_import
|
||||
from .rewriter_utils import ContextCaller, RewriterRegistry, import_function
|
||||
|
||||
|
||||
def _set_func(origin_func_name: str, rewrite_func: Callable):
|
||||
"""Rewrite a function by executing a python statement."""
|
||||
def _set_func(origin_func_path: str, rewrite_func: Callable):
|
||||
"""Rewrite a function by executing a python statement.
|
||||
|
||||
Args:
|
||||
origin_func_path (str): The path to origin function.
|
||||
rewrite_func (Callable): The new function instance.
|
||||
"""
|
||||
|
||||
# Import necessary module
|
||||
split_path = origin_func_name.split('.')
|
||||
split_path = origin_func_path.split('.')
|
||||
for i in range(len(split_path), 0, -1):
|
||||
try:
|
||||
exec('import {}'.format('.'.join(split_path[:i])))
|
||||
@ -18,7 +23,25 @@ def _set_func(origin_func_name: str, rewrite_func: Callable):
|
||||
except Exception:
|
||||
continue
|
||||
# Assign function
|
||||
exec(f'{origin_func_name} = rewrite_func')
|
||||
exec(f'{origin_func_path} = rewrite_func')
|
||||
|
||||
|
||||
def _del_func(path: str):
|
||||
"""Delete a function that is denoted by a path.
|
||||
|
||||
Args:
|
||||
path (str): The path to evaluate.
|
||||
"""
|
||||
|
||||
split_path = path.split('.')
|
||||
for i in range(len(split_path), 0, -1):
|
||||
try:
|
||||
exec('import {}'.format('.'.join(split_path[:i])))
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
exec(f'del {path}')
|
||||
|
||||
|
||||
class FunctionRewriter:
|
||||
@ -72,23 +95,38 @@ class FunctionRewriter:
|
||||
functions_records = self._registry.get_records(backend)
|
||||
|
||||
self._origin_functions = list()
|
||||
self._additional_functions = list()
|
||||
new_functions = list()
|
||||
for function_name, record_dict in functions_records:
|
||||
for function_path, record_dict in functions_records:
|
||||
|
||||
# Check if the origin function exists
|
||||
try:
|
||||
origin_func = eval_with_import(function_name)
|
||||
origin_func, origin_class = import_function(function_path)
|
||||
except Exception:
|
||||
origin_func = None
|
||||
logging.warning(
|
||||
f'Can not find {function_name}, function rewrite will '
|
||||
f'Can not find {function_path}, function rewrite will '
|
||||
'not be applied')
|
||||
|
||||
# Only rewrite functions that exist
|
||||
if origin_func is not None:
|
||||
|
||||
# Save origin function
|
||||
self._origin_functions.append((function_name, origin_func))
|
||||
is_addition_function = False
|
||||
if origin_class is not None:
|
||||
function_name = function_path.split('.')[-1]
|
||||
try:
|
||||
origin_class.__getattribute__(origin_class,
|
||||
function_name)
|
||||
except Exception:
|
||||
# The function is a method and it is derived from base
|
||||
# class.
|
||||
is_addition_function = True
|
||||
|
||||
if is_addition_function:
|
||||
self._additional_functions.append(function_path)
|
||||
else:
|
||||
# Save origin function
|
||||
self._origin_functions.append((function_path, origin_func))
|
||||
|
||||
# Create context_caller
|
||||
rewrite_function = record_dict['_object']
|
||||
@ -99,13 +137,15 @@ class FunctionRewriter:
|
||||
**extra_kwargs).get_wrapped_caller()
|
||||
|
||||
# Cache new the function to avoid homonymic bug
|
||||
new_functions.append((function_name, context_caller))
|
||||
new_functions.append((function_path, context_caller))
|
||||
|
||||
for function_name, new_function in new_functions:
|
||||
for function_path, new_function in new_functions:
|
||||
# Rewrite functions
|
||||
_set_func(function_name, new_function)
|
||||
_set_func(function_path, new_function)
|
||||
|
||||
def exit(self):
|
||||
"""Recover the function rewrite."""
|
||||
for func_name, func in self._origin_functions:
|
||||
_set_func(func_name, func)
|
||||
for func_path, func in self._origin_functions:
|
||||
_set_func(func_path, func)
|
||||
for func_path in self._additional_functions:
|
||||
_del_func(func_path)
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Callable, Dict, List
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from mmdeploy.utils.constants import Backend
|
||||
|
||||
@ -11,7 +12,7 @@ def eval_with_import(path: str) -> Any:
|
||||
path (str): The path to evaluate.
|
||||
|
||||
Returns:
|
||||
Any: The result of evaluate.
|
||||
Any: The result of evaluation.
|
||||
"""
|
||||
split_path = path.split('.')
|
||||
for i in range(len(split_path), 0, -1):
|
||||
@ -23,6 +24,38 @@ def eval_with_import(path: str) -> Any:
|
||||
return eval(path)
|
||||
|
||||
|
||||
def import_function(path: str) -> Tuple[Callable, Optional[type]]:
|
||||
"""Import and evaluate a function. If the function is defined in a class,
|
||||
evaluate the class additionally.
|
||||
|
||||
Args:
|
||||
path (str): The path to evaluate.
|
||||
|
||||
Returns:
|
||||
Callable: The function of evaluation.
|
||||
type: The class of evaluation if the function is defined in a class, or
|
||||
None.
|
||||
"""
|
||||
split_path = path.split('.')
|
||||
for i in range(len(split_path), 0, -1):
|
||||
try:
|
||||
exec('import {}'.format('.'.join(split_path[:i])))
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
obj = eval(path)
|
||||
|
||||
# The path that might be a class
|
||||
previous_obj = eval('.'.join(split_path[:-1]))
|
||||
|
||||
# Check if the path leads to a class
|
||||
if inspect.isclass(previous_obj):
|
||||
return obj, previous_obj
|
||||
else:
|
||||
return obj, None
|
||||
|
||||
|
||||
class RewriterRegistry:
|
||||
"""A registry that recoreds rewrite objects.
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .module import C, func
|
||||
from .module import C2, C, func
|
||||
|
||||
__all__ = ['func', 'C']
|
||||
__all__ = ['func', 'C', 'C2']
|
||||
|
@ -7,3 +7,7 @@ class C:
|
||||
|
||||
def method(self):
|
||||
return 1
|
||||
|
||||
|
||||
class C2(C):
|
||||
pass
|
||||
|
@ -182,3 +182,49 @@ class TestHomonymicRewriter:
|
||||
function_rewriter2.exit()
|
||||
|
||||
assert c.method() == 1
|
||||
|
||||
|
||||
def test_rewrite_derived_methods():
|
||||
import package
|
||||
path1 = 'package.C.method'
|
||||
path2 = 'package.C2.method'
|
||||
|
||||
base_obj = package.C()
|
||||
derived_obj = package.C2()
|
||||
|
||||
assert base_obj.method() == 1
|
||||
assert derived_obj.method() == 1
|
||||
|
||||
function_rewriter = FunctionRewriter()
|
||||
function_rewriter.add_backend(Backend.NCNN.value)
|
||||
|
||||
@function_rewriter.register_rewriter(func_name=path1)
|
||||
def func_2(ctx, self):
|
||||
return 2
|
||||
|
||||
@function_rewriter.register_rewriter(
|
||||
func_name=path2, backend=Backend.NCNN.value)
|
||||
def func_3(ctx, self):
|
||||
return 3
|
||||
|
||||
function_rewriter.enter()
|
||||
assert base_obj.method() == 2
|
||||
assert derived_obj.method() == 2
|
||||
function_rewriter.exit()
|
||||
|
||||
function_rewriter.enter(backend=Backend.NCNN.value)
|
||||
assert base_obj.method() == 2
|
||||
assert derived_obj.method() == 3
|
||||
function_rewriter.exit()
|
||||
|
||||
assert base_obj.method() == 1
|
||||
assert derived_obj.method() == 1
|
||||
|
||||
# Check if the recovery is correct
|
||||
function_rewriter.enter()
|
||||
assert base_obj.method() == 2
|
||||
assert derived_obj.method() == 2
|
||||
function_rewriter.exit()
|
||||
|
||||
assert base_obj.method() == 1
|
||||
assert derived_obj.method() == 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user