mmengine/tests/test_dist/test_utils.py

159 lines
3.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import pytest
import torch
import torch.distributed as torch_dist
import torch.multiprocessing as mp
import mmengine.dist as dist
def _test_get_backend_non_dist():
assert dist.get_backend() is None
def _test_get_world_size_non_dist():
assert dist.get_world_size() == 1
def _test_get_rank_non_dist():
assert dist.get_rank() == 0
def _test_local_size_non_dist():
assert dist.get_local_size() == 1
def _test_local_rank_non_dist():
assert dist.get_local_rank() == 0
def _test_get_dist_info_non_dist():
assert dist.get_dist_info() == (0, 1)
def _test_is_main_process_non_dist():
assert dist.is_main_process()
def _test_master_only_non_dist():
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def _test_barrier_non_dist():
dist.barrier() # nothing is done
def init_process(rank, world_size, functions, backend='gloo'):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29501'
os.environ['RANK'] = str(rank)
if backend == 'nccl':
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group(
backend=backend, rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
for func in functions:
func()
def main(functions, world_size=2, backend='gloo'):
try:
mp.spawn(
init_process,
args=(world_size, functions, backend),
nprocs=world_size)
except Exception:
pytest.fail('error')
def _test_get_backend_dist():
assert dist.get_backend() == torch_dist.get_backend()
def _test_get_world_size_dist():
assert dist.get_world_size() == 2
def _test_get_rank_dist():
if torch_dist.get_rank() == 0:
assert dist.get_rank() == 0
else:
assert dist.get_rank() == 1
def _test_local_size_dist():
assert dist.get_local_size() == 2
def _test_local_rank_dist():
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank()
def _test_get_dist_info_dist():
if dist.get_rank() == 0:
assert dist.get_dist_info() == (0, 2)
else:
assert dist.get_dist_info() == (1, 2)
def _test_is_main_process_dist():
if dist.get_rank() == 0:
assert dist.is_main_process()
else:
assert not dist.is_main_process()
def _test_master_only_dist():
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def test_non_distributed_env():
_test_get_backend_non_dist()
_test_get_world_size_non_dist()
_test_get_rank_non_dist()
_test_local_size_non_dist()
_test_local_rank_non_dist()
_test_get_dist_info_non_dist()
_test_is_main_process_non_dist()
_test_master_only_non_dist()
_test_barrier_non_dist()
functions_to_test = [
_test_get_backend_dist,
_test_get_world_size_dist,
_test_get_rank_dist,
_test_local_size_dist,
_test_local_rank_dist,
_test_get_dist_info_dist,
_test_is_main_process_dist,
_test_master_only_dist,
]
def test_gloo_backend():
main(functions_to_test)
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
def test_nccl_backend():
main(functions_to_test, backend='nccl')