[Fix] move target_wrapper into utils (#20)

* move target_wrapper into utils

* fix for lint

* add typehint and docstring

* update unit test

* fix isort

* update import
This commit is contained in:
q.yao 2022-01-11 15:43:47 +08:00 committed by GitHub
parent d1528e5b34
commit d157243077
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 24 deletions

View File

@ -8,7 +8,7 @@ from .config_utils import (cfg_apply_marks, get_backend, get_backend_config,
is_dynamic_batch, is_dynamic_shape, load_config) is_dynamic_batch, is_dynamic_shape, load_config)
from .constants import Backend, Codebase, Task from .constants import Backend, Codebase, Task
from .device import parse_cuda_device_id, parse_device_id from .device import parse_cuda_device_id, parse_device_id
from .utils import get_root_logger from .utils import get_root_logger, target_wrapper
__all__ = [ __all__ = [
'is_dynamic_batch', 'is_dynamic_shape', 'get_task_type', 'get_codebase', 'is_dynamic_batch', 'is_dynamic_shape', 'get_task_type', 'get_codebase',
@ -17,5 +17,6 @@ __all__ = [
'get_calib_config', 'get_calib_filename', 'get_common_config', 'get_calib_config', 'get_calib_filename', 'get_common_config',
'get_model_inputs', 'cfg_apply_marks', 'get_input_shape', 'get_model_inputs', 'cfg_apply_marks', 'get_input_shape',
'parse_device_id', 'parse_cuda_device_id', 'get_codebase_config', 'parse_device_id', 'parse_cuda_device_id', 'get_codebase_config',
'get_backend_config', 'get_root_logger', 'get_dynamic_axes' 'get_backend_config', 'get_root_logger', 'get_dynamic_axes',
'target_wrapper'
] ]

View File

@ -1,9 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging import logging
import sys
import traceback
from typing import Callable, Optional
import torch.multiprocessing as mp
from mmcv.utils import get_logger from mmcv.utils import get_logger
def target_wrapper(target: Callable,
log_level: int,
ret_value: Optional[mp.Value] = None,
*args,
**kwargs):
"""The wrapper used to start a new subprocess.
Args:
target (Callable): The target function to be wrapped.
log_level (int): Log level for logging.
ret_value (mp.Value): The success flag of target.
Return:
Any: The return of target.
"""
logger = logging.getLogger()
logging.basicConfig(
format='%(asctime)s,%(name)s %(levelname)-8s'
' [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d:%H:%M:%S')
logger.level
logger.setLevel(log_level)
if ret_value is not None:
ret_value.value = -1
try:
result = target(*args, **kwargs)
if ret_value is not None:
ret_value.value = 0
return result
except Exception as e:
logging.error(e)
traceback.print_exc(file=sys.stdout)
def get_root_logger(log_file=None, log_level=logging.INFO) -> logging.Logger: def get_root_logger(log_file=None, log_level=logging.INFO) -> logging.Logger:
"""Get root logger. """Get root logger.

View File

@ -1,11 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging
import os import os
import tempfile import tempfile
from functools import partial
import mmcv import mmcv
import pytest import pytest
import torch.multiprocessing as mp
import mmdeploy.utils as util import mmdeploy.utils as util
from mmdeploy.utils import target_wrapper
from mmdeploy.utils.constants import Backend, Codebase, Task from mmdeploy.utils.constants import Backend, Codebase, Task
from mmdeploy.utils.export_info import dump_info from mmdeploy.utils.export_info import dump_info
from mmdeploy.utils.test import get_random_name from mmdeploy.utils.test import get_random_name
@ -390,6 +394,24 @@ def test_export_info():
assert os.path.exists(deploy_json) assert os.path.exists(deploy_json)
def test_target_wrapper():
def target():
return 0
log_level = logging.INFO
ret_value = mp.Value('d', 0, lock=False)
ret_value.value = -1
wrap_func = partial(target_wrapper, target, log_level, ret_value)
process = mp.Process(target=wrap_func)
process.start()
process.join()
assert ret_value.value == 0
def test_get_root_logger(): def test_get_root_logger():
from mmdeploy.utils import get_root_logger from mmdeploy.utils import get_root_logger
logger = get_root_logger() logger = get_root_logger()

View File

@ -2,8 +2,6 @@
import argparse import argparse
import logging import logging
import os.path as osp import os.path as osp
import sys
import traceback
from functools import partial from functools import partial
import mmcv import mmcv
@ -15,7 +13,8 @@ from mmdeploy.apis import (create_calib_table, extract_model,
visualize_model) visualize_model)
from mmdeploy.utils import (Backend, get_backend, get_calib_filename, from mmdeploy.utils import (Backend, get_backend, get_calib_filename,
get_ir_config, get_model_inputs, get_onnx_config, get_ir_config, get_model_inputs, get_onnx_config,
get_partition_config, get_root_logger, load_config) get_partition_config, get_root_logger, load_config,
target_wrapper)
from mmdeploy.utils.export_info import dump_info from mmdeploy.utils.export_info import dump_info
@ -48,20 +47,6 @@ def parse_args():
return args return args
def target_wrapper(target, log_level, ret_value, *args, **kwargs):
logger = get_root_logger(log_level=log_level)
if ret_value is not None:
ret_value.value = -1
try:
result = target(*args, **kwargs)
if ret_value is not None:
ret_value.value = 0
return result
except Exception as e:
logger.error(e)
traceback.print_exc(file=sys.stdout)
def create_process(name, target, args, kwargs, ret_value=None): def create_process(name, target, args, kwargs, ret_value=None):
logger = get_root_logger() logger = get_root_logger()
logger.info(f'{name} start.') logger.info(f'{name} start.')
@ -198,7 +183,7 @@ def main():
logger.error('ncnn support is not available.') logger.error('ncnn support is not available.')
exit(-1) exit(-1)
from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file from mmdeploy.apis.ncnn import get_output_model_file, onnx2ncnn
backend_files = [] backend_files = []
for onnx_path in onnx_files: for onnx_path in onnx_files:
@ -218,9 +203,9 @@ def main():
assert is_available_openvino(), \ assert is_available_openvino(), \
'OpenVINO is not available, please install OpenVINO first.' 'OpenVINO is not available, please install OpenVINO first.'
from mmdeploy.apis.openvino import (onnx2openvino, from mmdeploy.apis.openvino import (get_input_info_from_cfg,
get_output_model_file, get_output_model_file,
get_input_info_from_cfg) onnx2openvino)
openvino_files = [] openvino_files = []
for onnx_path in onnx_files: for onnx_path in onnx_files:
model_xml_path = get_output_model_file(onnx_path, args.work_dir) model_xml_path = get_output_model_file(onnx_path, args.work_dir)
@ -236,8 +221,7 @@ def main():
backend_files = openvino_files backend_files = openvino_files
elif backend == Backend.PPLNN: elif backend == Backend.PPLNN:
from mmdeploy.apis.pplnn import \ from mmdeploy.apis.pplnn import is_available as is_available_pplnn
is_available as is_available_pplnn
assert is_available_pplnn(), \ assert is_available_pplnn(), \
'PPLNN is not available, please install PPLNN first.' 'PPLNN is not available, please install PPLNN first.'