[Feature] Add distributed module (#59)
* [Feature] Add distributed module * fix IS_DIST error * all_reduce_dict does operations in-place * support 'mean' operation * provide local group process * add tmpdir argument for collect_results * add unit tests * refactor unit tests * simplify steps to create multiple processes * minor fix * describe the different of *gather* in mmengine and pytorch * minor fix * add unit tests for nccl * test nccl backend in multiple gpu * add get_default_group function to handle different torch versions * minor fix * [Feature] Add distributed module * fix IS_DIST error * all_reduce_dict does operations in-place * support 'mean' operation * provide local group process * add tmpdir argument for collect_results * add unit tests * refactor unit tests * simplify steps to create multiple processes * minor fix * describe the different of *gather* in mmengine and pytorch * minor fix * add unit tests for nccl * test nccl backend in multiple gpu * add get_default_group function to handle different torch versions * minor fix * minor fix * handle torch1.5 * handle torch1.5 * minor fix * fix typo * refactor unit tests * nccl does not support gather and gather_object * fix gather * fix collect_results_cpu * fix collect_results and refactor unit tests * fix collect_results unit tests * handle torch.cat in torch1.5 * refine docstring * refine docstring * fix comments * fix commentspull/91/head^2
parent
817eb89ac2
commit
c6a8d72c5e
|
@ -10,7 +10,6 @@ __pycache__/
|
|||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
|
|
|
@ -7,3 +7,8 @@ Data
|
|||
--------
|
||||
.. automodule:: mmengine.data
|
||||
:members:
|
||||
|
||||
Distributed
|
||||
-----------
|
||||
.. automodule:: mmengine.dist
|
||||
:members:
|
||||
|
|
|
@ -7,3 +7,8 @@ Data
|
|||
--------
|
||||
.. automodule:: mmengine.data
|
||||
:members:
|
||||
|
||||
Distributed
|
||||
-----------
|
||||
.. automodule:: mmengine.dist
|
||||
:members:
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dist import (all_gather_object, all_reduce, all_gather, all_reduce_dict,
|
||||
collect_results, gather, broadcast, gather_object,
|
||||
sync_random_seed, broadcast_object_list,
|
||||
collect_results_cpu, collect_results_gpu)
|
||||
from .utils import (get_dist_info, init_dist, init_local_group, get_backend,
|
||||
get_world_size, get_rank, get_local_size, get_local_rank,
|
||||
is_main_process, master_only, barrier, get_local_group,
|
||||
is_distributed, get_default_group)
|
||||
|
||||
__all__ = [
|
||||
'all_gather_object', 'all_reduce', 'all_gather', 'all_reduce_dict',
|
||||
'collect_results', 'collect_results_cpu', 'collect_results_gpu', 'gather',
|
||||
'broadcast', 'gather_object', 'sync_random_seed', 'broadcast_object_list',
|
||||
'get_dist_info', 'init_dist', 'init_local_group', 'get_backend',
|
||||
'get_world_size', 'get_rank', 'get_local_size', 'get_local_group',
|
||||
'get_local_rank', 'is_main_process', 'master_only', 'barrier',
|
||||
'is_distributed', 'get_default_group'
|
||||
]
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,335 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch import distributed as dist
|
||||
|
||||
_LOCAL_PROCESS_GROUP = None
|
||||
|
||||
|
||||
def is_distributed() -> bool:
|
||||
"""Return True if distributed environment has been initialized."""
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
|
||||
def get_local_group() -> Optional[dist.ProcessGroup]:
|
||||
"""Return local process group."""
|
||||
if not is_distributed():
|
||||
return None
|
||||
|
||||
if _LOCAL_PROCESS_GROUP is None:
|
||||
raise RuntimeError('Local process group is not created, please use '
|
||||
'`init_local_group` to setup local process group.')
|
||||
|
||||
return _LOCAL_PROCESS_GROUP
|
||||
|
||||
|
||||
def get_default_group() -> Optional[dist.ProcessGroup]:
|
||||
"""Return default process group."""
|
||||
|
||||
return dist.distributed_c10d._get_default_group()
|
||||
|
||||
|
||||
def init_dist(launcher, backend='nccl', **kwargs) -> None:
|
||||
"""Initialize distributed environment.
|
||||
|
||||
Args:
|
||||
launcher (str): Way to launcher multi processes. Supported launchers
|
||||
are 'pytorch', 'mpi' and 'slurm'.
|
||||
backend (str): Communication Backends. Supported backends are 'nccl',
|
||||
'gloo' and 'mpi'. Defaults to 'nccl'.
|
||||
**kwargs: keyword arguments are passed to ``init_process_group``.
|
||||
"""
|
||||
if mp.get_start_method(allow_none=True) is None:
|
||||
mp.set_start_method('spawn')
|
||||
if launcher == 'pytorch':
|
||||
_init_dist_pytorch(backend, **kwargs)
|
||||
elif launcher == 'mpi':
|
||||
_init_dist_mpi(backend, **kwargs)
|
||||
elif launcher == 'slurm':
|
||||
_init_dist_slurm(backend, **kwargs)
|
||||
else:
|
||||
raise ValueError(f'Invalid launcher type: {launcher}')
|
||||
|
||||
|
||||
def _init_dist_pytorch(backend, **kwargs) -> None:
|
||||
"""Initialize distributed environment with PyTorch launcher.
|
||||
|
||||
Args:
|
||||
backend (str): Backend of torch.distributed. Supported backends are
|
||||
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
|
||||
**kwargs: keyword arguments are passed to ``init_process_group``.
|
||||
"""
|
||||
# TODO: use local_rank instead of rank % num_gpus
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
|
||||
def _init_dist_mpi(backend, **kwargs) -> None:
|
||||
"""Initialize distributed environment with MPI launcher.
|
||||
|
||||
Args:
|
||||
backend (str): Backend of torch.distributed. Supported backends are
|
||||
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
|
||||
**kwargs: keyword arguments are passed to ``init_process_group``.
|
||||
"""
|
||||
# TODO: use local_rank instead of rank % num_gpus
|
||||
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
|
||||
def _init_dist_slurm(backend, port=None) -> None:
|
||||
"""Initialize slurm distributed training environment.
|
||||
|
||||
If argument ``port`` is not specified, then the master port will be system
|
||||
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
||||
environment variable, then a default port ``29500`` will be used.
|
||||
|
||||
Args:
|
||||
backend (str): Backend of torch.distributed.
|
||||
port (int, optional): Master port. Defaults to None.
|
||||
|
||||
TODO: https://github.com/open-mmlab/mmcv/pull/1682
|
||||
"""
|
||||
proc_id = int(os.environ['SLURM_PROCID'])
|
||||
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||
node_list = os.environ['SLURM_NODELIST']
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(proc_id % num_gpus)
|
||||
addr = subprocess.getoutput(
|
||||
f'scontrol show hostname {node_list} | head -n1')
|
||||
# specify master port
|
||||
if port is not None:
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
elif 'MASTER_PORT' in os.environ:
|
||||
pass # use MASTER_PORT in the environment variable
|
||||
else:
|
||||
# 29500 is torch.distributed default port
|
||||
os.environ['MASTER_PORT'] = '29500'
|
||||
# use MASTER_ADDR in the environment variable if it already exists
|
||||
if 'MASTER_ADDR' not in os.environ:
|
||||
os.environ['MASTER_ADDR'] = addr
|
||||
os.environ['WORLD_SIZE'] = str(ntasks)
|
||||
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
||||
os.environ['RANK'] = str(proc_id)
|
||||
dist.init_process_group(backend=backend)
|
||||
|
||||
|
||||
def init_local_group(node_rank: int, num_gpus_per_node: int):
|
||||
"""Setup the local process group.
|
||||
|
||||
Setup a process group which only includes processes that on the same
|
||||
machine as the current process.
|
||||
|
||||
The code is modified from
|
||||
https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py
|
||||
|
||||
Args:
|
||||
node_rank (int): Rank of machines used for training.
|
||||
num_gpus_per_node (int): Number of gpus used for training in a single
|
||||
machine.
|
||||
""" # noqa: W501
|
||||
global _LOCAL_PROCESS_GROUP
|
||||
assert _LOCAL_PROCESS_GROUP is None
|
||||
|
||||
ranks = list(
|
||||
range(node_rank * num_gpus_per_node,
|
||||
(node_rank + 1) * num_gpus_per_node))
|
||||
_LOCAL_PROCESS_GROUP = dist.new_group(ranks)
|
||||
|
||||
|
||||
def get_backend(group: Optional[dist.ProcessGroup] = None) -> Optional[str]:
|
||||
"""Return the backend of the given process group.
|
||||
|
||||
Note:
|
||||
Calling ``get_backend`` in non-distributed environment will return
|
||||
None.
|
||||
|
||||
Args:
|
||||
group (ProcessGroup, optional): The process group to work on. The
|
||||
default is the general main process group. If another specific
|
||||
group is specified, the calling process must be part of
|
||||
:attr:`group`. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str or None: Return the backend of the given process group as a lower
|
||||
case string if in distributed environment, otherwise None.
|
||||
"""
|
||||
if is_distributed():
|
||||
# handle low versions of torch like 1.5.0 which does not support
|
||||
# passing in None for group argument
|
||||
if group is None:
|
||||
group = get_default_group()
|
||||
return dist.get_backend(group)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
|
||||
"""Return the number of the given process group.
|
||||
|
||||
Note:
|
||||
Calling ``get_world_size`` in non-distributed environment will return
|
||||
1.
|
||||
|
||||
Args:
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Defaults to None.
|
||||
|
||||
Returns:
|
||||
int: Return the number of processes of the given process group if in
|
||||
distributed environment, otherwise 1.
|
||||
"""
|
||||
if is_distributed():
|
||||
# handle low versions of torch like 1.5.0 which does not support
|
||||
# passing in None for group argument
|
||||
if group is None:
|
||||
group = get_default_group()
|
||||
return dist.get_world_size(group)
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
|
||||
"""Return the rank of the given process group.
|
||||
|
||||
Rank is a unique identifier assigned to each process within a distributed
|
||||
process group. They are always consecutive integers ranging from 0 to
|
||||
``world_size``.
|
||||
|
||||
Note:
|
||||
Calling ``get_rank`` in non-distributed environment will return 0.
|
||||
|
||||
Args:
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Defaults to None.
|
||||
|
||||
Returns:
|
||||
int: Return the rank of the process group if in distributed
|
||||
environment, otherwise 0.
|
||||
"""
|
||||
|
||||
if is_distributed():
|
||||
# handle low versions of torch like 1.5.0 which does not support
|
||||
# passing in None for group argument
|
||||
if group is None:
|
||||
group = get_default_group()
|
||||
return dist.get_rank(group)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_local_size() -> int:
|
||||
"""Return the number of the current node.
|
||||
|
||||
Returns:
|
||||
int: Return the number of processes in the current node if in
|
||||
distributed environment, otherwise 1.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return 1
|
||||
|
||||
if _LOCAL_PROCESS_GROUP is None:
|
||||
raise RuntimeError('Local process group is not created, please use '
|
||||
'`init_local_group` to setup local process group.')
|
||||
|
||||
return dist.get_world_size(_LOCAL_PROCESS_GROUP)
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""Return the rank of current process in the current node.
|
||||
|
||||
Returns:
|
||||
int: Return the rank of current process in the current node if in
|
||||
distributed environment, otherwise 0
|
||||
"""
|
||||
if not is_distributed():
|
||||
return 0
|
||||
|
||||
if _LOCAL_PROCESS_GROUP is None:
|
||||
raise RuntimeError('Local process group is not created, please use '
|
||||
'`init_local_group` to setup local process group.')
|
||||
|
||||
return dist.get_rank(_LOCAL_PROCESS_GROUP)
|
||||
|
||||
|
||||
def get_dist_info(
|
||||
group: Optional[dist.ProcessGroup] = None) -> Tuple[int, int]:
|
||||
"""Get distributed information of the given process group.
|
||||
|
||||
Note:
|
||||
Calling ``get_dist_info`` in non-distributed environment will return
|
||||
(0, 1).
|
||||
|
||||
Args:
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: Return a tuple containing the ``world_size`` and
|
||||
``rank``.
|
||||
"""
|
||||
world_size = get_world_size(group)
|
||||
rank = get_rank(group)
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def is_main_process(group: Optional[dist.ProcessGroup] = None) -> bool:
|
||||
"""Whether the current rank of the given process group is equal to 0.
|
||||
|
||||
Args:
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Defaults to None.
|
||||
|
||||
Returns:
|
||||
bool: Return True if the current rank of the given process group is
|
||||
equal to 0, otherwise False.
|
||||
"""
|
||||
return get_rank(group) == 0
|
||||
|
||||
|
||||
def master_only(func: Callable) -> Callable:
|
||||
"""Decorate those methods which should be executed in master process.
|
||||
|
||||
Args:
|
||||
func (callable): Function to be decorated.
|
||||
|
||||
Returns:
|
||||
callable: Return decorated function.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_main_process():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def barrier(group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
"""Synchronize all processes from the given process group.
|
||||
|
||||
This collective blocks processes until the whole group enters this
|
||||
function.
|
||||
|
||||
Note:
|
||||
Calling ``barrier`` in non-distributed environment will do nothing.
|
||||
|
||||
Args:
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Defaults to None.
|
||||
"""
|
||||
if is_distributed():
|
||||
# handle low versions of torch like 1.5.0 which does not support
|
||||
# passing in None for group argument
|
||||
if group is None:
|
||||
group = get_default_group()
|
||||
dist.barrier(group)
|
|
@ -0,0 +1,376 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import mmengine.dist as dist
|
||||
from mmengine.dist.dist import sync_random_seed
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
|
||||
|
||||
def _test_all_reduce_non_dist():
|
||||
data = torch.arange(2, dtype=torch.int64)
|
||||
expected = torch.arange(2, dtype=torch.int64)
|
||||
dist.all_reduce(data)
|
||||
assert torch.allclose(data, expected)
|
||||
|
||||
|
||||
def _test_all_gather_non_dist():
|
||||
data = torch.arange(2, dtype=torch.int64)
|
||||
expected = torch.arange(2, dtype=torch.int64)
|
||||
output = dist.all_gather(data)
|
||||
assert torch.allclose(output[0], expected)
|
||||
|
||||
|
||||
def _test_gather_non_dist():
|
||||
data = torch.arange(2, dtype=torch.int64)
|
||||
expected = torch.arange(2, dtype=torch.int64)
|
||||
output = dist.gather(data)
|
||||
assert torch.allclose(output[0], expected)
|
||||
|
||||
|
||||
def _test_broadcast_non_dist():
|
||||
data = torch.arange(2, dtype=torch.int64)
|
||||
expected = torch.arange(2, dtype=torch.int64)
|
||||
dist.broadcast(data)
|
||||
assert torch.allclose(data, expected)
|
||||
|
||||
|
||||
@patch('numpy.random.randint', return_value=10)
|
||||
def _test_sync_random_seed_no_dist(mock):
|
||||
assert sync_random_seed() == 10
|
||||
|
||||
|
||||
def _test_broadcast_object_list_no_dist():
|
||||
with pytest.raises(AssertionError):
|
||||
# input should be list of object
|
||||
dist.broadcast_object_list('foo')
|
||||
|
||||
data = ['foo', 12, {1: 2}]
|
||||
expected = ['foo', 12, {1: 2}]
|
||||
dist.broadcast_object_list(data)
|
||||
assert data == expected
|
||||
|
||||
|
||||
def _test_all_reduce_dict_no_dist():
|
||||
with pytest.raises(AssertionError):
|
||||
# input should be dict
|
||||
dist.all_reduce_dict('foo')
|
||||
|
||||
data = {
|
||||
'key1': torch.arange(2, dtype=torch.int64),
|
||||
'key2': torch.arange(3, dtype=torch.int64)
|
||||
}
|
||||
expected = {
|
||||
'key1': torch.arange(2, dtype=torch.int64),
|
||||
'key2': torch.arange(3, dtype=torch.int64)
|
||||
}
|
||||
dist.all_reduce_dict(data)
|
||||
for key in data:
|
||||
assert torch.allclose(data[key], expected[key])
|
||||
|
||||
|
||||
def _test_all_gather_object_no_dist():
|
||||
data = 'foo'
|
||||
expected = 'foo'
|
||||
gather_objects = dist.all_gather_object(data)
|
||||
assert gather_objects[0] == expected
|
||||
|
||||
|
||||
def _test_gather_object_no_dist():
|
||||
data = 'foo'
|
||||
expected = 'foo'
|
||||
gather_objects = dist.gather_object(data)
|
||||
assert gather_objects[0] == expected
|
||||
|
||||
|
||||
def _test_collect_results_non_dist():
|
||||
data = ['foo', {1: 2}]
|
||||
size = 2
|
||||
expected = ['foo', {1: 2}]
|
||||
|
||||
# test `device=cpu`
|
||||
output = dist.collect_results(data, size, device='cpu')
|
||||
assert output == expected
|
||||
|
||||
# test `device=gpu`
|
||||
output = dist.collect_results(data, size, device='cpu')
|
||||
assert output == expected
|
||||
|
||||
|
||||
def init_process(rank, world_size, functions, backend='gloo'):
|
||||
"""Initialize the distributed environment."""
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '29505'
|
||||
os.environ['RANK'] = str(rank)
|
||||
dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
|
||||
|
||||
device = 'cpu' if backend == 'gloo' else 'cuda'
|
||||
|
||||
for func in functions:
|
||||
func(device)
|
||||
|
||||
|
||||
def main(functions, world_size=2, backend='gloo'):
|
||||
try:
|
||||
mp.spawn(
|
||||
init_process,
|
||||
args=(world_size, functions, backend),
|
||||
nprocs=world_size)
|
||||
except Exception:
|
||||
pytest.fail(f'{backend} failed')
|
||||
|
||||
|
||||
def _test_all_reduce_dist(device):
|
||||
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
||||
['sum', 'mean']):
|
||||
if dist.get_rank() == 0:
|
||||
data = torch.tensor([1, 2], dtype=tensor_type).to(device)
|
||||
else:
|
||||
data = torch.tensor([3, 4], dtype=tensor_type).to(device)
|
||||
|
||||
if reduce_op == 'sum':
|
||||
expected = torch.tensor([4, 6], dtype=tensor_type).to(device)
|
||||
else:
|
||||
expected = torch.tensor([2, 3], dtype=tensor_type).to(device)
|
||||
|
||||
dist.all_reduce(data, reduce_op)
|
||||
assert torch.allclose(data, expected)
|
||||
|
||||
|
||||
def _test_all_gather_dist(device):
|
||||
if dist.get_rank() == 0:
|
||||
data = torch.tensor([0, 1]).to(device)
|
||||
else:
|
||||
data = torch.tensor([1, 2]).to(device)
|
||||
|
||||
expected = [
|
||||
torch.tensor([0, 1]).to(device),
|
||||
torch.tensor([1, 2]).to(device)
|
||||
]
|
||||
|
||||
output = dist.all_gather(data)
|
||||
assert torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])
|
||||
|
||||
|
||||
def _test_gather_dist(device):
|
||||
if dist.get_rank() == 0:
|
||||
data = torch.tensor([0, 1]).to(device)
|
||||
else:
|
||||
data = torch.tensor([1, 2]).to(device)
|
||||
|
||||
output = dist.gather(data)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
expected = [
|
||||
torch.tensor([0, 1]).to(device),
|
||||
torch.tensor([1, 2]).to(device)
|
||||
]
|
||||
for i in range(2):
|
||||
assert torch.allclose(output[i], expected[i])
|
||||
else:
|
||||
assert output == []
|
||||
|
||||
|
||||
def _test_broadcast_dist(device):
|
||||
if dist.get_rank() == 0:
|
||||
data = torch.tensor([0, 1]).to(device)
|
||||
else:
|
||||
data = torch.tensor([1, 2]).to(device)
|
||||
|
||||
expected = torch.tensor([0, 1]).to(device)
|
||||
dist.broadcast(data, 0)
|
||||
assert torch.allclose(data, expected)
|
||||
|
||||
|
||||
def _test_sync_random_seed_dist(device):
|
||||
with patch.object(
|
||||
torch, 'tensor',
|
||||
return_value=torch.tensor(1024).to(device)) as mock_tensor:
|
||||
output = dist.sync_random_seed()
|
||||
assert output == 1024
|
||||
mock_tensor.assert_called()
|
||||
|
||||
|
||||
def _test_broadcast_object_list_dist(device):
|
||||
if dist.get_rank() == 0:
|
||||
data = ['foo', 12, {1: 2}]
|
||||
else:
|
||||
data = [None, None, None]
|
||||
|
||||
expected = ['foo', 12, {1: 2}]
|
||||
|
||||
dist.broadcast_object_list(data)
|
||||
|
||||
assert data == expected
|
||||
|
||||
|
||||
def _test_all_reduce_dict_dist(device):
|
||||
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
||||
['sum', 'mean']):
|
||||
if dist.get_rank() == 0:
|
||||
data = {
|
||||
'key1': torch.tensor([0, 1], dtype=tensor_type).to(device),
|
||||
'key2': torch.tensor([1, 2], dtype=tensor_type).to(device)
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
'key1': torch.tensor([2, 3], dtype=tensor_type).to(device),
|
||||
'key2': torch.tensor([3, 4], dtype=tensor_type).to(device)
|
||||
}
|
||||
|
||||
if reduce_op == 'sum':
|
||||
expected = {
|
||||
'key1': torch.tensor([2, 4], dtype=tensor_type).to(device),
|
||||
'key2': torch.tensor([4, 6], dtype=tensor_type).to(device)
|
||||
}
|
||||
else:
|
||||
expected = {
|
||||
'key1': torch.tensor([1, 2], dtype=tensor_type).to(device),
|
||||
'key2': torch.tensor([2, 3], dtype=tensor_type).to(device)
|
||||
}
|
||||
|
||||
dist.all_reduce_dict(data, reduce_op)
|
||||
|
||||
for key in data:
|
||||
assert torch.allclose(data[key], expected[key])
|
||||
|
||||
# `torch.cat` in torch1.5 can not concatenate different types so we
|
||||
# fallback to convert them all to float type.
|
||||
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
|
||||
if dist.get_rank() == 0:
|
||||
data = {
|
||||
'key1': torch.tensor([0, 1], dtype=torch.float32).to(device),
|
||||
'key2': torch.tensor([1, 2], dtype=torch.int32).to(device)
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
'key1': torch.tensor([2, 3], dtype=torch.float32).to(device),
|
||||
'key2': torch.tensor([3, 4], dtype=torch.int32).to(device)
|
||||
}
|
||||
|
||||
expected = {
|
||||
'key1': torch.tensor([2, 4], dtype=torch.float32).to(device),
|
||||
'key2': torch.tensor([4, 6], dtype=torch.float32).to(device)
|
||||
}
|
||||
|
||||
dist.all_reduce_dict(data, 'sum')
|
||||
|
||||
for key in data:
|
||||
assert torch.allclose(data[key], expected[key])
|
||||
|
||||
|
||||
def _test_all_gather_object_dist(device):
|
||||
if dist.get_rank() == 0:
|
||||
data = 'foo'
|
||||
else:
|
||||
data = {1: 2}
|
||||
|
||||
expected = ['foo', {1: 2}]
|
||||
output = dist.all_gather_object(data)
|
||||
|
||||
assert output == expected
|
||||
|
||||
|
||||
def _test_gather_object_dist(device):
|
||||
if dist.get_rank() == 0:
|
||||
data = 'foo'
|
||||
else:
|
||||
data = {1: 2}
|
||||
|
||||
output = dist.gather_object(data, dst=0)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
assert output == ['foo', {1: 2}]
|
||||
else:
|
||||
assert output is None
|
||||
|
||||
|
||||
def _test_collect_results_dist(device):
|
||||
if dist.get_rank() == 0:
|
||||
data = ['foo', {1: 2}]
|
||||
else:
|
||||
data = [24, {'a': 'b'}]
|
||||
|
||||
size = 4
|
||||
|
||||
expected = ['foo', 24, {1: 2}, {'a': 'b'}]
|
||||
|
||||
# test `device=cpu`
|
||||
output = dist.collect_results(data, size, device='cpu')
|
||||
if dist.get_rank() == 0:
|
||||
assert output == expected
|
||||
else:
|
||||
assert output is None
|
||||
|
||||
# test `device=cpu` and `tmpdir is not None`
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
# broadcast tmpdir to all ranks to make it consistent
|
||||
object_list = [tmpdir]
|
||||
dist.broadcast_object_list(object_list)
|
||||
output = dist.collect_results(
|
||||
data, size, device='cpu', tmpdir=object_list[0])
|
||||
if dist.get_rank() == 0:
|
||||
assert output == expected
|
||||
else:
|
||||
assert output is None
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
# object_list[0] will be removed by `dist.collect_results`
|
||||
assert not osp.exists(object_list[0])
|
||||
|
||||
# test `device=gpu`
|
||||
output = dist.collect_results(data, size, device='gpu')
|
||||
if dist.get_rank() == 0:
|
||||
assert output == expected
|
||||
else:
|
||||
assert output is None
|
||||
|
||||
|
||||
def test_non_distributed_env():
|
||||
_test_all_reduce_non_dist()
|
||||
_test_all_gather_non_dist()
|
||||
_test_gather_non_dist()
|
||||
_test_broadcast_non_dist()
|
||||
_test_sync_random_seed_no_dist()
|
||||
_test_broadcast_object_list_no_dist()
|
||||
_test_all_reduce_dict_no_dist()
|
||||
_test_all_gather_object_no_dist()
|
||||
_test_gather_object_no_dist()
|
||||
_test_collect_results_non_dist()
|
||||
|
||||
|
||||
def test_gloo_backend():
|
||||
functions_to_test = [
|
||||
_test_all_reduce_dist,
|
||||
_test_all_gather_dist,
|
||||
_test_gather_dist,
|
||||
_test_broadcast_dist,
|
||||
_test_sync_random_seed_dist,
|
||||
_test_broadcast_object_list_dist,
|
||||
_test_all_reduce_dict_dist,
|
||||
_test_all_gather_object_dist,
|
||||
_test_gather_object_dist,
|
||||
]
|
||||
main(functions_to_test, backend='gloo')
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
||||
def test_nccl_backend():
|
||||
functions_to_test = [
|
||||
_test_all_reduce_dist,
|
||||
_test_all_gather_dist,
|
||||
_test_broadcast_dist,
|
||||
_test_sync_random_seed_dist,
|
||||
_test_broadcast_object_list_dist,
|
||||
_test_all_reduce_dict_dist,
|
||||
_test_all_gather_object_dist,
|
||||
_test_collect_results_dist,
|
||||
]
|
||||
main(functions_to_test, backend='nccl')
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as torch_dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import mmengine.dist as dist
|
||||
|
||||
|
||||
def _test_get_backend_non_dist():
|
||||
assert dist.get_backend() is None
|
||||
|
||||
|
||||
def _test_get_world_size_non_dist():
|
||||
assert dist.get_world_size() == 1
|
||||
|
||||
|
||||
def _test_get_rank_non_dist():
|
||||
assert dist.get_rank() == 0
|
||||
|
||||
|
||||
def _test_local_size_non_dist():
|
||||
assert dist.get_local_size() == 1
|
||||
|
||||
|
||||
def _test_local_rank_non_dist():
|
||||
assert dist.get_local_rank() == 0
|
||||
|
||||
|
||||
def _test_get_dist_info_non_dist():
|
||||
assert dist.get_dist_info() == (0, 1)
|
||||
|
||||
|
||||
def _test_is_main_process_non_dist():
|
||||
assert dist.is_main_process()
|
||||
|
||||
|
||||
def _test_master_only_non_dist():
|
||||
|
||||
@dist.master_only
|
||||
def fun():
|
||||
assert dist.get_rank() == 0
|
||||
|
||||
fun()
|
||||
|
||||
|
||||
def _test_barrier_non_dist():
|
||||
dist.barrier() # nothing is done
|
||||
|
||||
|
||||
def init_process(rank, world_size, functions, backend='gloo'):
|
||||
"""Initialize the distributed environment."""
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '29501'
|
||||
os.environ['RANK'] = str(rank)
|
||||
dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
|
||||
dist.init_local_group(0, world_size)
|
||||
|
||||
for func in functions:
|
||||
func()
|
||||
|
||||
|
||||
def main(functions, world_size=2, backend='gloo'):
|
||||
try:
|
||||
mp.spawn(
|
||||
init_process,
|
||||
args=(world_size, functions, backend),
|
||||
nprocs=world_size)
|
||||
except Exception:
|
||||
pytest.fail('error')
|
||||
|
||||
|
||||
def _test_get_backend_dist():
|
||||
assert dist.get_backend() == torch_dist.get_backend()
|
||||
|
||||
|
||||
def _test_get_world_size_dist():
|
||||
assert dist.get_world_size() == 2
|
||||
|
||||
|
||||
def _test_get_rank_dist():
|
||||
if torch_dist.get_rank() == 0:
|
||||
assert dist.get_rank() == 0
|
||||
else:
|
||||
assert dist.get_rank() == 1
|
||||
|
||||
|
||||
def _test_local_size_dist():
|
||||
assert dist.get_local_size() == 2
|
||||
|
||||
|
||||
def _test_local_rank_dist():
|
||||
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank()
|
||||
|
||||
|
||||
def _test_get_dist_info_dist():
|
||||
if dist.get_rank() == 0:
|
||||
assert dist.get_dist_info() == (0, 2)
|
||||
else:
|
||||
assert dist.get_dist_info() == (1, 2)
|
||||
|
||||
|
||||
def _test_is_main_process_dist():
|
||||
if dist.get_rank() == 0:
|
||||
assert dist.is_main_process()
|
||||
else:
|
||||
assert not dist.is_main_process()
|
||||
|
||||
|
||||
def _test_master_only_dist():
|
||||
|
||||
@dist.master_only
|
||||
def fun():
|
||||
assert dist.get_rank() == 0
|
||||
|
||||
fun()
|
||||
|
||||
|
||||
def test_non_distributed_env():
|
||||
_test_get_backend_non_dist()
|
||||
_test_get_world_size_non_dist()
|
||||
_test_get_rank_non_dist()
|
||||
_test_local_size_non_dist()
|
||||
_test_local_rank_non_dist()
|
||||
_test_get_dist_info_non_dist()
|
||||
_test_is_main_process_non_dist()
|
||||
_test_master_only_non_dist()
|
||||
_test_barrier_non_dist()
|
||||
|
||||
|
||||
functions_to_test = [
|
||||
_test_get_backend_dist,
|
||||
_test_get_world_size_dist,
|
||||
_test_get_rank_dist,
|
||||
_test_local_size_dist,
|
||||
_test_local_rank_dist,
|
||||
_test_get_dist_info_dist,
|
||||
_test_is_main_process_dist,
|
||||
_test_master_only_dist,
|
||||
]
|
||||
|
||||
|
||||
def test_gloo_backend():
|
||||
main(functions_to_test)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
||||
def test_nccl_backend():
|
||||
main(functions_to_test, backend='nccl')
|
Loading…
Reference in New Issue