From a1adbff11e3d17901e9d6ca54dff09cafa459a2a Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Thu, 5 May 2022 19:55:54 +0800 Subject: [PATCH] [Enhancement] Get local_rank in init_dist_mpi from env (#212) --- mmengine/dist/utils.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py index 0842a5d4..415b74b9 100644 --- a/mmengine/dist/utils.py +++ b/mmengine/dist/utils.py @@ -82,10 +82,15 @@ def _init_dist_mpi(backend, **kwargs) -> None: 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. **kwargs: keyword arguments are passed to ``init_process_group``. """ - # TODO: use local_rank instead of rank % num_gpus - rank = int(os.environ['OMPI_COMM_WORLD_RANK']) - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + torch.cuda.set_device(local_rank) + if 'MASTER_PORT' not in os.environ: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + if 'MASTER_ADDR' not in os.environ: + raise KeyError('The environment variable MASTER_ADDR is not set') + os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] + os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] torch_dist.init_process_group(backend=backend, **kwargs) @@ -99,8 +104,6 @@ def _init_dist_slurm(backend, port=None) -> None: Args: backend (str): Backend of torch.distributed. port (int, optional): Master port. Defaults to None. - - TODO: https://github.com/open-mmlab/mmcv/pull/1682 """ proc_id = int(os.environ['SLURM_PROCID']) ntasks = int(os.environ['SLURM_NTASKS'])