609 lines
23 KiB
Python
609 lines
23 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
import unittest
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as torch_dist
|
|
|
|
import mmengine.dist as dist
|
|
from mmengine.testing._internal import MultiProcessTestCase
|
|
|
|
|
|
class TestUtils(TestCase):
|
|
|
|
def test_get_backend(self):
|
|
self.assertIsNone(dist.get_backend())
|
|
|
|
def test_get_world_size(self):
|
|
self.assertEqual(dist.get_world_size(), 1)
|
|
|
|
def test_get_rank(self):
|
|
self.assertEqual(dist.get_rank(), 0)
|
|
|
|
def test_local_size(self):
|
|
self.assertEqual(dist.get_local_size(), 1)
|
|
|
|
def test_local_rank(self):
|
|
self.assertEqual(dist.get_local_rank(), 0)
|
|
|
|
def test_get_dist_info(self):
|
|
self.assertEqual(dist.get_dist_info(), (0, 1))
|
|
|
|
def test_is_main_process(self):
|
|
self.assertTrue(dist.is_main_process())
|
|
|
|
def test_master_only(self):
|
|
|
|
@dist.master_only
|
|
def fun():
|
|
assert dist.get_rank() == 0
|
|
|
|
fun()
|
|
|
|
def test_barrier(self):
|
|
dist.barrier() # nothing is done
|
|
|
|
def test_get_data_device(self):
|
|
# data is a Tensor
|
|
data = torch.tensor([0, 1])
|
|
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
|
|
|
|
# data is a list of Tensor
|
|
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
|
|
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
|
|
|
|
# data is a list but not all items are Tensor
|
|
data = [torch.tensor([0, 1]), 123]
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a list containing Tensor and a dict
|
|
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
|
|
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
|
|
|
|
# data is a list containing Tensor and a dict but the dict contains
|
|
# invalid type
|
|
data = [torch.tensor([0, 1]), {'key': '123'}]
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a empty list
|
|
with self.assertRaises(ValueError):
|
|
dist.get_data_device([])
|
|
|
|
# data is a dict
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])}
|
|
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
|
|
|
|
# data is a dict but not all values are Tensor
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': 123}
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a dict and one of values is list of Tensor
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([0, 1])]}
|
|
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
|
|
|
|
# data is a dict and one of values is an invalid type
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': ['123']}
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a empty dict
|
|
with self.assertRaises(ValueError):
|
|
dist.get_data_device({})
|
|
|
|
# data is not a valid type
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
'data should be a Tensor, sequence of tensor or dict'):
|
|
dist.get_data_device('123')
|
|
|
|
@unittest.skipIf(
|
|
torch.cuda.device_count() == 0, reason='at lest need 1 gpu to test')
|
|
def test_cast_data_device(self):
|
|
expected_device = torch.device('cuda', torch.cuda.current_device())
|
|
# data is a Tensor
|
|
data = torch.tensor([0, 1])
|
|
output = dist.cast_data_device(data, expected_device)
|
|
self.assertEqual(output.device, expected_device)
|
|
|
|
# data is a Tensor and out is also a Tensor
|
|
data = torch.tensor([0, 1])
|
|
out = torch.tensor([1, 2])
|
|
output = dist.cast_data_device(data, expected_device, out=out)
|
|
self.assertEqual(output.device, expected_device)
|
|
self.assertTrue(torch.allclose(output.cpu(), out))
|
|
|
|
# data is a list of Tensor
|
|
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
|
|
for item in dist.cast_data_device(data, expected_device):
|
|
self.assertEqual(item.device, expected_device)
|
|
|
|
# both data and out are list of tensor
|
|
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
|
|
out = [torch.tensor([3, 4]), torch.tensor([5, 6])]
|
|
output = dist.cast_data_device(data, expected_device, out=out)
|
|
for item1, item2 in zip(output, out):
|
|
self.assertEqual(item1.device, expected_device)
|
|
self.assertTrue(torch.allclose(item1.cpu(), item2))
|
|
|
|
# data is a list containing a Tensor and a dict
|
|
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
|
|
output = dist.cast_data_device(data, expected_device)
|
|
self.assertEqual(output[0].device, expected_device)
|
|
self.assertEqual(output[1]['key'].device, expected_device)
|
|
|
|
# data is a list containing a Tensor and a dict, so does out
|
|
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
|
|
out = [torch.tensor([3, 4]), {'key': torch.tensor([5, 6])}]
|
|
output = dist.cast_data_device(data, expected_device, out=out)
|
|
self.assertEqual(output[0].device, expected_device)
|
|
self.assertTrue(torch.allclose(output[0].cpu(), out[0]))
|
|
self.assertEqual(output[1]['key'].device, expected_device)
|
|
self.assertTrue(torch.allclose(output[1]['key'].cpu(), out[1]['key']))
|
|
|
|
# data is an empty list
|
|
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
|
|
dist.cast_data_device([], expected_device)
|
|
|
|
# data is a dict
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
|
|
output = dist.cast_data_device(data, expected_device)
|
|
for k, v in output.items():
|
|
self.assertEqual(v.device, expected_device)
|
|
|
|
# data is a dict, so does out
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
|
|
out = {'key1': torch.tensor([3, 4]), 'key2': torch.tensor([5, 6])}
|
|
output = dist.cast_data_device(data, expected_device, out=out)
|
|
for k, v in output.items():
|
|
self.assertEqual(v.device, expected_device)
|
|
self.assertTrue(torch.allclose(v.cpu(), out[k]))
|
|
|
|
# the length of data and out should be same
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
|
|
out = {'key1': torch.tensor([3, 4])}
|
|
with self.assertRaisesRegex(ValueError,
|
|
'length of data and out should be same'):
|
|
dist.cast_data_device(data, expected_device, out=out)
|
|
|
|
# data is an empty dict
|
|
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
|
|
dist.cast_data_device({}, expected_device)
|
|
|
|
# data is a dict and one of values is list
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([2, 3])]}
|
|
out = {'key1': torch.tensor([3, 4]), 'key2': [torch.tensor([5, 6])]}
|
|
output = dist.cast_data_device(data, expected_device, out=out)
|
|
self.assertEqual(output['key1'].device, expected_device)
|
|
self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1']))
|
|
self.assertEqual(output['key2'][0].device, expected_device)
|
|
self.assertTrue(
|
|
torch.allclose(output['key2'][0].cpu(), out['key2'][0]))
|
|
|
|
# data is not a valid type
|
|
with self.assertRaisesRegex(
|
|
TypeError, 'data should be a Tensor, list of tensor or dict'):
|
|
dist.cast_data_device(123, expected_device)
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError, 'data should be a Tensor, list of tensor or dict'):
|
|
dist.cast_data_device('123', expected_device)
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError, 'data should be a Tensor, list of tensor or dict'):
|
|
dist.cast_data_device(np.array([0, 1]), expected_device)
|
|
|
|
# data and out are not the same type
|
|
data = torch.tensor([0, 1])
|
|
out = '123'
|
|
with self.assertRaisesRegex(TypeError,
|
|
'out should be the same type with data'):
|
|
dist.cast_data_device(data, expected_device, out=out)
|
|
|
|
data = {0, 1}
|
|
out = {2, 3}
|
|
with self.assertRaisesRegex(TypeError, 'out should not be a set'):
|
|
dist.cast_data_device(data, expected_device, out=out)
|
|
|
|
|
|
class TestUtilsWithGLOOBackend(MultiProcessTestCase):
|
|
|
|
def _init_dist_env(self, rank, world_size):
|
|
"""Initialize the distributed environment."""
|
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
|
os.environ['MASTER_PORT'] = '29505'
|
|
os.environ['RANK'] = str(rank)
|
|
|
|
torch_dist.init_process_group(
|
|
backend='gloo', rank=rank, world_size=world_size)
|
|
dist.init_local_group(0, world_size)
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
def test_get_backend(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
|
|
|
|
def test_get_world_size(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
self.assertEqual(dist.get_world_size(), 2)
|
|
|
|
def test_get_rank(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
if torch_dist.get_rank() == 0:
|
|
self.assertEqual(dist.get_rank(), 0)
|
|
else:
|
|
self.assertEqual(dist.get_rank(), 1)
|
|
|
|
def test_local_size(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
self.assertEqual(dist.get_local_size(), 2)
|
|
|
|
def test_local_rank(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
self.assertEqual(
|
|
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
|
|
|
|
def test_get_dist_info(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
if dist.get_rank() == 0:
|
|
self.assertEqual(dist.get_dist_info(), (0, 2))
|
|
else:
|
|
self.assertEqual(dist.get_dist_info(), (1, 2))
|
|
|
|
def test_is_main_process(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
if dist.get_rank() == 0:
|
|
self.assertTrue(dist.is_main_process())
|
|
else:
|
|
self.assertFalse(dist.is_main_process())
|
|
|
|
def test_master_only(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
|
|
@dist.master_only
|
|
def fun():
|
|
assert dist.get_rank() == 0
|
|
|
|
fun()
|
|
|
|
def test_get_data_device(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
|
|
# data is a Tensor
|
|
data = torch.tensor([0, 1])
|
|
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
|
|
|
|
# data is a list of Tensor
|
|
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
|
|
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
|
|
|
|
# data is a list but not all items are Tensor
|
|
data = [torch.tensor([0, 1]), 123]
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a list containing Tensor and a dict
|
|
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
|
|
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
|
|
|
|
# data is a list containing Tensor and a dict but the dict contains
|
|
# invalid type
|
|
data = [torch.tensor([0, 1]), {'key': '123'}]
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a empty list
|
|
with self.assertRaises(ValueError):
|
|
dist.get_data_device([])
|
|
|
|
# data is a dict
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])}
|
|
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
|
|
|
|
# data is a dict but not all values are Tensor
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': 123}
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a dict and one of values is list of Tensor
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([0, 1])]}
|
|
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
|
|
|
|
# data is a dict and one of values is an invalid type
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': ['123']}
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a empty dict
|
|
with self.assertRaises(ValueError):
|
|
dist.get_data_device({})
|
|
|
|
# data is not a valid type
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
'data should be a Tensor, sequence of tensor or dict'):
|
|
dist.get_data_device('123')
|
|
|
|
def test_get_comm_device(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
group = dist.get_default_group()
|
|
assert dist.get_comm_device(group) == torch.device('cpu')
|
|
|
|
|
|
@unittest.skipIf(
|
|
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
|
class TestUtilsWithNCCLBackend(MultiProcessTestCase):
|
|
|
|
def _init_dist_env(self, rank, world_size):
|
|
"""Initialize the distributed environment."""
|
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
|
os.environ['MASTER_PORT'] = '29505'
|
|
os.environ['RANK'] = str(rank)
|
|
|
|
num_gpus = torch.cuda.device_count()
|
|
torch.cuda.set_device(rank % num_gpus)
|
|
torch_dist.init_process_group(
|
|
backend='nccl', rank=rank, world_size=world_size)
|
|
dist.init_local_group(0, world_size)
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
def test_get_backend(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
|
|
|
|
def test_get_world_size(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
self.assertEqual(dist.get_world_size(), 2)
|
|
|
|
def test_get_rank(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
if torch_dist.get_rank() == 0:
|
|
self.assertEqual(dist.get_rank(), 0)
|
|
else:
|
|
self.assertEqual(dist.get_rank(), 1)
|
|
|
|
def test_local_size(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
self.assertEqual(dist.get_local_size(), 2)
|
|
|
|
def test_local_rank(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
self.assertEqual(
|
|
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
|
|
|
|
def test_get_dist_info(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
if dist.get_rank() == 0:
|
|
self.assertEqual(dist.get_dist_info(), (0, 2))
|
|
else:
|
|
self.assertEqual(dist.get_dist_info(), (1, 2))
|
|
|
|
def test_is_main_process(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
if dist.get_rank() == 0:
|
|
self.assertTrue(dist.is_main_process())
|
|
else:
|
|
self.assertFalse(dist.is_main_process())
|
|
|
|
def test_master_only(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
|
|
@dist.master_only
|
|
def fun():
|
|
assert dist.get_rank() == 0
|
|
|
|
fun()
|
|
|
|
def test_get_data_device(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
|
|
expected_device = torch.device('cuda', torch.cuda.current_device())
|
|
|
|
# data is a Tensor
|
|
data = torch.tensor([0, 1]).to(expected_device)
|
|
self.assertEqual(dist.get_data_device(data), expected_device)
|
|
|
|
# data is a list of Tensor
|
|
data = [
|
|
torch.tensor([0, 1]).to(expected_device),
|
|
torch.tensor([2, 3]).to(expected_device)
|
|
]
|
|
self.assertEqual(dist.get_data_device(data), expected_device)
|
|
|
|
# data is a list but not all items are Tensor
|
|
data = [torch.tensor([0, 1]).to(expected_device), 123]
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a list of Tensor but not all items have the same device type
|
|
data = [torch.tensor([0, 1]), torch.tensor([2, 3]).to(expected_device)]
|
|
with self.assertRaises(ValueError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a list containing Tensor and a dict
|
|
data = [
|
|
torch.tensor([0, 1]).to(expected_device), {
|
|
'key': torch.tensor([2, 3]).to(expected_device)
|
|
}
|
|
]
|
|
self.assertEqual(dist.get_data_device(data), expected_device)
|
|
|
|
# data is a list containing Tensor and a dict but the dict contains
|
|
# invalid type
|
|
data = [torch.tensor([0, 1]).to(expected_device), {'key': '123'}]
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a empty list
|
|
with self.assertRaises(ValueError):
|
|
dist.get_data_device([])
|
|
|
|
# data is a dict
|
|
data = {
|
|
'key1': torch.tensor([0, 1]).to(expected_device),
|
|
'key2': torch.tensor([0, 1]).to(expected_device)
|
|
}
|
|
self.assertEqual(dist.get_data_device(data), expected_device)
|
|
|
|
# data is a dict but not all values are Tensor
|
|
data = {'key1': torch.tensor([0, 1]).to(expected_device), 'key2': 123}
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a dict but not all values have the same device type
|
|
data = {
|
|
'key1': torch.tensor([0, 1]),
|
|
'key2': torch.tensor([0, 1]).to(expected_device)
|
|
}
|
|
with self.assertRaises(ValueError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a dict and one of values is list of Tensor
|
|
data = {
|
|
'key1': torch.tensor([0, 1]).to(expected_device),
|
|
'key2': [torch.tensor([0, 1]).to(expected_device)]
|
|
}
|
|
self.assertEqual(dist.get_data_device(data), expected_device)
|
|
|
|
# data is a dict and one of values is an invalid type
|
|
data = {
|
|
'key1': torch.tensor([0, 1]).to(expected_device),
|
|
'key2': ['123']
|
|
}
|
|
with self.assertRaises(TypeError):
|
|
dist.get_data_device(data)
|
|
|
|
# data is a empty dict
|
|
with self.assertRaises(ValueError):
|
|
dist.get_data_device({})
|
|
|
|
# data is not a valid type
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
'data should be a Tensor, sequence of tensor or dict'):
|
|
dist.get_data_device('123')
|
|
|
|
def test_get_comm_device(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
group = dist.get_default_group()
|
|
expected = torch.device('cuda', torch.cuda.current_device())
|
|
self.assertEqual(dist.get_comm_device(group), expected)
|
|
|
|
def test_cast_data_device(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
|
|
expected_device = torch.device('cuda', torch.cuda.current_device())
|
|
# data is a Tensor
|
|
data = torch.tensor([0, 1])
|
|
output = dist.cast_data_device(data, expected_device)
|
|
self.assertEqual(output.device, expected_device)
|
|
|
|
# data is a Tensor and out is also a Tensor
|
|
data = torch.tensor([0, 1])
|
|
out = torch.tensor([1, 2])
|
|
output = dist.cast_data_device(data, expected_device, out=out)
|
|
self.assertEqual(output.device, expected_device)
|
|
self.assertTrue(torch.allclose(output.cpu(), out))
|
|
|
|
# data is a list of Tensor
|
|
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
|
|
for item in dist.cast_data_device(data, expected_device):
|
|
self.assertEqual(item.device, expected_device)
|
|
|
|
# both data and out are list of tensor
|
|
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
|
|
out = [torch.tensor([3, 4]), torch.tensor([5, 6])]
|
|
output = dist.cast_data_device(data, expected_device, out=out)
|
|
for item1, item2 in zip(output, out):
|
|
self.assertEqual(item1.device, expected_device)
|
|
self.assertTrue(torch.allclose(item1.cpu(), item2))
|
|
|
|
# data is a list containing a Tensor and a dict
|
|
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
|
|
output = dist.cast_data_device(data, expected_device)
|
|
self.assertEqual(output[0].device, expected_device)
|
|
self.assertEqual(output[1]['key'].device, expected_device)
|
|
|
|
# data is a list containing a Tensor and a dict, so does out
|
|
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
|
|
out = [torch.tensor([3, 4]), {'key': torch.tensor([5, 6])}]
|
|
output = dist.cast_data_device(data, expected_device, out=out)
|
|
self.assertEqual(output[0].device, expected_device)
|
|
self.assertTrue(torch.allclose(output[0].cpu(), out[0]))
|
|
self.assertEqual(output[1]['key'].device, expected_device)
|
|
self.assertTrue(torch.allclose(output[1]['key'].cpu(), out[1]['key']))
|
|
|
|
# data is an empty list
|
|
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
|
|
dist.cast_data_device([], expected_device)
|
|
|
|
# data is a dict
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
|
|
output = dist.cast_data_device(data, expected_device)
|
|
for k, v in output.items():
|
|
self.assertEqual(v.device, expected_device)
|
|
|
|
# data is a dict, so does out
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
|
|
out = {'key1': torch.tensor([3, 4]), 'key2': torch.tensor([5, 6])}
|
|
output = dist.cast_data_device(data, expected_device, out=out)
|
|
for k, v in output.items():
|
|
self.assertEqual(v.device, expected_device)
|
|
self.assertTrue(torch.allclose(v.cpu(), out[k]))
|
|
|
|
# the length of data and out should be same
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
|
|
out = {'key1': torch.tensor([3, 4])}
|
|
with self.assertRaisesRegex(ValueError,
|
|
'length of data and out should be same'):
|
|
dist.cast_data_device(data, expected_device, out=out)
|
|
|
|
# data is an empty dict
|
|
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
|
|
dist.cast_data_device({}, expected_device)
|
|
|
|
# data is a dict and one of values is list
|
|
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([2, 3])]}
|
|
out = {'key1': torch.tensor([3, 4]), 'key2': [torch.tensor([5, 6])]}
|
|
output = dist.cast_data_device(data, expected_device, out=out)
|
|
self.assertEqual(output['key1'].device, expected_device)
|
|
self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1']))
|
|
self.assertEqual(output['key2'][0].device, expected_device)
|
|
self.assertTrue(
|
|
torch.allclose(output['key2'][0].cpu(), out['key2'][0]))
|
|
|
|
# data is not a valid type
|
|
with self.assertRaisesRegex(
|
|
TypeError, 'data should be a Tensor, list of tensor or dict'):
|
|
dist.cast_data_device(123, expected_device)
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError, 'data should be a Tensor, list of tensor or dict'):
|
|
dist.cast_data_device('123', expected_device)
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError, 'data should be a Tensor, list of tensor or dict'):
|
|
dist.cast_data_device(np.array([0, 1]), expected_device)
|
|
|
|
# data and out are not the same type
|
|
data = torch.tensor([0, 1])
|
|
out = '123'
|
|
with self.assertRaisesRegex(TypeError,
|
|
'out should be the same type with data'):
|
|
dist.cast_data_device(data, expected_device, out=out)
|
|
|
|
data = {0, 1}
|
|
out = {2, 3}
|
|
with self.assertRaisesRegex(TypeError, 'out should not be a set'):
|
|
dist.cast_data_device(data, expected_device, out=out)
|