mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Enhancement] Rewriter support pre-import function (#899)
* support preimport * update rewriter * fix batched nms ort
This commit is contained in:
parent
47d4e6f733
commit
13920ec1a2
@ -211,7 +211,7 @@ def multiclass_nms__default(ctx,
|
|||||||
pre_top_k=pre_top_k,
|
pre_top_k=pre_top_k,
|
||||||
keep_top_k=keep_top_k)
|
keep_top_k=keep_top_k)
|
||||||
else:
|
else:
|
||||||
return _multiclass_nms(
|
return ctx.origin_func(
|
||||||
boxes,
|
boxes,
|
||||||
scores,
|
scores,
|
||||||
max_output_boxes_per_class=max_output_boxes_per_class,
|
max_output_boxes_per_class=max_output_boxes_per_class,
|
||||||
|
@ -1,17 +1,54 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import (Any, Callable, Dict, List, MutableSequence, Optional,
|
||||||
|
Tuple, Union)
|
||||||
|
|
||||||
from mmdeploy.utils import IR, Backend, get_root_logger
|
from mmdeploy.utils import IR, Backend, get_root_logger
|
||||||
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
|
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
|
||||||
import_function)
|
import_function)
|
||||||
|
|
||||||
|
|
||||||
def _set_func(origin_func_path: str, rewrite_func: Callable):
|
def _replace_all_obj(obj: Any,
|
||||||
|
new_obj: Any,
|
||||||
|
ignore_refs: Tuple[Any] = tuple(),
|
||||||
|
ignore_keys: Tuple[str] = tuple()):
|
||||||
|
"""Replace all object reference with new_object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (Any): The object to be replaced.
|
||||||
|
new_obj (Any): The object to replace obj.
|
||||||
|
ignore_refs (Tuple[Any]): These refs will be ignored.
|
||||||
|
ignore_keys (Tuple[str]): object with these keys will be ignored.
|
||||||
|
"""
|
||||||
|
import gc
|
||||||
|
refs = gc.get_referrers(obj)
|
||||||
|
obj_id = id(obj)
|
||||||
|
for ref in refs:
|
||||||
|
if ref in ignore_refs:
|
||||||
|
continue
|
||||||
|
elif isinstance(ref, MutableSequence):
|
||||||
|
for i, v in enumerate(ref):
|
||||||
|
if id(v) == obj_id:
|
||||||
|
ref[i] == new_obj
|
||||||
|
elif isinstance(ref, Dict):
|
||||||
|
for k, v in ref.items():
|
||||||
|
if id(v) == obj_id and k not in ignore_keys:
|
||||||
|
ref[k] = new_obj
|
||||||
|
else:
|
||||||
|
# TODO: check if we can replace tuple
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _set_func(origin_func_path: str,
|
||||||
|
rewrite_func: Callable,
|
||||||
|
ignore_refs: Tuple[Any] = tuple(),
|
||||||
|
ignore_keys: Tuple[str] = ('origin_func', )):
|
||||||
"""Rewrite a function by executing a python statement.
|
"""Rewrite a function by executing a python statement.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
origin_func_path (str): The path to origin function.
|
origin_func_path (str): The path to origin function.
|
||||||
rewrite_func (Callable): The new function instance.
|
rewrite_func (Callable): The new function instance.
|
||||||
|
ignore_refs (Tuple[Any]): These refs will be ignored.
|
||||||
|
ignore_keys (Tuple[str]): object with these keys will be ignored.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Import necessary module
|
# Import necessary module
|
||||||
@ -22,7 +59,19 @@ def _set_func(origin_func_path: str, rewrite_func: Callable):
|
|||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
origin_func = eval(origin_func_path)
|
||||||
|
method_class = False
|
||||||
|
if len(split_path) > 1:
|
||||||
|
module_or_class = eval('.'.join(split_path[:-1]))
|
||||||
|
if isinstance(module_or_class, type):
|
||||||
|
method_class = True
|
||||||
# Assign function
|
# Assign function
|
||||||
|
if not method_class:
|
||||||
|
_replace_all_obj(
|
||||||
|
origin_func,
|
||||||
|
rewrite_func,
|
||||||
|
ignore_refs=ignore_refs,
|
||||||
|
ignore_keys=ignore_keys)
|
||||||
exec(f'{origin_func_path} = rewrite_func')
|
exec(f'{origin_func_path} = rewrite_func')
|
||||||
|
|
||||||
|
|
||||||
@ -37,12 +86,11 @@ def _del_func(path: str):
|
|||||||
for i in range(len(split_path), 0, -1):
|
for i in range(len(split_path), 0, -1):
|
||||||
try:
|
try:
|
||||||
exec('import {}'.format('.'.join(split_path[:i])))
|
exec('import {}'.format('.'.join(split_path[:i])))
|
||||||
|
exec(f'del {path}')
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
exec(f'del {path}')
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionRewriter:
|
class FunctionRewriter:
|
||||||
"""A function rewriter which maintains rewritten functions.
|
"""A function rewriter which maintains rewritten functions.
|
||||||
@ -126,9 +174,10 @@ class FunctionRewriter:
|
|||||||
|
|
||||||
if is_addition_function:
|
if is_addition_function:
|
||||||
self._additional_functions.append(function_path)
|
self._additional_functions.append(function_path)
|
||||||
else:
|
|
||||||
# Save origin function
|
# Save origin function
|
||||||
self._origin_functions.append((function_path, origin_func))
|
self._origin_functions.append(
|
||||||
|
dict(func_path=function_path, origin_func=origin_func))
|
||||||
|
|
||||||
# Create context_caller
|
# Create context_caller
|
||||||
rewrite_function = record_dict['_object']
|
rewrite_function = record_dict['_object']
|
||||||
@ -139,15 +188,20 @@ class FunctionRewriter:
|
|||||||
**extra_kwargs).get_wrapped_caller()
|
**extra_kwargs).get_wrapped_caller()
|
||||||
|
|
||||||
# Cache new the function to avoid homonymic bug
|
# Cache new the function to avoid homonymic bug
|
||||||
new_functions.append((function_path, context_caller))
|
new_functions.append(
|
||||||
|
dict(func_path=function_path, origin_func=context_caller))
|
||||||
|
|
||||||
for function_path, new_function in new_functions:
|
for func_dict in new_functions:
|
||||||
|
function_path = func_dict['func_path']
|
||||||
|
new_function = func_dict['origin_func']
|
||||||
# Rewrite functions
|
# Rewrite functions
|
||||||
_set_func(function_path, new_function)
|
_set_func(function_path, new_function)
|
||||||
|
|
||||||
def exit(self):
|
def exit(self):
|
||||||
"""Recover the function rewrite."""
|
"""Recover the function rewrite."""
|
||||||
for func_path, func in self._origin_functions:
|
for func_dict in self._origin_functions:
|
||||||
|
func_path = func_dict['func_path']
|
||||||
|
func = func_dict['origin_func']
|
||||||
_set_func(func_path, func)
|
_set_func(func_path, func)
|
||||||
for func_path in self._additional_functions:
|
for func_path in self._additional_functions:
|
||||||
_del_func(func_path)
|
_del_func(func_path)
|
||||||
|
@ -89,54 +89,6 @@ def test_rewrite_empty_function():
|
|||||||
|
|
||||||
class TestHomonymicRewriter:
|
class TestHomonymicRewriter:
|
||||||
|
|
||||||
def test_rewrite_homonymic_functions(self):
|
|
||||||
import package
|
|
||||||
path1 = 'package.func'
|
|
||||||
path2 = 'package.module.func'
|
|
||||||
|
|
||||||
assert package.func() == 1
|
|
||||||
assert package.module.func() == 1
|
|
||||||
|
|
||||||
function_rewriter = FunctionRewriter()
|
|
||||||
|
|
||||||
@function_rewriter.register_rewriter(func_name=path1)
|
|
||||||
def func_2(ctx):
|
|
||||||
return 2
|
|
||||||
|
|
||||||
@function_rewriter.register_rewriter(
|
|
||||||
func_name=path2, backend=Backend.NCNN.value)
|
|
||||||
def func_3(ctx):
|
|
||||||
return 3
|
|
||||||
|
|
||||||
function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT))
|
|
||||||
# This is a feature
|
|
||||||
assert package.func() == 2
|
|
||||||
assert package.module.func() == 3
|
|
||||||
function_rewriter.exit()
|
|
||||||
|
|
||||||
assert package.func() == 1
|
|
||||||
assert package.module.func() == 1
|
|
||||||
|
|
||||||
function_rewriter2 = FunctionRewriter()
|
|
||||||
|
|
||||||
@function_rewriter2.register_rewriter(
|
|
||||||
func_name=path1, backend=Backend.NCNN.value)
|
|
||||||
def func_4(ctx):
|
|
||||||
return 4
|
|
||||||
|
|
||||||
@function_rewriter2.register_rewriter(func_name=path2)
|
|
||||||
def func_5(ctx):
|
|
||||||
return 5
|
|
||||||
|
|
||||||
function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT))
|
|
||||||
# This is a feature
|
|
||||||
assert package.func() == 4
|
|
||||||
assert package.module.func() == 5
|
|
||||||
function_rewriter2.exit()
|
|
||||||
|
|
||||||
assert package.func() == 1
|
|
||||||
assert package.module.func() == 1
|
|
||||||
|
|
||||||
def test_rewrite_homonymic_methods(self):
|
def test_rewrite_homonymic_methods(self):
|
||||||
import package
|
import package
|
||||||
path1 = 'package.C.method'
|
path1 = 'package.C.method'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user