[Fix] move target_wrapper into utils (#20)

* move target_wrapper into utils

* fix for lint

* add typehint and docstring

* update unit test

* fix isort

* update import
This commit is contained in:
q.yao 2022-01-11 15:43:47 +08:00 committed by GitHub
parent d1528e5b34
commit d157243077
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 24 deletions

View File

@ -8,7 +8,7 @@ from .config_utils import (cfg_apply_marks, get_backend, get_backend_config,
is_dynamic_batch, is_dynamic_shape, load_config)
from .constants import Backend, Codebase, Task
from .device import parse_cuda_device_id, parse_device_id
from .utils import get_root_logger
from .utils import get_root_logger, target_wrapper
__all__ = [
'is_dynamic_batch', 'is_dynamic_shape', 'get_task_type', 'get_codebase',
@ -17,5 +17,6 @@ __all__ = [
'get_calib_config', 'get_calib_filename', 'get_common_config',
'get_model_inputs', 'cfg_apply_marks', 'get_input_shape',
'parse_device_id', 'parse_cuda_device_id', 'get_codebase_config',
'get_backend_config', 'get_root_logger', 'get_dynamic_axes'
'get_backend_config', 'get_root_logger', 'get_dynamic_axes',
'target_wrapper'
]

View File

@ -1,9 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import sys
import traceback
from typing import Callable, Optional
import torch.multiprocessing as mp
from mmcv.utils import get_logger
def target_wrapper(target: Callable,
log_level: int,
ret_value: Optional[mp.Value] = None,
*args,
**kwargs):
"""The wrapper used to start a new subprocess.
Args:
target (Callable): The target function to be wrapped.
log_level (int): Log level for logging.
ret_value (mp.Value): The success flag of target.
Return:
Any: The return of target.
"""
logger = logging.getLogger()
logging.basicConfig(
format='%(asctime)s,%(name)s %(levelname)-8s'
' [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d:%H:%M:%S')
logger.level
logger.setLevel(log_level)
if ret_value is not None:
ret_value.value = -1
try:
result = target(*args, **kwargs)
if ret_value is not None:
ret_value.value = 0
return result
except Exception as e:
logging.error(e)
traceback.print_exc(file=sys.stdout)
def get_root_logger(log_file=None, log_level=logging.INFO) -> logging.Logger:
"""Get root logger.

View File

@ -1,11 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import tempfile
from functools import partial
import mmcv
import pytest
import torch.multiprocessing as mp
import mmdeploy.utils as util
from mmdeploy.utils import target_wrapper
from mmdeploy.utils.constants import Backend, Codebase, Task
from mmdeploy.utils.export_info import dump_info
from mmdeploy.utils.test import get_random_name
@ -390,6 +394,24 @@ def test_export_info():
assert os.path.exists(deploy_json)
def test_target_wrapper():
def target():
return 0
log_level = logging.INFO
ret_value = mp.Value('d', 0, lock=False)
ret_value.value = -1
wrap_func = partial(target_wrapper, target, log_level, ret_value)
process = mp.Process(target=wrap_func)
process.start()
process.join()
assert ret_value.value == 0
def test_get_root_logger():
from mmdeploy.utils import get_root_logger
logger = get_root_logger()

View File

@ -2,8 +2,6 @@
import argparse
import logging
import os.path as osp
import sys
import traceback
from functools import partial
import mmcv
@ -15,7 +13,8 @@ from mmdeploy.apis import (create_calib_table, extract_model,
visualize_model)
from mmdeploy.utils import (Backend, get_backend, get_calib_filename,
get_ir_config, get_model_inputs, get_onnx_config,
get_partition_config, get_root_logger, load_config)
get_partition_config, get_root_logger, load_config,
target_wrapper)
from mmdeploy.utils.export_info import dump_info
@ -48,20 +47,6 @@ def parse_args():
return args
def target_wrapper(target, log_level, ret_value, *args, **kwargs):
logger = get_root_logger(log_level=log_level)
if ret_value is not None:
ret_value.value = -1
try:
result = target(*args, **kwargs)
if ret_value is not None:
ret_value.value = 0
return result
except Exception as e:
logger.error(e)
traceback.print_exc(file=sys.stdout)
def create_process(name, target, args, kwargs, ret_value=None):
logger = get_root_logger()
logger.info(f'{name} start.')
@ -198,7 +183,7 @@ def main():
logger.error('ncnn support is not available.')
exit(-1)
from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file
from mmdeploy.apis.ncnn import get_output_model_file, onnx2ncnn
backend_files = []
for onnx_path in onnx_files:
@ -218,9 +203,9 @@ def main():
assert is_available_openvino(), \
'OpenVINO is not available, please install OpenVINO first.'
from mmdeploy.apis.openvino import (onnx2openvino,
from mmdeploy.apis.openvino import (get_input_info_from_cfg,
get_output_model_file,
get_input_info_from_cfg)
onnx2openvino)
openvino_files = []
for onnx_path in onnx_files:
model_xml_path = get_output_model_file(onnx_path, args.work_dir)
@ -236,8 +221,7 @@ def main():
backend_files = openvino_files
elif backend == Backend.PPLNN:
from mmdeploy.apis.pplnn import \
is_available as is_available_pplnn
from mmdeploy.apis.pplnn import is_available as is_available_pplnn
assert is_available_pplnn(), \
'PPLNN is not available, please install PPLNN first.'