[Enhancement] Refactor the unit tests of dist module with MultiProcessTestCase (#138)

* [Enhancement] Provide MultiProcessTestCase to test distributed related modules

* remove debugging info

* add timeout property

* [Enhancement] Refactor the unit tests of dist module with MultiProcessTestCase

* minor refinement

* minor fix
This commit is contained in:
Zaida Zhou 2022-04-08 15:58:03 +08:00 committed by GitHub
parent 2d80367893
commit 50650e0b7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 564 additions and 467 deletions

View File

@ -2,382 +2,460 @@
import os import os
import os.path as osp import os.path as osp
import tempfile import tempfile
import unittest
from unittest import TestCase
from unittest.mock import patch from unittest.mock import patch
import pytest
import torch import torch
import torch.distributed as torch_dist import torch.distributed as torch_dist
import torch.multiprocessing as mp
import mmengine.dist as dist import mmengine.dist as dist
from mmengine.dist.dist import sync_random_seed from mmengine.dist.dist import sync_random_seed
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils import TORCH_VERSION, digit_version from mmengine.utils import TORCH_VERSION, digit_version
def _test_all_reduce_non_dist(): class TestDist(TestCase):
data = torch.arange(2, dtype=torch.int64) """Test dist module in non-distributed environment."""
expected = torch.arange(2, dtype=torch.int64)
dist.all_reduce(data) def test_all_reduce(self):
assert torch.allclose(data, expected) data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64)
dist.all_reduce(data)
self.assertTrue(torch.allclose(data, expected))
def test_all_gather(self):
data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64)
output = dist.all_gather(data)
self.assertTrue(torch.allclose(output[0], expected))
def test_gather(self):
data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64)
output = dist.gather(data)
self.assertTrue(torch.allclose(output[0], expected))
def test_broadcast(self):
data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64)
dist.broadcast(data)
self.assertTrue(torch.allclose(data, expected))
@patch('numpy.random.randint', return_value=10)
def test_sync_random_seed(self, mock):
self.assertEqual(sync_random_seed(), 10)
def test_broadcast_object_list(self):
with self.assertRaises(AssertionError):
# input should be list of object
dist.broadcast_object_list('foo')
data = ['foo', 12, {1: 2}]
expected = ['foo', 12, {1: 2}]
dist.broadcast_object_list(data)
self.assertEqual(data, expected)
def test_all_reduce_dict(self):
with self.assertRaises(AssertionError):
# input should be dict
dist.all_reduce_dict('foo')
data = {
'key1': torch.arange(2, dtype=torch.int64),
'key2': torch.arange(3, dtype=torch.int64)
}
expected = {
'key1': torch.arange(2, dtype=torch.int64),
'key2': torch.arange(3, dtype=torch.int64)
}
dist.all_reduce_dict(data)
for key in data:
self.assertTrue(torch.allclose(data[key], expected[key]))
def test_all_gather_object(self):
data = 'foo'
expected = 'foo'
gather_objects = dist.all_gather_object(data)
self.assertEqual(gather_objects[0], expected)
def test_gather_object(self):
data = 'foo'
expected = 'foo'
gather_objects = dist.gather_object(data)
self.assertEqual(gather_objects[0], expected)
def test_collect_results(self):
data = ['foo', {1: 2}]
size = 2
expected = ['foo', {1: 2}]
# test `device=cpu`
output = dist.collect_results(data, size, device='cpu')
self.assertEqual(output, expected)
# test `device=gpu`
output = dist.collect_results(data, size, device='gpu')
self.assertEqual(output, expected)
def _test_all_gather_non_dist(): class TestDistWithGLOOBackend(MultiProcessTestCase):
data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64)
output = dist.all_gather(data)
assert torch.allclose(output[0], expected)
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
torch_dist.init_process_group(
backend='gloo', rank=rank, world_size=world_size)
def _test_gather_non_dist(): def setUp(self):
data = torch.arange(2, dtype=torch.int64) super().setUp()
expected = torch.arange(2, dtype=torch.int64) self._spawn_processes()
output = dist.gather(data)
assert torch.allclose(output[0], expected)
def test_all_reduce(self):
self._init_dist_env(self.rank, self.world_size)
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
if dist.get_rank() == 0:
data = torch.tensor([1, 2], dtype=tensor_type)
else:
data = torch.tensor([3, 4], dtype=tensor_type)
def _test_broadcast_non_dist(): if reduce_op == 'sum':
data = torch.arange(2, dtype=torch.int64) expected = torch.tensor([4, 6], dtype=tensor_type)
expected = torch.arange(2, dtype=torch.int64) else:
dist.broadcast(data) expected = torch.tensor([2, 3], dtype=tensor_type)
assert torch.allclose(data, expected)
dist.all_reduce(data, reduce_op)
self.assertTrue(torch.allclose(data, expected))
@patch('numpy.random.randint', return_value=10) def test_all_gather(self):
def _test_sync_random_seed_no_dist(mock): self._init_dist_env(self.rank, self.world_size)
assert sync_random_seed() == 10
def _test_broadcast_object_list_no_dist():
with pytest.raises(AssertionError):
# input should be list of object
dist.broadcast_object_list('foo')
data = ['foo', 12, {1: 2}]
expected = ['foo', 12, {1: 2}]
dist.broadcast_object_list(data)
assert data == expected
def _test_all_reduce_dict_no_dist():
with pytest.raises(AssertionError):
# input should be dict
dist.all_reduce_dict('foo')
data = {
'key1': torch.arange(2, dtype=torch.int64),
'key2': torch.arange(3, dtype=torch.int64)
}
expected = {
'key1': torch.arange(2, dtype=torch.int64),
'key2': torch.arange(3, dtype=torch.int64)
}
dist.all_reduce_dict(data)
for key in data:
assert torch.allclose(data[key], expected[key])
def _test_all_gather_object_no_dist():
data = 'foo'
expected = 'foo'
gather_objects = dist.all_gather_object(data)
assert gather_objects[0] == expected
def _test_gather_object_no_dist():
data = 'foo'
expected = 'foo'
gather_objects = dist.gather_object(data)
assert gather_objects[0] == expected
def _test_collect_results_non_dist():
data = ['foo', {1: 2}]
size = 2
expected = ['foo', {1: 2}]
# test `device=cpu`
output = dist.collect_results(data, size, device='cpu')
assert output == expected
# test `device=gpu`
output = dist.collect_results(data, size, device='cpu')
assert output == expected
def init_process(rank, world_size, functions, backend='gloo'):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
if backend == 'nccl':
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
device = 'cuda'
else:
device = 'cpu'
torch_dist.init_process_group(
backend=backend, rank=rank, world_size=world_size)
for func in functions:
func(device)
def main(functions, world_size=2, backend='gloo'):
try:
mp.spawn(
init_process,
args=(world_size, functions, backend),
nprocs=world_size)
except Exception:
pytest.fail(f'{backend} failed')
def _test_all_reduce_dist(device):
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = torch.tensor([1, 2], dtype=tensor_type).to(device) data = torch.tensor([0, 1])
else: else:
data = torch.tensor([3, 4], dtype=tensor_type).to(device) data = torch.tensor([1, 2])
if reduce_op == 'sum': expected = [torch.tensor([0, 1]), torch.tensor([1, 2])]
expected = torch.tensor([4, 6], dtype=tensor_type).to(device)
output = dist.all_gather(data)
self.assertTrue(
torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]))
def test_gather(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = torch.tensor([0, 1])
else: else:
expected = torch.tensor([2, 3], dtype=tensor_type).to(device) data = torch.tensor([1, 2])
dist.all_reduce(data, reduce_op) output = dist.gather(data)
if dist.get_rank() == 0:
expected = [torch.tensor([0, 1]), torch.tensor([1, 2])]
for i in range(2):
assert torch.allclose(output[i], expected[i])
else:
assert output == []
def test_broadcast_dist(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = torch.tensor([0, 1])
else:
data = torch.tensor([1, 2])
expected = torch.tensor([0, 1])
dist.broadcast(data, 0)
assert torch.allclose(data, expected) assert torch.allclose(data, expected)
def test_sync_random_seed(self):
self._init_dist_env(self.rank, self.world_size)
with patch.object(
torch, 'tensor',
return_value=torch.tensor(1024)) as mock_tensor:
output = dist.sync_random_seed()
assert output == 1024
mock_tensor.assert_called()
def _test_all_gather_dist(device): def test_broadcast_object_list(self):
if dist.get_rank() == 0: self._init_dist_env(self.rank, self.world_size)
data = torch.tensor([0, 1]).to(device)
else:
data = torch.tensor([1, 2]).to(device)
expected = [
torch.tensor([0, 1]).to(device),
torch.tensor([1, 2]).to(device)
]
output = dist.all_gather(data)
assert torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])
def _test_gather_dist(device):
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).to(device)
else:
data = torch.tensor([1, 2]).to(device)
output = dist.gather(data)
if dist.get_rank() == 0:
expected = [
torch.tensor([0, 1]).to(device),
torch.tensor([1, 2]).to(device)
]
for i in range(2):
assert torch.allclose(output[i], expected[i])
else:
assert output == []
def _test_broadcast_dist(device):
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).to(device)
else:
data = torch.tensor([1, 2]).to(device)
expected = torch.tensor([0, 1]).to(device)
dist.broadcast(data, 0)
assert torch.allclose(data, expected)
def _test_sync_random_seed_dist(device):
with patch.object(
torch, 'tensor', return_value=torch.tensor(1024)) as mock_tensor:
output = dist.sync_random_seed()
assert output == 1024
mock_tensor.assert_called()
def _test_broadcast_object_list_dist(device):
if dist.get_rank() == 0:
data = ['foo', 12, {1: 2}]
else:
data = [None, None, None]
expected = ['foo', 12, {1: 2}]
dist.broadcast_object_list(data)
assert data == expected
def _test_all_reduce_dict_dist(device):
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = { data = ['foo', 12, {1: 2}]
'key1': torch.tensor([0, 1], dtype=tensor_type).to(device),
'key2': torch.tensor([1, 2], dtype=tensor_type).to(device)
}
else: else:
data = { data = [None, None, None]
'key1': torch.tensor([2, 3], dtype=tensor_type).to(device),
'key2': torch.tensor([3, 4], dtype=tensor_type).to(device) expected = ['foo', 12, {1: 2}]
} dist.broadcast_object_list(data)
self.assertEqual(data, expected)
def test_all_reduce_dict(self):
self._init_dist_env(self.rank, self.world_size)
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
if dist.get_rank() == 0:
data = {
'key1': torch.tensor([0, 1], dtype=tensor_type),
'key2': torch.tensor([1, 2], dtype=tensor_type),
}
else:
data = {
'key1': torch.tensor([2, 3], dtype=tensor_type),
'key2': torch.tensor([3, 4], dtype=tensor_type),
}
if reduce_op == 'sum':
expected = {
'key1': torch.tensor([2, 4], dtype=tensor_type),
'key2': torch.tensor([4, 6], dtype=tensor_type),
}
else:
expected = {
'key1': torch.tensor([1, 2], dtype=tensor_type),
'key2': torch.tensor([2, 3], dtype=tensor_type),
}
dist.all_reduce_dict(data, reduce_op)
for key in data:
assert torch.allclose(data[key], expected[key])
# `torch.cat` in torch1.5 can not concatenate different types so we
# fallback to convert them all to float type.
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
if dist.get_rank() == 0:
data = {
'key1': torch.tensor([0, 1], dtype=torch.float32),
'key2': torch.tensor([1, 2], dtype=torch.int32)
}
else:
data = {
'key1': torch.tensor([2, 3], dtype=torch.float32),
'key2': torch.tensor([3, 4], dtype=torch.int32),
}
if reduce_op == 'sum':
expected = { expected = {
'key1': torch.tensor([2, 4], dtype=tensor_type).to(device), 'key1': torch.tensor([2, 4], dtype=torch.float32),
'key2': torch.tensor([4, 6], dtype=tensor_type).to(device) 'key2': torch.tensor([4, 6], dtype=torch.float32),
}
else:
expected = {
'key1': torch.tensor([1, 2], dtype=tensor_type).to(device),
'key2': torch.tensor([2, 3], dtype=tensor_type).to(device)
} }
dist.all_reduce_dict(data, reduce_op) dist.all_reduce_dict(data, 'sum')
for key in data: for key in data:
assert torch.allclose(data[key], expected[key]) assert torch.allclose(data[key], expected[key])
# `torch.cat` in torch1.5 can not concatenate different types so we def test_all_gather_object(self):
# fallback to convert them all to float type. self._init_dist_env(self.rank, self.world_size)
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = { data = 'foo'
'key1': torch.tensor([0, 1], dtype=torch.float32).to(device),
'key2': torch.tensor([1, 2], dtype=torch.int32).to(device)
}
else: else:
data = { data = {1: 2}
'key1': torch.tensor([2, 3], dtype=torch.float32).to(device),
'key2': torch.tensor([3, 4], dtype=torch.int32).to(device)
}
expected = { expected = ['foo', {1: 2}]
'key1': torch.tensor([2, 4], dtype=torch.float32).to(device), output = dist.all_gather_object(data)
'key2': torch.tensor([4, 6], dtype=torch.float32).to(device)
}
dist.all_reduce_dict(data, 'sum') self.assertEqual(output, expected)
for key in data: def test_gather_object(self):
assert torch.allclose(data[key], expected[key]) self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = 'foo'
else:
data = {1: 2}
output = dist.gather_object(data, dst=0)
if dist.get_rank() == 0:
self.assertEqual(output, ['foo', {1: 2}])
else:
self.assertIsNone(output)
def _test_all_gather_object_dist(device): @unittest.skipIf(
if dist.get_rank() == 0:
data = 'foo'
else:
data = {1: 2}
expected = ['foo', {1: 2}]
output = dist.all_gather_object(data)
assert output == expected
def _test_gather_object_dist(device):
if dist.get_rank() == 0:
data = 'foo'
else:
data = {1: 2}
output = dist.gather_object(data, dst=0)
if dist.get_rank() == 0:
assert output == ['foo', {1: 2}]
else:
assert output is None
def _test_collect_results_dist(device):
if dist.get_rank() == 0:
data = ['foo', {1: 2}]
else:
data = [24, {'a': 'b'}]
size = 4
expected = ['foo', 24, {1: 2}, {'a': 'b'}]
# test `device=cpu`
output = dist.collect_results(data, size, device='cpu')
if dist.get_rank() == 0:
assert output == expected
else:
assert output is None
# test `device=cpu` and `tmpdir is not None`
tmpdir = tempfile.mkdtemp()
# broadcast tmpdir to all ranks to make it consistent
object_list = [tmpdir]
dist.broadcast_object_list(object_list)
output = dist.collect_results(
data, size, device='cpu', tmpdir=object_list[0])
if dist.get_rank() == 0:
assert output == expected
else:
assert output is None
if dist.get_rank() == 0:
# object_list[0] will be removed by `dist.collect_results`
assert not osp.exists(object_list[0])
# test `device=gpu`
output = dist.collect_results(data, size, device='gpu')
if dist.get_rank() == 0:
assert output == expected
else:
assert output is None
def test_non_distributed_env():
_test_all_reduce_non_dist()
_test_all_gather_non_dist()
_test_gather_non_dist()
_test_broadcast_non_dist()
_test_sync_random_seed_no_dist()
_test_broadcast_object_list_no_dist()
_test_all_reduce_dict_no_dist()
_test_all_gather_object_no_dist()
_test_gather_object_no_dist()
_test_collect_results_non_dist()
def test_gloo_backend():
functions_to_test = [
_test_all_reduce_dist,
_test_all_gather_dist,
_test_gather_dist,
_test_broadcast_dist,
_test_sync_random_seed_dist,
_test_broadcast_object_list_dist,
_test_all_reduce_dict_dist,
_test_all_gather_object_dist,
_test_gather_object_dist,
]
main(functions_to_test, backend='gloo')
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
def test_nccl_backend(): class TestDistWithNCCLBackend(MultiProcessTestCase):
functions_to_test = [
_test_all_reduce_dist, def _init_dist_env(self, rank, world_size):
_test_all_gather_dist, """Initialize the distributed environment."""
_test_broadcast_dist, os.environ['MASTER_ADDR'] = '127.0.0.1'
_test_sync_random_seed_dist, os.environ['MASTER_PORT'] = '29505'
_test_broadcast_object_list_dist, os.environ['RANK'] = str(rank)
_test_all_reduce_dict_dist,
_test_all_gather_object_dist, num_gpus = torch.cuda.device_count()
_test_collect_results_dist, torch.cuda.set_device(rank % num_gpus)
] torch_dist.init_process_group(
main(functions_to_test, backend='nccl') backend='nccl', rank=rank, world_size=world_size)
def setUp(self):
super().setUp()
self._spawn_processes()
def test_all_reduce(self):
self._init_dist_env(self.rank, self.world_size)
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
if dist.get_rank() == 0:
data = torch.tensor([1, 2], dtype=tensor_type).cuda()
else:
data = torch.tensor([3, 4], dtype=tensor_type).cuda()
if reduce_op == 'sum':
expected = torch.tensor([4, 6], dtype=tensor_type).cuda()
else:
expected = torch.tensor([2, 3], dtype=tensor_type).cuda()
dist.all_reduce(data, reduce_op)
self.assertTrue(torch.allclose(data, expected))
def test_all_gather(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).cuda()
else:
data = torch.tensor([1, 2]).cuda()
expected = [torch.tensor([0, 1]).cuda(), torch.tensor([1, 2]).cuda()]
output = dist.all_gather(data)
self.assertTrue(
torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]))
def test_broadcast_dist(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).cuda()
else:
data = torch.tensor([1, 2]).cuda()
expected = torch.tensor([0, 1]).cuda()
dist.broadcast(data, 0)
assert torch.allclose(data, expected)
def test_sync_random_seed(self):
self._init_dist_env(self.rank, self.world_size)
with patch.object(
torch, 'tensor',
return_value=torch.tensor(1024)) as mock_tensor:
output = dist.sync_random_seed()
assert output == 1024
mock_tensor.assert_called()
def test_broadcast_object_list(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = ['foo', 12, {1: 2}]
else:
data = [None, None, None]
expected = ['foo', 12, {1: 2}]
dist.broadcast_object_list(data)
self.assertEqual(data, expected)
def test_all_reduce_dict(self):
self._init_dist_env(self.rank, self.world_size)
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
if dist.get_rank() == 0:
data = {
'key1': torch.tensor([0, 1], dtype=tensor_type).cuda(),
'key2': torch.tensor([1, 2], dtype=tensor_type).cuda(),
}
else:
data = {
'key1': torch.tensor([2, 3], dtype=tensor_type).cuda(),
'key2': torch.tensor([3, 4], dtype=tensor_type).cuda(),
}
if reduce_op == 'sum':
expected = {
'key1': torch.tensor([2, 4], dtype=tensor_type).cuda(),
'key2': torch.tensor([4, 6], dtype=tensor_type).cuda(),
}
else:
expected = {
'key1': torch.tensor([1, 2], dtype=tensor_type).cuda(),
'key2': torch.tensor([2, 3], dtype=tensor_type).cuda(),
}
dist.all_reduce_dict(data, reduce_op)
for key in data:
assert torch.allclose(data[key], expected[key])
# `torch.cat` in torch1.5 can not concatenate different types so we
# fallback to convert them all to float type.
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
if dist.get_rank() == 0:
data = {
'key1': torch.tensor([0, 1], dtype=torch.float32).cuda(),
'key2': torch.tensor([1, 2], dtype=torch.int32).cuda(),
}
else:
data = {
'key1': torch.tensor([2, 3], dtype=torch.float32).cuda(),
'key2': torch.tensor([3, 4], dtype=torch.int32).cuda(),
}
expected = {
'key1': torch.tensor([2, 4], dtype=torch.float32).cuda(),
'key2': torch.tensor([4, 6], dtype=torch.float32).cuda(),
}
dist.all_reduce_dict(data, 'sum')
for key in data:
assert torch.allclose(data[key], expected[key])
def test_all_gather_object(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = 'foo'
else:
data = {1: 2}
expected = ['foo', {1: 2}]
output = dist.all_gather_object(data)
self.assertEqual(output, expected)
def test_collect_results(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = ['foo', {1: 2}]
else:
data = [24, {'a': 'b'}]
size = 4
expected = ['foo', 24, {1: 2}, {'a': 'b'}]
# test `device=cpu`
output = dist.collect_results(data, size, device='cpu')
if dist.get_rank() == 0:
self.assertEqual(output, expected)
else:
self.assertIsNone(output)
# test `device=cpu` and `tmpdir is not None`
tmpdir = tempfile.mkdtemp()
# broadcast tmpdir to all ranks to make it consistent
object_list = [tmpdir]
dist.broadcast_object_list(object_list)
output = dist.collect_results(
data, size, device='cpu', tmpdir=object_list[0])
if dist.get_rank() == 0:
self.assertEqual(output, expected)
else:
self.assertIsNone(output)
if dist.get_rank() == 0:
# object_list[0] will be removed by `dist.collect_results`
self.assertFalse(osp.exists(object_list[0]))
# test `device=gpu`
output = dist.collect_results(data, size, device='gpu')
if dist.get_rank() == 0:
self.assertEqual(output, expected)
else:
self.assertIsNone(output)

View File

@ -1,158 +1,177 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
import unittest
from unittest import TestCase
import pytest
import torch import torch
import torch.distributed as torch_dist import torch.distributed as torch_dist
import torch.multiprocessing as mp
import mmengine.dist as dist import mmengine.dist as dist
from mmengine.testing._internal import MultiProcessTestCase
def _test_get_backend_non_dist(): class TestUtils(TestCase):
assert dist.get_backend() is None
def test_get_backend(self):
self.assertIsNone(dist.get_backend())
def test_get_world_size(self):
self.assertEqual(dist.get_world_size(), 1)
def test_get_rank(self):
self.assertEqual(dist.get_rank(), 0)
def test_local_size(self):
self.assertEqual(dist.get_local_size(), 1)
def test_local_rank(self):
self.assertEqual(dist.get_local_rank(), 0)
def test_get_dist_info(self):
self.assertEqual(dist.get_dist_info(), (0, 1))
def test_is_main_process(self):
self.assertTrue(dist.is_main_process())
def test_master_only(self):
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def test_barrier(self):
dist.barrier() # nothing is done
def _test_get_world_size_non_dist(): class TestUtilsWithGLOOBackend(MultiProcessTestCase):
assert dist.get_world_size() == 1
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
torch_dist.init_process_group(
backend='gloo', rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
def setUp(self):
super().setUp()
self._spawn_processes()
def test_get_backend(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def test_get_world_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_world_size(), 2)
def test_get_rank(self):
self._init_dist_env(self.rank, self.world_size)
if torch_dist.get_rank() == 0:
self.assertEqual(dist.get_rank(), 0)
else:
self.assertEqual(dist.get_rank(), 1)
def test_local_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_local_size(), 2)
def test_local_rank(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
def test_get_dist_info(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertEqual(dist.get_dist_info(), (0, 2))
else:
self.assertEqual(dist.get_dist_info(), (1, 2))
def test_is_main_process(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertTrue(dist.is_main_process())
else:
self.assertFalse(dist.is_main_process())
def test_master_only(self):
self._init_dist_env(self.rank, self.world_size)
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def _test_get_rank_non_dist(): @unittest.skipIf(
assert dist.get_rank() == 0 torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
class TestUtilsWithNCCLBackend(MultiProcessTestCase):
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
def _test_local_size_non_dist():
assert dist.get_local_size() == 1
def _test_local_rank_non_dist():
assert dist.get_local_rank() == 0
def _test_get_dist_info_non_dist():
assert dist.get_dist_info() == (0, 1)
def _test_is_main_process_non_dist():
assert dist.is_main_process()
def _test_master_only_non_dist():
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def _test_barrier_non_dist():
dist.barrier() # nothing is done
def init_process(rank, world_size, functions, backend='gloo'):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29501'
os.environ['RANK'] = str(rank)
if backend == 'nccl':
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus) torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group(
backend='nccl', rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
torch_dist.init_process_group( def setUp(self):
backend=backend, rank=rank, world_size=world_size) super().setUp()
dist.init_local_group(0, world_size) self._spawn_processes()
for func in functions: def test_get_backend(self):
func() self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def test_get_world_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_world_size(), 2)
def main(functions, world_size=2, backend='gloo'): def test_get_rank(self):
try: self._init_dist_env(self.rank, self.world_size)
mp.spawn( if torch_dist.get_rank() == 0:
init_process, self.assertEqual(dist.get_rank(), 0)
args=(world_size, functions, backend), else:
nprocs=world_size) self.assertEqual(dist.get_rank(), 1)
except Exception:
pytest.fail('error')
def test_local_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_local_size(), 2)
def _test_get_backend_dist(): def test_local_rank(self):
assert dist.get_backend() == torch_dist.get_backend() self._init_dist_env(self.rank, self.world_size)
self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
def test_get_dist_info(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertEqual(dist.get_dist_info(), (0, 2))
else:
self.assertEqual(dist.get_dist_info(), (1, 2))
def _test_get_world_size_dist(): def test_is_main_process(self):
assert dist.get_world_size() == 2 self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertTrue(dist.is_main_process())
else:
self.assertFalse(dist.is_main_process())
def test_master_only(self):
self._init_dist_env(self.rank, self.world_size)
def _test_get_rank_dist(): @dist.master_only
if torch_dist.get_rank() == 0: def fun():
assert dist.get_rank() == 0 assert dist.get_rank() == 0
else:
assert dist.get_rank() == 1
fun()
def _test_local_size_dist():
assert dist.get_local_size() == 2
def _test_local_rank_dist():
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank()
def _test_get_dist_info_dist():
if dist.get_rank() == 0:
assert dist.get_dist_info() == (0, 2)
else:
assert dist.get_dist_info() == (1, 2)
def _test_is_main_process_dist():
if dist.get_rank() == 0:
assert dist.is_main_process()
else:
assert not dist.is_main_process()
def _test_master_only_dist():
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def test_non_distributed_env():
_test_get_backend_non_dist()
_test_get_world_size_non_dist()
_test_get_rank_non_dist()
_test_local_size_non_dist()
_test_local_rank_non_dist()
_test_get_dist_info_non_dist()
_test_is_main_process_non_dist()
_test_master_only_non_dist()
_test_barrier_non_dist()
functions_to_test = [
_test_get_backend_dist,
_test_get_world_size_dist,
_test_get_rank_dist,
_test_local_size_dist,
_test_local_rank_dist,
_test_get_dist_info_dist,
_test_is_main_process_dist,
_test_master_only_dist,
]
def test_gloo_backend():
main(functions_to_test)
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
def test_nccl_backend():
main(functions_to_test, backend='nccl')