[Fix]: fix RuntimeError of SyncBuffersHook (#309)
* fix RuntimeError of SyncBuffersHook * add UTpull/330/head
parent
e18832f046
commit
2b8a32eca0
|
@ -94,7 +94,11 @@ def all_reduce(data: Tensor,
|
|||
# it with 'sum' operation.
|
||||
if op.lower() == 'mean':
|
||||
torch_dist.all_reduce(data_on_device, _get_reduce_op('sum'), group)
|
||||
data_on_device.div_(world_size) # type: ignore
|
||||
|
||||
# When the type of `data_on_device` is int64,
|
||||
# `data_on_device.div_(world_size)` will appear RuntimeError:
|
||||
# result type Float can't be cast to the desired output type Long.
|
||||
data_on_device = data_on_device / world_size # type: ignore
|
||||
else:
|
||||
torch_dist.all_reduce(data_on_device, _get_reduce_op(op), group)
|
||||
|
||||
|
|
|
@ -132,8 +132,9 @@ class TestDistWithGLOOBackend(MultiProcessTestCase):
|
|||
|
||||
def test_all_reduce(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
||||
['sum', 'mean']):
|
||||
tensor_types = [torch.int64, torch.float32, torch.int64]
|
||||
reduce_ops = ['sum', 'mean', 'mean']
|
||||
for tensor_type, reduce_op in zip(tensor_types, reduce_ops):
|
||||
if dist.get_rank() == 0:
|
||||
data = torch.tensor([1, 2], dtype=tensor_type)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue