mirror of https://github.com/alibaba/EasyCV.git
103 lines
3.3 KiB
Python
103 lines
3.3 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from __future__ import (absolute_import, division, print_function,
|
|
unicode_literals)
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from argparse import REMAINDER, ArgumentParser
|
|
|
|
|
|
def parse_args():
|
|
"""
|
|
Helper function parsing the command line options
|
|
@retval ArgumentParser
|
|
"""
|
|
parser = ArgumentParser(description='PyTorch distributed training launch '
|
|
'helper utilty that will spawn up '
|
|
'multiple distributed processes')
|
|
|
|
# Optional arguments for the launch helper
|
|
parser.add_argument(
|
|
'--nproc_per_node',
|
|
type=int,
|
|
default=1,
|
|
help='The number of processes to launch on each node, '
|
|
'for GPU training, this is recommended to be set '
|
|
'to the number of GPUs in your system so that '
|
|
'each process can be bound to a single GPU.')
|
|
|
|
parser.add_argument(
|
|
'--local_mode',
|
|
action='store_true',
|
|
help='If assigned, traning_script should be path of python'
|
|
'script, otherwise python module name')
|
|
|
|
# positional
|
|
parser.add_argument(
|
|
'training_script',
|
|
type=str,
|
|
help='The full path to the single GPU training '
|
|
'program/script to be launched in parallel, '
|
|
'followed by all the arguments for the '
|
|
'training script',
|
|
default='tools/train.py')
|
|
|
|
# rest from the training program
|
|
parser.add_argument('training_script_args', nargs=REMAINDER)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
args.node_rank = int(os.environ.get('RANK', '0'))
|
|
args.nnodes = int(os.getenv('WORLD_SIZE', '1'))
|
|
# world size in terms of number of processes
|
|
dist_world_size = args.nproc_per_node * args.nnodes
|
|
|
|
# set PyTorch distributed related environmental variables
|
|
current_env = os.environ.copy()
|
|
current_env['WORLD_SIZE'] = str(dist_world_size)
|
|
|
|
processes = []
|
|
|
|
if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
|
|
current_env['OMP_NUM_THREADS'] = str(1)
|
|
print('*****************************************\n'
|
|
'Setting OMP_NUM_THREADS environment variable for each process '
|
|
'to be {} in default, to avoid your system being overloaded, '
|
|
'please further tune the variable for optimal performance in '
|
|
'your application as needed. \n'
|
|
'*****************************************'.format(
|
|
current_env['OMP_NUM_THREADS']))
|
|
|
|
for local_rank in range(0, args.nproc_per_node):
|
|
# each process's rank
|
|
dist_rank = args.nproc_per_node * args.node_rank + local_rank
|
|
current_env['RANK'] = str(dist_rank)
|
|
current_env['LOCAL_RANK'] = str(local_rank)
|
|
|
|
# spawn the processes
|
|
cmd = [sys.executable, '-u']
|
|
if not args.local_mode:
|
|
cmd.append('-m')
|
|
|
|
cmd.append(args.training_script)
|
|
|
|
cmd.append('--local_rank={}'.format(local_rank))
|
|
|
|
cmd.extend(args.training_script_args)
|
|
|
|
process = subprocess.Popen(cmd, env=current_env)
|
|
processes.append(process)
|
|
|
|
for process in processes:
|
|
process.wait()
|
|
if process.returncode != 0:
|
|
raise subprocess.CalledProcessError(
|
|
returncode=process.returncode, cmd=cmd)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|