mmrazor/tests/utils/set_dist_env.py

32 lines
824 B
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import os
import random
import torch
import torch.distributed as dist
class SetDistEnv:
def __init__(self, using_cuda=False, port=None) -> None:
self.using_cuda = using_cuda
if self.using_cuda:
assert torch.cuda.is_available()
if port is None:
port = random.randint(10000, 20000)
self.port = port
def __enter__(self):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(self.port)
# initialize the process group
if self.using_cuda:
backend = 'nccl'
else:
backend = 'gloo'
dist.init_process_group(backend, rank=0, world_size=1)
def __exit__(self, exc_type, exc_value, tb):
dist.destroy_process_group()