mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Enhancement] Rewriter support pre-import function (#899)
* support preimport * update rewriter * fix batched nms ort
This commit is contained in:
parent
47d4e6f733
commit
13920ec1a2
@ -211,7 +211,7 @@ def multiclass_nms__default(ctx,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
||||
else:
|
||||
return _multiclass_nms(
|
||||
return ctx.origin_func(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_boxes_per_class=max_output_boxes_per_class,
|
||||
|
@ -1,17 +1,54 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import (Any, Callable, Dict, List, MutableSequence, Optional,
|
||||
Tuple, Union)
|
||||
|
||||
from mmdeploy.utils import IR, Backend, get_root_logger
|
||||
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
|
||||
import_function)
|
||||
|
||||
|
||||
def _set_func(origin_func_path: str, rewrite_func: Callable):
|
||||
def _replace_all_obj(obj: Any,
|
||||
new_obj: Any,
|
||||
ignore_refs: Tuple[Any] = tuple(),
|
||||
ignore_keys: Tuple[str] = tuple()):
|
||||
"""Replace all object reference with new_object.
|
||||
|
||||
Args:
|
||||
obj (Any): The object to be replaced.
|
||||
new_obj (Any): The object to replace obj.
|
||||
ignore_refs (Tuple[Any]): These refs will be ignored.
|
||||
ignore_keys (Tuple[str]): object with these keys will be ignored.
|
||||
"""
|
||||
import gc
|
||||
refs = gc.get_referrers(obj)
|
||||
obj_id = id(obj)
|
||||
for ref in refs:
|
||||
if ref in ignore_refs:
|
||||
continue
|
||||
elif isinstance(ref, MutableSequence):
|
||||
for i, v in enumerate(ref):
|
||||
if id(v) == obj_id:
|
||||
ref[i] == new_obj
|
||||
elif isinstance(ref, Dict):
|
||||
for k, v in ref.items():
|
||||
if id(v) == obj_id and k not in ignore_keys:
|
||||
ref[k] = new_obj
|
||||
else:
|
||||
# TODO: check if we can replace tuple
|
||||
pass
|
||||
|
||||
|
||||
def _set_func(origin_func_path: str,
|
||||
rewrite_func: Callable,
|
||||
ignore_refs: Tuple[Any] = tuple(),
|
||||
ignore_keys: Tuple[str] = ('origin_func', )):
|
||||
"""Rewrite a function by executing a python statement.
|
||||
|
||||
Args:
|
||||
origin_func_path (str): The path to origin function.
|
||||
rewrite_func (Callable): The new function instance.
|
||||
ignore_refs (Tuple[Any]): These refs will be ignored.
|
||||
ignore_keys (Tuple[str]): object with these keys will be ignored.
|
||||
"""
|
||||
|
||||
# Import necessary module
|
||||
@ -22,7 +59,19 @@ def _set_func(origin_func_path: str, rewrite_func: Callable):
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
origin_func = eval(origin_func_path)
|
||||
method_class = False
|
||||
if len(split_path) > 1:
|
||||
module_or_class = eval('.'.join(split_path[:-1]))
|
||||
if isinstance(module_or_class, type):
|
||||
method_class = True
|
||||
# Assign function
|
||||
if not method_class:
|
||||
_replace_all_obj(
|
||||
origin_func,
|
||||
rewrite_func,
|
||||
ignore_refs=ignore_refs,
|
||||
ignore_keys=ignore_keys)
|
||||
exec(f'{origin_func_path} = rewrite_func')
|
||||
|
||||
|
||||
@ -37,12 +86,11 @@ def _del_func(path: str):
|
||||
for i in range(len(split_path), 0, -1):
|
||||
try:
|
||||
exec('import {}'.format('.'.join(split_path[:i])))
|
||||
exec(f'del {path}')
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
exec(f'del {path}')
|
||||
|
||||
|
||||
class FunctionRewriter:
|
||||
"""A function rewriter which maintains rewritten functions.
|
||||
@ -126,9 +174,10 @@ class FunctionRewriter:
|
||||
|
||||
if is_addition_function:
|
||||
self._additional_functions.append(function_path)
|
||||
else:
|
||||
# Save origin function
|
||||
self._origin_functions.append((function_path, origin_func))
|
||||
|
||||
# Save origin function
|
||||
self._origin_functions.append(
|
||||
dict(func_path=function_path, origin_func=origin_func))
|
||||
|
||||
# Create context_caller
|
||||
rewrite_function = record_dict['_object']
|
||||
@ -139,15 +188,20 @@ class FunctionRewriter:
|
||||
**extra_kwargs).get_wrapped_caller()
|
||||
|
||||
# Cache new the function to avoid homonymic bug
|
||||
new_functions.append((function_path, context_caller))
|
||||
new_functions.append(
|
||||
dict(func_path=function_path, origin_func=context_caller))
|
||||
|
||||
for function_path, new_function in new_functions:
|
||||
for func_dict in new_functions:
|
||||
function_path = func_dict['func_path']
|
||||
new_function = func_dict['origin_func']
|
||||
# Rewrite functions
|
||||
_set_func(function_path, new_function)
|
||||
|
||||
def exit(self):
|
||||
"""Recover the function rewrite."""
|
||||
for func_path, func in self._origin_functions:
|
||||
for func_dict in self._origin_functions:
|
||||
func_path = func_dict['func_path']
|
||||
func = func_dict['origin_func']
|
||||
_set_func(func_path, func)
|
||||
for func_path in self._additional_functions:
|
||||
_del_func(func_path)
|
||||
|
@ -89,54 +89,6 @@ def test_rewrite_empty_function():
|
||||
|
||||
class TestHomonymicRewriter:
|
||||
|
||||
def test_rewrite_homonymic_functions(self):
|
||||
import package
|
||||
path1 = 'package.func'
|
||||
path2 = 'package.module.func'
|
||||
|
||||
assert package.func() == 1
|
||||
assert package.module.func() == 1
|
||||
|
||||
function_rewriter = FunctionRewriter()
|
||||
|
||||
@function_rewriter.register_rewriter(func_name=path1)
|
||||
def func_2(ctx):
|
||||
return 2
|
||||
|
||||
@function_rewriter.register_rewriter(
|
||||
func_name=path2, backend=Backend.NCNN.value)
|
||||
def func_3(ctx):
|
||||
return 3
|
||||
|
||||
function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT))
|
||||
# This is a feature
|
||||
assert package.func() == 2
|
||||
assert package.module.func() == 3
|
||||
function_rewriter.exit()
|
||||
|
||||
assert package.func() == 1
|
||||
assert package.module.func() == 1
|
||||
|
||||
function_rewriter2 = FunctionRewriter()
|
||||
|
||||
@function_rewriter2.register_rewriter(
|
||||
func_name=path1, backend=Backend.NCNN.value)
|
||||
def func_4(ctx):
|
||||
return 4
|
||||
|
||||
@function_rewriter2.register_rewriter(func_name=path2)
|
||||
def func_5(ctx):
|
||||
return 5
|
||||
|
||||
function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT))
|
||||
# This is a feature
|
||||
assert package.func() == 4
|
||||
assert package.module.func() == 5
|
||||
function_rewriter2.exit()
|
||||
|
||||
assert package.func() == 1
|
||||
assert package.module.func() == 1
|
||||
|
||||
def test_rewrite_homonymic_methods(self):
|
||||
import package
|
||||
path1 = 'package.C.method'
|
||||
|
Loading…
x
Reference in New Issue
Block a user