[Enhancement] Refactor the unit tests of dist module with MultiProcessTestCase (#138)

* [Enhancement] Provide MultiProcessTestCase to test distributed related modules

* remove debugging info

* add timeout property

* [Enhancement] Refactor the unit tests of dist module with MultiProcessTestCase

* minor refinement

* minor fix
This commit is contained in:
Zaida Zhou 2022-04-08 15:58:03 +08:00 committed by GitHub
parent 2d80367893
commit 50650e0b7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 564 additions and 467 deletions

View File

@ -2,64 +2,62 @@
import os import os
import os.path as osp import os.path as osp
import tempfile import tempfile
import unittest
from unittest import TestCase
from unittest.mock import patch from unittest.mock import patch
import pytest
import torch import torch
import torch.distributed as torch_dist import torch.distributed as torch_dist
import torch.multiprocessing as mp
import mmengine.dist as dist import mmengine.dist as dist
from mmengine.dist.dist import sync_random_seed from mmengine.dist.dist import sync_random_seed
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils import TORCH_VERSION, digit_version from mmengine.utils import TORCH_VERSION, digit_version
def _test_all_reduce_non_dist(): class TestDist(TestCase):
"""Test dist module in non-distributed environment."""
def test_all_reduce(self):
data = torch.arange(2, dtype=torch.int64) data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64) expected = torch.arange(2, dtype=torch.int64)
dist.all_reduce(data) dist.all_reduce(data)
assert torch.allclose(data, expected) self.assertTrue(torch.allclose(data, expected))
def test_all_gather(self):
def _test_all_gather_non_dist():
data = torch.arange(2, dtype=torch.int64) data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64) expected = torch.arange(2, dtype=torch.int64)
output = dist.all_gather(data) output = dist.all_gather(data)
assert torch.allclose(output[0], expected) self.assertTrue(torch.allclose(output[0], expected))
def test_gather(self):
def _test_gather_non_dist():
data = torch.arange(2, dtype=torch.int64) data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64) expected = torch.arange(2, dtype=torch.int64)
output = dist.gather(data) output = dist.gather(data)
assert torch.allclose(output[0], expected) self.assertTrue(torch.allclose(output[0], expected))
def test_broadcast(self):
def _test_broadcast_non_dist():
data = torch.arange(2, dtype=torch.int64) data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64) expected = torch.arange(2, dtype=torch.int64)
dist.broadcast(data) dist.broadcast(data)
assert torch.allclose(data, expected) self.assertTrue(torch.allclose(data, expected))
@patch('numpy.random.randint', return_value=10)
def test_sync_random_seed(self, mock):
self.assertEqual(sync_random_seed(), 10)
@patch('numpy.random.randint', return_value=10) def test_broadcast_object_list(self):
def _test_sync_random_seed_no_dist(mock): with self.assertRaises(AssertionError):
assert sync_random_seed() == 10
def _test_broadcast_object_list_no_dist():
with pytest.raises(AssertionError):
# input should be list of object # input should be list of object
dist.broadcast_object_list('foo') dist.broadcast_object_list('foo')
data = ['foo', 12, {1: 2}] data = ['foo', 12, {1: 2}]
expected = ['foo', 12, {1: 2}] expected = ['foo', 12, {1: 2}]
dist.broadcast_object_list(data) dist.broadcast_object_list(data)
assert data == expected self.assertEqual(data, expected)
def test_all_reduce_dict(self):
def _test_all_reduce_dict_no_dist(): with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
# input should be dict # input should be dict
dist.all_reduce_dict('foo') dist.all_reduce_dict('foo')
@ -73,173 +71,149 @@ def _test_all_reduce_dict_no_dist():
} }
dist.all_reduce_dict(data) dist.all_reduce_dict(data)
for key in data: for key in data:
assert torch.allclose(data[key], expected[key]) self.assertTrue(torch.allclose(data[key], expected[key]))
def test_all_gather_object(self):
def _test_all_gather_object_no_dist():
data = 'foo' data = 'foo'
expected = 'foo' expected = 'foo'
gather_objects = dist.all_gather_object(data) gather_objects = dist.all_gather_object(data)
assert gather_objects[0] == expected self.assertEqual(gather_objects[0], expected)
def test_gather_object(self):
def _test_gather_object_no_dist():
data = 'foo' data = 'foo'
expected = 'foo' expected = 'foo'
gather_objects = dist.gather_object(data) gather_objects = dist.gather_object(data)
assert gather_objects[0] == expected self.assertEqual(gather_objects[0], expected)
def test_collect_results(self):
def _test_collect_results_non_dist():
data = ['foo', {1: 2}] data = ['foo', {1: 2}]
size = 2 size = 2
expected = ['foo', {1: 2}] expected = ['foo', {1: 2}]
# test `device=cpu` # test `device=cpu`
output = dist.collect_results(data, size, device='cpu') output = dist.collect_results(data, size, device='cpu')
assert output == expected self.assertEqual(output, expected)
# test `device=gpu` # test `device=gpu`
output = dist.collect_results(data, size, device='cpu') output = dist.collect_results(data, size, device='gpu')
assert output == expected self.assertEqual(output, expected)
def init_process(rank, world_size, functions, backend='gloo'): class TestDistWithGLOOBackend(MultiProcessTestCase):
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505' os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank) os.environ['RANK'] = str(rank)
if backend == 'nccl':
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
device = 'cuda'
else:
device = 'cpu'
torch_dist.init_process_group( torch_dist.init_process_group(
backend=backend, rank=rank, world_size=world_size) backend='gloo', rank=rank, world_size=world_size)
for func in functions: def setUp(self):
func(device) super().setUp()
self._spawn_processes()
def test_all_reduce(self):
def main(functions, world_size=2, backend='gloo'): self._init_dist_env(self.rank, self.world_size)
try:
mp.spawn(
init_process,
args=(world_size, functions, backend),
nprocs=world_size)
except Exception:
pytest.fail(f'{backend} failed')
def _test_all_reduce_dist(device):
for tensor_type, reduce_op in zip([torch.int64, torch.float32], for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']): ['sum', 'mean']):
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = torch.tensor([1, 2], dtype=tensor_type).to(device) data = torch.tensor([1, 2], dtype=tensor_type)
else: else:
data = torch.tensor([3, 4], dtype=tensor_type).to(device) data = torch.tensor([3, 4], dtype=tensor_type)
if reduce_op == 'sum': if reduce_op == 'sum':
expected = torch.tensor([4, 6], dtype=tensor_type).to(device) expected = torch.tensor([4, 6], dtype=tensor_type)
else: else:
expected = torch.tensor([2, 3], dtype=tensor_type).to(device) expected = torch.tensor([2, 3], dtype=tensor_type)
dist.all_reduce(data, reduce_op) dist.all_reduce(data, reduce_op)
assert torch.allclose(data, expected) self.assertTrue(torch.allclose(data, expected))
def test_all_gather(self):
def _test_all_gather_dist(device): self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = torch.tensor([0, 1]).to(device) data = torch.tensor([0, 1])
else: else:
data = torch.tensor([1, 2]).to(device) data = torch.tensor([1, 2])
expected = [ expected = [torch.tensor([0, 1]), torch.tensor([1, 2])]
torch.tensor([0, 1]).to(device),
torch.tensor([1, 2]).to(device)
]
output = dist.all_gather(data) output = dist.all_gather(data)
assert torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]) self.assertTrue(
torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]))
def test_gather(self):
def _test_gather_dist(device): self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = torch.tensor([0, 1]).to(device) data = torch.tensor([0, 1])
else: else:
data = torch.tensor([1, 2]).to(device) data = torch.tensor([1, 2])
output = dist.gather(data) output = dist.gather(data)
if dist.get_rank() == 0: if dist.get_rank() == 0:
expected = [ expected = [torch.tensor([0, 1]), torch.tensor([1, 2])]
torch.tensor([0, 1]).to(device),
torch.tensor([1, 2]).to(device)
]
for i in range(2): for i in range(2):
assert torch.allclose(output[i], expected[i]) assert torch.allclose(output[i], expected[i])
else: else:
assert output == [] assert output == []
def test_broadcast_dist(self):
def _test_broadcast_dist(device): self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = torch.tensor([0, 1]).to(device) data = torch.tensor([0, 1])
else: else:
data = torch.tensor([1, 2]).to(device) data = torch.tensor([1, 2])
expected = torch.tensor([0, 1]).to(device) expected = torch.tensor([0, 1])
dist.broadcast(data, 0) dist.broadcast(data, 0)
assert torch.allclose(data, expected) assert torch.allclose(data, expected)
def test_sync_random_seed(self):
def _test_sync_random_seed_dist(device): self._init_dist_env(self.rank, self.world_size)
with patch.object( with patch.object(
torch, 'tensor', return_value=torch.tensor(1024)) as mock_tensor: torch, 'tensor',
return_value=torch.tensor(1024)) as mock_tensor:
output = dist.sync_random_seed() output = dist.sync_random_seed()
assert output == 1024 assert output == 1024
mock_tensor.assert_called() mock_tensor.assert_called()
def test_broadcast_object_list(self):
def _test_broadcast_object_list_dist(device): self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = ['foo', 12, {1: 2}] data = ['foo', 12, {1: 2}]
else: else:
data = [None, None, None] data = [None, None, None]
expected = ['foo', 12, {1: 2}] expected = ['foo', 12, {1: 2}]
dist.broadcast_object_list(data) dist.broadcast_object_list(data)
self.assertEqual(data, expected)
assert data == expected def test_all_reduce_dict(self):
self._init_dist_env(self.rank, self.world_size)
def _test_all_reduce_dict_dist(device):
for tensor_type, reduce_op in zip([torch.int64, torch.float32], for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']): ['sum', 'mean']):
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = { data = {
'key1': torch.tensor([0, 1], dtype=tensor_type).to(device), 'key1': torch.tensor([0, 1], dtype=tensor_type),
'key2': torch.tensor([1, 2], dtype=tensor_type).to(device) 'key2': torch.tensor([1, 2], dtype=tensor_type),
} }
else: else:
data = { data = {
'key1': torch.tensor([2, 3], dtype=tensor_type).to(device), 'key1': torch.tensor([2, 3], dtype=tensor_type),
'key2': torch.tensor([3, 4], dtype=tensor_type).to(device) 'key2': torch.tensor([3, 4], dtype=tensor_type),
} }
if reduce_op == 'sum': if reduce_op == 'sum':
expected = { expected = {
'key1': torch.tensor([2, 4], dtype=tensor_type).to(device), 'key1': torch.tensor([2, 4], dtype=tensor_type),
'key2': torch.tensor([4, 6], dtype=tensor_type).to(device) 'key2': torch.tensor([4, 6], dtype=tensor_type),
} }
else: else:
expected = { expected = {
'key1': torch.tensor([1, 2], dtype=tensor_type).to(device), 'key1': torch.tensor([1, 2], dtype=tensor_type),
'key2': torch.tensor([2, 3], dtype=tensor_type).to(device) 'key2': torch.tensor([2, 3], dtype=tensor_type),
} }
dist.all_reduce_dict(data, reduce_op) dist.all_reduce_dict(data, reduce_op)
@ -252,18 +226,18 @@ def _test_all_reduce_dict_dist(device):
if digit_version(TORCH_VERSION) == digit_version('1.5.0'): if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = { data = {
'key1': torch.tensor([0, 1], dtype=torch.float32).to(device), 'key1': torch.tensor([0, 1], dtype=torch.float32),
'key2': torch.tensor([1, 2], dtype=torch.int32).to(device) 'key2': torch.tensor([1, 2], dtype=torch.int32)
} }
else: else:
data = { data = {
'key1': torch.tensor([2, 3], dtype=torch.float32).to(device), 'key1': torch.tensor([2, 3], dtype=torch.float32),
'key2': torch.tensor([3, 4], dtype=torch.int32).to(device) 'key2': torch.tensor([3, 4], dtype=torch.int32),
} }
expected = { expected = {
'key1': torch.tensor([2, 4], dtype=torch.float32).to(device), 'key1': torch.tensor([2, 4], dtype=torch.float32),
'key2': torch.tensor([4, 6], dtype=torch.float32).to(device) 'key2': torch.tensor([4, 6], dtype=torch.float32),
} }
dist.all_reduce_dict(data, 'sum') dist.all_reduce_dict(data, 'sum')
@ -271,8 +245,8 @@ def _test_all_reduce_dict_dist(device):
for key in data: for key in data:
assert torch.allclose(data[key], expected[key]) assert torch.allclose(data[key], expected[key])
def test_all_gather_object(self):
def _test_all_gather_object_dist(device): self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = 'foo' data = 'foo'
else: else:
@ -281,10 +255,10 @@ def _test_all_gather_object_dist(device):
expected = ['foo', {1: 2}] expected = ['foo', {1: 2}]
output = dist.all_gather_object(data) output = dist.all_gather_object(data)
assert output == expected self.assertEqual(output, expected)
def test_gather_object(self):
def _test_gather_object_dist(device): self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = 'foo' data = 'foo'
else: else:
@ -293,12 +267,160 @@ def _test_gather_object_dist(device):
output = dist.gather_object(data, dst=0) output = dist.gather_object(data, dst=0)
if dist.get_rank() == 0: if dist.get_rank() == 0:
assert output == ['foo', {1: 2}] self.assertEqual(output, ['foo', {1: 2}])
else: else:
assert output is None self.assertIsNone(output)
def _test_collect_results_dist(device): @unittest.skipIf(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
class TestDistWithNCCLBackend(MultiProcessTestCase):
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group(
backend='nccl', rank=rank, world_size=world_size)
def setUp(self):
super().setUp()
self._spawn_processes()
def test_all_reduce(self):
self._init_dist_env(self.rank, self.world_size)
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
if dist.get_rank() == 0:
data = torch.tensor([1, 2], dtype=tensor_type).cuda()
else:
data = torch.tensor([3, 4], dtype=tensor_type).cuda()
if reduce_op == 'sum':
expected = torch.tensor([4, 6], dtype=tensor_type).cuda()
else:
expected = torch.tensor([2, 3], dtype=tensor_type).cuda()
dist.all_reduce(data, reduce_op)
self.assertTrue(torch.allclose(data, expected))
def test_all_gather(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).cuda()
else:
data = torch.tensor([1, 2]).cuda()
expected = [torch.tensor([0, 1]).cuda(), torch.tensor([1, 2]).cuda()]
output = dist.all_gather(data)
self.assertTrue(
torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]))
def test_broadcast_dist(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).cuda()
else:
data = torch.tensor([1, 2]).cuda()
expected = torch.tensor([0, 1]).cuda()
dist.broadcast(data, 0)
assert torch.allclose(data, expected)
def test_sync_random_seed(self):
self._init_dist_env(self.rank, self.world_size)
with patch.object(
torch, 'tensor',
return_value=torch.tensor(1024)) as mock_tensor:
output = dist.sync_random_seed()
assert output == 1024
mock_tensor.assert_called()
def test_broadcast_object_list(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = ['foo', 12, {1: 2}]
else:
data = [None, None, None]
expected = ['foo', 12, {1: 2}]
dist.broadcast_object_list(data)
self.assertEqual(data, expected)
def test_all_reduce_dict(self):
self._init_dist_env(self.rank, self.world_size)
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
if dist.get_rank() == 0:
data = {
'key1': torch.tensor([0, 1], dtype=tensor_type).cuda(),
'key2': torch.tensor([1, 2], dtype=tensor_type).cuda(),
}
else:
data = {
'key1': torch.tensor([2, 3], dtype=tensor_type).cuda(),
'key2': torch.tensor([3, 4], dtype=tensor_type).cuda(),
}
if reduce_op == 'sum':
expected = {
'key1': torch.tensor([2, 4], dtype=tensor_type).cuda(),
'key2': torch.tensor([4, 6], dtype=tensor_type).cuda(),
}
else:
expected = {
'key1': torch.tensor([1, 2], dtype=tensor_type).cuda(),
'key2': torch.tensor([2, 3], dtype=tensor_type).cuda(),
}
dist.all_reduce_dict(data, reduce_op)
for key in data:
assert torch.allclose(data[key], expected[key])
# `torch.cat` in torch1.5 can not concatenate different types so we
# fallback to convert them all to float type.
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
if dist.get_rank() == 0:
data = {
'key1': torch.tensor([0, 1], dtype=torch.float32).cuda(),
'key2': torch.tensor([1, 2], dtype=torch.int32).cuda(),
}
else:
data = {
'key1': torch.tensor([2, 3], dtype=torch.float32).cuda(),
'key2': torch.tensor([3, 4], dtype=torch.int32).cuda(),
}
expected = {
'key1': torch.tensor([2, 4], dtype=torch.float32).cuda(),
'key2': torch.tensor([4, 6], dtype=torch.float32).cuda(),
}
dist.all_reduce_dict(data, 'sum')
for key in data:
assert torch.allclose(data[key], expected[key])
def test_all_gather_object(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = 'foo'
else:
data = {1: 2}
expected = ['foo', {1: 2}]
output = dist.all_gather_object(data)
self.assertEqual(output, expected)
def test_collect_results(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0: if dist.get_rank() == 0:
data = ['foo', {1: 2}] data = ['foo', {1: 2}]
else: else:
@ -311,9 +433,9 @@ def _test_collect_results_dist(device):
# test `device=cpu` # test `device=cpu`
output = dist.collect_results(data, size, device='cpu') output = dist.collect_results(data, size, device='cpu')
if dist.get_rank() == 0: if dist.get_rank() == 0:
assert output == expected self.assertEqual(output, expected)
else: else:
assert output is None self.assertIsNone(output)
# test `device=cpu` and `tmpdir is not None` # test `device=cpu` and `tmpdir is not None`
tmpdir = tempfile.mkdtemp() tmpdir = tempfile.mkdtemp()
@ -323,61 +445,17 @@ def _test_collect_results_dist(device):
output = dist.collect_results( output = dist.collect_results(
data, size, device='cpu', tmpdir=object_list[0]) data, size, device='cpu', tmpdir=object_list[0])
if dist.get_rank() == 0: if dist.get_rank() == 0:
assert output == expected self.assertEqual(output, expected)
else: else:
assert output is None self.assertIsNone(output)
if dist.get_rank() == 0: if dist.get_rank() == 0:
# object_list[0] will be removed by `dist.collect_results` # object_list[0] will be removed by `dist.collect_results`
assert not osp.exists(object_list[0]) self.assertFalse(osp.exists(object_list[0]))
# test `device=gpu` # test `device=gpu`
output = dist.collect_results(data, size, device='gpu') output = dist.collect_results(data, size, device='gpu')
if dist.get_rank() == 0: if dist.get_rank() == 0:
assert output == expected self.assertEqual(output, expected)
else: else:
assert output is None self.assertIsNone(output)
def test_non_distributed_env():
_test_all_reduce_non_dist()
_test_all_gather_non_dist()
_test_gather_non_dist()
_test_broadcast_non_dist()
_test_sync_random_seed_no_dist()
_test_broadcast_object_list_no_dist()
_test_all_reduce_dict_no_dist()
_test_all_gather_object_no_dist()
_test_gather_object_no_dist()
_test_collect_results_non_dist()
def test_gloo_backend():
functions_to_test = [
_test_all_reduce_dist,
_test_all_gather_dist,
_test_gather_dist,
_test_broadcast_dist,
_test_sync_random_seed_dist,
_test_broadcast_object_list_dist,
_test_all_reduce_dict_dist,
_test_all_gather_object_dist,
_test_gather_object_dist,
]
main(functions_to_test, backend='gloo')
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
def test_nccl_backend():
functions_to_test = [
_test_all_reduce_dist,
_test_all_gather_dist,
_test_broadcast_dist,
_test_sync_random_seed_dist,
_test_broadcast_object_list_dist,
_test_all_reduce_dict_dist,
_test_all_gather_object_dist,
_test_collect_results_dist,
]
main(functions_to_test, backend='nccl')

View File

@ -1,43 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
import unittest
from unittest import TestCase
import pytest
import torch import torch
import torch.distributed as torch_dist import torch.distributed as torch_dist
import torch.multiprocessing as mp
import mmengine.dist as dist import mmengine.dist as dist
from mmengine.testing._internal import MultiProcessTestCase
def _test_get_backend_non_dist(): class TestUtils(TestCase):
assert dist.get_backend() is None
def test_get_backend(self):
self.assertIsNone(dist.get_backend())
def _test_get_world_size_non_dist(): def test_get_world_size(self):
assert dist.get_world_size() == 1 self.assertEqual(dist.get_world_size(), 1)
def test_get_rank(self):
self.assertEqual(dist.get_rank(), 0)
def _test_get_rank_non_dist(): def test_local_size(self):
assert dist.get_rank() == 0 self.assertEqual(dist.get_local_size(), 1)
def test_local_rank(self):
self.assertEqual(dist.get_local_rank(), 0)
def _test_local_size_non_dist(): def test_get_dist_info(self):
assert dist.get_local_size() == 1 self.assertEqual(dist.get_dist_info(), (0, 1))
def test_is_main_process(self):
self.assertTrue(dist.is_main_process())
def _test_local_rank_non_dist(): def test_master_only(self):
assert dist.get_local_rank() == 0
def _test_get_dist_info_non_dist():
assert dist.get_dist_info() == (0, 1)
def _test_is_main_process_non_dist():
assert dist.is_main_process()
def _test_master_only_non_dist():
@dist.master_only @dist.master_only
def fun(): def fun():
@ -45,77 +41,66 @@ def _test_master_only_non_dist():
fun() fun()
def test_barrier(self):
def _test_barrier_non_dist():
dist.barrier() # nothing is done dist.barrier() # nothing is done
def init_process(rank, world_size, functions, backend='gloo'): class TestUtilsWithGLOOBackend(MultiProcessTestCase):
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29501' os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank) os.environ['RANK'] = str(rank)
if backend == 'nccl':
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group( torch_dist.init_process_group(
backend=backend, rank=rank, world_size=world_size) backend='gloo', rank=rank, world_size=world_size)
dist.init_local_group(0, world_size) dist.init_local_group(0, world_size)
for func in functions: def setUp(self):
func() super().setUp()
self._spawn_processes()
def test_get_backend(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def main(functions, world_size=2, backend='gloo'): def test_get_world_size(self):
try: self._init_dist_env(self.rank, self.world_size)
mp.spawn( self.assertEqual(dist.get_world_size(), 2)
init_process,
args=(world_size, functions, backend),
nprocs=world_size)
except Exception:
pytest.fail('error')
def test_get_rank(self):
def _test_get_backend_dist(): self._init_dist_env(self.rank, self.world_size)
assert dist.get_backend() == torch_dist.get_backend()
def _test_get_world_size_dist():
assert dist.get_world_size() == 2
def _test_get_rank_dist():
if torch_dist.get_rank() == 0: if torch_dist.get_rank() == 0:
assert dist.get_rank() == 0 self.assertEqual(dist.get_rank(), 0)
else: else:
assert dist.get_rank() == 1 self.assertEqual(dist.get_rank(), 1)
def test_local_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_local_size(), 2)
def _test_local_size_dist(): def test_local_rank(self):
assert dist.get_local_size() == 2 self._init_dist_env(self.rank, self.world_size)
self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
def test_get_dist_info(self):
def _test_local_rank_dist(): self._init_dist_env(self.rank, self.world_size)
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank()
def _test_get_dist_info_dist():
if dist.get_rank() == 0: if dist.get_rank() == 0:
assert dist.get_dist_info() == (0, 2) self.assertEqual(dist.get_dist_info(), (0, 2))
else: else:
assert dist.get_dist_info() == (1, 2) self.assertEqual(dist.get_dist_info(), (1, 2))
def test_is_main_process(self):
def _test_is_main_process_dist(): self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0: if dist.get_rank() == 0:
assert dist.is_main_process() self.assertTrue(dist.is_main_process())
else: else:
assert not dist.is_main_process() self.assertFalse(dist.is_main_process())
def test_master_only(self):
def _test_master_only_dist(): self._init_dist_env(self.rank, self.world_size)
@dist.master_only @dist.master_only
def fun(): def fun():
@ -124,35 +109,69 @@ def _test_master_only_dist():
fun() fun()
def test_non_distributed_env(): @unittest.skipIf(
_test_get_backend_non_dist()
_test_get_world_size_non_dist()
_test_get_rank_non_dist()
_test_local_size_non_dist()
_test_local_rank_non_dist()
_test_get_dist_info_non_dist()
_test_is_main_process_non_dist()
_test_master_only_non_dist()
_test_barrier_non_dist()
functions_to_test = [
_test_get_backend_dist,
_test_get_world_size_dist,
_test_get_rank_dist,
_test_local_size_dist,
_test_local_rank_dist,
_test_get_dist_info_dist,
_test_is_main_process_dist,
_test_master_only_dist,
]
def test_gloo_backend():
main(functions_to_test)
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
def test_nccl_backend(): class TestUtilsWithNCCLBackend(MultiProcessTestCase):
main(functions_to_test, backend='nccl')
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group(
backend='nccl', rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
def setUp(self):
super().setUp()
self._spawn_processes()
def test_get_backend(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def test_get_world_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_world_size(), 2)
def test_get_rank(self):
self._init_dist_env(self.rank, self.world_size)
if torch_dist.get_rank() == 0:
self.assertEqual(dist.get_rank(), 0)
else:
self.assertEqual(dist.get_rank(), 1)
def test_local_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_local_size(), 2)
def test_local_rank(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
def test_get_dist_info(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertEqual(dist.get_dist_info(), (0, 2))
else:
self.assertEqual(dist.get_dist_info(), (1, 2))
def test_is_main_process(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertTrue(dist.is_main_process())
else:
self.assertFalse(dist.is_main_process())
def test_master_only(self):
self._init_dist_env(self.rank, self.world_size)
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()