[Feature] Support finding free port in _init_dist_slurm() (#1846)

* [feat]:support find free port in _init_dist_slurm

* fix format

* Update mmcv/runner/dist_utils.py

should support port taken by a non-localhost address.

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Update dist_utils.py

Add Copyright.

* rename inner function

* Update mmcv/runner/dist_utils.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix dist_utils.py

change _is_port_in_use() criterion.

* Update dist_utils.py

rename _is_port_in_use to _is_free_port

* Update mmcv/runner/dist_utils.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update dist_utils.py

fix lint

* Update dist_utils.py

fix lint

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/1876/head
Alex Yang 2022-04-09 12:53:23 +08:00 committed by GitHub
parent c33f248987
commit cff3feccbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 2 deletions

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools import functools
import os import os
import socket
import subprocess import subprocess
from collections import OrderedDict from collections import OrderedDict
@ -11,6 +13,24 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors) _unflatten_dense_tensors)
def _find_free_port():
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(('', 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def _is_free_port(port):
ips = socket.gethostbyname_ex(socket.gethostname())[-1]
ips.append('localhost')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
def init_dist(launcher, backend='nccl', **kwargs): def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None: if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn') mp.set_start_method('spawn')
@ -64,8 +84,12 @@ def _init_dist_slurm(backend, port=None):
elif 'MASTER_PORT' in os.environ: elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable pass # use MASTER_PORT in the environment variable
else: else:
# 29500 is torch.distributed default port # if torch.distributed default port(29500) is available
# then use it, else find a free port
if _is_free_port(29500):
os.environ['MASTER_PORT'] = '29500' os.environ['MASTER_PORT'] = '29500'
else:
os.environ['MASTER_PORT'] = str(_find_free_port())
# use MASTER_ADDR in the environment variable if it already exists # use MASTER_ADDR in the environment variable if it already exists
if 'MASTER_ADDR' not in os.environ: if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = addr os.environ['MASTER_ADDR'] = addr