mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhancement] Refactor the unit tests of dist module with MultiProcessTestCase (#138)
* [Enhancement] Provide MultiProcessTestCase to test distributed related modules * remove debugging info * add timeout property * [Enhancement] Refactor the unit tests of dist module with MultiProcessTestCase * minor refinement * minor fix
This commit is contained in:
parent
2d80367893
commit
50650e0b7a
@ -2,64 +2,62 @@
|
|||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from unittest import TestCase
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as torch_dist
|
import torch.distributed as torch_dist
|
||||||
import torch.multiprocessing as mp
|
|
||||||
|
|
||||||
import mmengine.dist as dist
|
import mmengine.dist as dist
|
||||||
from mmengine.dist.dist import sync_random_seed
|
from mmengine.dist.dist import sync_random_seed
|
||||||
|
from mmengine.testing._internal import MultiProcessTestCase
|
||||||
from mmengine.utils import TORCH_VERSION, digit_version
|
from mmengine.utils import TORCH_VERSION, digit_version
|
||||||
|
|
||||||
|
|
||||||
def _test_all_reduce_non_dist():
|
class TestDist(TestCase):
|
||||||
|
"""Test dist module in non-distributed environment."""
|
||||||
|
|
||||||
|
def test_all_reduce(self):
|
||||||
data = torch.arange(2, dtype=torch.int64)
|
data = torch.arange(2, dtype=torch.int64)
|
||||||
expected = torch.arange(2, dtype=torch.int64)
|
expected = torch.arange(2, dtype=torch.int64)
|
||||||
dist.all_reduce(data)
|
dist.all_reduce(data)
|
||||||
assert torch.allclose(data, expected)
|
self.assertTrue(torch.allclose(data, expected))
|
||||||
|
|
||||||
|
def test_all_gather(self):
|
||||||
def _test_all_gather_non_dist():
|
|
||||||
data = torch.arange(2, dtype=torch.int64)
|
data = torch.arange(2, dtype=torch.int64)
|
||||||
expected = torch.arange(2, dtype=torch.int64)
|
expected = torch.arange(2, dtype=torch.int64)
|
||||||
output = dist.all_gather(data)
|
output = dist.all_gather(data)
|
||||||
assert torch.allclose(output[0], expected)
|
self.assertTrue(torch.allclose(output[0], expected))
|
||||||
|
|
||||||
|
def test_gather(self):
|
||||||
def _test_gather_non_dist():
|
|
||||||
data = torch.arange(2, dtype=torch.int64)
|
data = torch.arange(2, dtype=torch.int64)
|
||||||
expected = torch.arange(2, dtype=torch.int64)
|
expected = torch.arange(2, dtype=torch.int64)
|
||||||
output = dist.gather(data)
|
output = dist.gather(data)
|
||||||
assert torch.allclose(output[0], expected)
|
self.assertTrue(torch.allclose(output[0], expected))
|
||||||
|
|
||||||
|
def test_broadcast(self):
|
||||||
def _test_broadcast_non_dist():
|
|
||||||
data = torch.arange(2, dtype=torch.int64)
|
data = torch.arange(2, dtype=torch.int64)
|
||||||
expected = torch.arange(2, dtype=torch.int64)
|
expected = torch.arange(2, dtype=torch.int64)
|
||||||
dist.broadcast(data)
|
dist.broadcast(data)
|
||||||
assert torch.allclose(data, expected)
|
self.assertTrue(torch.allclose(data, expected))
|
||||||
|
|
||||||
|
@patch('numpy.random.randint', return_value=10)
|
||||||
|
def test_sync_random_seed(self, mock):
|
||||||
|
self.assertEqual(sync_random_seed(), 10)
|
||||||
|
|
||||||
@patch('numpy.random.randint', return_value=10)
|
def test_broadcast_object_list(self):
|
||||||
def _test_sync_random_seed_no_dist(mock):
|
with self.assertRaises(AssertionError):
|
||||||
assert sync_random_seed() == 10
|
|
||||||
|
|
||||||
|
|
||||||
def _test_broadcast_object_list_no_dist():
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
# input should be list of object
|
# input should be list of object
|
||||||
dist.broadcast_object_list('foo')
|
dist.broadcast_object_list('foo')
|
||||||
|
|
||||||
data = ['foo', 12, {1: 2}]
|
data = ['foo', 12, {1: 2}]
|
||||||
expected = ['foo', 12, {1: 2}]
|
expected = ['foo', 12, {1: 2}]
|
||||||
dist.broadcast_object_list(data)
|
dist.broadcast_object_list(data)
|
||||||
assert data == expected
|
self.assertEqual(data, expected)
|
||||||
|
|
||||||
|
def test_all_reduce_dict(self):
|
||||||
def _test_all_reduce_dict_no_dist():
|
with self.assertRaises(AssertionError):
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
# input should be dict
|
# input should be dict
|
||||||
dist.all_reduce_dict('foo')
|
dist.all_reduce_dict('foo')
|
||||||
|
|
||||||
@ -73,173 +71,149 @@ def _test_all_reduce_dict_no_dist():
|
|||||||
}
|
}
|
||||||
dist.all_reduce_dict(data)
|
dist.all_reduce_dict(data)
|
||||||
for key in data:
|
for key in data:
|
||||||
assert torch.allclose(data[key], expected[key])
|
self.assertTrue(torch.allclose(data[key], expected[key]))
|
||||||
|
|
||||||
|
def test_all_gather_object(self):
|
||||||
def _test_all_gather_object_no_dist():
|
|
||||||
data = 'foo'
|
data = 'foo'
|
||||||
expected = 'foo'
|
expected = 'foo'
|
||||||
gather_objects = dist.all_gather_object(data)
|
gather_objects = dist.all_gather_object(data)
|
||||||
assert gather_objects[0] == expected
|
self.assertEqual(gather_objects[0], expected)
|
||||||
|
|
||||||
|
def test_gather_object(self):
|
||||||
def _test_gather_object_no_dist():
|
|
||||||
data = 'foo'
|
data = 'foo'
|
||||||
expected = 'foo'
|
expected = 'foo'
|
||||||
gather_objects = dist.gather_object(data)
|
gather_objects = dist.gather_object(data)
|
||||||
assert gather_objects[0] == expected
|
self.assertEqual(gather_objects[0], expected)
|
||||||
|
|
||||||
|
def test_collect_results(self):
|
||||||
def _test_collect_results_non_dist():
|
|
||||||
data = ['foo', {1: 2}]
|
data = ['foo', {1: 2}]
|
||||||
size = 2
|
size = 2
|
||||||
expected = ['foo', {1: 2}]
|
expected = ['foo', {1: 2}]
|
||||||
|
|
||||||
# test `device=cpu`
|
# test `device=cpu`
|
||||||
output = dist.collect_results(data, size, device='cpu')
|
output = dist.collect_results(data, size, device='cpu')
|
||||||
assert output == expected
|
self.assertEqual(output, expected)
|
||||||
|
|
||||||
# test `device=gpu`
|
# test `device=gpu`
|
||||||
output = dist.collect_results(data, size, device='cpu')
|
output = dist.collect_results(data, size, device='gpu')
|
||||||
assert output == expected
|
self.assertEqual(output, expected)
|
||||||
|
|
||||||
|
|
||||||
def init_process(rank, world_size, functions, backend='gloo'):
|
class TestDistWithGLOOBackend(MultiProcessTestCase):
|
||||||
|
|
||||||
|
def _init_dist_env(self, rank, world_size):
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
os.environ['MASTER_PORT'] = '29505'
|
os.environ['MASTER_PORT'] = '29505'
|
||||||
os.environ['RANK'] = str(rank)
|
os.environ['RANK'] = str(rank)
|
||||||
|
|
||||||
if backend == 'nccl':
|
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
torch.cuda.set_device(rank % num_gpus)
|
|
||||||
device = 'cuda'
|
|
||||||
else:
|
|
||||||
device = 'cpu'
|
|
||||||
|
|
||||||
torch_dist.init_process_group(
|
torch_dist.init_process_group(
|
||||||
backend=backend, rank=rank, world_size=world_size)
|
backend='gloo', rank=rank, world_size=world_size)
|
||||||
|
|
||||||
for func in functions:
|
def setUp(self):
|
||||||
func(device)
|
super().setUp()
|
||||||
|
self._spawn_processes()
|
||||||
|
|
||||||
|
def test_all_reduce(self):
|
||||||
def main(functions, world_size=2, backend='gloo'):
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
try:
|
|
||||||
mp.spawn(
|
|
||||||
init_process,
|
|
||||||
args=(world_size, functions, backend),
|
|
||||||
nprocs=world_size)
|
|
||||||
except Exception:
|
|
||||||
pytest.fail(f'{backend} failed')
|
|
||||||
|
|
||||||
|
|
||||||
def _test_all_reduce_dist(device):
|
|
||||||
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
||||||
['sum', 'mean']):
|
['sum', 'mean']):
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
data = torch.tensor([1, 2], dtype=tensor_type).to(device)
|
data = torch.tensor([1, 2], dtype=tensor_type)
|
||||||
else:
|
else:
|
||||||
data = torch.tensor([3, 4], dtype=tensor_type).to(device)
|
data = torch.tensor([3, 4], dtype=tensor_type)
|
||||||
|
|
||||||
if reduce_op == 'sum':
|
if reduce_op == 'sum':
|
||||||
expected = torch.tensor([4, 6], dtype=tensor_type).to(device)
|
expected = torch.tensor([4, 6], dtype=tensor_type)
|
||||||
else:
|
else:
|
||||||
expected = torch.tensor([2, 3], dtype=tensor_type).to(device)
|
expected = torch.tensor([2, 3], dtype=tensor_type)
|
||||||
|
|
||||||
dist.all_reduce(data, reduce_op)
|
dist.all_reduce(data, reduce_op)
|
||||||
assert torch.allclose(data, expected)
|
self.assertTrue(torch.allclose(data, expected))
|
||||||
|
|
||||||
|
def test_all_gather(self):
|
||||||
def _test_all_gather_dist(device):
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
data = torch.tensor([0, 1]).to(device)
|
data = torch.tensor([0, 1])
|
||||||
else:
|
else:
|
||||||
data = torch.tensor([1, 2]).to(device)
|
data = torch.tensor([1, 2])
|
||||||
|
|
||||||
expected = [
|
expected = [torch.tensor([0, 1]), torch.tensor([1, 2])]
|
||||||
torch.tensor([0, 1]).to(device),
|
|
||||||
torch.tensor([1, 2]).to(device)
|
|
||||||
]
|
|
||||||
|
|
||||||
output = dist.all_gather(data)
|
output = dist.all_gather(data)
|
||||||
assert torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])
|
self.assertTrue(
|
||||||
|
torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]))
|
||||||
|
|
||||||
|
def test_gather(self):
|
||||||
def _test_gather_dist(device):
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
data = torch.tensor([0, 1]).to(device)
|
data = torch.tensor([0, 1])
|
||||||
else:
|
else:
|
||||||
data = torch.tensor([1, 2]).to(device)
|
data = torch.tensor([1, 2])
|
||||||
|
|
||||||
output = dist.gather(data)
|
output = dist.gather(data)
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
expected = [
|
expected = [torch.tensor([0, 1]), torch.tensor([1, 2])]
|
||||||
torch.tensor([0, 1]).to(device),
|
|
||||||
torch.tensor([1, 2]).to(device)
|
|
||||||
]
|
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
assert torch.allclose(output[i], expected[i])
|
assert torch.allclose(output[i], expected[i])
|
||||||
else:
|
else:
|
||||||
assert output == []
|
assert output == []
|
||||||
|
|
||||||
|
def test_broadcast_dist(self):
|
||||||
def _test_broadcast_dist(device):
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
data = torch.tensor([0, 1]).to(device)
|
data = torch.tensor([0, 1])
|
||||||
else:
|
else:
|
||||||
data = torch.tensor([1, 2]).to(device)
|
data = torch.tensor([1, 2])
|
||||||
|
|
||||||
expected = torch.tensor([0, 1]).to(device)
|
expected = torch.tensor([0, 1])
|
||||||
dist.broadcast(data, 0)
|
dist.broadcast(data, 0)
|
||||||
assert torch.allclose(data, expected)
|
assert torch.allclose(data, expected)
|
||||||
|
|
||||||
|
def test_sync_random_seed(self):
|
||||||
def _test_sync_random_seed_dist(device):
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
with patch.object(
|
with patch.object(
|
||||||
torch, 'tensor', return_value=torch.tensor(1024)) as mock_tensor:
|
torch, 'tensor',
|
||||||
|
return_value=torch.tensor(1024)) as mock_tensor:
|
||||||
output = dist.sync_random_seed()
|
output = dist.sync_random_seed()
|
||||||
assert output == 1024
|
assert output == 1024
|
||||||
mock_tensor.assert_called()
|
mock_tensor.assert_called()
|
||||||
|
|
||||||
|
def test_broadcast_object_list(self):
|
||||||
def _test_broadcast_object_list_dist(device):
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
data = ['foo', 12, {1: 2}]
|
data = ['foo', 12, {1: 2}]
|
||||||
else:
|
else:
|
||||||
data = [None, None, None]
|
data = [None, None, None]
|
||||||
|
|
||||||
expected = ['foo', 12, {1: 2}]
|
expected = ['foo', 12, {1: 2}]
|
||||||
|
|
||||||
dist.broadcast_object_list(data)
|
dist.broadcast_object_list(data)
|
||||||
|
self.assertEqual(data, expected)
|
||||||
|
|
||||||
assert data == expected
|
def test_all_reduce_dict(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
|
||||||
def _test_all_reduce_dict_dist(device):
|
|
||||||
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
||||||
['sum', 'mean']):
|
['sum', 'mean']):
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
data = {
|
data = {
|
||||||
'key1': torch.tensor([0, 1], dtype=tensor_type).to(device),
|
'key1': torch.tensor([0, 1], dtype=tensor_type),
|
||||||
'key2': torch.tensor([1, 2], dtype=tensor_type).to(device)
|
'key2': torch.tensor([1, 2], dtype=tensor_type),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
data = {
|
data = {
|
||||||
'key1': torch.tensor([2, 3], dtype=tensor_type).to(device),
|
'key1': torch.tensor([2, 3], dtype=tensor_type),
|
||||||
'key2': torch.tensor([3, 4], dtype=tensor_type).to(device)
|
'key2': torch.tensor([3, 4], dtype=tensor_type),
|
||||||
}
|
}
|
||||||
|
|
||||||
if reduce_op == 'sum':
|
if reduce_op == 'sum':
|
||||||
expected = {
|
expected = {
|
||||||
'key1': torch.tensor([2, 4], dtype=tensor_type).to(device),
|
'key1': torch.tensor([2, 4], dtype=tensor_type),
|
||||||
'key2': torch.tensor([4, 6], dtype=tensor_type).to(device)
|
'key2': torch.tensor([4, 6], dtype=tensor_type),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
expected = {
|
expected = {
|
||||||
'key1': torch.tensor([1, 2], dtype=tensor_type).to(device),
|
'key1': torch.tensor([1, 2], dtype=tensor_type),
|
||||||
'key2': torch.tensor([2, 3], dtype=tensor_type).to(device)
|
'key2': torch.tensor([2, 3], dtype=tensor_type),
|
||||||
}
|
}
|
||||||
|
|
||||||
dist.all_reduce_dict(data, reduce_op)
|
dist.all_reduce_dict(data, reduce_op)
|
||||||
@ -252,18 +226,18 @@ def _test_all_reduce_dict_dist(device):
|
|||||||
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
|
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
data = {
|
data = {
|
||||||
'key1': torch.tensor([0, 1], dtype=torch.float32).to(device),
|
'key1': torch.tensor([0, 1], dtype=torch.float32),
|
||||||
'key2': torch.tensor([1, 2], dtype=torch.int32).to(device)
|
'key2': torch.tensor([1, 2], dtype=torch.int32)
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
data = {
|
data = {
|
||||||
'key1': torch.tensor([2, 3], dtype=torch.float32).to(device),
|
'key1': torch.tensor([2, 3], dtype=torch.float32),
|
||||||
'key2': torch.tensor([3, 4], dtype=torch.int32).to(device)
|
'key2': torch.tensor([3, 4], dtype=torch.int32),
|
||||||
}
|
}
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
'key1': torch.tensor([2, 4], dtype=torch.float32).to(device),
|
'key1': torch.tensor([2, 4], dtype=torch.float32),
|
||||||
'key2': torch.tensor([4, 6], dtype=torch.float32).to(device)
|
'key2': torch.tensor([4, 6], dtype=torch.float32),
|
||||||
}
|
}
|
||||||
|
|
||||||
dist.all_reduce_dict(data, 'sum')
|
dist.all_reduce_dict(data, 'sum')
|
||||||
@ -271,8 +245,8 @@ def _test_all_reduce_dict_dist(device):
|
|||||||
for key in data:
|
for key in data:
|
||||||
assert torch.allclose(data[key], expected[key])
|
assert torch.allclose(data[key], expected[key])
|
||||||
|
|
||||||
|
def test_all_gather_object(self):
|
||||||
def _test_all_gather_object_dist(device):
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
data = 'foo'
|
data = 'foo'
|
||||||
else:
|
else:
|
||||||
@ -281,10 +255,10 @@ def _test_all_gather_object_dist(device):
|
|||||||
expected = ['foo', {1: 2}]
|
expected = ['foo', {1: 2}]
|
||||||
output = dist.all_gather_object(data)
|
output = dist.all_gather_object(data)
|
||||||
|
|
||||||
assert output == expected
|
self.assertEqual(output, expected)
|
||||||
|
|
||||||
|
def test_gather_object(self):
|
||||||
def _test_gather_object_dist(device):
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
data = 'foo'
|
data = 'foo'
|
||||||
else:
|
else:
|
||||||
@ -293,12 +267,160 @@ def _test_gather_object_dist(device):
|
|||||||
output = dist.gather_object(data, dst=0)
|
output = dist.gather_object(data, dst=0)
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
assert output == ['foo', {1: 2}]
|
self.assertEqual(output, ['foo', {1: 2}])
|
||||||
else:
|
else:
|
||||||
assert output is None
|
self.assertIsNone(output)
|
||||||
|
|
||||||
|
|
||||||
def _test_collect_results_dist(device):
|
@unittest.skipIf(
|
||||||
|
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
||||||
|
class TestDistWithNCCLBackend(MultiProcessTestCase):
|
||||||
|
|
||||||
|
def _init_dist_env(self, rank, world_size):
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
|
os.environ['MASTER_PORT'] = '29505'
|
||||||
|
os.environ['RANK'] = str(rank)
|
||||||
|
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(rank % num_gpus)
|
||||||
|
torch_dist.init_process_group(
|
||||||
|
backend='nccl', rank=rank, world_size=world_size)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self._spawn_processes()
|
||||||
|
|
||||||
|
def test_all_reduce(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
||||||
|
['sum', 'mean']):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = torch.tensor([1, 2], dtype=tensor_type).cuda()
|
||||||
|
else:
|
||||||
|
data = torch.tensor([3, 4], dtype=tensor_type).cuda()
|
||||||
|
|
||||||
|
if reduce_op == 'sum':
|
||||||
|
expected = torch.tensor([4, 6], dtype=tensor_type).cuda()
|
||||||
|
else:
|
||||||
|
expected = torch.tensor([2, 3], dtype=tensor_type).cuda()
|
||||||
|
|
||||||
|
dist.all_reduce(data, reduce_op)
|
||||||
|
self.assertTrue(torch.allclose(data, expected))
|
||||||
|
|
||||||
|
def test_all_gather(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = torch.tensor([0, 1]).cuda()
|
||||||
|
else:
|
||||||
|
data = torch.tensor([1, 2]).cuda()
|
||||||
|
|
||||||
|
expected = [torch.tensor([0, 1]).cuda(), torch.tensor([1, 2]).cuda()]
|
||||||
|
|
||||||
|
output = dist.all_gather(data)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]))
|
||||||
|
|
||||||
|
def test_broadcast_dist(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = torch.tensor([0, 1]).cuda()
|
||||||
|
else:
|
||||||
|
data = torch.tensor([1, 2]).cuda()
|
||||||
|
|
||||||
|
expected = torch.tensor([0, 1]).cuda()
|
||||||
|
dist.broadcast(data, 0)
|
||||||
|
assert torch.allclose(data, expected)
|
||||||
|
|
||||||
|
def test_sync_random_seed(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
with patch.object(
|
||||||
|
torch, 'tensor',
|
||||||
|
return_value=torch.tensor(1024)) as mock_tensor:
|
||||||
|
output = dist.sync_random_seed()
|
||||||
|
assert output == 1024
|
||||||
|
mock_tensor.assert_called()
|
||||||
|
|
||||||
|
def test_broadcast_object_list(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = ['foo', 12, {1: 2}]
|
||||||
|
else:
|
||||||
|
data = [None, None, None]
|
||||||
|
|
||||||
|
expected = ['foo', 12, {1: 2}]
|
||||||
|
dist.broadcast_object_list(data)
|
||||||
|
self.assertEqual(data, expected)
|
||||||
|
|
||||||
|
def test_all_reduce_dict(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
||||||
|
['sum', 'mean']):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = {
|
||||||
|
'key1': torch.tensor([0, 1], dtype=tensor_type).cuda(),
|
||||||
|
'key2': torch.tensor([1, 2], dtype=tensor_type).cuda(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
data = {
|
||||||
|
'key1': torch.tensor([2, 3], dtype=tensor_type).cuda(),
|
||||||
|
'key2': torch.tensor([3, 4], dtype=tensor_type).cuda(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if reduce_op == 'sum':
|
||||||
|
expected = {
|
||||||
|
'key1': torch.tensor([2, 4], dtype=tensor_type).cuda(),
|
||||||
|
'key2': torch.tensor([4, 6], dtype=tensor_type).cuda(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
expected = {
|
||||||
|
'key1': torch.tensor([1, 2], dtype=tensor_type).cuda(),
|
||||||
|
'key2': torch.tensor([2, 3], dtype=tensor_type).cuda(),
|
||||||
|
}
|
||||||
|
|
||||||
|
dist.all_reduce_dict(data, reduce_op)
|
||||||
|
|
||||||
|
for key in data:
|
||||||
|
assert torch.allclose(data[key], expected[key])
|
||||||
|
|
||||||
|
# `torch.cat` in torch1.5 can not concatenate different types so we
|
||||||
|
# fallback to convert them all to float type.
|
||||||
|
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = {
|
||||||
|
'key1': torch.tensor([0, 1], dtype=torch.float32).cuda(),
|
||||||
|
'key2': torch.tensor([1, 2], dtype=torch.int32).cuda(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
data = {
|
||||||
|
'key1': torch.tensor([2, 3], dtype=torch.float32).cuda(),
|
||||||
|
'key2': torch.tensor([3, 4], dtype=torch.int32).cuda(),
|
||||||
|
}
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
'key1': torch.tensor([2, 4], dtype=torch.float32).cuda(),
|
||||||
|
'key2': torch.tensor([4, 6], dtype=torch.float32).cuda(),
|
||||||
|
}
|
||||||
|
|
||||||
|
dist.all_reduce_dict(data, 'sum')
|
||||||
|
|
||||||
|
for key in data:
|
||||||
|
assert torch.allclose(data[key], expected[key])
|
||||||
|
|
||||||
|
def test_all_gather_object(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = 'foo'
|
||||||
|
else:
|
||||||
|
data = {1: 2}
|
||||||
|
|
||||||
|
expected = ['foo', {1: 2}]
|
||||||
|
output = dist.all_gather_object(data)
|
||||||
|
|
||||||
|
self.assertEqual(output, expected)
|
||||||
|
|
||||||
|
def test_collect_results(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
data = ['foo', {1: 2}]
|
data = ['foo', {1: 2}]
|
||||||
else:
|
else:
|
||||||
@ -311,9 +433,9 @@ def _test_collect_results_dist(device):
|
|||||||
# test `device=cpu`
|
# test `device=cpu`
|
||||||
output = dist.collect_results(data, size, device='cpu')
|
output = dist.collect_results(data, size, device='cpu')
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
assert output == expected
|
self.assertEqual(output, expected)
|
||||||
else:
|
else:
|
||||||
assert output is None
|
self.assertIsNone(output)
|
||||||
|
|
||||||
# test `device=cpu` and `tmpdir is not None`
|
# test `device=cpu` and `tmpdir is not None`
|
||||||
tmpdir = tempfile.mkdtemp()
|
tmpdir = tempfile.mkdtemp()
|
||||||
@ -323,61 +445,17 @@ def _test_collect_results_dist(device):
|
|||||||
output = dist.collect_results(
|
output = dist.collect_results(
|
||||||
data, size, device='cpu', tmpdir=object_list[0])
|
data, size, device='cpu', tmpdir=object_list[0])
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
assert output == expected
|
self.assertEqual(output, expected)
|
||||||
else:
|
else:
|
||||||
assert output is None
|
self.assertIsNone(output)
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
# object_list[0] will be removed by `dist.collect_results`
|
# object_list[0] will be removed by `dist.collect_results`
|
||||||
assert not osp.exists(object_list[0])
|
self.assertFalse(osp.exists(object_list[0]))
|
||||||
|
|
||||||
# test `device=gpu`
|
# test `device=gpu`
|
||||||
output = dist.collect_results(data, size, device='gpu')
|
output = dist.collect_results(data, size, device='gpu')
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
assert output == expected
|
self.assertEqual(output, expected)
|
||||||
else:
|
else:
|
||||||
assert output is None
|
self.assertIsNone(output)
|
||||||
|
|
||||||
|
|
||||||
def test_non_distributed_env():
|
|
||||||
_test_all_reduce_non_dist()
|
|
||||||
_test_all_gather_non_dist()
|
|
||||||
_test_gather_non_dist()
|
|
||||||
_test_broadcast_non_dist()
|
|
||||||
_test_sync_random_seed_no_dist()
|
|
||||||
_test_broadcast_object_list_no_dist()
|
|
||||||
_test_all_reduce_dict_no_dist()
|
|
||||||
_test_all_gather_object_no_dist()
|
|
||||||
_test_gather_object_no_dist()
|
|
||||||
_test_collect_results_non_dist()
|
|
||||||
|
|
||||||
|
|
||||||
def test_gloo_backend():
|
|
||||||
functions_to_test = [
|
|
||||||
_test_all_reduce_dist,
|
|
||||||
_test_all_gather_dist,
|
|
||||||
_test_gather_dist,
|
|
||||||
_test_broadcast_dist,
|
|
||||||
_test_sync_random_seed_dist,
|
|
||||||
_test_broadcast_object_list_dist,
|
|
||||||
_test_all_reduce_dict_dist,
|
|
||||||
_test_all_gather_object_dist,
|
|
||||||
_test_gather_object_dist,
|
|
||||||
]
|
|
||||||
main(functions_to_test, backend='gloo')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
|
||||||
def test_nccl_backend():
|
|
||||||
functions_to_test = [
|
|
||||||
_test_all_reduce_dist,
|
|
||||||
_test_all_gather_dist,
|
|
||||||
_test_broadcast_dist,
|
|
||||||
_test_sync_random_seed_dist,
|
|
||||||
_test_broadcast_object_list_dist,
|
|
||||||
_test_all_reduce_dict_dist,
|
|
||||||
_test_all_gather_object_dist,
|
|
||||||
_test_collect_results_dist,
|
|
||||||
]
|
|
||||||
main(functions_to_test, backend='nccl')
|
|
||||||
|
@ -1,43 +1,39 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as torch_dist
|
import torch.distributed as torch_dist
|
||||||
import torch.multiprocessing as mp
|
|
||||||
|
|
||||||
import mmengine.dist as dist
|
import mmengine.dist as dist
|
||||||
|
from mmengine.testing._internal import MultiProcessTestCase
|
||||||
|
|
||||||
|
|
||||||
def _test_get_backend_non_dist():
|
class TestUtils(TestCase):
|
||||||
assert dist.get_backend() is None
|
|
||||||
|
|
||||||
|
def test_get_backend(self):
|
||||||
|
self.assertIsNone(dist.get_backend())
|
||||||
|
|
||||||
def _test_get_world_size_non_dist():
|
def test_get_world_size(self):
|
||||||
assert dist.get_world_size() == 1
|
self.assertEqual(dist.get_world_size(), 1)
|
||||||
|
|
||||||
|
def test_get_rank(self):
|
||||||
|
self.assertEqual(dist.get_rank(), 0)
|
||||||
|
|
||||||
def _test_get_rank_non_dist():
|
def test_local_size(self):
|
||||||
assert dist.get_rank() == 0
|
self.assertEqual(dist.get_local_size(), 1)
|
||||||
|
|
||||||
|
def test_local_rank(self):
|
||||||
|
self.assertEqual(dist.get_local_rank(), 0)
|
||||||
|
|
||||||
def _test_local_size_non_dist():
|
def test_get_dist_info(self):
|
||||||
assert dist.get_local_size() == 1
|
self.assertEqual(dist.get_dist_info(), (0, 1))
|
||||||
|
|
||||||
|
def test_is_main_process(self):
|
||||||
|
self.assertTrue(dist.is_main_process())
|
||||||
|
|
||||||
def _test_local_rank_non_dist():
|
def test_master_only(self):
|
||||||
assert dist.get_local_rank() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def _test_get_dist_info_non_dist():
|
|
||||||
assert dist.get_dist_info() == (0, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def _test_is_main_process_non_dist():
|
|
||||||
assert dist.is_main_process()
|
|
||||||
|
|
||||||
|
|
||||||
def _test_master_only_non_dist():
|
|
||||||
|
|
||||||
@dist.master_only
|
@dist.master_only
|
||||||
def fun():
|
def fun():
|
||||||
@ -45,77 +41,66 @@ def _test_master_only_non_dist():
|
|||||||
|
|
||||||
fun()
|
fun()
|
||||||
|
|
||||||
|
def test_barrier(self):
|
||||||
def _test_barrier_non_dist():
|
|
||||||
dist.barrier() # nothing is done
|
dist.barrier() # nothing is done
|
||||||
|
|
||||||
|
|
||||||
def init_process(rank, world_size, functions, backend='gloo'):
|
class TestUtilsWithGLOOBackend(MultiProcessTestCase):
|
||||||
|
|
||||||
|
def _init_dist_env(self, rank, world_size):
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
os.environ['MASTER_PORT'] = '29501'
|
os.environ['MASTER_PORT'] = '29505'
|
||||||
os.environ['RANK'] = str(rank)
|
os.environ['RANK'] = str(rank)
|
||||||
|
|
||||||
if backend == 'nccl':
|
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
torch.cuda.set_device(rank % num_gpus)
|
|
||||||
|
|
||||||
torch_dist.init_process_group(
|
torch_dist.init_process_group(
|
||||||
backend=backend, rank=rank, world_size=world_size)
|
backend='gloo', rank=rank, world_size=world_size)
|
||||||
dist.init_local_group(0, world_size)
|
dist.init_local_group(0, world_size)
|
||||||
|
|
||||||
for func in functions:
|
def setUp(self):
|
||||||
func()
|
super().setUp()
|
||||||
|
self._spawn_processes()
|
||||||
|
|
||||||
|
def test_get_backend(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
|
||||||
|
|
||||||
def main(functions, world_size=2, backend='gloo'):
|
def test_get_world_size(self):
|
||||||
try:
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
mp.spawn(
|
self.assertEqual(dist.get_world_size(), 2)
|
||||||
init_process,
|
|
||||||
args=(world_size, functions, backend),
|
|
||||||
nprocs=world_size)
|
|
||||||
except Exception:
|
|
||||||
pytest.fail('error')
|
|
||||||
|
|
||||||
|
def test_get_rank(self):
|
||||||
def _test_get_backend_dist():
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
assert dist.get_backend() == torch_dist.get_backend()
|
|
||||||
|
|
||||||
|
|
||||||
def _test_get_world_size_dist():
|
|
||||||
assert dist.get_world_size() == 2
|
|
||||||
|
|
||||||
|
|
||||||
def _test_get_rank_dist():
|
|
||||||
if torch_dist.get_rank() == 0:
|
if torch_dist.get_rank() == 0:
|
||||||
assert dist.get_rank() == 0
|
self.assertEqual(dist.get_rank(), 0)
|
||||||
else:
|
else:
|
||||||
assert dist.get_rank() == 1
|
self.assertEqual(dist.get_rank(), 1)
|
||||||
|
|
||||||
|
def test_local_size(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
self.assertEqual(dist.get_local_size(), 2)
|
||||||
|
|
||||||
def _test_local_size_dist():
|
def test_local_rank(self):
|
||||||
assert dist.get_local_size() == 2
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
self.assertEqual(
|
||||||
|
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
|
||||||
|
|
||||||
|
def test_get_dist_info(self):
|
||||||
def _test_local_rank_dist():
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank()
|
|
||||||
|
|
||||||
|
|
||||||
def _test_get_dist_info_dist():
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
assert dist.get_dist_info() == (0, 2)
|
self.assertEqual(dist.get_dist_info(), (0, 2))
|
||||||
else:
|
else:
|
||||||
assert dist.get_dist_info() == (1, 2)
|
self.assertEqual(dist.get_dist_info(), (1, 2))
|
||||||
|
|
||||||
|
def test_is_main_process(self):
|
||||||
def _test_is_main_process_dist():
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
assert dist.is_main_process()
|
self.assertTrue(dist.is_main_process())
|
||||||
else:
|
else:
|
||||||
assert not dist.is_main_process()
|
self.assertFalse(dist.is_main_process())
|
||||||
|
|
||||||
|
def test_master_only(self):
|
||||||
def _test_master_only_dist():
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
|
||||||
@dist.master_only
|
@dist.master_only
|
||||||
def fun():
|
def fun():
|
||||||
@ -124,35 +109,69 @@ def _test_master_only_dist():
|
|||||||
fun()
|
fun()
|
||||||
|
|
||||||
|
|
||||||
def test_non_distributed_env():
|
@unittest.skipIf(
|
||||||
_test_get_backend_non_dist()
|
|
||||||
_test_get_world_size_non_dist()
|
|
||||||
_test_get_rank_non_dist()
|
|
||||||
_test_local_size_non_dist()
|
|
||||||
_test_local_rank_non_dist()
|
|
||||||
_test_get_dist_info_non_dist()
|
|
||||||
_test_is_main_process_non_dist()
|
|
||||||
_test_master_only_non_dist()
|
|
||||||
_test_barrier_non_dist()
|
|
||||||
|
|
||||||
|
|
||||||
functions_to_test = [
|
|
||||||
_test_get_backend_dist,
|
|
||||||
_test_get_world_size_dist,
|
|
||||||
_test_get_rank_dist,
|
|
||||||
_test_local_size_dist,
|
|
||||||
_test_local_rank_dist,
|
|
||||||
_test_get_dist_info_dist,
|
|
||||||
_test_is_main_process_dist,
|
|
||||||
_test_master_only_dist,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_gloo_backend():
|
|
||||||
main(functions_to_test)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
||||||
def test_nccl_backend():
|
class TestUtilsWithNCCLBackend(MultiProcessTestCase):
|
||||||
main(functions_to_test, backend='nccl')
|
|
||||||
|
def _init_dist_env(self, rank, world_size):
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
|
os.environ['MASTER_PORT'] = '29505'
|
||||||
|
os.environ['RANK'] = str(rank)
|
||||||
|
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(rank % num_gpus)
|
||||||
|
torch_dist.init_process_group(
|
||||||
|
backend='nccl', rank=rank, world_size=world_size)
|
||||||
|
dist.init_local_group(0, world_size)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self._spawn_processes()
|
||||||
|
|
||||||
|
def test_get_backend(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
|
||||||
|
|
||||||
|
def test_get_world_size(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
self.assertEqual(dist.get_world_size(), 2)
|
||||||
|
|
||||||
|
def test_get_rank(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
if torch_dist.get_rank() == 0:
|
||||||
|
self.assertEqual(dist.get_rank(), 0)
|
||||||
|
else:
|
||||||
|
self.assertEqual(dist.get_rank(), 1)
|
||||||
|
|
||||||
|
def test_local_size(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
self.assertEqual(dist.get_local_size(), 2)
|
||||||
|
|
||||||
|
def test_local_rank(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
self.assertEqual(
|
||||||
|
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
|
||||||
|
|
||||||
|
def test_get_dist_info(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
self.assertEqual(dist.get_dist_info(), (0, 2))
|
||||||
|
else:
|
||||||
|
self.assertEqual(dist.get_dist_info(), (1, 2))
|
||||||
|
|
||||||
|
def test_is_main_process(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
self.assertTrue(dist.is_main_process())
|
||||||
|
else:
|
||||||
|
self.assertFalse(dist.is_main_process())
|
||||||
|
|
||||||
|
def test_master_only(self):
|
||||||
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
|
|
||||||
|
@dist.master_only
|
||||||
|
def fun():
|
||||||
|
assert dist.get_rank() == 0
|
||||||
|
|
||||||
|
fun()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user