DINOv/utils/dist.py

70 lines
2.6 KiB
Python

import os
import json, time
import torch
import torch.distributed as dist
def get_world_size():
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
return 1
def all_gather(x):
if get_world_size() > 1:
all_x = [torch.zeros_like(x) for _ in range(get_world_size())]
torch.distributed.all_gather(all_x, x.detach())
all_x[torch.distributed.get_rank()] = x
x = torch.stack(all_x, dim=0)
return x
def init_distributed_mode(args):
if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': # 'RANK' in os.environ and
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank))
print(json.dumps(dict(os.environ), indent=2))
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID'])
args.world_size = int(os.environ['SLURM_NPROCS'])
if os.environ.get('HAND_DEFINE_DIST_URL', 0) == '1':
pass
else:
import util.hostlist as uh
nodenames = uh.parse_nodelist(os.environ['SLURM_JOB_NODELIST'])
gpu_ids = [int(node[3:]) for node in nodenames]
fixid = int(os.environ.get('FIX_DISTRIBUTED_PORT_NUMBER', 0))
# fixid += random.randint(0, 300)
port = str(3137 + int(min(gpu_ids)) + fixid)
args.dist_url = "tcp://{ip}:{port}".format(ip=uh.nodename_to_ip(nodenames[0]), port=port)
print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count()))
else:
print('Not using distributed mode')
args.distributed = False
args.world_size = 1
args.rank = 0
args.local_rank = 0
return
print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
args.distributed = True
torch.cuda.set_device(args.local_rank)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend,
world_size=args.world_size,
rank=args.rank,
init_method=args.dist_url,
)
print("Before torch.distributed.barrier()")
torch.distributed.barrier()
print("End torch.distributed.barrier()")