mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] move target_wrapper into utils (#20)
* move target_wrapper into utils * fix for lint * add typehint and docstring * update unit test * fix isort * update import
This commit is contained in:
parent
d1528e5b34
commit
d157243077
@ -8,7 +8,7 @@ from .config_utils import (cfg_apply_marks, get_backend, get_backend_config,
|
|||||||
is_dynamic_batch, is_dynamic_shape, load_config)
|
is_dynamic_batch, is_dynamic_shape, load_config)
|
||||||
from .constants import Backend, Codebase, Task
|
from .constants import Backend, Codebase, Task
|
||||||
from .device import parse_cuda_device_id, parse_device_id
|
from .device import parse_cuda_device_id, parse_device_id
|
||||||
from .utils import get_root_logger
|
from .utils import get_root_logger, target_wrapper
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'is_dynamic_batch', 'is_dynamic_shape', 'get_task_type', 'get_codebase',
|
'is_dynamic_batch', 'is_dynamic_shape', 'get_task_type', 'get_codebase',
|
||||||
@ -17,5 +17,6 @@ __all__ = [
|
|||||||
'get_calib_config', 'get_calib_filename', 'get_common_config',
|
'get_calib_config', 'get_calib_filename', 'get_common_config',
|
||||||
'get_model_inputs', 'cfg_apply_marks', 'get_input_shape',
|
'get_model_inputs', 'cfg_apply_marks', 'get_input_shape',
|
||||||
'parse_device_id', 'parse_cuda_device_id', 'get_codebase_config',
|
'parse_device_id', 'parse_cuda_device_id', 'get_codebase_config',
|
||||||
'get_backend_config', 'get_root_logger', 'get_dynamic_axes'
|
'get_backend_config', 'get_root_logger', 'get_dynamic_axes',
|
||||||
|
'target_wrapper'
|
||||||
]
|
]
|
||||||
|
@ -1,9 +1,47 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch.multiprocessing as mp
|
||||||
from mmcv.utils import get_logger
|
from mmcv.utils import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
def target_wrapper(target: Callable,
|
||||||
|
log_level: int,
|
||||||
|
ret_value: Optional[mp.Value] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
"""The wrapper used to start a new subprocess.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (Callable): The target function to be wrapped.
|
||||||
|
log_level (int): Log level for logging.
|
||||||
|
ret_value (mp.Value): The success flag of target.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Any: The return of target.
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s,%(name)s %(levelname)-8s'
|
||||||
|
' [%(filename)s:%(lineno)d] %(message)s',
|
||||||
|
datefmt='%Y-%m-%d:%H:%M:%S')
|
||||||
|
logger.level
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
if ret_value is not None:
|
||||||
|
ret_value.value = -1
|
||||||
|
try:
|
||||||
|
result = target(*args, **kwargs)
|
||||||
|
if ret_value is not None:
|
||||||
|
ret_value.value = 0
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(e)
|
||||||
|
traceback.print_exc(file=sys.stdout)
|
||||||
|
|
||||||
|
|
||||||
def get_root_logger(log_file=None, log_level=logging.INFO) -> logging.Logger:
|
def get_root_logger(log_file=None, log_level=logging.INFO) -> logging.Logger:
|
||||||
"""Get root logger.
|
"""Get root logger.
|
||||||
|
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
import mmdeploy.utils as util
|
import mmdeploy.utils as util
|
||||||
|
from mmdeploy.utils import target_wrapper
|
||||||
from mmdeploy.utils.constants import Backend, Codebase, Task
|
from mmdeploy.utils.constants import Backend, Codebase, Task
|
||||||
from mmdeploy.utils.export_info import dump_info
|
from mmdeploy.utils.export_info import dump_info
|
||||||
from mmdeploy.utils.test import get_random_name
|
from mmdeploy.utils.test import get_random_name
|
||||||
@ -390,6 +394,24 @@ def test_export_info():
|
|||||||
assert os.path.exists(deploy_json)
|
assert os.path.exists(deploy_json)
|
||||||
|
|
||||||
|
|
||||||
|
def test_target_wrapper():
|
||||||
|
|
||||||
|
def target():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
log_level = logging.INFO
|
||||||
|
|
||||||
|
ret_value = mp.Value('d', 0, lock=False)
|
||||||
|
ret_value.value = -1
|
||||||
|
wrap_func = partial(target_wrapper, target, log_level, ret_value)
|
||||||
|
|
||||||
|
process = mp.Process(target=wrap_func)
|
||||||
|
process.start()
|
||||||
|
process.join()
|
||||||
|
|
||||||
|
assert ret_value.value == 0
|
||||||
|
|
||||||
|
|
||||||
def test_get_root_logger():
|
def test_get_root_logger():
|
||||||
from mmdeploy.utils import get_root_logger
|
from mmdeploy.utils import get_root_logger
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
@ -15,7 +13,8 @@ from mmdeploy.apis import (create_calib_table, extract_model,
|
|||||||
visualize_model)
|
visualize_model)
|
||||||
from mmdeploy.utils import (Backend, get_backend, get_calib_filename,
|
from mmdeploy.utils import (Backend, get_backend, get_calib_filename,
|
||||||
get_ir_config, get_model_inputs, get_onnx_config,
|
get_ir_config, get_model_inputs, get_onnx_config,
|
||||||
get_partition_config, get_root_logger, load_config)
|
get_partition_config, get_root_logger, load_config,
|
||||||
|
target_wrapper)
|
||||||
from mmdeploy.utils.export_info import dump_info
|
from mmdeploy.utils.export_info import dump_info
|
||||||
|
|
||||||
|
|
||||||
@ -48,20 +47,6 @@ def parse_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def target_wrapper(target, log_level, ret_value, *args, **kwargs):
|
|
||||||
logger = get_root_logger(log_level=log_level)
|
|
||||||
if ret_value is not None:
|
|
||||||
ret_value.value = -1
|
|
||||||
try:
|
|
||||||
result = target(*args, **kwargs)
|
|
||||||
if ret_value is not None:
|
|
||||||
ret_value.value = 0
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
traceback.print_exc(file=sys.stdout)
|
|
||||||
|
|
||||||
|
|
||||||
def create_process(name, target, args, kwargs, ret_value=None):
|
def create_process(name, target, args, kwargs, ret_value=None):
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
logger.info(f'{name} start.')
|
logger.info(f'{name} start.')
|
||||||
@ -198,7 +183,7 @@ def main():
|
|||||||
logger.error('ncnn support is not available.')
|
logger.error('ncnn support is not available.')
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file
|
from mmdeploy.apis.ncnn import get_output_model_file, onnx2ncnn
|
||||||
|
|
||||||
backend_files = []
|
backend_files = []
|
||||||
for onnx_path in onnx_files:
|
for onnx_path in onnx_files:
|
||||||
@ -218,9 +203,9 @@ def main():
|
|||||||
assert is_available_openvino(), \
|
assert is_available_openvino(), \
|
||||||
'OpenVINO is not available, please install OpenVINO first.'
|
'OpenVINO is not available, please install OpenVINO first.'
|
||||||
|
|
||||||
from mmdeploy.apis.openvino import (onnx2openvino,
|
from mmdeploy.apis.openvino import (get_input_info_from_cfg,
|
||||||
get_output_model_file,
|
get_output_model_file,
|
||||||
get_input_info_from_cfg)
|
onnx2openvino)
|
||||||
openvino_files = []
|
openvino_files = []
|
||||||
for onnx_path in onnx_files:
|
for onnx_path in onnx_files:
|
||||||
model_xml_path = get_output_model_file(onnx_path, args.work_dir)
|
model_xml_path = get_output_model_file(onnx_path, args.work_dir)
|
||||||
@ -236,8 +221,7 @@ def main():
|
|||||||
backend_files = openvino_files
|
backend_files = openvino_files
|
||||||
|
|
||||||
elif backend == Backend.PPLNN:
|
elif backend == Backend.PPLNN:
|
||||||
from mmdeploy.apis.pplnn import \
|
from mmdeploy.apis.pplnn import is_available as is_available_pplnn
|
||||||
is_available as is_available_pplnn
|
|
||||||
assert is_available_pplnn(), \
|
assert is_available_pplnn(), \
|
||||||
'PPLNN is not available, please install PPLNN first.'
|
'PPLNN is not available, please install PPLNN first.'
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user