[Enhance] Add setup multi-processing both in train and test. (#671)
parent
833152b1f4
commit
f552419e45
|
@ -1,5 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .collect_env import collect_env
|
||||
from .logger import get_root_logger, load_json_log
|
||||
from .setup_env import setup_multi_processes
|
||||
|
||||
__all__ = ['collect_env', 'get_root_logger', 'load_json_log']
|
||||
__all__ = [
|
||||
'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import platform
|
||||
import warnings
|
||||
|
||||
import cv2
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
def setup_multi_processes(cfg):
|
||||
"""Setup multi-processing environment variables."""
|
||||
# set multi-process start method as `fork` to speed up the training
|
||||
if platform.system() != 'Windows':
|
||||
mp_start_method = cfg.get('mp_start_method', 'fork')
|
||||
current_method = mp.get_start_method(allow_none=True)
|
||||
if current_method is not None and current_method != mp_start_method:
|
||||
warnings.warn(
|
||||
f'Multi-processing start method `{mp_start_method}` is '
|
||||
f'different from the previous setting `{current_method}`.'
|
||||
f'It will be force set to `{mp_start_method}`. You can change '
|
||||
f'this behavior by changing `mp_start_method` in your config.')
|
||||
mp.set_start_method(mp_start_method, force=True)
|
||||
|
||||
# disable opencv multithreading to avoid system being overloaded
|
||||
opencv_num_threads = cfg.get('opencv_num_threads', 0)
|
||||
cv2.setNumThreads(opencv_num_threads)
|
||||
|
||||
# setup OMP threads
|
||||
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
||||
if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
|
||||
omp_num_threads = 1
|
||||
warnings.warn(
|
||||
f'Setting OMP_NUM_THREADS environment variable for each process '
|
||||
f'to be {omp_num_threads} in default, to avoid your system being '
|
||||
f'overloaded, please further tune the variable for optimal '
|
||||
f'performance in your application as needed.')
|
||||
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
|
||||
|
||||
# setup MKL threads
|
||||
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
|
||||
mkl_num_threads = 1
|
||||
warnings.warn(
|
||||
f'Setting MKL_NUM_THREADS environment variable for each process '
|
||||
f'to be {mkl_num_threads} in default, to avoid your system being '
|
||||
f'overloaded, please further tune the variable for optimal '
|
||||
f'performance in your application as needed.')
|
||||
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
|
|
@ -14,7 +14,7 @@ line_length = 79
|
|||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = mmcls
|
||||
known_third_party = PIL,matplotlib,mmcv,mmdet,modelindex,numpy,onnxruntime,packaging,pytest,pytorch_sphinx_theme,requests,rich,sphinx,tensorflow,torch,torchvision,ts
|
||||
known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,modelindex,numpy,onnxruntime,packaging,pytest,pytorch_sphinx_theme,requests,rich,sphinx,tensorflow,torch,torchvision,ts
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import platform
|
||||
|
||||
import cv2
|
||||
from mmcv import Config
|
||||
|
||||
from mmcls.utils import setup_multi_processes
|
||||
|
||||
|
||||
def test_setup_multi_processes():
|
||||
# temp save system setting
|
||||
sys_start_mehod = mp.get_start_method(allow_none=True)
|
||||
sys_cv_threads = cv2.getNumThreads()
|
||||
# pop and temp save system env vars
|
||||
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
|
||||
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
|
||||
|
||||
# test config without setting env
|
||||
config = dict(data=dict(workers_per_gpu=2))
|
||||
cfg = Config(config)
|
||||
setup_multi_processes(cfg)
|
||||
assert os.getenv('OMP_NUM_THREADS') == '1'
|
||||
assert os.getenv('MKL_NUM_THREADS') == '1'
|
||||
# when set to 0, the num threads will be 1
|
||||
assert cv2.getNumThreads() == 1
|
||||
if platform.system() != 'Windows':
|
||||
assert mp.get_start_method() == 'fork'
|
||||
|
||||
# test num workers <= 1
|
||||
os.environ.pop('OMP_NUM_THREADS')
|
||||
os.environ.pop('MKL_NUM_THREADS')
|
||||
config = dict(data=dict(workers_per_gpu=0))
|
||||
cfg = Config(config)
|
||||
setup_multi_processes(cfg)
|
||||
assert 'OMP_NUM_THREADS' not in os.environ
|
||||
assert 'MKL_NUM_THREADS' not in os.environ
|
||||
|
||||
# test manually set env var
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
config = dict(data=dict(workers_per_gpu=2))
|
||||
cfg = Config(config)
|
||||
setup_multi_processes(cfg)
|
||||
assert os.getenv('OMP_NUM_THREADS') == '4'
|
||||
|
||||
# test manually set opencv threads and mp start method
|
||||
config = dict(
|
||||
data=dict(workers_per_gpu=2),
|
||||
opencv_num_threads=4,
|
||||
mp_start_method='spawn')
|
||||
cfg = Config(config)
|
||||
setup_multi_processes(cfg)
|
||||
assert cv2.getNumThreads() == 4
|
||||
assert mp.get_start_method() == 'spawn'
|
||||
|
||||
# revert setting to avoid affecting other programs
|
||||
if sys_start_mehod:
|
||||
mp.set_start_method(sys_start_mehod, force=True)
|
||||
cv2.setNumThreads(sys_cv_threads)
|
||||
if sys_omp_threads:
|
||||
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
|
||||
else:
|
||||
os.environ.pop('OMP_NUM_THREADS')
|
||||
if sys_mkl_threads:
|
||||
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
|
||||
else:
|
||||
os.environ.pop('MKL_NUM_THREADS')
|
|
@ -14,6 +14,7 @@ from mmcv.runner import get_dist_info, init_dist, load_checkpoint
|
|||
from mmcls.apis import multi_gpu_test, single_gpu_test
|
||||
from mmcls.datasets import build_dataloader, build_dataset
|
||||
from mmcls.models import build_classifier
|
||||
from mmcls.utils import setup_multi_processes
|
||||
|
||||
# TODO import `wrap_fp16_model` from mmcv and delete them from mmcls
|
||||
try:
|
||||
|
@ -119,6 +120,10 @@ def main():
|
|||
cfg = mmcv.Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# set multi-process settings
|
||||
setup_multi_processes(cfg)
|
||||
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
|
|
@ -15,7 +15,7 @@ from mmcls import __version__
|
|||
from mmcls.apis import init_random_seed, set_random_seed, train_model
|
||||
from mmcls.datasets import build_dataset
|
||||
from mmcls.models import build_classifier
|
||||
from mmcls.utils import collect_env, get_root_logger
|
||||
from mmcls.utils import collect_env, get_root_logger, setup_multi_processes
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -90,6 +90,10 @@ def main():
|
|||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# set multi-process settings
|
||||
setup_multi_processes(cfg)
|
||||
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
|
Loading…
Reference in New Issue