From 73df26d74923eb20f3253ebf6c6a0f479fe861d1 Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Mon, 27 Mar 2023 10:34:54 +0800 Subject: [PATCH] [Enhancement] Accepts local-rank in train.py and test.py (#1806) * [Enhancement] Accepts local-rank * add * update --- tools/test.py | 5 ++++- tools/train.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tools/test.py b/tools/test.py index 395a89f7..15645f22 100755 --- a/tools/test.py +++ b/tools/test.py @@ -45,7 +45,10 @@ def parse_args(): help='Job launcher') parser.add_argument( '--tta', action='store_true', help='Test time augmentation') - parser.add_argument('--local_rank', type=int, default=0) + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/test.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) diff --git a/tools/train.py b/tools/train.py index aee2d41b..df691f79 100755 --- a/tools/train.py +++ b/tools/train.py @@ -41,7 +41,10 @@ def parse_args(): choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='Job launcher') - parser.add_argument('--local_rank', type=int, default=0) + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank)