diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 95655f2c..92b8a6b8 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -92,6 +92,7 @@ def init_distributed_device(args): args.world_size = result['world_size'] args.rank = result['global_rank'] args.local_rank = result['local_rank'] + args.distributed = args.world_size > 1 device = torch.device(args.device) return device