mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Move get_max_cuda_memory and set_multi_processing to public function (#250)
* move get_max_cuda_memory and set_multi_processing to a public function * fix lint * fix lint * fix lint * delete _set_multi_processing * fix error * rename
This commit is contained in:
parent
a976257ca9
commit
8d3bd4dfef
@ -3,6 +3,7 @@
|
||||
from .config import *
|
||||
from .data import *
|
||||
from .dataset import *
|
||||
from .device import *
|
||||
from .fileio import *
|
||||
from .hooks import *
|
||||
from .logging import *
|
||||
|
4
mmengine/device/__init__.py
Normal file
4
mmengine/device/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .utils import get_max_cuda_memory
|
||||
|
||||
__all__ = ['get_max_cuda_memory']
|
27
mmengine/device/utils.py
Normal file
27
mmengine/device/utils.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
|
||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
|
||||
a given device. By default, this returns the peak allocated memory since
|
||||
the beginning of this program.
|
||||
|
||||
Args:
|
||||
device (torch.device, optional): selected device. Returns
|
||||
statistic for the current device, given by
|
||||
:func:`~torch.cuda.current_device`, if ``device`` is None.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
int: The maximum GPU memory occupied by tensors in megabytes
|
||||
for a given device.
|
||||
"""
|
||||
mem = torch.cuda.max_memory_allocated(device=device)
|
||||
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return int(mem_mb.item())
|
@ -6,6 +6,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.device import get_max_cuda_memory
|
||||
from mmengine.registry import LOG_PROCESSOR
|
||||
|
||||
|
||||
@ -345,13 +346,9 @@ class LogProcessor:
|
||||
The maximum GPU memory occupied by tensors in megabytes for a given
|
||||
device.
|
||||
"""
|
||||
|
||||
device = getattr(runner.model, 'output_device', None)
|
||||
mem = torch.cuda.max_memory_allocated(device=device)
|
||||
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return int(mem_mb.item())
|
||||
return get_max_cuda_memory(device)
|
||||
|
||||
def _get_iter(self, runner, batch_idx: int = None) -> int:
|
||||
"""Get current iteration index.
|
||||
|
@ -1,7 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import random
|
||||
@ -34,7 +32,8 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
||||
count_registered_modules)
|
||||
from mmengine.registry.root import LOG_PROCESSOR
|
||||
from mmengine.utils import (TORCH_VERSION, digit_version,
|
||||
find_latest_checkpoint, is_list_of, symlink)
|
||||
find_latest_checkpoint, is_list_of,
|
||||
set_multi_processing, symlink)
|
||||
from mmengine.visualization import Visualizer
|
||||
from .base_loop import BaseLoop
|
||||
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||
@ -582,12 +581,13 @@ class Runner:
|
||||
if env_cfg.get('cudnn_benchmark'):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
if env_cfg.get('mp_cfg') is not None:
|
||||
self._set_multi_processing(**env_cfg.get('mp_cfg')) # type: ignore
|
||||
mp_cfg: dict = env_cfg.get('mp_cfg', {})
|
||||
set_multi_processing(**mp_cfg, distributed=self.distributed)
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
if self.distributed and env_cfg.get('dist_cfg') is not None:
|
||||
init_dist(self.launcher, **env_cfg.get('dist_cfg')) # type: ignore
|
||||
if self.distributed:
|
||||
dist_cfg: dict = env_cfg.get('dist_cfg', {})
|
||||
init_dist(self.launcher, **dist_cfg)
|
||||
|
||||
self._rank, self._world_size = get_dist_info()
|
||||
|
||||
@ -597,59 +597,6 @@ class Runner:
|
||||
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
|
||||
time.localtime(timestamp.item()))
|
||||
|
||||
def _set_multi_processing(self,
|
||||
mp_start_method: str = 'fork',
|
||||
opencv_num_threads: int = 0) -> None:
|
||||
"""Set multi-processing related environment.
|
||||
|
||||
Args:
|
||||
mp_start_method (str): Set the method which should be used to start
|
||||
child processes. Defaults to 'fork'.
|
||||
opencv_num_threads (int): Number of threads for opencv.
|
||||
Defaults to 0.
|
||||
"""
|
||||
# set multi-process start method as `fork` to speed up the training
|
||||
if platform.system() != 'Windows':
|
||||
current_method = mp.get_start_method(allow_none=True)
|
||||
if (current_method is not None
|
||||
and current_method != mp_start_method):
|
||||
warnings.warn(
|
||||
f'Multi-processing start method `{mp_start_method}` is '
|
||||
f'different from the previous setting `{current_method}`.'
|
||||
f'It will be force set to `{mp_start_method}`. You can '
|
||||
'change this behavior by changing `mp_start_method` in '
|
||||
'your config.')
|
||||
mp.set_start_method(mp_start_method, force=True)
|
||||
|
||||
try:
|
||||
import cv2
|
||||
|
||||
# disable opencv multithreading to avoid system being overloaded
|
||||
cv2.setNumThreads(opencv_num_threads)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# setup OMP threads
|
||||
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
||||
if 'OMP_NUM_THREADS' not in os.environ and self.distributed:
|
||||
omp_num_threads = 1
|
||||
warnings.warn(
|
||||
'Setting OMP_NUM_THREADS environment variable for each process'
|
||||
f' to be {omp_num_threads} in default, to avoid your system '
|
||||
'being overloaded, please further tune the variable for '
|
||||
'optimal performance in your application as needed.')
|
||||
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
|
||||
|
||||
# setup MKL threads
|
||||
if 'MKL_NUM_THREADS' not in os.environ and self.distributed:
|
||||
mkl_num_threads = 1
|
||||
warnings.warn(
|
||||
'Setting MKL_NUM_THREADS environment variable for each process'
|
||||
f' to be {mkl_num_threads} in default, to avoid your system '
|
||||
'being overloaded, please further tune the variable for '
|
||||
'optimal performance in your application as needed.')
|
||||
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
|
||||
|
||||
def set_randomness(self, seed, deterministic: bool = False) -> None:
|
||||
"""Set random seed to guarantee reproducible results.
|
||||
|
||||
|
@ -12,6 +12,7 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
||||
from .parrots_wrapper import TORCH_VERSION
|
||||
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
|
||||
scandir, symlink)
|
||||
from .setup_env import set_multi_processing
|
||||
from .version_utils import digit_version, get_git_hash
|
||||
|
||||
__all__ = [
|
||||
@ -23,5 +24,6 @@ __all__ = [
|
||||
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
||||
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
||||
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
||||
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin'
|
||||
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin',
|
||||
'set_multi_processing'
|
||||
]
|
||||
|
61
mmengine/utils/setup_env.py
Normal file
61
mmengine/utils/setup_env.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import platform
|
||||
import warnings
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
def set_multi_processing(mp_start_method: str = 'fork',
|
||||
opencv_num_threads: int = 0,
|
||||
distributed: bool = False) -> None:
|
||||
"""Set multi-processing related environment.
|
||||
|
||||
Args:
|
||||
mp_start_method (str): Set the method which should be used to start
|
||||
child processes. Defaults to 'fork'.
|
||||
opencv_num_threads (int): Number of threads for opencv.
|
||||
Defaults to 0.
|
||||
distributed (bool): True if distributed environment.
|
||||
Defaults to False.
|
||||
"""
|
||||
# set multi-process start method as `fork` to speed up the training
|
||||
if platform.system() != 'Windows':
|
||||
current_method = mp.get_start_method(allow_none=True)
|
||||
if (current_method is not None and current_method != mp_start_method):
|
||||
warnings.warn(
|
||||
f'Multi-processing start method `{mp_start_method}` is '
|
||||
f'different from the previous setting `{current_method}`.'
|
||||
f'It will be force set to `{mp_start_method}`. You can '
|
||||
'change this behavior by changing `mp_start_method` in '
|
||||
'your config.')
|
||||
mp.set_start_method(mp_start_method, force=True)
|
||||
|
||||
try:
|
||||
import cv2
|
||||
|
||||
# disable opencv multithreading to avoid system being overloaded
|
||||
cv2.setNumThreads(opencv_num_threads)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# setup OMP threads
|
||||
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
||||
if 'OMP_NUM_THREADS' not in os.environ and distributed:
|
||||
omp_num_threads = 1
|
||||
warnings.warn(
|
||||
'Setting OMP_NUM_THREADS environment variable for each process'
|
||||
f' to be {omp_num_threads} in default, to avoid your system '
|
||||
'being overloaded, please further tune the variable for '
|
||||
'optimal performance in your application as needed.')
|
||||
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
|
||||
|
||||
# setup MKL threads
|
||||
if 'MKL_NUM_THREADS' not in os.environ and distributed:
|
||||
mkl_num_threads = 1
|
||||
warnings.warn(
|
||||
'Setting MKL_NUM_THREADS environment variable for each process'
|
||||
f' to be {mkl_num_threads} in default, to avoid your system '
|
||||
'being overloaded, please further tune the variable for '
|
||||
'optimal performance in your application as needed.')
|
||||
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
|
58
tests/test_utils/test_setup_env.py
Normal file
58
tests/test_utils/test_setup_env.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import platform
|
||||
|
||||
import cv2
|
||||
|
||||
from mmengine.utils import set_multi_processing
|
||||
|
||||
|
||||
def test_setup_multi_processes():
|
||||
# temp save system setting
|
||||
sys_start_mehod = mp.get_start_method(allow_none=True)
|
||||
sys_cv_threads = cv2.getNumThreads()
|
||||
# pop and temp save system env vars
|
||||
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
|
||||
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
|
||||
|
||||
# test distributed
|
||||
set_multi_processing(distributed=True)
|
||||
assert os.getenv('OMP_NUM_THREADS') == '1'
|
||||
assert os.getenv('MKL_NUM_THREADS') == '1'
|
||||
# when set to 0, the num threads will be 1
|
||||
assert cv2.getNumThreads() == 1
|
||||
if platform.system() != 'Windows':
|
||||
assert mp.get_start_method() == 'fork'
|
||||
|
||||
# test num workers <= 1
|
||||
os.environ.pop('OMP_NUM_THREADS')
|
||||
os.environ.pop('MKL_NUM_THREADS')
|
||||
set_multi_processing(distributed=False)
|
||||
assert 'OMP_NUM_THREADS' not in os.environ
|
||||
assert 'MKL_NUM_THREADS' not in os.environ
|
||||
|
||||
# test manually set env var
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
set_multi_processing(distributed=False)
|
||||
assert os.getenv('OMP_NUM_THREADS') == '4'
|
||||
|
||||
# test manually set opencv threads and mp start method
|
||||
config = dict(
|
||||
mp_start_method='spawn', opencv_num_threads=4, distributed=True)
|
||||
set_multi_processing(**config)
|
||||
assert cv2.getNumThreads() == 4
|
||||
assert mp.get_start_method() == 'spawn'
|
||||
|
||||
# revert setting to avoid affecting other programs
|
||||
if sys_start_mehod:
|
||||
mp.set_start_method(sys_start_mehod, force=True)
|
||||
cv2.setNumThreads(sys_cv_threads)
|
||||
if sys_omp_threads:
|
||||
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
|
||||
else:
|
||||
os.environ.pop('OMP_NUM_THREADS')
|
||||
if sys_mkl_threads:
|
||||
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
|
||||
else:
|
||||
os.environ.pop('MKL_NUM_THREADS')
|
Loading…
x
Reference in New Issue
Block a user