mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support finding free port in _init_dist_slurm() (#1846)
* [feat]:support find free port in _init_dist_slurm * fix format * Update mmcv/runner/dist_utils.py should support port taken by a non-localhost address. Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Update dist_utils.py Add Copyright. * rename inner function * Update mmcv/runner/dist_utils.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix dist_utils.py change _is_port_in_use() criterion. * Update dist_utils.py rename _is_port_in_use to _is_free_port * Update mmcv/runner/dist_utils.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update dist_utils.py fix lint * Update dist_utils.py fix lint Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/1876/head
parent
c33f248987
commit
cff3feccbe
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import functools
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
from collections import OrderedDict
|
||||
|
||||
|
@ -11,6 +13,24 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
|||
_unflatten_dense_tensors)
|
||||
|
||||
|
||||
def _find_free_port():
|
||||
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
# Binding to port 0 will cause the OS to find an available port for us
|
||||
sock.bind(('', 0))
|
||||
port = sock.getsockname()[1]
|
||||
sock.close()
|
||||
# NOTE: there is still a chance the port could be taken by other processes.
|
||||
return port
|
||||
|
||||
|
||||
def _is_free_port(port):
|
||||
ips = socket.gethostbyname_ex(socket.gethostname())[-1]
|
||||
ips.append('localhost')
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
|
||||
|
||||
|
||||
def init_dist(launcher, backend='nccl', **kwargs):
|
||||
if mp.get_start_method(allow_none=True) is None:
|
||||
mp.set_start_method('spawn')
|
||||
|
@ -64,8 +84,12 @@ def _init_dist_slurm(backend, port=None):
|
|||
elif 'MASTER_PORT' in os.environ:
|
||||
pass # use MASTER_PORT in the environment variable
|
||||
else:
|
||||
# 29500 is torch.distributed default port
|
||||
os.environ['MASTER_PORT'] = '29500'
|
||||
# if torch.distributed default port(29500) is available
|
||||
# then use it, else find a free port
|
||||
if _is_free_port(29500):
|
||||
os.environ['MASTER_PORT'] = '29500'
|
||||
else:
|
||||
os.environ['MASTER_PORT'] = str(_find_free_port())
|
||||
# use MASTER_ADDR in the environment variable if it already exists
|
||||
if 'MASTER_ADDR' not in os.environ:
|
||||
os.environ['MASTER_ADDR'] = addr
|
||||
|
|
Loading…
Reference in New Issue