mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add distributed module (#59)
* [Feature] Add distributed module * fix IS_DIST error * all_reduce_dict does operations in-place * support 'mean' operation * provide local group process * add tmpdir argument for collect_results * add unit tests * refactor unit tests * simplify steps to create multiple processes * minor fix * describe the different of *gather* in mmengine and pytorch * minor fix * add unit tests for nccl * test nccl backend in multiple gpu * add get_default_group function to handle different torch versions * minor fix * [Feature] Add distributed module * fix IS_DIST error * all_reduce_dict does operations in-place * support 'mean' operation * provide local group process * add tmpdir argument for collect_results * add unit tests * refactor unit tests * simplify steps to create multiple processes * minor fix * describe the different of *gather* in mmengine and pytorch * minor fix * add unit tests for nccl * test nccl backend in multiple gpu * add get_default_group function to handle different torch versions * minor fix * minor fix * handle torch1.5 * handle torch1.5 * minor fix * fix typo * refactor unit tests * nccl does not support gather and gather_object * fix gather * fix collect_results_cpu * fix collect_results and refactor unit tests * fix collect_results unit tests * handle torch.cat in torch1.5 * refine docstring * refine docstring * fix comments * fix comments
This commit is contained in:
parent
817eb89ac2
commit
c6a8d72c5e
1
.gitignore
vendored
1
.gitignore
vendored
@ -10,7 +10,6 @@ __pycache__/
|
|||||||
.Python
|
.Python
|
||||||
build/
|
build/
|
||||||
develop-eggs/
|
develop-eggs/
|
||||||
dist/
|
|
||||||
downloads/
|
downloads/
|
||||||
eggs/
|
eggs/
|
||||||
.eggs/
|
.eggs/
|
||||||
|
@ -7,3 +7,8 @@ Data
|
|||||||
--------
|
--------
|
||||||
.. automodule:: mmengine.data
|
.. automodule:: mmengine.data
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
Distributed
|
||||||
|
-----------
|
||||||
|
.. automodule:: mmengine.dist
|
||||||
|
:members:
|
||||||
|
@ -7,3 +7,8 @@ Data
|
|||||||
--------
|
--------
|
||||||
.. automodule:: mmengine.data
|
.. automodule:: mmengine.data
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
Distributed
|
||||||
|
-----------
|
||||||
|
.. automodule:: mmengine.dist
|
||||||
|
:members:
|
||||||
|
19
mmengine/dist/__init__.py
vendored
Normal file
19
mmengine/dist/__init__.py
vendored
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .dist import (all_gather_object, all_reduce, all_gather, all_reduce_dict,
|
||||||
|
collect_results, gather, broadcast, gather_object,
|
||||||
|
sync_random_seed, broadcast_object_list,
|
||||||
|
collect_results_cpu, collect_results_gpu)
|
||||||
|
from .utils import (get_dist_info, init_dist, init_local_group, get_backend,
|
||||||
|
get_world_size, get_rank, get_local_size, get_local_rank,
|
||||||
|
is_main_process, master_only, barrier, get_local_group,
|
||||||
|
is_distributed, get_default_group)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'all_gather_object', 'all_reduce', 'all_gather', 'all_reduce_dict',
|
||||||
|
'collect_results', 'collect_results_cpu', 'collect_results_gpu', 'gather',
|
||||||
|
'broadcast', 'gather_object', 'sync_random_seed', 'broadcast_object_list',
|
||||||
|
'get_dist_info', 'init_dist', 'init_local_group', 'get_backend',
|
||||||
|
'get_world_size', 'get_rank', 'get_local_size', 'get_local_group',
|
||||||
|
'get_local_rank', 'is_main_process', 'master_only', 'barrier',
|
||||||
|
'is_distributed', 'get_default_group'
|
||||||
|
]
|
1023
mmengine/dist/dist.py
vendored
Normal file
1023
mmengine/dist/dist.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
335
mmengine/dist/utils.py
vendored
Normal file
335
mmengine/dist/utils.py
vendored
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch import distributed as dist
|
||||||
|
|
||||||
|
_LOCAL_PROCESS_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_distributed() -> bool:
|
||||||
|
"""Return True if distributed environment has been initialized."""
|
||||||
|
return dist.is_available() and dist.is_initialized()
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_group() -> Optional[dist.ProcessGroup]:
|
||||||
|
"""Return local process group."""
|
||||||
|
if not is_distributed():
|
||||||
|
return None
|
||||||
|
|
||||||
|
if _LOCAL_PROCESS_GROUP is None:
|
||||||
|
raise RuntimeError('Local process group is not created, please use '
|
||||||
|
'`init_local_group` to setup local process group.')
|
||||||
|
|
||||||
|
return _LOCAL_PROCESS_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_group() -> Optional[dist.ProcessGroup]:
|
||||||
|
"""Return default process group."""
|
||||||
|
|
||||||
|
return dist.distributed_c10d._get_default_group()
|
||||||
|
|
||||||
|
|
||||||
|
def init_dist(launcher, backend='nccl', **kwargs) -> None:
|
||||||
|
"""Initialize distributed environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
launcher (str): Way to launcher multi processes. Supported launchers
|
||||||
|
are 'pytorch', 'mpi' and 'slurm'.
|
||||||
|
backend (str): Communication Backends. Supported backends are 'nccl',
|
||||||
|
'gloo' and 'mpi'. Defaults to 'nccl'.
|
||||||
|
**kwargs: keyword arguments are passed to ``init_process_group``.
|
||||||
|
"""
|
||||||
|
if mp.get_start_method(allow_none=True) is None:
|
||||||
|
mp.set_start_method('spawn')
|
||||||
|
if launcher == 'pytorch':
|
||||||
|
_init_dist_pytorch(backend, **kwargs)
|
||||||
|
elif launcher == 'mpi':
|
||||||
|
_init_dist_mpi(backend, **kwargs)
|
||||||
|
elif launcher == 'slurm':
|
||||||
|
_init_dist_slurm(backend, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid launcher type: {launcher}')
|
||||||
|
|
||||||
|
|
||||||
|
def _init_dist_pytorch(backend, **kwargs) -> None:
|
||||||
|
"""Initialize distributed environment with PyTorch launcher.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend (str): Backend of torch.distributed. Supported backends are
|
||||||
|
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
|
||||||
|
**kwargs: keyword arguments are passed to ``init_process_group``.
|
||||||
|
"""
|
||||||
|
# TODO: use local_rank instead of rank % num_gpus
|
||||||
|
rank = int(os.environ['RANK'])
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(rank % num_gpus)
|
||||||
|
dist.init_process_group(backend=backend, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_dist_mpi(backend, **kwargs) -> None:
|
||||||
|
"""Initialize distributed environment with MPI launcher.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend (str): Backend of torch.distributed. Supported backends are
|
||||||
|
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
|
||||||
|
**kwargs: keyword arguments are passed to ``init_process_group``.
|
||||||
|
"""
|
||||||
|
# TODO: use local_rank instead of rank % num_gpus
|
||||||
|
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(rank % num_gpus)
|
||||||
|
dist.init_process_group(backend=backend, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_dist_slurm(backend, port=None) -> None:
|
||||||
|
"""Initialize slurm distributed training environment.
|
||||||
|
|
||||||
|
If argument ``port`` is not specified, then the master port will be system
|
||||||
|
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
||||||
|
environment variable, then a default port ``29500`` will be used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend (str): Backend of torch.distributed.
|
||||||
|
port (int, optional): Master port. Defaults to None.
|
||||||
|
|
||||||
|
TODO: https://github.com/open-mmlab/mmcv/pull/1682
|
||||||
|
"""
|
||||||
|
proc_id = int(os.environ['SLURM_PROCID'])
|
||||||
|
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||||
|
node_list = os.environ['SLURM_NODELIST']
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(proc_id % num_gpus)
|
||||||
|
addr = subprocess.getoutput(
|
||||||
|
f'scontrol show hostname {node_list} | head -n1')
|
||||||
|
# specify master port
|
||||||
|
if port is not None:
|
||||||
|
os.environ['MASTER_PORT'] = str(port)
|
||||||
|
elif 'MASTER_PORT' in os.environ:
|
||||||
|
pass # use MASTER_PORT in the environment variable
|
||||||
|
else:
|
||||||
|
# 29500 is torch.distributed default port
|
||||||
|
os.environ['MASTER_PORT'] = '29500'
|
||||||
|
# use MASTER_ADDR in the environment variable if it already exists
|
||||||
|
if 'MASTER_ADDR' not in os.environ:
|
||||||
|
os.environ['MASTER_ADDR'] = addr
|
||||||
|
os.environ['WORLD_SIZE'] = str(ntasks)
|
||||||
|
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
||||||
|
os.environ['RANK'] = str(proc_id)
|
||||||
|
dist.init_process_group(backend=backend)
|
||||||
|
|
||||||
|
|
||||||
|
def init_local_group(node_rank: int, num_gpus_per_node: int):
|
||||||
|
"""Setup the local process group.
|
||||||
|
|
||||||
|
Setup a process group which only includes processes that on the same
|
||||||
|
machine as the current process.
|
||||||
|
|
||||||
|
The code is modified from
|
||||||
|
https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_rank (int): Rank of machines used for training.
|
||||||
|
num_gpus_per_node (int): Number of gpus used for training in a single
|
||||||
|
machine.
|
||||||
|
""" # noqa: W501
|
||||||
|
global _LOCAL_PROCESS_GROUP
|
||||||
|
assert _LOCAL_PROCESS_GROUP is None
|
||||||
|
|
||||||
|
ranks = list(
|
||||||
|
range(node_rank * num_gpus_per_node,
|
||||||
|
(node_rank + 1) * num_gpus_per_node))
|
||||||
|
_LOCAL_PROCESS_GROUP = dist.new_group(ranks)
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend(group: Optional[dist.ProcessGroup] = None) -> Optional[str]:
|
||||||
|
"""Return the backend of the given process group.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Calling ``get_backend`` in non-distributed environment will return
|
||||||
|
None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group (ProcessGroup, optional): The process group to work on. The
|
||||||
|
default is the general main process group. If another specific
|
||||||
|
group is specified, the calling process must be part of
|
||||||
|
:attr:`group`. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or None: Return the backend of the given process group as a lower
|
||||||
|
case string if in distributed environment, otherwise None.
|
||||||
|
"""
|
||||||
|
if is_distributed():
|
||||||
|
# handle low versions of torch like 1.5.0 which does not support
|
||||||
|
# passing in None for group argument
|
||||||
|
if group is None:
|
||||||
|
group = get_default_group()
|
||||||
|
return dist.get_backend(group)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
|
||||||
|
"""Return the number of the given process group.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Calling ``get_world_size`` in non-distributed environment will return
|
||||||
|
1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group (ProcessGroup, optional): The process group to work on. If None,
|
||||||
|
the default process group will be used. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Return the number of processes of the given process group if in
|
||||||
|
distributed environment, otherwise 1.
|
||||||
|
"""
|
||||||
|
if is_distributed():
|
||||||
|
# handle low versions of torch like 1.5.0 which does not support
|
||||||
|
# passing in None for group argument
|
||||||
|
if group is None:
|
||||||
|
group = get_default_group()
|
||||||
|
return dist.get_world_size(group)
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
|
||||||
|
"""Return the rank of the given process group.
|
||||||
|
|
||||||
|
Rank is a unique identifier assigned to each process within a distributed
|
||||||
|
process group. They are always consecutive integers ranging from 0 to
|
||||||
|
``world_size``.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Calling ``get_rank`` in non-distributed environment will return 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group (ProcessGroup, optional): The process group to work on. If None,
|
||||||
|
the default process group will be used. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Return the rank of the process group if in distributed
|
||||||
|
environment, otherwise 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if is_distributed():
|
||||||
|
# handle low versions of torch like 1.5.0 which does not support
|
||||||
|
# passing in None for group argument
|
||||||
|
if group is None:
|
||||||
|
group = get_default_group()
|
||||||
|
return dist.get_rank(group)
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_size() -> int:
|
||||||
|
"""Return the number of the current node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Return the number of processes in the current node if in
|
||||||
|
distributed environment, otherwise 1.
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if _LOCAL_PROCESS_GROUP is None:
|
||||||
|
raise RuntimeError('Local process group is not created, please use '
|
||||||
|
'`init_local_group` to setup local process group.')
|
||||||
|
|
||||||
|
return dist.get_world_size(_LOCAL_PROCESS_GROUP)
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_rank() -> int:
|
||||||
|
"""Return the rank of current process in the current node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Return the rank of current process in the current node if in
|
||||||
|
distributed environment, otherwise 0
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if _LOCAL_PROCESS_GROUP is None:
|
||||||
|
raise RuntimeError('Local process group is not created, please use '
|
||||||
|
'`init_local_group` to setup local process group.')
|
||||||
|
|
||||||
|
return dist.get_rank(_LOCAL_PROCESS_GROUP)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dist_info(
|
||||||
|
group: Optional[dist.ProcessGroup] = None) -> Tuple[int, int]:
|
||||||
|
"""Get distributed information of the given process group.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Calling ``get_dist_info`` in non-distributed environment will return
|
||||||
|
(0, 1).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group (ProcessGroup, optional): The process group to work on. If None,
|
||||||
|
the default process group will be used. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[int, int]: Return a tuple containing the ``world_size`` and
|
||||||
|
``rank``.
|
||||||
|
"""
|
||||||
|
world_size = get_world_size(group)
|
||||||
|
rank = get_rank(group)
|
||||||
|
return rank, world_size
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process(group: Optional[dist.ProcessGroup] = None) -> bool:
|
||||||
|
"""Whether the current rank of the given process group is equal to 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group (ProcessGroup, optional): The process group to work on. If None,
|
||||||
|
the default process group will be used. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Return True if the current rank of the given process group is
|
||||||
|
equal to 0, otherwise False.
|
||||||
|
"""
|
||||||
|
return get_rank(group) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def master_only(func: Callable) -> Callable:
|
||||||
|
"""Decorate those methods which should be executed in master process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (callable): Function to be decorated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
callable: Return decorated function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if is_main_process():
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def barrier(group: Optional[dist.ProcessGroup] = None) -> None:
|
||||||
|
"""Synchronize all processes from the given process group.
|
||||||
|
|
||||||
|
This collective blocks processes until the whole group enters this
|
||||||
|
function.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Calling ``barrier`` in non-distributed environment will do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group (ProcessGroup, optional): The process group to work on. If None,
|
||||||
|
the default process group will be used. Defaults to None.
|
||||||
|
"""
|
||||||
|
if is_distributed():
|
||||||
|
# handle low versions of torch like 1.5.0 which does not support
|
||||||
|
# passing in None for group argument
|
||||||
|
if group is None:
|
||||||
|
group = get_default_group()
|
||||||
|
dist.barrier(group)
|
376
tests/test_dist/test_dist.py
Normal file
376
tests/test_dist/test_dist.py
Normal file
@ -0,0 +1,376 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import mmengine.dist as dist
|
||||||
|
from mmengine.dist.dist import sync_random_seed
|
||||||
|
from mmengine.utils import TORCH_VERSION, digit_version
|
||||||
|
|
||||||
|
|
||||||
|
def _test_all_reduce_non_dist():
|
||||||
|
data = torch.arange(2, dtype=torch.int64)
|
||||||
|
expected = torch.arange(2, dtype=torch.int64)
|
||||||
|
dist.all_reduce(data)
|
||||||
|
assert torch.allclose(data, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_all_gather_non_dist():
|
||||||
|
data = torch.arange(2, dtype=torch.int64)
|
||||||
|
expected = torch.arange(2, dtype=torch.int64)
|
||||||
|
output = dist.all_gather(data)
|
||||||
|
assert torch.allclose(output[0], expected)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_gather_non_dist():
|
||||||
|
data = torch.arange(2, dtype=torch.int64)
|
||||||
|
expected = torch.arange(2, dtype=torch.int64)
|
||||||
|
output = dist.gather(data)
|
||||||
|
assert torch.allclose(output[0], expected)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_broadcast_non_dist():
|
||||||
|
data = torch.arange(2, dtype=torch.int64)
|
||||||
|
expected = torch.arange(2, dtype=torch.int64)
|
||||||
|
dist.broadcast(data)
|
||||||
|
assert torch.allclose(data, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@patch('numpy.random.randint', return_value=10)
|
||||||
|
def _test_sync_random_seed_no_dist(mock):
|
||||||
|
assert sync_random_seed() == 10
|
||||||
|
|
||||||
|
|
||||||
|
def _test_broadcast_object_list_no_dist():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# input should be list of object
|
||||||
|
dist.broadcast_object_list('foo')
|
||||||
|
|
||||||
|
data = ['foo', 12, {1: 2}]
|
||||||
|
expected = ['foo', 12, {1: 2}]
|
||||||
|
dist.broadcast_object_list(data)
|
||||||
|
assert data == expected
|
||||||
|
|
||||||
|
|
||||||
|
def _test_all_reduce_dict_no_dist():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# input should be dict
|
||||||
|
dist.all_reduce_dict('foo')
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'key1': torch.arange(2, dtype=torch.int64),
|
||||||
|
'key2': torch.arange(3, dtype=torch.int64)
|
||||||
|
}
|
||||||
|
expected = {
|
||||||
|
'key1': torch.arange(2, dtype=torch.int64),
|
||||||
|
'key2': torch.arange(3, dtype=torch.int64)
|
||||||
|
}
|
||||||
|
dist.all_reduce_dict(data)
|
||||||
|
for key in data:
|
||||||
|
assert torch.allclose(data[key], expected[key])
|
||||||
|
|
||||||
|
|
||||||
|
def _test_all_gather_object_no_dist():
|
||||||
|
data = 'foo'
|
||||||
|
expected = 'foo'
|
||||||
|
gather_objects = dist.all_gather_object(data)
|
||||||
|
assert gather_objects[0] == expected
|
||||||
|
|
||||||
|
|
||||||
|
def _test_gather_object_no_dist():
|
||||||
|
data = 'foo'
|
||||||
|
expected = 'foo'
|
||||||
|
gather_objects = dist.gather_object(data)
|
||||||
|
assert gather_objects[0] == expected
|
||||||
|
|
||||||
|
|
||||||
|
def _test_collect_results_non_dist():
|
||||||
|
data = ['foo', {1: 2}]
|
||||||
|
size = 2
|
||||||
|
expected = ['foo', {1: 2}]
|
||||||
|
|
||||||
|
# test `device=cpu`
|
||||||
|
output = dist.collect_results(data, size, device='cpu')
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
# test `device=gpu`
|
||||||
|
output = dist.collect_results(data, size, device='cpu')
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
def init_process(rank, world_size, functions, backend='gloo'):
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
|
os.environ['MASTER_PORT'] = '29505'
|
||||||
|
os.environ['RANK'] = str(rank)
|
||||||
|
dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
|
||||||
|
|
||||||
|
device = 'cpu' if backend == 'gloo' else 'cuda'
|
||||||
|
|
||||||
|
for func in functions:
|
||||||
|
func(device)
|
||||||
|
|
||||||
|
|
||||||
|
def main(functions, world_size=2, backend='gloo'):
|
||||||
|
try:
|
||||||
|
mp.spawn(
|
||||||
|
init_process,
|
||||||
|
args=(world_size, functions, backend),
|
||||||
|
nprocs=world_size)
|
||||||
|
except Exception:
|
||||||
|
pytest.fail(f'{backend} failed')
|
||||||
|
|
||||||
|
|
||||||
|
def _test_all_reduce_dist(device):
|
||||||
|
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
||||||
|
['sum', 'mean']):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = torch.tensor([1, 2], dtype=tensor_type).to(device)
|
||||||
|
else:
|
||||||
|
data = torch.tensor([3, 4], dtype=tensor_type).to(device)
|
||||||
|
|
||||||
|
if reduce_op == 'sum':
|
||||||
|
expected = torch.tensor([4, 6], dtype=tensor_type).to(device)
|
||||||
|
else:
|
||||||
|
expected = torch.tensor([2, 3], dtype=tensor_type).to(device)
|
||||||
|
|
||||||
|
dist.all_reduce(data, reduce_op)
|
||||||
|
assert torch.allclose(data, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_all_gather_dist(device):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = torch.tensor([0, 1]).to(device)
|
||||||
|
else:
|
||||||
|
data = torch.tensor([1, 2]).to(device)
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
torch.tensor([0, 1]).to(device),
|
||||||
|
torch.tensor([1, 2]).to(device)
|
||||||
|
]
|
||||||
|
|
||||||
|
output = dist.all_gather(data)
|
||||||
|
assert torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])
|
||||||
|
|
||||||
|
|
||||||
|
def _test_gather_dist(device):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = torch.tensor([0, 1]).to(device)
|
||||||
|
else:
|
||||||
|
data = torch.tensor([1, 2]).to(device)
|
||||||
|
|
||||||
|
output = dist.gather(data)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
expected = [
|
||||||
|
torch.tensor([0, 1]).to(device),
|
||||||
|
torch.tensor([1, 2]).to(device)
|
||||||
|
]
|
||||||
|
for i in range(2):
|
||||||
|
assert torch.allclose(output[i], expected[i])
|
||||||
|
else:
|
||||||
|
assert output == []
|
||||||
|
|
||||||
|
|
||||||
|
def _test_broadcast_dist(device):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = torch.tensor([0, 1]).to(device)
|
||||||
|
else:
|
||||||
|
data = torch.tensor([1, 2]).to(device)
|
||||||
|
|
||||||
|
expected = torch.tensor([0, 1]).to(device)
|
||||||
|
dist.broadcast(data, 0)
|
||||||
|
assert torch.allclose(data, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_sync_random_seed_dist(device):
|
||||||
|
with patch.object(
|
||||||
|
torch, 'tensor',
|
||||||
|
return_value=torch.tensor(1024).to(device)) as mock_tensor:
|
||||||
|
output = dist.sync_random_seed()
|
||||||
|
assert output == 1024
|
||||||
|
mock_tensor.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def _test_broadcast_object_list_dist(device):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = ['foo', 12, {1: 2}]
|
||||||
|
else:
|
||||||
|
data = [None, None, None]
|
||||||
|
|
||||||
|
expected = ['foo', 12, {1: 2}]
|
||||||
|
|
||||||
|
dist.broadcast_object_list(data)
|
||||||
|
|
||||||
|
assert data == expected
|
||||||
|
|
||||||
|
|
||||||
|
def _test_all_reduce_dict_dist(device):
|
||||||
|
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
||||||
|
['sum', 'mean']):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = {
|
||||||
|
'key1': torch.tensor([0, 1], dtype=tensor_type).to(device),
|
||||||
|
'key2': torch.tensor([1, 2], dtype=tensor_type).to(device)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
data = {
|
||||||
|
'key1': torch.tensor([2, 3], dtype=tensor_type).to(device),
|
||||||
|
'key2': torch.tensor([3, 4], dtype=tensor_type).to(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
if reduce_op == 'sum':
|
||||||
|
expected = {
|
||||||
|
'key1': torch.tensor([2, 4], dtype=tensor_type).to(device),
|
||||||
|
'key2': torch.tensor([4, 6], dtype=tensor_type).to(device)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
expected = {
|
||||||
|
'key1': torch.tensor([1, 2], dtype=tensor_type).to(device),
|
||||||
|
'key2': torch.tensor([2, 3], dtype=tensor_type).to(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
dist.all_reduce_dict(data, reduce_op)
|
||||||
|
|
||||||
|
for key in data:
|
||||||
|
assert torch.allclose(data[key], expected[key])
|
||||||
|
|
||||||
|
# `torch.cat` in torch1.5 can not concatenate different types so we
|
||||||
|
# fallback to convert them all to float type.
|
||||||
|
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = {
|
||||||
|
'key1': torch.tensor([0, 1], dtype=torch.float32).to(device),
|
||||||
|
'key2': torch.tensor([1, 2], dtype=torch.int32).to(device)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
data = {
|
||||||
|
'key1': torch.tensor([2, 3], dtype=torch.float32).to(device),
|
||||||
|
'key2': torch.tensor([3, 4], dtype=torch.int32).to(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
'key1': torch.tensor([2, 4], dtype=torch.float32).to(device),
|
||||||
|
'key2': torch.tensor([4, 6], dtype=torch.float32).to(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
dist.all_reduce_dict(data, 'sum')
|
||||||
|
|
||||||
|
for key in data:
|
||||||
|
assert torch.allclose(data[key], expected[key])
|
||||||
|
|
||||||
|
|
||||||
|
def _test_all_gather_object_dist(device):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = 'foo'
|
||||||
|
else:
|
||||||
|
data = {1: 2}
|
||||||
|
|
||||||
|
expected = ['foo', {1: 2}]
|
||||||
|
output = dist.all_gather_object(data)
|
||||||
|
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
def _test_gather_object_dist(device):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = 'foo'
|
||||||
|
else:
|
||||||
|
data = {1: 2}
|
||||||
|
|
||||||
|
output = dist.gather_object(data, dst=0)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
assert output == ['foo', {1: 2}]
|
||||||
|
else:
|
||||||
|
assert output is None
|
||||||
|
|
||||||
|
|
||||||
|
def _test_collect_results_dist(device):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = ['foo', {1: 2}]
|
||||||
|
else:
|
||||||
|
data = [24, {'a': 'b'}]
|
||||||
|
|
||||||
|
size = 4
|
||||||
|
|
||||||
|
expected = ['foo', 24, {1: 2}, {'a': 'b'}]
|
||||||
|
|
||||||
|
# test `device=cpu`
|
||||||
|
output = dist.collect_results(data, size, device='cpu')
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
assert output == expected
|
||||||
|
else:
|
||||||
|
assert output is None
|
||||||
|
|
||||||
|
# test `device=cpu` and `tmpdir is not None`
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
# broadcast tmpdir to all ranks to make it consistent
|
||||||
|
object_list = [tmpdir]
|
||||||
|
dist.broadcast_object_list(object_list)
|
||||||
|
output = dist.collect_results(
|
||||||
|
data, size, device='cpu', tmpdir=object_list[0])
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
assert output == expected
|
||||||
|
else:
|
||||||
|
assert output is None
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
# object_list[0] will be removed by `dist.collect_results`
|
||||||
|
assert not osp.exists(object_list[0])
|
||||||
|
|
||||||
|
# test `device=gpu`
|
||||||
|
output = dist.collect_results(data, size, device='gpu')
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
assert output == expected
|
||||||
|
else:
|
||||||
|
assert output is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_distributed_env():
|
||||||
|
_test_all_reduce_non_dist()
|
||||||
|
_test_all_gather_non_dist()
|
||||||
|
_test_gather_non_dist()
|
||||||
|
_test_broadcast_non_dist()
|
||||||
|
_test_sync_random_seed_no_dist()
|
||||||
|
_test_broadcast_object_list_no_dist()
|
||||||
|
_test_all_reduce_dict_no_dist()
|
||||||
|
_test_all_gather_object_no_dist()
|
||||||
|
_test_gather_object_no_dist()
|
||||||
|
_test_collect_results_non_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def test_gloo_backend():
|
||||||
|
functions_to_test = [
|
||||||
|
_test_all_reduce_dist,
|
||||||
|
_test_all_gather_dist,
|
||||||
|
_test_gather_dist,
|
||||||
|
_test_broadcast_dist,
|
||||||
|
_test_sync_random_seed_dist,
|
||||||
|
_test_broadcast_object_list_dist,
|
||||||
|
_test_all_reduce_dict_dist,
|
||||||
|
_test_all_gather_object_dist,
|
||||||
|
_test_gather_object_dist,
|
||||||
|
]
|
||||||
|
main(functions_to_test, backend='gloo')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
||||||
|
def test_nccl_backend():
|
||||||
|
functions_to_test = [
|
||||||
|
_test_all_reduce_dist,
|
||||||
|
_test_all_gather_dist,
|
||||||
|
_test_broadcast_dist,
|
||||||
|
_test_sync_random_seed_dist,
|
||||||
|
_test_broadcast_object_list_dist,
|
||||||
|
_test_all_reduce_dict_dist,
|
||||||
|
_test_all_gather_object_dist,
|
||||||
|
_test_collect_results_dist,
|
||||||
|
]
|
||||||
|
main(functions_to_test, backend='nccl')
|
152
tests/test_dist/test_utils.py
Normal file
152
tests/test_dist/test_utils.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as torch_dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import mmengine.dist as dist
|
||||||
|
|
||||||
|
|
||||||
|
def _test_get_backend_non_dist():
|
||||||
|
assert dist.get_backend() is None
|
||||||
|
|
||||||
|
|
||||||
|
def _test_get_world_size_non_dist():
|
||||||
|
assert dist.get_world_size() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def _test_get_rank_non_dist():
|
||||||
|
assert dist.get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def _test_local_size_non_dist():
|
||||||
|
assert dist.get_local_size() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def _test_local_rank_non_dist():
|
||||||
|
assert dist.get_local_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def _test_get_dist_info_non_dist():
|
||||||
|
assert dist.get_dist_info() == (0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_is_main_process_non_dist():
|
||||||
|
assert dist.is_main_process()
|
||||||
|
|
||||||
|
|
||||||
|
def _test_master_only_non_dist():
|
||||||
|
|
||||||
|
@dist.master_only
|
||||||
|
def fun():
|
||||||
|
assert dist.get_rank() == 0
|
||||||
|
|
||||||
|
fun()
|
||||||
|
|
||||||
|
|
||||||
|
def _test_barrier_non_dist():
|
||||||
|
dist.barrier() # nothing is done
|
||||||
|
|
||||||
|
|
||||||
|
def init_process(rank, world_size, functions, backend='gloo'):
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
|
os.environ['MASTER_PORT'] = '29501'
|
||||||
|
os.environ['RANK'] = str(rank)
|
||||||
|
dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
|
||||||
|
dist.init_local_group(0, world_size)
|
||||||
|
|
||||||
|
for func in functions:
|
||||||
|
func()
|
||||||
|
|
||||||
|
|
||||||
|
def main(functions, world_size=2, backend='gloo'):
|
||||||
|
try:
|
||||||
|
mp.spawn(
|
||||||
|
init_process,
|
||||||
|
args=(world_size, functions, backend),
|
||||||
|
nprocs=world_size)
|
||||||
|
except Exception:
|
||||||
|
pytest.fail('error')
|
||||||
|
|
||||||
|
|
||||||
|
def _test_get_backend_dist():
|
||||||
|
assert dist.get_backend() == torch_dist.get_backend()
|
||||||
|
|
||||||
|
|
||||||
|
def _test_get_world_size_dist():
|
||||||
|
assert dist.get_world_size() == 2
|
||||||
|
|
||||||
|
|
||||||
|
def _test_get_rank_dist():
|
||||||
|
if torch_dist.get_rank() == 0:
|
||||||
|
assert dist.get_rank() == 0
|
||||||
|
else:
|
||||||
|
assert dist.get_rank() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def _test_local_size_dist():
|
||||||
|
assert dist.get_local_size() == 2
|
||||||
|
|
||||||
|
|
||||||
|
def _test_local_rank_dist():
|
||||||
|
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def _test_get_dist_info_dist():
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
assert dist.get_dist_info() == (0, 2)
|
||||||
|
else:
|
||||||
|
assert dist.get_dist_info() == (1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_is_main_process_dist():
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
assert dist.is_main_process()
|
||||||
|
else:
|
||||||
|
assert not dist.is_main_process()
|
||||||
|
|
||||||
|
|
||||||
|
def _test_master_only_dist():
|
||||||
|
|
||||||
|
@dist.master_only
|
||||||
|
def fun():
|
||||||
|
assert dist.get_rank() == 0
|
||||||
|
|
||||||
|
fun()
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_distributed_env():
|
||||||
|
_test_get_backend_non_dist()
|
||||||
|
_test_get_world_size_non_dist()
|
||||||
|
_test_get_rank_non_dist()
|
||||||
|
_test_local_size_non_dist()
|
||||||
|
_test_local_rank_non_dist()
|
||||||
|
_test_get_dist_info_non_dist()
|
||||||
|
_test_is_main_process_non_dist()
|
||||||
|
_test_master_only_non_dist()
|
||||||
|
_test_barrier_non_dist()
|
||||||
|
|
||||||
|
|
||||||
|
functions_to_test = [
|
||||||
|
_test_get_backend_dist,
|
||||||
|
_test_get_world_size_dist,
|
||||||
|
_test_get_rank_dist,
|
||||||
|
_test_local_size_dist,
|
||||||
|
_test_local_rank_dist,
|
||||||
|
_test_get_dist_info_dist,
|
||||||
|
_test_is_main_process_dist,
|
||||||
|
_test_master_only_dist,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_gloo_backend():
|
||||||
|
main(functions_to_test)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
||||||
|
def test_nccl_backend():
|
||||||
|
main(functions_to_test, backend='nccl')
|
Loading…
x
Reference in New Issue
Block a user