mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1336 from xwang233/add-local-rank
Make train.py compatible with torchrun
This commit is contained in:
commit
2456223052
2
train.py
2
train.py
@ -355,6 +355,8 @@ def main():
|
||||
args.world_size = 1
|
||||
args.rank = 0 # global rank
|
||||
if args.distributed:
|
||||
if 'LOCAL_RANK' in os.environ:
|
||||
args.local_rank = int(os.getenv('LOCAL_RANK'))
|
||||
args.device = 'cuda:%d' % args.local_rank
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
|
Loading…
x
Reference in New Issue
Block a user