mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Move get_max_cuda_memory and set_multi_processing to public function (#250)
* move get_max_cuda_memory and set_multi_processing to a public function * fix lint * fix lint * fix lint * delete _set_multi_processing * fix error * rename
This commit is contained in:
parent
a976257ca9
commit
8d3bd4dfef
@ -3,6 +3,7 @@
|
|||||||
from .config import *
|
from .config import *
|
||||||
from .data import *
|
from .data import *
|
||||||
from .dataset import *
|
from .dataset import *
|
||||||
|
from .device import *
|
||||||
from .fileio import *
|
from .fileio import *
|
||||||
from .hooks import *
|
from .hooks import *
|
||||||
from .logging import *
|
from .logging import *
|
||||||
|
4
mmengine/device/__init__.py
Normal file
4
mmengine/device/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .utils import get_max_cuda_memory
|
||||||
|
|
||||||
|
__all__ = ['get_max_cuda_memory']
|
27
mmengine/device/utils.py
Normal file
27
mmengine/device/utils.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
|
||||||
|
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
|
||||||
|
a given device. By default, this returns the peak allocated memory since
|
||||||
|
the beginning of this program.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device, optional): selected device. Returns
|
||||||
|
statistic for the current device, given by
|
||||||
|
:func:`~torch.cuda.current_device`, if ``device`` is None.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The maximum GPU memory occupied by tensors in megabytes
|
||||||
|
for a given device.
|
||||||
|
"""
|
||||||
|
mem = torch.cuda.max_memory_allocated(device=device)
|
||||||
|
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
|
||||||
|
dtype=torch.int,
|
||||||
|
device=device)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
return int(mem_mb.item())
|
@ -6,6 +6,7 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from mmengine.device import get_max_cuda_memory
|
||||||
from mmengine.registry import LOG_PROCESSOR
|
from mmengine.registry import LOG_PROCESSOR
|
||||||
|
|
||||||
|
|
||||||
@ -345,13 +346,9 @@ class LogProcessor:
|
|||||||
The maximum GPU memory occupied by tensors in megabytes for a given
|
The maximum GPU memory occupied by tensors in megabytes for a given
|
||||||
device.
|
device.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
device = getattr(runner.model, 'output_device', None)
|
device = getattr(runner.model, 'output_device', None)
|
||||||
mem = torch.cuda.max_memory_allocated(device=device)
|
return get_max_cuda_memory(device)
|
||||||
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
|
|
||||||
dtype=torch.int,
|
|
||||||
device=device)
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
return int(mem_mb.item())
|
|
||||||
|
|
||||||
def _get_iter(self, runner, batch_idx: int = None) -> int:
|
def _get_iter(self, runner, batch_idx: int = None) -> int:
|
||||||
"""Get current iteration index.
|
"""Get current iteration index.
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import copy
|
import copy
|
||||||
import multiprocessing as mp
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
@ -34,7 +32,8 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
|||||||
count_registered_modules)
|
count_registered_modules)
|
||||||
from mmengine.registry.root import LOG_PROCESSOR
|
from mmengine.registry.root import LOG_PROCESSOR
|
||||||
from mmengine.utils import (TORCH_VERSION, digit_version,
|
from mmengine.utils import (TORCH_VERSION, digit_version,
|
||||||
find_latest_checkpoint, is_list_of, symlink)
|
find_latest_checkpoint, is_list_of,
|
||||||
|
set_multi_processing, symlink)
|
||||||
from mmengine.visualization import Visualizer
|
from mmengine.visualization import Visualizer
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||||
@ -582,12 +581,13 @@ class Runner:
|
|||||||
if env_cfg.get('cudnn_benchmark'):
|
if env_cfg.get('cudnn_benchmark'):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
if env_cfg.get('mp_cfg') is not None:
|
mp_cfg: dict = env_cfg.get('mp_cfg', {})
|
||||||
self._set_multi_processing(**env_cfg.get('mp_cfg')) # type: ignore
|
set_multi_processing(**mp_cfg, distributed=self.distributed)
|
||||||
|
|
||||||
# init distributed env first, since logger depends on the dist info.
|
# init distributed env first, since logger depends on the dist info.
|
||||||
if self.distributed and env_cfg.get('dist_cfg') is not None:
|
if self.distributed:
|
||||||
init_dist(self.launcher, **env_cfg.get('dist_cfg')) # type: ignore
|
dist_cfg: dict = env_cfg.get('dist_cfg', {})
|
||||||
|
init_dist(self.launcher, **dist_cfg)
|
||||||
|
|
||||||
self._rank, self._world_size = get_dist_info()
|
self._rank, self._world_size = get_dist_info()
|
||||||
|
|
||||||
@ -597,59 +597,6 @@ class Runner:
|
|||||||
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
|
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
|
||||||
time.localtime(timestamp.item()))
|
time.localtime(timestamp.item()))
|
||||||
|
|
||||||
def _set_multi_processing(self,
|
|
||||||
mp_start_method: str = 'fork',
|
|
||||||
opencv_num_threads: int = 0) -> None:
|
|
||||||
"""Set multi-processing related environment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mp_start_method (str): Set the method which should be used to start
|
|
||||||
child processes. Defaults to 'fork'.
|
|
||||||
opencv_num_threads (int): Number of threads for opencv.
|
|
||||||
Defaults to 0.
|
|
||||||
"""
|
|
||||||
# set multi-process start method as `fork` to speed up the training
|
|
||||||
if platform.system() != 'Windows':
|
|
||||||
current_method = mp.get_start_method(allow_none=True)
|
|
||||||
if (current_method is not None
|
|
||||||
and current_method != mp_start_method):
|
|
||||||
warnings.warn(
|
|
||||||
f'Multi-processing start method `{mp_start_method}` is '
|
|
||||||
f'different from the previous setting `{current_method}`.'
|
|
||||||
f'It will be force set to `{mp_start_method}`. You can '
|
|
||||||
'change this behavior by changing `mp_start_method` in '
|
|
||||||
'your config.')
|
|
||||||
mp.set_start_method(mp_start_method, force=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
# disable opencv multithreading to avoid system being overloaded
|
|
||||||
cv2.setNumThreads(opencv_num_threads)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# setup OMP threads
|
|
||||||
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
|
||||||
if 'OMP_NUM_THREADS' not in os.environ and self.distributed:
|
|
||||||
omp_num_threads = 1
|
|
||||||
warnings.warn(
|
|
||||||
'Setting OMP_NUM_THREADS environment variable for each process'
|
|
||||||
f' to be {omp_num_threads} in default, to avoid your system '
|
|
||||||
'being overloaded, please further tune the variable for '
|
|
||||||
'optimal performance in your application as needed.')
|
|
||||||
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
|
|
||||||
|
|
||||||
# setup MKL threads
|
|
||||||
if 'MKL_NUM_THREADS' not in os.environ and self.distributed:
|
|
||||||
mkl_num_threads = 1
|
|
||||||
warnings.warn(
|
|
||||||
'Setting MKL_NUM_THREADS environment variable for each process'
|
|
||||||
f' to be {mkl_num_threads} in default, to avoid your system '
|
|
||||||
'being overloaded, please further tune the variable for '
|
|
||||||
'optimal performance in your application as needed.')
|
|
||||||
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
|
|
||||||
|
|
||||||
def set_randomness(self, seed, deterministic: bool = False) -> None:
|
def set_randomness(self, seed, deterministic: bool = False) -> None:
|
||||||
"""Set random seed to guarantee reproducible results.
|
"""Set random seed to guarantee reproducible results.
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
|||||||
from .parrots_wrapper import TORCH_VERSION
|
from .parrots_wrapper import TORCH_VERSION
|
||||||
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
|
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
|
||||||
scandir, symlink)
|
scandir, symlink)
|
||||||
|
from .setup_env import set_multi_processing
|
||||||
from .version_utils import digit_version, get_git_hash
|
from .version_utils import digit_version, get_git_hash
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -23,5 +24,6 @@ __all__ = [
|
|||||||
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
||||||
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
||||||
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
||||||
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin'
|
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin',
|
||||||
|
'set_multi_processing'
|
||||||
]
|
]
|
||||||
|
61
mmengine/utils/setup_env.py
Normal file
61
mmengine/utils/setup_env.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
|
||||||
|
def set_multi_processing(mp_start_method: str = 'fork',
|
||||||
|
opencv_num_threads: int = 0,
|
||||||
|
distributed: bool = False) -> None:
|
||||||
|
"""Set multi-processing related environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mp_start_method (str): Set the method which should be used to start
|
||||||
|
child processes. Defaults to 'fork'.
|
||||||
|
opencv_num_threads (int): Number of threads for opencv.
|
||||||
|
Defaults to 0.
|
||||||
|
distributed (bool): True if distributed environment.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
# set multi-process start method as `fork` to speed up the training
|
||||||
|
if platform.system() != 'Windows':
|
||||||
|
current_method = mp.get_start_method(allow_none=True)
|
||||||
|
if (current_method is not None and current_method != mp_start_method):
|
||||||
|
warnings.warn(
|
||||||
|
f'Multi-processing start method `{mp_start_method}` is '
|
||||||
|
f'different from the previous setting `{current_method}`.'
|
||||||
|
f'It will be force set to `{mp_start_method}`. You can '
|
||||||
|
'change this behavior by changing `mp_start_method` in '
|
||||||
|
'your config.')
|
||||||
|
mp.set_start_method(mp_start_method, force=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# disable opencv multithreading to avoid system being overloaded
|
||||||
|
cv2.setNumThreads(opencv_num_threads)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# setup OMP threads
|
||||||
|
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
||||||
|
if 'OMP_NUM_THREADS' not in os.environ and distributed:
|
||||||
|
omp_num_threads = 1
|
||||||
|
warnings.warn(
|
||||||
|
'Setting OMP_NUM_THREADS environment variable for each process'
|
||||||
|
f' to be {omp_num_threads} in default, to avoid your system '
|
||||||
|
'being overloaded, please further tune the variable for '
|
||||||
|
'optimal performance in your application as needed.')
|
||||||
|
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
|
||||||
|
|
||||||
|
# setup MKL threads
|
||||||
|
if 'MKL_NUM_THREADS' not in os.environ and distributed:
|
||||||
|
mkl_num_threads = 1
|
||||||
|
warnings.warn(
|
||||||
|
'Setting MKL_NUM_THREADS environment variable for each process'
|
||||||
|
f' to be {mkl_num_threads} in default, to avoid your system '
|
||||||
|
'being overloaded, please further tune the variable for '
|
||||||
|
'optimal performance in your application as needed.')
|
||||||
|
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
|
58
tests/test_utils/test_setup_env.py
Normal file
58
tests/test_utils/test_setup_env.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from mmengine.utils import set_multi_processing
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_multi_processes():
|
||||||
|
# temp save system setting
|
||||||
|
sys_start_mehod = mp.get_start_method(allow_none=True)
|
||||||
|
sys_cv_threads = cv2.getNumThreads()
|
||||||
|
# pop and temp save system env vars
|
||||||
|
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
|
||||||
|
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
|
||||||
|
|
||||||
|
# test distributed
|
||||||
|
set_multi_processing(distributed=True)
|
||||||
|
assert os.getenv('OMP_NUM_THREADS') == '1'
|
||||||
|
assert os.getenv('MKL_NUM_THREADS') == '1'
|
||||||
|
# when set to 0, the num threads will be 1
|
||||||
|
assert cv2.getNumThreads() == 1
|
||||||
|
if platform.system() != 'Windows':
|
||||||
|
assert mp.get_start_method() == 'fork'
|
||||||
|
|
||||||
|
# test num workers <= 1
|
||||||
|
os.environ.pop('OMP_NUM_THREADS')
|
||||||
|
os.environ.pop('MKL_NUM_THREADS')
|
||||||
|
set_multi_processing(distributed=False)
|
||||||
|
assert 'OMP_NUM_THREADS' not in os.environ
|
||||||
|
assert 'MKL_NUM_THREADS' not in os.environ
|
||||||
|
|
||||||
|
# test manually set env var
|
||||||
|
os.environ['OMP_NUM_THREADS'] = '4'
|
||||||
|
set_multi_processing(distributed=False)
|
||||||
|
assert os.getenv('OMP_NUM_THREADS') == '4'
|
||||||
|
|
||||||
|
# test manually set opencv threads and mp start method
|
||||||
|
config = dict(
|
||||||
|
mp_start_method='spawn', opencv_num_threads=4, distributed=True)
|
||||||
|
set_multi_processing(**config)
|
||||||
|
assert cv2.getNumThreads() == 4
|
||||||
|
assert mp.get_start_method() == 'spawn'
|
||||||
|
|
||||||
|
# revert setting to avoid affecting other programs
|
||||||
|
if sys_start_mehod:
|
||||||
|
mp.set_start_method(sys_start_mehod, force=True)
|
||||||
|
cv2.setNumThreads(sys_cv_threads)
|
||||||
|
if sys_omp_threads:
|
||||||
|
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
|
||||||
|
else:
|
||||||
|
os.environ.pop('OMP_NUM_THREADS')
|
||||||
|
if sys_mkl_threads:
|
||||||
|
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
|
||||||
|
else:
|
||||||
|
os.environ.pop('MKL_NUM_THREADS')
|
Loading…
x
Reference in New Issue
Block a user