Move get_max_cuda_memory and set_multi_processing to public function (#250)

* move get_max_cuda_memory and set_multi_processing to a public function

* fix lint

* fix lint

* fix lint

* delete _set_multi_processing

* fix error

* rename
This commit is contained in:
Haian Huang(深度眸) 2022-05-24 19:36:55 +08:00 committed by GitHub
parent a976257ca9
commit 8d3bd4dfef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 164 additions and 67 deletions

View File

@ -3,6 +3,7 @@
from .config import * from .config import *
from .data import * from .data import *
from .dataset import * from .dataset import *
from .device import *
from .fileio import * from .fileio import *
from .hooks import * from .hooks import *
from .logging import * from .logging import *

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import get_max_cuda_memory
__all__ = ['get_max_cuda_memory']

27
mmengine/device/utils.py Normal file
View File

@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
a given device. By default, this returns the peak allocated memory since
the beginning of this program.
Args:
device (torch.device, optional): selected device. Returns
statistic for the current device, given by
:func:`~torch.cuda.current_device`, if ``device`` is None.
Defaults to None.
Returns:
int: The maximum GPU memory occupied by tensors in megabytes
for a given device.
"""
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
torch.cuda.reset_peak_memory_stats()
return int(mem_mb.item())

View File

@ -6,6 +6,7 @@ from typing import List, Optional, Tuple
import torch import torch
from mmengine.device import get_max_cuda_memory
from mmengine.registry import LOG_PROCESSOR from mmengine.registry import LOG_PROCESSOR
@ -345,13 +346,9 @@ class LogProcessor:
The maximum GPU memory occupied by tensors in megabytes for a given The maximum GPU memory occupied by tensors in megabytes for a given
device. device.
""" """
device = getattr(runner.model, 'output_device', None) device = getattr(runner.model, 'output_device', None)
mem = torch.cuda.max_memory_allocated(device=device) return get_max_cuda_memory(device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
torch.cuda.reset_peak_memory_stats()
return int(mem_mb.item())
def _get_iter(self, runner, batch_idx: int = None) -> int: def _get_iter(self, runner, batch_idx: int = None) -> int:
"""Get current iteration index. """Get current iteration index.

View File

@ -1,7 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import multiprocessing as mp
import os
import os.path as osp import os.path as osp
import platform import platform
import random import random
@ -34,7 +32,8 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
count_registered_modules) count_registered_modules)
from mmengine.registry.root import LOG_PROCESSOR from mmengine.registry.root import LOG_PROCESSOR
from mmengine.utils import (TORCH_VERSION, digit_version, from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, is_list_of, symlink) find_latest_checkpoint, is_list_of,
set_multi_processing, symlink)
from mmengine.visualization import Visualizer from mmengine.visualization import Visualizer
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
@ -582,12 +581,13 @@ class Runner:
if env_cfg.get('cudnn_benchmark'): if env_cfg.get('cudnn_benchmark'):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
if env_cfg.get('mp_cfg') is not None: mp_cfg: dict = env_cfg.get('mp_cfg', {})
self._set_multi_processing(**env_cfg.get('mp_cfg')) # type: ignore set_multi_processing(**mp_cfg, distributed=self.distributed)
# init distributed env first, since logger depends on the dist info. # init distributed env first, since logger depends on the dist info.
if self.distributed and env_cfg.get('dist_cfg') is not None: if self.distributed:
init_dist(self.launcher, **env_cfg.get('dist_cfg')) # type: ignore dist_cfg: dict = env_cfg.get('dist_cfg', {})
init_dist(self.launcher, **dist_cfg)
self._rank, self._world_size = get_dist_info() self._rank, self._world_size = get_dist_info()
@ -597,59 +597,6 @@ class Runner:
self._timestamp = time.strftime('%Y%m%d_%H%M%S', self._timestamp = time.strftime('%Y%m%d_%H%M%S',
time.localtime(timestamp.item())) time.localtime(timestamp.item()))
def _set_multi_processing(self,
mp_start_method: str = 'fork',
opencv_num_threads: int = 0) -> None:
"""Set multi-processing related environment.
Args:
mp_start_method (str): Set the method which should be used to start
child processes. Defaults to 'fork'.
opencv_num_threads (int): Number of threads for opencv.
Defaults to 0.
"""
# set multi-process start method as `fork` to speed up the training
if platform.system() != 'Windows':
current_method = mp.get_start_method(allow_none=True)
if (current_method is not None
and current_method != mp_start_method):
warnings.warn(
f'Multi-processing start method `{mp_start_method}` is '
f'different from the previous setting `{current_method}`.'
f'It will be force set to `{mp_start_method}`. You can '
'change this behavior by changing `mp_start_method` in '
'your config.')
mp.set_start_method(mp_start_method, force=True)
try:
import cv2
# disable opencv multithreading to avoid system being overloaded
cv2.setNumThreads(opencv_num_threads)
except ImportError:
pass
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if 'OMP_NUM_THREADS' not in os.environ and self.distributed:
omp_num_threads = 1
warnings.warn(
'Setting OMP_NUM_THREADS environment variable for each process'
f' to be {omp_num_threads} in default, to avoid your system '
'being overloaded, please further tune the variable for '
'optimal performance in your application as needed.')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and self.distributed:
mkl_num_threads = 1
warnings.warn(
'Setting MKL_NUM_THREADS environment variable for each process'
f' to be {mkl_num_threads} in default, to avoid your system '
'being overloaded, please further tune the variable for '
'optimal performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def set_randomness(self, seed, deterministic: bool = False) -> None: def set_randomness(self, seed, deterministic: bool = False) -> None:
"""Set random seed to guarantee reproducible results. """Set random seed to guarantee reproducible results.

View File

@ -12,6 +12,7 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
from .parrots_wrapper import TORCH_VERSION from .parrots_wrapper import TORCH_VERSION
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink) scandir, symlink)
from .setup_env import set_multi_processing
from .version_utils import digit_version, get_git_hash from .version_utils import digit_version, get_git_hash
__all__ = [ __all__ = [
@ -23,5 +24,6 @@ __all__ = [
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden', 'has_method', 'mmcv_full_available', 'is_method_overridden', 'has_method', 'mmcv_full_available',
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url', 'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin' 'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin',
'set_multi_processing'
] ]

View File

@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import platform
import warnings
import torch.multiprocessing as mp
def set_multi_processing(mp_start_method: str = 'fork',
opencv_num_threads: int = 0,
distributed: bool = False) -> None:
"""Set multi-processing related environment.
Args:
mp_start_method (str): Set the method which should be used to start
child processes. Defaults to 'fork'.
opencv_num_threads (int): Number of threads for opencv.
Defaults to 0.
distributed (bool): True if distributed environment.
Defaults to False.
"""
# set multi-process start method as `fork` to speed up the training
if platform.system() != 'Windows':
current_method = mp.get_start_method(allow_none=True)
if (current_method is not None and current_method != mp_start_method):
warnings.warn(
f'Multi-processing start method `{mp_start_method}` is '
f'different from the previous setting `{current_method}`.'
f'It will be force set to `{mp_start_method}`. You can '
'change this behavior by changing `mp_start_method` in '
'your config.')
mp.set_start_method(mp_start_method, force=True)
try:
import cv2
# disable opencv multithreading to avoid system being overloaded
cv2.setNumThreads(opencv_num_threads)
except ImportError:
pass
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if 'OMP_NUM_THREADS' not in os.environ and distributed:
omp_num_threads = 1
warnings.warn(
'Setting OMP_NUM_THREADS environment variable for each process'
f' to be {omp_num_threads} in default, to avoid your system '
'being overloaded, please further tune the variable for '
'optimal performance in your application as needed.')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and distributed:
mkl_num_threads = 1
warnings.warn(
'Setting MKL_NUM_THREADS environment variable for each process'
f' to be {mkl_num_threads} in default, to avoid your system '
'being overloaded, please further tune the variable for '
'optimal performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)

View File

@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
import multiprocessing as mp
import os
import platform
import cv2
from mmengine.utils import set_multi_processing
def test_setup_multi_processes():
# temp save system setting
sys_start_mehod = mp.get_start_method(allow_none=True)
sys_cv_threads = cv2.getNumThreads()
# pop and temp save system env vars
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
# test distributed
set_multi_processing(distributed=True)
assert os.getenv('OMP_NUM_THREADS') == '1'
assert os.getenv('MKL_NUM_THREADS') == '1'
# when set to 0, the num threads will be 1
assert cv2.getNumThreads() == 1
if platform.system() != 'Windows':
assert mp.get_start_method() == 'fork'
# test num workers <= 1
os.environ.pop('OMP_NUM_THREADS')
os.environ.pop('MKL_NUM_THREADS')
set_multi_processing(distributed=False)
assert 'OMP_NUM_THREADS' not in os.environ
assert 'MKL_NUM_THREADS' not in os.environ
# test manually set env var
os.environ['OMP_NUM_THREADS'] = '4'
set_multi_processing(distributed=False)
assert os.getenv('OMP_NUM_THREADS') == '4'
# test manually set opencv threads and mp start method
config = dict(
mp_start_method='spawn', opencv_num_threads=4, distributed=True)
set_multi_processing(**config)
assert cv2.getNumThreads() == 4
assert mp.get_start_method() == 'spawn'
# revert setting to avoid affecting other programs
if sys_start_mehod:
mp.set_start_method(sys_start_mehod, force=True)
cv2.setNumThreads(sys_cv_threads)
if sys_omp_threads:
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
else:
os.environ.pop('OMP_NUM_THREADS')
if sys_mkl_threads:
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
else:
os.environ.pop('MKL_NUM_THREADS')