mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix]: fix RuntimeError of SyncBuffersHook (#309)
* fix RuntimeError of SyncBuffersHook * add UT
This commit is contained in:
parent
e18832f046
commit
2b8a32eca0
6
mmengine/dist/dist.py
vendored
6
mmengine/dist/dist.py
vendored
@ -94,7 +94,11 @@ def all_reduce(data: Tensor,
|
|||||||
# it with 'sum' operation.
|
# it with 'sum' operation.
|
||||||
if op.lower() == 'mean':
|
if op.lower() == 'mean':
|
||||||
torch_dist.all_reduce(data_on_device, _get_reduce_op('sum'), group)
|
torch_dist.all_reduce(data_on_device, _get_reduce_op('sum'), group)
|
||||||
data_on_device.div_(world_size) # type: ignore
|
|
||||||
|
# When the type of `data_on_device` is int64,
|
||||||
|
# `data_on_device.div_(world_size)` will appear RuntimeError:
|
||||||
|
# result type Float can't be cast to the desired output type Long.
|
||||||
|
data_on_device = data_on_device / world_size # type: ignore
|
||||||
else:
|
else:
|
||||||
torch_dist.all_reduce(data_on_device, _get_reduce_op(op), group)
|
torch_dist.all_reduce(data_on_device, _get_reduce_op(op), group)
|
||||||
|
|
||||||
|
@ -132,8 +132,9 @@ class TestDistWithGLOOBackend(MultiProcessTestCase):
|
|||||||
|
|
||||||
def test_all_reduce(self):
|
def test_all_reduce(self):
|
||||||
self._init_dist_env(self.rank, self.world_size)
|
self._init_dist_env(self.rank, self.world_size)
|
||||||
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
tensor_types = [torch.int64, torch.float32, torch.int64]
|
||||||
['sum', 'mean']):
|
reduce_ops = ['sum', 'mean', 'mean']
|
||||||
|
for tensor_type, reduce_op in zip(tensor_types, reduce_ops):
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
data = torch.tensor([1, 2], dtype=tensor_type)
|
data = torch.tensor([1, 2], dtype=tensor_type)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user