[Fix] Support multi-node distributed training with NPU backend (#1459)

This commit is contained in:
lanzeshun 2023-12-26 16:14:45 +08:00 committed by GitHub
parent 671f3bcdf4
commit 8e6fb12b1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -99,9 +99,10 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
rank = int(os.environ['RANK'])
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
if is_mlu_available():
import torch_mlu # noqa: F401
local_rank = int(os.environ['LOCAL_RANK'])
torch.mlu.set_device(local_rank)
torch_dist.init_process_group(
backend='cncl',
@ -110,15 +111,13 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
**kwargs)
elif is_npu_available():
import torch_npu # noqa: F401
torch.npu.set_device(rank)
torch.npu.set_device(local_rank)
torch_dist.init_process_group(
backend='hccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
if init_backend == 'torch':