mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhancement] Get local_rank in init_dist_mpi from env (#212)
This commit is contained in:
parent
0c59eeab5f
commit
a1adbff11e
15
mmengine/dist/utils.py
vendored
15
mmengine/dist/utils.py
vendored
@ -82,10 +82,15 @@ def _init_dist_mpi(backend, **kwargs) -> None:
|
||||
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
|
||||
**kwargs: keyword arguments are passed to ``init_process_group``.
|
||||
"""
|
||||
# TODO: use local_rank instead of rank % num_gpus
|
||||
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
||||
torch.cuda.set_device(local_rank)
|
||||
if 'MASTER_PORT' not in os.environ:
|
||||
# 29500 is torch.distributed default port
|
||||
os.environ['MASTER_PORT'] = '29500'
|
||||
if 'MASTER_ADDR' not in os.environ:
|
||||
raise KeyError('The environment variable MASTER_ADDR is not set')
|
||||
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
|
||||
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
|
||||
torch_dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
|
||||
@ -99,8 +104,6 @@ def _init_dist_slurm(backend, port=None) -> None:
|
||||
Args:
|
||||
backend (str): Backend of torch.distributed.
|
||||
port (int, optional): Master port. Defaults to None.
|
||||
|
||||
TODO: https://github.com/open-mmlab/mmcv/pull/1682
|
||||
"""
|
||||
proc_id = int(os.environ['SLURM_PROCID'])
|
||||
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||
|
Loading…
x
Reference in New Issue
Block a user