mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] move target_wrapper into utils (#20)
* move target_wrapper into utils * fix for lint * add typehint and docstring * update unit test * fix isort * update import
This commit is contained in:
parent
d1528e5b34
commit
d157243077
@ -8,7 +8,7 @@ from .config_utils import (cfg_apply_marks, get_backend, get_backend_config,
|
||||
is_dynamic_batch, is_dynamic_shape, load_config)
|
||||
from .constants import Backend, Codebase, Task
|
||||
from .device import parse_cuda_device_id, parse_device_id
|
||||
from .utils import get_root_logger
|
||||
from .utils import get_root_logger, target_wrapper
|
||||
|
||||
__all__ = [
|
||||
'is_dynamic_batch', 'is_dynamic_shape', 'get_task_type', 'get_codebase',
|
||||
@ -17,5 +17,6 @@ __all__ = [
|
||||
'get_calib_config', 'get_calib_filename', 'get_common_config',
|
||||
'get_model_inputs', 'cfg_apply_marks', 'get_input_shape',
|
||||
'parse_device_id', 'parse_cuda_device_id', 'get_codebase_config',
|
||||
'get_backend_config', 'get_root_logger', 'get_dynamic_axes'
|
||||
'get_backend_config', 'get_root_logger', 'get_dynamic_axes',
|
||||
'target_wrapper'
|
||||
]
|
||||
|
@ -1,9 +1,47 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from mmcv.utils import get_logger
|
||||
|
||||
|
||||
def target_wrapper(target: Callable,
|
||||
log_level: int,
|
||||
ret_value: Optional[mp.Value] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""The wrapper used to start a new subprocess.
|
||||
|
||||
Args:
|
||||
target (Callable): The target function to be wrapped.
|
||||
log_level (int): Log level for logging.
|
||||
ret_value (mp.Value): The success flag of target.
|
||||
|
||||
Return:
|
||||
Any: The return of target.
|
||||
"""
|
||||
logger = logging.getLogger()
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s,%(name)s %(levelname)-8s'
|
||||
' [%(filename)s:%(lineno)d] %(message)s',
|
||||
datefmt='%Y-%m-%d:%H:%M:%S')
|
||||
logger.level
|
||||
logger.setLevel(log_level)
|
||||
if ret_value is not None:
|
||||
ret_value.value = -1
|
||||
try:
|
||||
result = target(*args, **kwargs)
|
||||
if ret_value is not None:
|
||||
ret_value.value = 0
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
|
||||
|
||||
def get_root_logger(log_file=None, log_level=logging.INFO) -> logging.Logger:
|
||||
"""Get root logger.
|
||||
|
||||
|
@ -1,11 +1,15 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from functools import partial
|
||||
|
||||
import mmcv
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import mmdeploy.utils as util
|
||||
from mmdeploy.utils import target_wrapper
|
||||
from mmdeploy.utils.constants import Backend, Codebase, Task
|
||||
from mmdeploy.utils.export_info import dump_info
|
||||
from mmdeploy.utils.test import get_random_name
|
||||
@ -390,6 +394,24 @@ def test_export_info():
|
||||
assert os.path.exists(deploy_json)
|
||||
|
||||
|
||||
def test_target_wrapper():
|
||||
|
||||
def target():
|
||||
return 0
|
||||
|
||||
log_level = logging.INFO
|
||||
|
||||
ret_value = mp.Value('d', 0, lock=False)
|
||||
ret_value.value = -1
|
||||
wrap_func = partial(target_wrapper, target, log_level, ret_value)
|
||||
|
||||
process = mp.Process(target=wrap_func)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
assert ret_value.value == 0
|
||||
|
||||
|
||||
def test_get_root_logger():
|
||||
from mmdeploy.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
|
@ -2,8 +2,6 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os.path as osp
|
||||
import sys
|
||||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import mmcv
|
||||
@ -15,7 +13,8 @@ from mmdeploy.apis import (create_calib_table, extract_model,
|
||||
visualize_model)
|
||||
from mmdeploy.utils import (Backend, get_backend, get_calib_filename,
|
||||
get_ir_config, get_model_inputs, get_onnx_config,
|
||||
get_partition_config, get_root_logger, load_config)
|
||||
get_partition_config, get_root_logger, load_config,
|
||||
target_wrapper)
|
||||
from mmdeploy.utils.export_info import dump_info
|
||||
|
||||
|
||||
@ -48,20 +47,6 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
def target_wrapper(target, log_level, ret_value, *args, **kwargs):
|
||||
logger = get_root_logger(log_level=log_level)
|
||||
if ret_value is not None:
|
||||
ret_value.value = -1
|
||||
try:
|
||||
result = target(*args, **kwargs)
|
||||
if ret_value is not None:
|
||||
ret_value.value = 0
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
|
||||
|
||||
def create_process(name, target, args, kwargs, ret_value=None):
|
||||
logger = get_root_logger()
|
||||
logger.info(f'{name} start.')
|
||||
@ -198,7 +183,7 @@ def main():
|
||||
logger.error('ncnn support is not available.')
|
||||
exit(-1)
|
||||
|
||||
from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file
|
||||
from mmdeploy.apis.ncnn import get_output_model_file, onnx2ncnn
|
||||
|
||||
backend_files = []
|
||||
for onnx_path in onnx_files:
|
||||
@ -218,9 +203,9 @@ def main():
|
||||
assert is_available_openvino(), \
|
||||
'OpenVINO is not available, please install OpenVINO first.'
|
||||
|
||||
from mmdeploy.apis.openvino import (onnx2openvino,
|
||||
from mmdeploy.apis.openvino import (get_input_info_from_cfg,
|
||||
get_output_model_file,
|
||||
get_input_info_from_cfg)
|
||||
onnx2openvino)
|
||||
openvino_files = []
|
||||
for onnx_path in onnx_files:
|
||||
model_xml_path = get_output_model_file(onnx_path, args.work_dir)
|
||||
@ -236,8 +221,7 @@ def main():
|
||||
backend_files = openvino_files
|
||||
|
||||
elif backend == Backend.PPLNN:
|
||||
from mmdeploy.apis.pplnn import \
|
||||
is_available as is_available_pplnn
|
||||
from mmdeploy.apis.pplnn import is_available as is_available_pplnn
|
||||
assert is_available_pplnn(), \
|
||||
'PPLNN is not available, please install PPLNN first.'
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user