mirror of https://github.com/UX-Decoder/DINOv.git
70 lines
2.6 KiB
Python
70 lines
2.6 KiB
Python
import os
|
|
import json, time
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
|
|
def get_world_size():
|
|
if torch.distributed.is_initialized():
|
|
return torch.distributed.get_world_size()
|
|
return 1
|
|
|
|
def all_gather(x):
|
|
if get_world_size() > 1:
|
|
all_x = [torch.zeros_like(x) for _ in range(get_world_size())]
|
|
torch.distributed.all_gather(all_x, x.detach())
|
|
all_x[torch.distributed.get_rank()] = x
|
|
x = torch.stack(all_x, dim=0)
|
|
return x
|
|
|
|
def init_distributed_mode(args):
|
|
if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': # 'RANK' in os.environ and
|
|
args.rank = int(os.environ["RANK"])
|
|
args.world_size = int(os.environ['WORLD_SIZE'])
|
|
args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
|
|
|
|
print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank))
|
|
print(json.dumps(dict(os.environ), indent=2))
|
|
elif 'SLURM_PROCID' in os.environ:
|
|
args.rank = int(os.environ['SLURM_PROCID'])
|
|
args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID'])
|
|
args.world_size = int(os.environ['SLURM_NPROCS'])
|
|
|
|
if os.environ.get('HAND_DEFINE_DIST_URL', 0) == '1':
|
|
pass
|
|
else:
|
|
import util.hostlist as uh
|
|
nodenames = uh.parse_nodelist(os.environ['SLURM_JOB_NODELIST'])
|
|
gpu_ids = [int(node[3:]) for node in nodenames]
|
|
fixid = int(os.environ.get('FIX_DISTRIBUTED_PORT_NUMBER', 0))
|
|
# fixid += random.randint(0, 300)
|
|
port = str(3137 + int(min(gpu_ids)) + fixid)
|
|
args.dist_url = "tcp://{ip}:{port}".format(ip=uh.nodename_to_ip(nodenames[0]), port=port)
|
|
|
|
print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count()))
|
|
|
|
|
|
else:
|
|
print('Not using distributed mode')
|
|
args.distributed = False
|
|
args.world_size = 1
|
|
args.rank = 0
|
|
args.local_rank = 0
|
|
return
|
|
|
|
print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
|
|
args.distributed = True
|
|
torch.cuda.set_device(args.local_rank)
|
|
args.dist_backend = 'nccl'
|
|
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
|
|
|
|
torch.distributed.init_process_group(
|
|
backend=args.dist_backend,
|
|
world_size=args.world_size,
|
|
rank=args.rank,
|
|
init_method=args.dist_url,
|
|
)
|
|
|
|
print("Before torch.distributed.barrier()")
|
|
torch.distributed.barrier()
|
|
print("End torch.distributed.barrier()") |