mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
add multi-processes script (#1238)
This commit is contained in:
parent
ee47c41740
commit
2729a6d4d1
@ -2,5 +2,9 @@
|
||||
from .collect_env import collect_env
|
||||
from .logger import get_root_logger
|
||||
from .misc import find_latest_checkpoint
|
||||
from .set_env import setup_multi_processes
|
||||
|
||||
__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint']
|
||||
__all__ = [
|
||||
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
|
||||
'setup_multi_processes'
|
||||
]
|
||||
|
55
mmseg/utils/set_env.py
Normal file
55
mmseg/utils/set_env.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import platform
|
||||
|
||||
import cv2
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from ..utils import get_root_logger
|
||||
|
||||
|
||||
def setup_multi_processes(cfg):
|
||||
"""Setup multi-processing environment variables."""
|
||||
logger = get_root_logger()
|
||||
|
||||
# set multi-process start method
|
||||
if platform.system() != 'Windows':
|
||||
mp_start_method = cfg.get('mp_start_method', None)
|
||||
current_method = mp.get_start_method(allow_none=True)
|
||||
if mp_start_method in ('fork', 'spawn', 'forkserver'):
|
||||
logger.info(
|
||||
f'Multi-processing start method `{mp_start_method}` is '
|
||||
f'different from the previous setting `{current_method}`.'
|
||||
f'It will be force set to `{mp_start_method}`.')
|
||||
mp.set_start_method(mp_start_method, force=True)
|
||||
else:
|
||||
logger.info(
|
||||
f'Multi-processing start method is `{mp_start_method}`')
|
||||
|
||||
# disable opencv multithreading to avoid system being overloaded
|
||||
opencv_num_threads = cfg.get('opencv_num_threads', None)
|
||||
if isinstance(opencv_num_threads, int):
|
||||
logger.info(f'OpenCV num_threads is `{opencv_num_threads}`')
|
||||
cv2.setNumThreads(opencv_num_threads)
|
||||
else:
|
||||
logger.info(f'OpenCV num_threads is `{cv2.getNumThreads}')
|
||||
|
||||
if cfg.data.workers_per_gpu > 1:
|
||||
# setup OMP threads
|
||||
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
||||
omp_num_threads = cfg.get('omp_num_threads', None)
|
||||
if 'OMP_NUM_THREADS' not in os.environ:
|
||||
if isinstance(omp_num_threads, int):
|
||||
logger.info(f'OMP num threads is {omp_num_threads}')
|
||||
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
|
||||
else:
|
||||
logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }')
|
||||
|
||||
# setup MKL threads
|
||||
if 'MKL_NUM_THREADS' not in os.environ:
|
||||
mkl_num_threads = cfg.get('mkl_num_threads', None)
|
||||
if isinstance(mkl_num_threads, int):
|
||||
logger.info(f'MKL num threads is {mkl_num_threads}')
|
||||
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
|
||||
else:
|
||||
logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}')
|
85
tests/test_utils/test_set_env.py
Normal file
85
tests/test_utils/test_set_env.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import platform
|
||||
|
||||
import cv2
|
||||
import pytest
|
||||
from mmcv import Config
|
||||
|
||||
from mmseg.utils import setup_multi_processes
|
||||
|
||||
|
||||
@pytest.mark.parametrize('workers_per_gpu', (0, 2))
|
||||
@pytest.mark.parametrize(('valid', 'env_cfg'), [(True,
|
||||
dict(
|
||||
mp_start_method='fork',
|
||||
opencv_num_threads=0,
|
||||
omp_num_threads=1,
|
||||
mkl_num_threads=1)),
|
||||
(False,
|
||||
dict(
|
||||
mp_start_method=1,
|
||||
opencv_num_threads=0.1,
|
||||
omp_num_threads='s',
|
||||
mkl_num_threads='1'))])
|
||||
def test_setup_multi_processes(workers_per_gpu, valid, env_cfg):
|
||||
# temp save system setting
|
||||
sys_start_mehod = mp.get_start_method(allow_none=True)
|
||||
sys_cv_threads = cv2.getNumThreads()
|
||||
# pop and temp save system env vars
|
||||
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
|
||||
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
|
||||
|
||||
config = dict(data=dict(workers_per_gpu=workers_per_gpu))
|
||||
config.update(env_cfg)
|
||||
cfg = Config(config)
|
||||
setup_multi_processes(cfg)
|
||||
|
||||
# test when cfg is valid and workers_per_gpu > 0
|
||||
# setup_multi_processes will work
|
||||
if valid and workers_per_gpu > 0:
|
||||
# test config without setting env
|
||||
|
||||
assert os.getenv('OMP_NUM_THREADS') == str(env_cfg['omp_num_threads'])
|
||||
assert os.getenv('MKL_NUM_THREADS') == str(env_cfg['mkl_num_threads'])
|
||||
# when set to 0, the num threads will be 1
|
||||
assert cv2.getNumThreads() == env_cfg[
|
||||
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
|
||||
if platform.system() != 'Windows':
|
||||
assert mp.get_start_method() == env_cfg['mp_start_method']
|
||||
|
||||
# revert setting to avoid affecting other programs
|
||||
if sys_start_mehod:
|
||||
mp.set_start_method(sys_start_mehod, force=True)
|
||||
cv2.setNumThreads(sys_cv_threads)
|
||||
if sys_omp_threads:
|
||||
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
|
||||
else:
|
||||
os.environ.pop('OMP_NUM_THREADS')
|
||||
if sys_mkl_threads:
|
||||
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
|
||||
else:
|
||||
os.environ.pop('MKL_NUM_THREADS')
|
||||
|
||||
elif valid and workers_per_gpu == 0:
|
||||
|
||||
if platform.system() != 'Windows':
|
||||
assert mp.get_start_method() == env_cfg['mp_start_method']
|
||||
assert cv2.getNumThreads() == env_cfg[
|
||||
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
|
||||
assert 'OMP_NUM_THREADS' not in os.environ
|
||||
assert 'MKL_NUM_THREADS' not in os.environ
|
||||
if sys_start_mehod:
|
||||
mp.set_start_method(sys_start_mehod, force=True)
|
||||
cv2.setNumThreads(sys_cv_threads)
|
||||
if sys_omp_threads:
|
||||
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
|
||||
if sys_mkl_threads:
|
||||
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
|
||||
|
||||
else:
|
||||
assert mp.get_start_method() == sys_start_mehod
|
||||
assert cv2.getNumThreads() == sys_cv_threads
|
||||
assert 'OMP_NUM_THREADS' not in os.environ
|
||||
assert 'MKL_NUM_THREADS' not in os.environ
|
@ -16,6 +16,7 @@ from mmcv.utils import DictAction
|
||||
from mmseg.apis import multi_gpu_test, single_gpu_test
|
||||
from mmseg.datasets import build_dataloader, build_dataset
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.utils import setup_multi_processes
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -124,6 +125,10 @@ def main():
|
||||
cfg = mmcv.Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# set multi-process settings
|
||||
setup_multi_processes(cfg)
|
||||
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
@ -16,7 +16,7 @@ from mmseg import __version__
|
||||
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
|
||||
from mmseg.datasets import build_dataset
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.utils import collect_env, get_root_logger
|
||||
from mmseg.utils import collect_env, get_root_logger, setup_multi_processes
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -102,6 +102,10 @@ def main():
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# set multi-process settings
|
||||
setup_multi_processes(cfg)
|
||||
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user