mirror of https://github.com/open-mmlab/mmocr.git
[Enhancement] Accepts local-rank in train.py and test.py (#1806)
* [Enhancement] Accepts local-rank * add * updatepull/1811/head
parent
f47cff5199
commit
73df26d749
|
@ -45,7 +45,10 @@ def parse_args():
|
|||
help='Job launcher')
|
||||
parser.add_argument(
|
||||
'--tta', action='store_true', help='Test time augmentation')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
|
||||
# will pass the `--local-rank` parameter to `tools/test.py` instead
|
||||
# of `--local_rank`.
|
||||
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
|
|
@ -41,7 +41,10 @@ def parse_args():
|
|||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='Job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
|
||||
# will pass the `--local-rank` parameter to `tools/train.py` instead
|
||||
# of `--local_rank`.
|
||||
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
|
Loading…
Reference in New Issue