159 lines
3.3 KiB
Python
159 lines
3.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as torch_dist
|
|
import torch.multiprocessing as mp
|
|
|
|
import mmengine.dist as dist
|
|
|
|
|
|
def _test_get_backend_non_dist():
|
|
assert dist.get_backend() is None
|
|
|
|
|
|
def _test_get_world_size_non_dist():
|
|
assert dist.get_world_size() == 1
|
|
|
|
|
|
def _test_get_rank_non_dist():
|
|
assert dist.get_rank() == 0
|
|
|
|
|
|
def _test_local_size_non_dist():
|
|
assert dist.get_local_size() == 1
|
|
|
|
|
|
def _test_local_rank_non_dist():
|
|
assert dist.get_local_rank() == 0
|
|
|
|
|
|
def _test_get_dist_info_non_dist():
|
|
assert dist.get_dist_info() == (0, 1)
|
|
|
|
|
|
def _test_is_main_process_non_dist():
|
|
assert dist.is_main_process()
|
|
|
|
|
|
def _test_master_only_non_dist():
|
|
|
|
@dist.master_only
|
|
def fun():
|
|
assert dist.get_rank() == 0
|
|
|
|
fun()
|
|
|
|
|
|
def _test_barrier_non_dist():
|
|
dist.barrier() # nothing is done
|
|
|
|
|
|
def init_process(rank, world_size, functions, backend='gloo'):
|
|
"""Initialize the distributed environment."""
|
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
|
os.environ['MASTER_PORT'] = '29501'
|
|
os.environ['RANK'] = str(rank)
|
|
|
|
if backend == 'nccl':
|
|
num_gpus = torch.cuda.device_count()
|
|
torch.cuda.set_device(rank % num_gpus)
|
|
|
|
torch_dist.init_process_group(
|
|
backend=backend, rank=rank, world_size=world_size)
|
|
dist.init_local_group(0, world_size)
|
|
|
|
for func in functions:
|
|
func()
|
|
|
|
|
|
def main(functions, world_size=2, backend='gloo'):
|
|
try:
|
|
mp.spawn(
|
|
init_process,
|
|
args=(world_size, functions, backend),
|
|
nprocs=world_size)
|
|
except Exception:
|
|
pytest.fail('error')
|
|
|
|
|
|
def _test_get_backend_dist():
|
|
assert dist.get_backend() == torch_dist.get_backend()
|
|
|
|
|
|
def _test_get_world_size_dist():
|
|
assert dist.get_world_size() == 2
|
|
|
|
|
|
def _test_get_rank_dist():
|
|
if torch_dist.get_rank() == 0:
|
|
assert dist.get_rank() == 0
|
|
else:
|
|
assert dist.get_rank() == 1
|
|
|
|
|
|
def _test_local_size_dist():
|
|
assert dist.get_local_size() == 2
|
|
|
|
|
|
def _test_local_rank_dist():
|
|
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank()
|
|
|
|
|
|
def _test_get_dist_info_dist():
|
|
if dist.get_rank() == 0:
|
|
assert dist.get_dist_info() == (0, 2)
|
|
else:
|
|
assert dist.get_dist_info() == (1, 2)
|
|
|
|
|
|
def _test_is_main_process_dist():
|
|
if dist.get_rank() == 0:
|
|
assert dist.is_main_process()
|
|
else:
|
|
assert not dist.is_main_process()
|
|
|
|
|
|
def _test_master_only_dist():
|
|
|
|
@dist.master_only
|
|
def fun():
|
|
assert dist.get_rank() == 0
|
|
|
|
fun()
|
|
|
|
|
|
def test_non_distributed_env():
|
|
_test_get_backend_non_dist()
|
|
_test_get_world_size_non_dist()
|
|
_test_get_rank_non_dist()
|
|
_test_local_size_non_dist()
|
|
_test_local_rank_non_dist()
|
|
_test_get_dist_info_non_dist()
|
|
_test_is_main_process_non_dist()
|
|
_test_master_only_non_dist()
|
|
_test_barrier_non_dist()
|
|
|
|
|
|
functions_to_test = [
|
|
_test_get_backend_dist,
|
|
_test_get_world_size_dist,
|
|
_test_get_rank_dist,
|
|
_test_local_size_dist,
|
|
_test_local_rank_dist,
|
|
_test_get_dist_info_dist,
|
|
_test_is_main_process_dist,
|
|
_test_master_only_dist,
|
|
]
|
|
|
|
|
|
def test_gloo_backend():
|
|
main(functions_to_test)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
|
def test_nccl_backend():
|
|
main(functions_to_test, backend='nccl')
|