[Enhancement] Rewriter support pre-import function (#899)

* support preimport

* update rewriter

* fix batched nms ort
This commit is contained in:
q.yao 2022-08-30 12:32:43 +08:00 committed by GitHub
parent 47d4e6f733
commit 13920ec1a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 59 deletions

View File

@ -211,7 +211,7 @@ def multiclass_nms__default(ctx,
pre_top_k=pre_top_k, pre_top_k=pre_top_k,
keep_top_k=keep_top_k) keep_top_k=keep_top_k)
else: else:
return _multiclass_nms( return ctx.origin_func(
boxes, boxes,
scores, scores,
max_output_boxes_per_class=max_output_boxes_per_class, max_output_boxes_per_class=max_output_boxes_per_class,

View File

@ -1,17 +1,54 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, List, Optional, Union from typing import (Any, Callable, Dict, List, MutableSequence, Optional,
Tuple, Union)
from mmdeploy.utils import IR, Backend, get_root_logger from mmdeploy.utils import IR, Backend, get_root_logger
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
import_function) import_function)
def _set_func(origin_func_path: str, rewrite_func: Callable): def _replace_all_obj(obj: Any,
new_obj: Any,
ignore_refs: Tuple[Any] = tuple(),
ignore_keys: Tuple[str] = tuple()):
"""Replace all object reference with new_object.
Args:
obj (Any): The object to be replaced.
new_obj (Any): The object to replace obj.
ignore_refs (Tuple[Any]): These refs will be ignored.
ignore_keys (Tuple[str]): object with these keys will be ignored.
"""
import gc
refs = gc.get_referrers(obj)
obj_id = id(obj)
for ref in refs:
if ref in ignore_refs:
continue
elif isinstance(ref, MutableSequence):
for i, v in enumerate(ref):
if id(v) == obj_id:
ref[i] == new_obj
elif isinstance(ref, Dict):
for k, v in ref.items():
if id(v) == obj_id and k not in ignore_keys:
ref[k] = new_obj
else:
# TODO: check if we can replace tuple
pass
def _set_func(origin_func_path: str,
rewrite_func: Callable,
ignore_refs: Tuple[Any] = tuple(),
ignore_keys: Tuple[str] = ('origin_func', )):
"""Rewrite a function by executing a python statement. """Rewrite a function by executing a python statement.
Args: Args:
origin_func_path (str): The path to origin function. origin_func_path (str): The path to origin function.
rewrite_func (Callable): The new function instance. rewrite_func (Callable): The new function instance.
ignore_refs (Tuple[Any]): These refs will be ignored.
ignore_keys (Tuple[str]): object with these keys will be ignored.
""" """
# Import necessary module # Import necessary module
@ -22,7 +59,19 @@ def _set_func(origin_func_path: str, rewrite_func: Callable):
break break
except Exception: except Exception:
continue continue
origin_func = eval(origin_func_path)
method_class = False
if len(split_path) > 1:
module_or_class = eval('.'.join(split_path[:-1]))
if isinstance(module_or_class, type):
method_class = True
# Assign function # Assign function
if not method_class:
_replace_all_obj(
origin_func,
rewrite_func,
ignore_refs=ignore_refs,
ignore_keys=ignore_keys)
exec(f'{origin_func_path} = rewrite_func') exec(f'{origin_func_path} = rewrite_func')
@ -37,12 +86,11 @@ def _del_func(path: str):
for i in range(len(split_path), 0, -1): for i in range(len(split_path), 0, -1):
try: try:
exec('import {}'.format('.'.join(split_path[:i]))) exec('import {}'.format('.'.join(split_path[:i])))
exec(f'del {path}')
break break
except Exception: except Exception:
continue continue
exec(f'del {path}')
class FunctionRewriter: class FunctionRewriter:
"""A function rewriter which maintains rewritten functions. """A function rewriter which maintains rewritten functions.
@ -126,9 +174,10 @@ class FunctionRewriter:
if is_addition_function: if is_addition_function:
self._additional_functions.append(function_path) self._additional_functions.append(function_path)
else:
# Save origin function # Save origin function
self._origin_functions.append((function_path, origin_func)) self._origin_functions.append(
dict(func_path=function_path, origin_func=origin_func))
# Create context_caller # Create context_caller
rewrite_function = record_dict['_object'] rewrite_function = record_dict['_object']
@ -139,15 +188,20 @@ class FunctionRewriter:
**extra_kwargs).get_wrapped_caller() **extra_kwargs).get_wrapped_caller()
# Cache new the function to avoid homonymic bug # Cache new the function to avoid homonymic bug
new_functions.append((function_path, context_caller)) new_functions.append(
dict(func_path=function_path, origin_func=context_caller))
for function_path, new_function in new_functions: for func_dict in new_functions:
function_path = func_dict['func_path']
new_function = func_dict['origin_func']
# Rewrite functions # Rewrite functions
_set_func(function_path, new_function) _set_func(function_path, new_function)
def exit(self): def exit(self):
"""Recover the function rewrite.""" """Recover the function rewrite."""
for func_path, func in self._origin_functions: for func_dict in self._origin_functions:
func_path = func_dict['func_path']
func = func_dict['origin_func']
_set_func(func_path, func) _set_func(func_path, func)
for func_path in self._additional_functions: for func_path in self._additional_functions:
_del_func(func_path) _del_func(func_path)

View File

@ -89,54 +89,6 @@ def test_rewrite_empty_function():
class TestHomonymicRewriter: class TestHomonymicRewriter:
def test_rewrite_homonymic_functions(self):
import package
path1 = 'package.func'
path2 = 'package.module.func'
assert package.func() == 1
assert package.module.func() == 1
function_rewriter = FunctionRewriter()
@function_rewriter.register_rewriter(func_name=path1)
def func_2(ctx):
return 2
@function_rewriter.register_rewriter(
func_name=path2, backend=Backend.NCNN.value)
def func_3(ctx):
return 3
function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT))
# This is a feature
assert package.func() == 2
assert package.module.func() == 3
function_rewriter.exit()
assert package.func() == 1
assert package.module.func() == 1
function_rewriter2 = FunctionRewriter()
@function_rewriter2.register_rewriter(
func_name=path1, backend=Backend.NCNN.value)
def func_4(ctx):
return 4
@function_rewriter2.register_rewriter(func_name=path2)
def func_5(ctx):
return 5
function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT))
# This is a feature
assert package.func() == 4
assert package.module.func() == 5
function_rewriter2.exit()
assert package.func() == 1
assert package.module.func() == 1
def test_rewrite_homonymic_methods(self): def test_rewrite_homonymic_methods(self):
import package import package
path1 = 'package.C.method' path1 = 'package.C.method'