From 83e7cc24ea8ec112a4ed0f2b32e36e6f91eea096 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Thu, 30 Mar 2023 16:34:38 +0800 Subject: [PATCH] [Fix] Fix accepting an unexpected argument local-rank in PyTorch 2.0 (#2813) --- tools/test.py | 5 ++++- tools/train.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tools/test.py b/tools/test.py index a643b08be..7454fa60b 100644 --- a/tools/test.py +++ b/tools/test.py @@ -97,7 +97,10 @@ def parse_args(): type=float, default=0.5, help='Opacity of painted segmentation map. In (0, 1] range.') - parser.add_argument('--local_rank', type=int, default=0) + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) diff --git a/tools/train.py b/tools/train.py index c4219b04b..5cc0b15bc 100644 --- a/tools/train.py +++ b/tools/train.py @@ -86,7 +86,10 @@ def parse_args(): choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) parser.add_argument( '--auto-resume', action='store_true',