From d157243077310dcea5e78700c017408ed778838f Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 11 Jan 2022 15:43:47 +0800 Subject: [PATCH] [Fix] move target_wrapper into utils (#20) * move target_wrapper into utils * fix for lint * add typehint and docstring * update unit test * fix isort * update import --- mmdeploy/utils/__init__.py | 5 +++-- mmdeploy/utils/utils.py | 38 +++++++++++++++++++++++++++++++++++ tests/test_utils/test_util.py | 22 ++++++++++++++++++++ tools/deploy.py | 28 ++++++-------------------- 4 files changed, 69 insertions(+), 24 deletions(-) diff --git a/mmdeploy/utils/__init__.py b/mmdeploy/utils/__init__.py index 625e3bfd1..57a4bc222 100644 --- a/mmdeploy/utils/__init__.py +++ b/mmdeploy/utils/__init__.py @@ -8,7 +8,7 @@ from .config_utils import (cfg_apply_marks, get_backend, get_backend_config, is_dynamic_batch, is_dynamic_shape, load_config) from .constants import Backend, Codebase, Task from .device import parse_cuda_device_id, parse_device_id -from .utils import get_root_logger +from .utils import get_root_logger, target_wrapper __all__ = [ 'is_dynamic_batch', 'is_dynamic_shape', 'get_task_type', 'get_codebase', @@ -17,5 +17,6 @@ __all__ = [ 'get_calib_config', 'get_calib_filename', 'get_common_config', 'get_model_inputs', 'cfg_apply_marks', 'get_input_shape', 'parse_device_id', 'parse_cuda_device_id', 'get_codebase_config', - 'get_backend_config', 'get_root_logger', 'get_dynamic_axes' + 'get_backend_config', 'get_root_logger', 'get_dynamic_axes', + 'target_wrapper' ] diff --git a/mmdeploy/utils/utils.py b/mmdeploy/utils/utils.py index ae8182a46..9917dd477 100644 --- a/mmdeploy/utils/utils.py +++ b/mmdeploy/utils/utils.py @@ -1,9 +1,47 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging +import sys +import traceback +from typing import Callable, Optional +import torch.multiprocessing as mp from mmcv.utils import get_logger +def target_wrapper(target: Callable, + log_level: int, + ret_value: Optional[mp.Value] = None, + *args, + **kwargs): + """The wrapper used to start a new subprocess. + + Args: + target (Callable): The target function to be wrapped. + log_level (int): Log level for logging. + ret_value (mp.Value): The success flag of target. + + Return: + Any: The return of target. + """ + logger = logging.getLogger() + logging.basicConfig( + format='%(asctime)s,%(name)s %(levelname)-8s' + ' [%(filename)s:%(lineno)d] %(message)s', + datefmt='%Y-%m-%d:%H:%M:%S') + logger.level + logger.setLevel(log_level) + if ret_value is not None: + ret_value.value = -1 + try: + result = target(*args, **kwargs) + if ret_value is not None: + ret_value.value = 0 + return result + except Exception as e: + logging.error(e) + traceback.print_exc(file=sys.stdout) + + def get_root_logger(log_file=None, log_level=logging.INFO) -> logging.Logger: """Get root logger. diff --git a/tests/test_utils/test_util.py b/tests/test_utils/test_util.py index 965908edb..5eec87b6e 100644 --- a/tests/test_utils/test_util.py +++ b/tests/test_utils/test_util.py @@ -1,11 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging import os import tempfile +from functools import partial import mmcv import pytest +import torch.multiprocessing as mp import mmdeploy.utils as util +from mmdeploy.utils import target_wrapper from mmdeploy.utils.constants import Backend, Codebase, Task from mmdeploy.utils.export_info import dump_info from mmdeploy.utils.test import get_random_name @@ -390,6 +394,24 @@ def test_export_info(): assert os.path.exists(deploy_json) +def test_target_wrapper(): + + def target(): + return 0 + + log_level = logging.INFO + + ret_value = mp.Value('d', 0, lock=False) + ret_value.value = -1 + wrap_func = partial(target_wrapper, target, log_level, ret_value) + + process = mp.Process(target=wrap_func) + process.start() + process.join() + + assert ret_value.value == 0 + + def test_get_root_logger(): from mmdeploy.utils import get_root_logger logger = get_root_logger() diff --git a/tools/deploy.py b/tools/deploy.py index eec13f6c9..9d06d2468 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -2,8 +2,6 @@ import argparse import logging import os.path as osp -import sys -import traceback from functools import partial import mmcv @@ -15,7 +13,8 @@ from mmdeploy.apis import (create_calib_table, extract_model, visualize_model) from mmdeploy.utils import (Backend, get_backend, get_calib_filename, get_ir_config, get_model_inputs, get_onnx_config, - get_partition_config, get_root_logger, load_config) + get_partition_config, get_root_logger, load_config, + target_wrapper) from mmdeploy.utils.export_info import dump_info @@ -48,20 +47,6 @@ def parse_args(): return args -def target_wrapper(target, log_level, ret_value, *args, **kwargs): - logger = get_root_logger(log_level=log_level) - if ret_value is not None: - ret_value.value = -1 - try: - result = target(*args, **kwargs) - if ret_value is not None: - ret_value.value = 0 - return result - except Exception as e: - logger.error(e) - traceback.print_exc(file=sys.stdout) - - def create_process(name, target, args, kwargs, ret_value=None): logger = get_root_logger() logger.info(f'{name} start.') @@ -198,7 +183,7 @@ def main(): logger.error('ncnn support is not available.') exit(-1) - from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file + from mmdeploy.apis.ncnn import get_output_model_file, onnx2ncnn backend_files = [] for onnx_path in onnx_files: @@ -218,9 +203,9 @@ def main(): assert is_available_openvino(), \ 'OpenVINO is not available, please install OpenVINO first.' - from mmdeploy.apis.openvino import (onnx2openvino, + from mmdeploy.apis.openvino import (get_input_info_from_cfg, get_output_model_file, - get_input_info_from_cfg) + onnx2openvino) openvino_files = [] for onnx_path in onnx_files: model_xml_path = get_output_model_file(onnx_path, args.work_dir) @@ -236,8 +221,7 @@ def main(): backend_files = openvino_files elif backend == Backend.PPLNN: - from mmdeploy.apis.pplnn import \ - is_available as is_available_pplnn + from mmdeploy.apis.pplnn import is_available as is_available_pplnn assert is_available_pplnn(), \ 'PPLNN is not available, please install PPLNN first.'