Fix distributed flag bug w/ flex device handling

This commit is contained in:
Ross Wightman 2024-02-03 16:26:15 -08:00
parent bee0471f91
commit a08b57e801

View File

@ -92,6 +92,7 @@ def init_distributed_device(args):
args.world_size = result['world_size']
args.rank = result['global_rank']
args.local_rank = result['local_rank']
args.distributed = args.world_size > 1
device = torch.device(args.device)
return device