diff --git a/mmcv/runner/dist_utils.py b/mmcv/runner/dist_utils.py index 45f73c9b0..c061b3c11 100644 --- a/mmcv/runner/dist_utils.py +++ b/mmcv/runner/dist_utils.py @@ -60,7 +60,8 @@ def _init_dist_pytorch(backend: str, **kwargs) -> None: **kwargs) elif IS_NPU_AVAILABLE: import torch_npu # noqa: F401 - torch.npu.set_device(rank) + num_npus = torch.npu.device_count() + torch.npu.set_device(rank % num_npus) dist.init_process_group( backend='hccl', rank=rank,