[Enhancement] Rewriter support pre-import function (#899)

* support preimport

* update rewriter

* fix batched nms ort
This commit is contained in:
q.yao 2022-08-30 12:32:43 +08:00 committed by GitHub
parent 47d4e6f733
commit 13920ec1a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 59 deletions

View File

@ -211,7 +211,7 @@ def multiclass_nms__default(ctx,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
else:
return _multiclass_nms(
return ctx.origin_func(
boxes,
scores,
max_output_boxes_per_class=max_output_boxes_per_class,

View File

@ -1,17 +1,54 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, List, Optional, Union
from typing import (Any, Callable, Dict, List, MutableSequence, Optional,
Tuple, Union)
from mmdeploy.utils import IR, Backend, get_root_logger
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
import_function)
def _set_func(origin_func_path: str, rewrite_func: Callable):
def _replace_all_obj(obj: Any,
new_obj: Any,
ignore_refs: Tuple[Any] = tuple(),
ignore_keys: Tuple[str] = tuple()):
"""Replace all object reference with new_object.
Args:
obj (Any): The object to be replaced.
new_obj (Any): The object to replace obj.
ignore_refs (Tuple[Any]): These refs will be ignored.
ignore_keys (Tuple[str]): object with these keys will be ignored.
"""
import gc
refs = gc.get_referrers(obj)
obj_id = id(obj)
for ref in refs:
if ref in ignore_refs:
continue
elif isinstance(ref, MutableSequence):
for i, v in enumerate(ref):
if id(v) == obj_id:
ref[i] == new_obj
elif isinstance(ref, Dict):
for k, v in ref.items():
if id(v) == obj_id and k not in ignore_keys:
ref[k] = new_obj
else:
# TODO: check if we can replace tuple
pass
def _set_func(origin_func_path: str,
rewrite_func: Callable,
ignore_refs: Tuple[Any] = tuple(),
ignore_keys: Tuple[str] = ('origin_func', )):
"""Rewrite a function by executing a python statement.
Args:
origin_func_path (str): The path to origin function.
rewrite_func (Callable): The new function instance.
ignore_refs (Tuple[Any]): These refs will be ignored.
ignore_keys (Tuple[str]): object with these keys will be ignored.
"""
# Import necessary module
@ -22,7 +59,19 @@ def _set_func(origin_func_path: str, rewrite_func: Callable):
break
except Exception:
continue
origin_func = eval(origin_func_path)
method_class = False
if len(split_path) > 1:
module_or_class = eval('.'.join(split_path[:-1]))
if isinstance(module_or_class, type):
method_class = True
# Assign function
if not method_class:
_replace_all_obj(
origin_func,
rewrite_func,
ignore_refs=ignore_refs,
ignore_keys=ignore_keys)
exec(f'{origin_func_path} = rewrite_func')
@ -37,12 +86,11 @@ def _del_func(path: str):
for i in range(len(split_path), 0, -1):
try:
exec('import {}'.format('.'.join(split_path[:i])))
exec(f'del {path}')
break
except Exception:
continue
exec(f'del {path}')
class FunctionRewriter:
"""A function rewriter which maintains rewritten functions.
@ -126,9 +174,10 @@ class FunctionRewriter:
if is_addition_function:
self._additional_functions.append(function_path)
else:
# Save origin function
self._origin_functions.append((function_path, origin_func))
# Save origin function
self._origin_functions.append(
dict(func_path=function_path, origin_func=origin_func))
# Create context_caller
rewrite_function = record_dict['_object']
@ -139,15 +188,20 @@ class FunctionRewriter:
**extra_kwargs).get_wrapped_caller()
# Cache new the function to avoid homonymic bug
new_functions.append((function_path, context_caller))
new_functions.append(
dict(func_path=function_path, origin_func=context_caller))
for function_path, new_function in new_functions:
for func_dict in new_functions:
function_path = func_dict['func_path']
new_function = func_dict['origin_func']
# Rewrite functions
_set_func(function_path, new_function)
def exit(self):
"""Recover the function rewrite."""
for func_path, func in self._origin_functions:
for func_dict in self._origin_functions:
func_path = func_dict['func_path']
func = func_dict['origin_func']
_set_func(func_path, func)
for func_path in self._additional_functions:
_del_func(func_path)

View File

@ -89,54 +89,6 @@ def test_rewrite_empty_function():
class TestHomonymicRewriter:
def test_rewrite_homonymic_functions(self):
import package
path1 = 'package.func'
path2 = 'package.module.func'
assert package.func() == 1
assert package.module.func() == 1
function_rewriter = FunctionRewriter()
@function_rewriter.register_rewriter(func_name=path1)
def func_2(ctx):
return 2
@function_rewriter.register_rewriter(
func_name=path2, backend=Backend.NCNN.value)
def func_3(ctx):
return 3
function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT))
# This is a feature
assert package.func() == 2
assert package.module.func() == 3
function_rewriter.exit()
assert package.func() == 1
assert package.module.func() == 1
function_rewriter2 = FunctionRewriter()
@function_rewriter2.register_rewriter(
func_name=path1, backend=Backend.NCNN.value)
def func_4(ctx):
return 4
@function_rewriter2.register_rewriter(func_name=path2)
def func_5(ctx):
return 5
function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT))
# This is a feature
assert package.func() == 4
assert package.module.func() == 5
function_rewriter2.exit()
assert package.func() == 1
assert package.module.func() == 1
def test_rewrite_homonymic_methods(self):
import package
path1 = 'package.C.method'