mirror of https://github.com/FoundationVision/GLEE
90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
#!/usr/bin/python3
|
|
import os
|
|
import sys
|
|
import socket
|
|
import random
|
|
import argparse
|
|
import subprocess
|
|
import time
|
|
|
|
def _find_free_port():
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
# Binding to port 0 will cause the OS to find an available port for us
|
|
sock.bind(("", 0))
|
|
port = sock.getsockname()[1]
|
|
sock.close()
|
|
# NOTE: there is still a chance the port could be taken by other processes.
|
|
return port
|
|
|
|
|
|
def _get_rand_port():
|
|
hour = time.time() // 3600
|
|
random.seed(int(hour))
|
|
return random.randrange(40000, 60000)
|
|
|
|
|
|
def init_workdir():
|
|
ROOT = os.path.dirname(os.path.abspath(__file__))
|
|
os.chdir(ROOT)
|
|
sys.path.insert(0, ROOT)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Launcher')
|
|
parser.add_argument('--launch', type=str, default='projects/GLEE/train_net.py',
|
|
help='Specify launcher script.')
|
|
parser.add_argument('--dist', type=int, default=1,
|
|
help='Whether start by torch.distributed.launch.')
|
|
parser.add_argument('--np', type=int, default=8,
|
|
help='number of (GPU) processes per node')
|
|
parser.add_argument('--nn', type=int, default=1,
|
|
help='number of workers')
|
|
parser.add_argument('--port', type=int, default=-1,
|
|
help='master port for communication')
|
|
parser.add_argument('--worker_rank', type=int, default=0)
|
|
parser.add_argument('--master_address', type=str)
|
|
args, other_args = parser.parse_known_args()
|
|
|
|
# change to current dir
|
|
prj_dir = os.path.dirname(os.path.abspath(__file__))
|
|
os.chdir(prj_dir)
|
|
init_workdir()
|
|
|
|
# Get training info
|
|
master_address = args.master_address
|
|
num_processes_per_worker = args.np
|
|
num_workers = args.nn
|
|
worker_rank = args.worker_rank
|
|
|
|
# Get port
|
|
if args.port > 0:
|
|
master_port = args.port
|
|
elif num_workers == 1:
|
|
master_port = _find_free_port()
|
|
else: # This reduce the conflict possibility, but the port availablity is not guaranteed.
|
|
master_port = _get_rand_port()
|
|
|
|
|
|
if args.dist >= 1:
|
|
print(f'Start {args.launch} by torch.distributed.launch with port {master_port}!', flush=True)
|
|
cmd = f'python3 {args.launch}\
|
|
--num-gpus={num_processes_per_worker}'
|
|
if num_workers > 1:
|
|
# multi-machine
|
|
assert master_address is not None
|
|
dist_url = "tcp://" + str(master_address) + ":" + str(master_port)
|
|
cmd += f" --num-machines={num_workers}\
|
|
--machine-rank={worker_rank}\
|
|
--dist-url={dist_url}"
|
|
else:
|
|
print(f'Start {args.launch}!', flush=True)
|
|
cmd = f'python3 {args.launch}'
|
|
# $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}'
|
|
for argv in other_args:
|
|
cmd += f' {argv}'
|
|
print("==> Run command: " + cmd)
|
|
exit_code = subprocess.call(cmd, shell=True)
|
|
sys.exit(exit_code)
|
|
|