add multi-processes script (#1238)

This commit is contained in:
MengzhangLI 2022-01-27 21:18:55 +08:00 committed by GitHub
parent ee47c41740
commit 2729a6d4d1
5 changed files with 155 additions and 2 deletions

View File

@ -2,5 +2,9 @@
from .collect_env import collect_env from .collect_env import collect_env
from .logger import get_root_logger from .logger import get_root_logger
from .misc import find_latest_checkpoint from .misc import find_latest_checkpoint
from .set_env import setup_multi_processes
__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint'] __all__ = [
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
'setup_multi_processes'
]

55
mmseg/utils/set_env.py Normal file
View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import platform
import cv2
import torch.multiprocessing as mp
from ..utils import get_root_logger
def setup_multi_processes(cfg):
"""Setup multi-processing environment variables."""
logger = get_root_logger()
# set multi-process start method
if platform.system() != 'Windows':
mp_start_method = cfg.get('mp_start_method', None)
current_method = mp.get_start_method(allow_none=True)
if mp_start_method in ('fork', 'spawn', 'forkserver'):
logger.info(
f'Multi-processing start method `{mp_start_method}` is '
f'different from the previous setting `{current_method}`.'
f'It will be force set to `{mp_start_method}`.')
mp.set_start_method(mp_start_method, force=True)
else:
logger.info(
f'Multi-processing start method is `{mp_start_method}`')
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = cfg.get('opencv_num_threads', None)
if isinstance(opencv_num_threads, int):
logger.info(f'OpenCV num_threads is `{opencv_num_threads}`')
cv2.setNumThreads(opencv_num_threads)
else:
logger.info(f'OpenCV num_threads is `{cv2.getNumThreads}')
if cfg.data.workers_per_gpu > 1:
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
omp_num_threads = cfg.get('omp_num_threads', None)
if 'OMP_NUM_THREADS' not in os.environ:
if isinstance(omp_num_threads, int):
logger.info(f'OMP num threads is {omp_num_threads}')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
else:
logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }')
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ:
mkl_num_threads = cfg.get('mkl_num_threads', None)
if isinstance(mkl_num_threads, int):
logger.info(f'MKL num threads is {mkl_num_threads}')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
else:
logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}')

View File

@ -0,0 +1,85 @@
# Copyright (c) OpenMMLab. All rights reserved.
import multiprocessing as mp
import os
import platform
import cv2
import pytest
from mmcv import Config
from mmseg.utils import setup_multi_processes
@pytest.mark.parametrize('workers_per_gpu', (0, 2))
@pytest.mark.parametrize(('valid', 'env_cfg'), [(True,
dict(
mp_start_method='fork',
opencv_num_threads=0,
omp_num_threads=1,
mkl_num_threads=1)),
(False,
dict(
mp_start_method=1,
opencv_num_threads=0.1,
omp_num_threads='s',
mkl_num_threads='1'))])
def test_setup_multi_processes(workers_per_gpu, valid, env_cfg):
# temp save system setting
sys_start_mehod = mp.get_start_method(allow_none=True)
sys_cv_threads = cv2.getNumThreads()
# pop and temp save system env vars
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
config = dict(data=dict(workers_per_gpu=workers_per_gpu))
config.update(env_cfg)
cfg = Config(config)
setup_multi_processes(cfg)
# test when cfg is valid and workers_per_gpu > 0
# setup_multi_processes will work
if valid and workers_per_gpu > 0:
# test config without setting env
assert os.getenv('OMP_NUM_THREADS') == str(env_cfg['omp_num_threads'])
assert os.getenv('MKL_NUM_THREADS') == str(env_cfg['mkl_num_threads'])
# when set to 0, the num threads will be 1
assert cv2.getNumThreads() == env_cfg[
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
if platform.system() != 'Windows':
assert mp.get_start_method() == env_cfg['mp_start_method']
# revert setting to avoid affecting other programs
if sys_start_mehod:
mp.set_start_method(sys_start_mehod, force=True)
cv2.setNumThreads(sys_cv_threads)
if sys_omp_threads:
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
else:
os.environ.pop('OMP_NUM_THREADS')
if sys_mkl_threads:
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
else:
os.environ.pop('MKL_NUM_THREADS')
elif valid and workers_per_gpu == 0:
if platform.system() != 'Windows':
assert mp.get_start_method() == env_cfg['mp_start_method']
assert cv2.getNumThreads() == env_cfg[
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
assert 'OMP_NUM_THREADS' not in os.environ
assert 'MKL_NUM_THREADS' not in os.environ
if sys_start_mehod:
mp.set_start_method(sys_start_mehod, force=True)
cv2.setNumThreads(sys_cv_threads)
if sys_omp_threads:
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
if sys_mkl_threads:
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
else:
assert mp.get_start_method() == sys_start_mehod
assert cv2.getNumThreads() == sys_cv_threads
assert 'OMP_NUM_THREADS' not in os.environ
assert 'MKL_NUM_THREADS' not in os.environ

View File

@ -16,6 +16,7 @@ from mmcv.utils import DictAction
from mmseg.apis import multi_gpu_test, single_gpu_test from mmseg.apis import multi_gpu_test, single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models import build_segmentor from mmseg.models import build_segmentor
from mmseg.utils import setup_multi_processes
def parse_args(): def parse_args():
@ -124,6 +125,10 @@ def main():
cfg = mmcv.Config.fromfile(args.config) cfg = mmcv.Config.fromfile(args.config)
if args.cfg_options is not None: if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options) cfg.merge_from_dict(args.cfg_options)
# set multi-process settings
setup_multi_processes(cfg)
# set cudnn_benchmark # set cudnn_benchmark
if cfg.get('cudnn_benchmark', False): if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True

View File

@ -16,7 +16,7 @@ from mmseg import __version__
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
from mmseg.datasets import build_dataset from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor from mmseg.models import build_segmentor
from mmseg.utils import collect_env, get_root_logger from mmseg.utils import collect_env, get_root_logger, setup_multi_processes
def parse_args(): def parse_args():
@ -102,6 +102,10 @@ def main():
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
if args.cfg_options is not None: if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options) cfg.merge_from_dict(args.cfg_options)
# set multi-process settings
setup_multi_processes(cfg)
# set cudnn_benchmark # set cudnn_benchmark
if cfg.get('cudnn_benchmark', False): if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True