diff --git a/mmcv/runner/dist_utils.py b/mmcv/runner/dist_utils.py index d3a1ef3fd..71010963c 100644 --- a/mmcv/runner/dist_utils.py +++ b/mmcv/runner/dist_utils.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import functools import os +import socket import subprocess from collections import OrderedDict @@ -11,6 +13,24 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors, _unflatten_dense_tensors) +def _find_free_port(): + # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(('', 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def _is_free_port(port): + ips = socket.gethostbyname_ex(socket.gethostname())[-1] + ips.append('localhost') + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return all(s.connect_ex((ip, port)) != 0 for ip in ips) + + def init_dist(launcher, backend='nccl', **kwargs): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') @@ -64,8 +84,12 @@ def _init_dist_slurm(backend, port=None): elif 'MASTER_PORT' in os.environ: pass # use MASTER_PORT in the environment variable else: - # 29500 is torch.distributed default port - os.environ['MASTER_PORT'] = '29500' + # if torch.distributed default port(29500) is available + # then use it, else find a free port + if _is_free_port(29500): + os.environ['MASTER_PORT'] = '29500' + else: + os.environ['MASTER_PORT'] = str(_find_free_port()) # use MASTER_ADDR in the environment variable if it already exists if 'MASTER_ADDR' not in os.environ: os.environ['MASTER_ADDR'] = addr