[Fix]: fix RuntimeError of SyncBuffersHook (#309)

* fix RuntimeError of SyncBuffersHook

* add UT
pull/330/head
Haian Huang(深度眸) 2022-06-22 20:00:46 +08:00 committed by GitHub
parent e18832f046
commit 2b8a32eca0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 3 deletions

View File

@ -94,7 +94,11 @@ def all_reduce(data: Tensor,
# it with 'sum' operation.
if op.lower() == 'mean':
torch_dist.all_reduce(data_on_device, _get_reduce_op('sum'), group)
data_on_device.div_(world_size) # type: ignore
# When the type of `data_on_device` is int64,
# `data_on_device.div_(world_size)` will appear RuntimeError:
# result type Float can't be cast to the desired output type Long.
data_on_device = data_on_device / world_size # type: ignore
else:
torch_dist.all_reduce(data_on_device, _get_reduce_op(op), group)

View File

@ -132,8 +132,9 @@ class TestDistWithGLOOBackend(MultiProcessTestCase):
def test_all_reduce(self):
self._init_dist_env(self.rank, self.world_size)
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
tensor_types = [torch.int64, torch.float32, torch.int64]
reduce_ops = ['sum', 'mean', 'mean']
for tensor_type, reduce_op in zip(tensor_types, reduce_ops):
if dist.get_rank() == 0:
data = torch.tensor([1, 2], dtype=tensor_type)
else: