mmengine/tests/test_dist/test_utils.py

609 lines
23 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
from unittest import TestCase
import numpy as np
import torch
import torch.distributed as torch_dist
import mmengine.dist as dist
from mmengine.testing._internal import MultiProcessTestCase
class TestUtils(TestCase):
def test_get_backend(self):
self.assertIsNone(dist.get_backend())
def test_get_world_size(self):
self.assertEqual(dist.get_world_size(), 1)
def test_get_rank(self):
self.assertEqual(dist.get_rank(), 0)
def test_local_size(self):
self.assertEqual(dist.get_local_size(), 1)
def test_local_rank(self):
self.assertEqual(dist.get_local_rank(), 0)
def test_get_dist_info(self):
self.assertEqual(dist.get_dist_info(), (0, 1))
def test_is_main_process(self):
self.assertTrue(dist.is_main_process())
def test_master_only(self):
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def test_barrier(self):
dist.barrier() # nothing is done
def test_get_data_device(self):
# data is a Tensor
data = torch.tensor([0, 1])
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list of Tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list but not all items are Tensor
data = [torch.tensor([0, 1]), 123]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a list containing Tensor and a dict
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list containing Tensor and a dict but the dict contains
# invalid type
data = [torch.tensor([0, 1]), {'key': '123'}]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty list
with self.assertRaises(ValueError):
dist.get_data_device([])
# data is a dict
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])}
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a dict but not all values are Tensor
data = {'key1': torch.tensor([0, 1]), 'key2': 123}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a dict and one of values is list of Tensor
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([0, 1])]}
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a dict and one of values is an invalid type
data = {'key1': torch.tensor([0, 1]), 'key2': ['123']}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty dict
with self.assertRaises(ValueError):
dist.get_data_device({})
# data is not a valid type
with self.assertRaisesRegex(
TypeError,
'data should be a Tensor, sequence of tensor or dict'):
dist.get_data_device('123')
@unittest.skipIf(
torch.cuda.device_count() == 0, reason='at lest need 1 gpu to test')
def test_cast_data_device(self):
expected_device = torch.device('cuda', torch.cuda.current_device())
# data is a Tensor
data = torch.tensor([0, 1])
output = dist.cast_data_device(data, expected_device)
self.assertEqual(output.device, expected_device)
# data is a Tensor and out is also a Tensor
data = torch.tensor([0, 1])
out = torch.tensor([1, 2])
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output.device, expected_device)
self.assertTrue(torch.allclose(output.cpu(), out))
# data is a list of Tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
for item in dist.cast_data_device(data, expected_device):
self.assertEqual(item.device, expected_device)
# both data and out are list of tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
out = [torch.tensor([3, 4]), torch.tensor([5, 6])]
output = dist.cast_data_device(data, expected_device, out=out)
for item1, item2 in zip(output, out):
self.assertEqual(item1.device, expected_device)
self.assertTrue(torch.allclose(item1.cpu(), item2))
# data is a list containing a Tensor and a dict
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
output = dist.cast_data_device(data, expected_device)
self.assertEqual(output[0].device, expected_device)
self.assertEqual(output[1]['key'].device, expected_device)
# data is a list containing a Tensor and a dict, so does out
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
out = [torch.tensor([3, 4]), {'key': torch.tensor([5, 6])}]
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output[0].device, expected_device)
self.assertTrue(torch.allclose(output[0].cpu(), out[0]))
self.assertEqual(output[1]['key'].device, expected_device)
self.assertTrue(torch.allclose(output[1]['key'].cpu(), out[1]['key']))
# data is an empty list
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
dist.cast_data_device([], expected_device)
# data is a dict
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
output = dist.cast_data_device(data, expected_device)
for k, v in output.items():
self.assertEqual(v.device, expected_device)
# data is a dict, so does out
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
out = {'key1': torch.tensor([3, 4]), 'key2': torch.tensor([5, 6])}
output = dist.cast_data_device(data, expected_device, out=out)
for k, v in output.items():
self.assertEqual(v.device, expected_device)
self.assertTrue(torch.allclose(v.cpu(), out[k]))
# the length of data and out should be same
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
out = {'key1': torch.tensor([3, 4])}
with self.assertRaisesRegex(ValueError,
'length of data and out should be same'):
dist.cast_data_device(data, expected_device, out=out)
# data is an empty dict
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
dist.cast_data_device({}, expected_device)
# data is a dict and one of values is list
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([2, 3])]}
out = {'key1': torch.tensor([3, 4]), 'key2': [torch.tensor([5, 6])]}
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output['key1'].device, expected_device)
self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1']))
self.assertEqual(output['key2'][0].device, expected_device)
self.assertTrue(
torch.allclose(output['key2'][0].cpu(), out['key2'][0]))
# data is not a valid type
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device(123, expected_device)
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device('123', expected_device)
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device(np.array([0, 1]), expected_device)
# data and out are not the same type
data = torch.tensor([0, 1])
out = '123'
with self.assertRaisesRegex(TypeError,
'out should be the same type with data'):
dist.cast_data_device(data, expected_device, out=out)
data = {0, 1}
out = {2, 3}
with self.assertRaisesRegex(TypeError, 'out should not be a set'):
dist.cast_data_device(data, expected_device, out=out)
class TestUtilsWithGLOOBackend(MultiProcessTestCase):
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
torch_dist.init_process_group(
backend='gloo', rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
def setUp(self):
super().setUp()
self._spawn_processes()
def test_get_backend(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def test_get_world_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_world_size(), 2)
def test_get_rank(self):
self._init_dist_env(self.rank, self.world_size)
if torch_dist.get_rank() == 0:
self.assertEqual(dist.get_rank(), 0)
else:
self.assertEqual(dist.get_rank(), 1)
def test_local_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_local_size(), 2)
def test_local_rank(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
def test_get_dist_info(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertEqual(dist.get_dist_info(), (0, 2))
else:
self.assertEqual(dist.get_dist_info(), (1, 2))
def test_is_main_process(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertTrue(dist.is_main_process())
else:
self.assertFalse(dist.is_main_process())
def test_master_only(self):
self._init_dist_env(self.rank, self.world_size)
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def test_get_data_device(self):
self._init_dist_env(self.rank, self.world_size)
# data is a Tensor
data = torch.tensor([0, 1])
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list of Tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list but not all items are Tensor
data = [torch.tensor([0, 1]), 123]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a list containing Tensor and a dict
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list containing Tensor and a dict but the dict contains
# invalid type
data = [torch.tensor([0, 1]), {'key': '123'}]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty list
with self.assertRaises(ValueError):
dist.get_data_device([])
# data is a dict
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])}
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a dict but not all values are Tensor
data = {'key1': torch.tensor([0, 1]), 'key2': 123}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a dict and one of values is list of Tensor
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([0, 1])]}
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a dict and one of values is an invalid type
data = {'key1': torch.tensor([0, 1]), 'key2': ['123']}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty dict
with self.assertRaises(ValueError):
dist.get_data_device({})
# data is not a valid type
with self.assertRaisesRegex(
TypeError,
'data should be a Tensor, sequence of tensor or dict'):
dist.get_data_device('123')
def test_get_comm_device(self):
self._init_dist_env(self.rank, self.world_size)
group = dist.get_default_group()
assert dist.get_comm_device(group) == torch.device('cpu')
@unittest.skipIf(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
class TestUtilsWithNCCLBackend(MultiProcessTestCase):
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group(
backend='nccl', rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
def setUp(self):
super().setUp()
self._spawn_processes()
def test_get_backend(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def test_get_world_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_world_size(), 2)
def test_get_rank(self):
self._init_dist_env(self.rank, self.world_size)
if torch_dist.get_rank() == 0:
self.assertEqual(dist.get_rank(), 0)
else:
self.assertEqual(dist.get_rank(), 1)
def test_local_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_local_size(), 2)
def test_local_rank(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
def test_get_dist_info(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertEqual(dist.get_dist_info(), (0, 2))
else:
self.assertEqual(dist.get_dist_info(), (1, 2))
def test_is_main_process(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertTrue(dist.is_main_process())
else:
self.assertFalse(dist.is_main_process())
def test_master_only(self):
self._init_dist_env(self.rank, self.world_size)
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def test_get_data_device(self):
self._init_dist_env(self.rank, self.world_size)
expected_device = torch.device('cuda', torch.cuda.current_device())
# data is a Tensor
data = torch.tensor([0, 1]).to(expected_device)
self.assertEqual(dist.get_data_device(data), expected_device)
# data is a list of Tensor
data = [
torch.tensor([0, 1]).to(expected_device),
torch.tensor([2, 3]).to(expected_device)
]
self.assertEqual(dist.get_data_device(data), expected_device)
# data is a list but not all items are Tensor
data = [torch.tensor([0, 1]).to(expected_device), 123]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a list of Tensor but not all items have the same device type
data = [torch.tensor([0, 1]), torch.tensor([2, 3]).to(expected_device)]
with self.assertRaises(ValueError):
dist.get_data_device(data)
# data is a list containing Tensor and a dict
data = [
torch.tensor([0, 1]).to(expected_device), {
'key': torch.tensor([2, 3]).to(expected_device)
}
]
self.assertEqual(dist.get_data_device(data), expected_device)
# data is a list containing Tensor and a dict but the dict contains
# invalid type
data = [torch.tensor([0, 1]).to(expected_device), {'key': '123'}]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty list
with self.assertRaises(ValueError):
dist.get_data_device([])
# data is a dict
data = {
'key1': torch.tensor([0, 1]).to(expected_device),
'key2': torch.tensor([0, 1]).to(expected_device)
}
self.assertEqual(dist.get_data_device(data), expected_device)
# data is a dict but not all values are Tensor
data = {'key1': torch.tensor([0, 1]).to(expected_device), 'key2': 123}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a dict but not all values have the same device type
data = {
'key1': torch.tensor([0, 1]),
'key2': torch.tensor([0, 1]).to(expected_device)
}
with self.assertRaises(ValueError):
dist.get_data_device(data)
# data is a dict and one of values is list of Tensor
data = {
'key1': torch.tensor([0, 1]).to(expected_device),
'key2': [torch.tensor([0, 1]).to(expected_device)]
}
self.assertEqual(dist.get_data_device(data), expected_device)
# data is a dict and one of values is an invalid type
data = {
'key1': torch.tensor([0, 1]).to(expected_device),
'key2': ['123']
}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty dict
with self.assertRaises(ValueError):
dist.get_data_device({})
# data is not a valid type
with self.assertRaisesRegex(
TypeError,
'data should be a Tensor, sequence of tensor or dict'):
dist.get_data_device('123')
def test_get_comm_device(self):
self._init_dist_env(self.rank, self.world_size)
group = dist.get_default_group()
expected = torch.device('cuda', torch.cuda.current_device())
self.assertEqual(dist.get_comm_device(group), expected)
def test_cast_data_device(self):
self._init_dist_env(self.rank, self.world_size)
expected_device = torch.device('cuda', torch.cuda.current_device())
# data is a Tensor
data = torch.tensor([0, 1])
output = dist.cast_data_device(data, expected_device)
self.assertEqual(output.device, expected_device)
# data is a Tensor and out is also a Tensor
data = torch.tensor([0, 1])
out = torch.tensor([1, 2])
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output.device, expected_device)
self.assertTrue(torch.allclose(output.cpu(), out))
# data is a list of Tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
for item in dist.cast_data_device(data, expected_device):
self.assertEqual(item.device, expected_device)
# both data and out are list of tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
out = [torch.tensor([3, 4]), torch.tensor([5, 6])]
output = dist.cast_data_device(data, expected_device, out=out)
for item1, item2 in zip(output, out):
self.assertEqual(item1.device, expected_device)
self.assertTrue(torch.allclose(item1.cpu(), item2))
# data is a list containing a Tensor and a dict
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
output = dist.cast_data_device(data, expected_device)
self.assertEqual(output[0].device, expected_device)
self.assertEqual(output[1]['key'].device, expected_device)
# data is a list containing a Tensor and a dict, so does out
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
out = [torch.tensor([3, 4]), {'key': torch.tensor([5, 6])}]
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output[0].device, expected_device)
self.assertTrue(torch.allclose(output[0].cpu(), out[0]))
self.assertEqual(output[1]['key'].device, expected_device)
self.assertTrue(torch.allclose(output[1]['key'].cpu(), out[1]['key']))
# data is an empty list
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
dist.cast_data_device([], expected_device)
# data is a dict
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
output = dist.cast_data_device(data, expected_device)
for k, v in output.items():
self.assertEqual(v.device, expected_device)
# data is a dict, so does out
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
out = {'key1': torch.tensor([3, 4]), 'key2': torch.tensor([5, 6])}
output = dist.cast_data_device(data, expected_device, out=out)
for k, v in output.items():
self.assertEqual(v.device, expected_device)
self.assertTrue(torch.allclose(v.cpu(), out[k]))
# the length of data and out should be same
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
out = {'key1': torch.tensor([3, 4])}
with self.assertRaisesRegex(ValueError,
'length of data and out should be same'):
dist.cast_data_device(data, expected_device, out=out)
# data is an empty dict
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
dist.cast_data_device({}, expected_device)
# data is a dict and one of values is list
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([2, 3])]}
out = {'key1': torch.tensor([3, 4]), 'key2': [torch.tensor([5, 6])]}
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output['key1'].device, expected_device)
self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1']))
self.assertEqual(output['key2'][0].device, expected_device)
self.assertTrue(
torch.allclose(output['key2'][0].cpu(), out['key2'][0]))
# data is not a valid type
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device(123, expected_device)
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device('123', expected_device)
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device(np.array([0, 1]), expected_device)
# data and out are not the same type
data = torch.tensor([0, 1])
out = '123'
with self.assertRaisesRegex(TypeError,
'out should be the same type with data'):
dist.cast_data_device(data, expected_device, out=out)
data = {0, 1}
out = {2, 3}
with self.assertRaisesRegex(TypeError, 'out should not be a set'):
dist.cast_data_device(data, expected_device, out=out)