[Enhance] Make sure Tensors to broadcast is contiguous (#948)

* Make sure Tensors to cast is contiguous

* simplify

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
Weihang Xia 2023-02-22 11:41:06 +08:00 committed by GitHub
parent e271454527
commit 8370c1e7f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -307,7 +307,8 @@ def broadcast(data: Tensor,
input_device = get_data_device(data)
backend_device = get_comm_device(group)
data_on_device = cast_data_device(data, backend_device)
# broadcast requires tensor is contiguous
data_on_device = data_on_device.contiguous() # type: ignore
torch_dist.broadcast(data_on_device, src, group)
if get_rank(group) != src: