[Fix] Fix multi card issue in PyTorch v2.1 on Ascend (#1321)

This commit is contained in:
LRJKD 2023-08-25 10:35:58 +08:00 committed by GitHub
parent e1c6079d73
commit a53c2802a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -416,7 +416,7 @@ def _broadcast_object_list(object_list: List[Any],
is_hccl_backend = group_backend == 'hccl'
is_cncl_backend = group_backend == 'cncl'
if is_hccl_backend:
current_device = torch.npu.current_device()
current_device = torch.device('npu', torch.npu.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_cncl_backend:
current_device = torch.device('mlu', torch.mlu.current_device())