mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* [Feature] Add distributed module * fix IS_DIST error * all_reduce_dict does operations in-place * support 'mean' operation * provide local group process * add tmpdir argument for collect_results * add unit tests * refactor unit tests * simplify steps to create multiple processes * minor fix * describe the different of *gather* in mmengine and pytorch * minor fix * add unit tests for nccl * test nccl backend in multiple gpu * add get_default_group function to handle different torch versions * minor fix * [Feature] Add distributed module * fix IS_DIST error * all_reduce_dict does operations in-place * support 'mean' operation * provide local group process * add tmpdir argument for collect_results * add unit tests * refactor unit tests * simplify steps to create multiple processes * minor fix * describe the different of *gather* in mmengine and pytorch * minor fix * add unit tests for nccl * test nccl backend in multiple gpu * add get_default_group function to handle different torch versions * minor fix * minor fix * handle torch1.5 * handle torch1.5 * minor fix * fix typo * refactor unit tests * nccl does not support gather and gather_object * fix gather * fix collect_results_cpu * fix collect_results and refactor unit tests * fix collect_results unit tests * handle torch.cat in torch1.5 * refine docstring * refine docstring * fix comments * fix comments
153 lines
3.2 KiB
Python
153 lines
3.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as torch_dist
|
|
import torch.multiprocessing as mp
|
|
|
|
import mmengine.dist as dist
|
|
|
|
|
|
def _test_get_backend_non_dist():
|
|
assert dist.get_backend() is None
|
|
|
|
|
|
def _test_get_world_size_non_dist():
|
|
assert dist.get_world_size() == 1
|
|
|
|
|
|
def _test_get_rank_non_dist():
|
|
assert dist.get_rank() == 0
|
|
|
|
|
|
def _test_local_size_non_dist():
|
|
assert dist.get_local_size() == 1
|
|
|
|
|
|
def _test_local_rank_non_dist():
|
|
assert dist.get_local_rank() == 0
|
|
|
|
|
|
def _test_get_dist_info_non_dist():
|
|
assert dist.get_dist_info() == (0, 1)
|
|
|
|
|
|
def _test_is_main_process_non_dist():
|
|
assert dist.is_main_process()
|
|
|
|
|
|
def _test_master_only_non_dist():
|
|
|
|
@dist.master_only
|
|
def fun():
|
|
assert dist.get_rank() == 0
|
|
|
|
fun()
|
|
|
|
|
|
def _test_barrier_non_dist():
|
|
dist.barrier() # nothing is done
|
|
|
|
|
|
def init_process(rank, world_size, functions, backend='gloo'):
|
|
"""Initialize the distributed environment."""
|
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
|
os.environ['MASTER_PORT'] = '29501'
|
|
os.environ['RANK'] = str(rank)
|
|
dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
|
|
dist.init_local_group(0, world_size)
|
|
|
|
for func in functions:
|
|
func()
|
|
|
|
|
|
def main(functions, world_size=2, backend='gloo'):
|
|
try:
|
|
mp.spawn(
|
|
init_process,
|
|
args=(world_size, functions, backend),
|
|
nprocs=world_size)
|
|
except Exception:
|
|
pytest.fail('error')
|
|
|
|
|
|
def _test_get_backend_dist():
|
|
assert dist.get_backend() == torch_dist.get_backend()
|
|
|
|
|
|
def _test_get_world_size_dist():
|
|
assert dist.get_world_size() == 2
|
|
|
|
|
|
def _test_get_rank_dist():
|
|
if torch_dist.get_rank() == 0:
|
|
assert dist.get_rank() == 0
|
|
else:
|
|
assert dist.get_rank() == 1
|
|
|
|
|
|
def _test_local_size_dist():
|
|
assert dist.get_local_size() == 2
|
|
|
|
|
|
def _test_local_rank_dist():
|
|
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank()
|
|
|
|
|
|
def _test_get_dist_info_dist():
|
|
if dist.get_rank() == 0:
|
|
assert dist.get_dist_info() == (0, 2)
|
|
else:
|
|
assert dist.get_dist_info() == (1, 2)
|
|
|
|
|
|
def _test_is_main_process_dist():
|
|
if dist.get_rank() == 0:
|
|
assert dist.is_main_process()
|
|
else:
|
|
assert not dist.is_main_process()
|
|
|
|
|
|
def _test_master_only_dist():
|
|
|
|
@dist.master_only
|
|
def fun():
|
|
assert dist.get_rank() == 0
|
|
|
|
fun()
|
|
|
|
|
|
def test_non_distributed_env():
|
|
_test_get_backend_non_dist()
|
|
_test_get_world_size_non_dist()
|
|
_test_get_rank_non_dist()
|
|
_test_local_size_non_dist()
|
|
_test_local_rank_non_dist()
|
|
_test_get_dist_info_non_dist()
|
|
_test_is_main_process_non_dist()
|
|
_test_master_only_non_dist()
|
|
_test_barrier_non_dist()
|
|
|
|
|
|
functions_to_test = [
|
|
_test_get_backend_dist,
|
|
_test_get_world_size_dist,
|
|
_test_get_rank_dist,
|
|
_test_local_size_dist,
|
|
_test_local_rank_dist,
|
|
_test_get_dist_info_dist,
|
|
_test_is_main_process_dist,
|
|
_test_master_only_dist,
|
|
]
|
|
|
|
|
|
def test_gloo_backend():
|
|
main(functions_to_test)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
|
def test_nccl_backend():
|
|
main(functions_to_test, backend='nccl')
|