Add dist utils (#157)

* add dist utils

* rename due to a typo

* bump version to 0.2.15

* fix imports
pull/159/head v0.2.15
Kai Chen 2019-12-12 22:18:40 +08:00 committed by GitHub
parent 2330c4f468
commit ead4bc39bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 88 additions and 43 deletions

View File

@ -8,8 +8,8 @@ from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
save_checkpoint)
from .parallel_test import parallel_test
from .priority import Priority, get_priority
from .utils import (get_host_info, get_dist_info, master_only, get_time_str,
obj_from_dict)
from .utils import get_host_info, get_time_str, obj_from_dict
from .dist_utils import init_dist, get_dist_info, master_only
__all__ = [
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
@ -17,6 +17,6 @@ __all__ = [
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu',
'save_checkpoint', 'parallel_test', 'Priority', 'get_priority',
'get_host_info', 'get_dist_info', 'master_only', 'get_time_str',
'obj_from_dict'
'get_host_info', 'get_time_str', 'obj_from_dict', 'init_dist',
'get_dist_info', 'master_only'
]

View File

@ -12,7 +12,7 @@ from terminaltables import AsciiTable
from torch.utils import model_zoo
import mmcv
from .utils import get_dist_info
from .dist_utils import get_dist_info
open_mmlab_model_urls = {
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501

View File

@ -0,0 +1,75 @@
import functools
import os
import subprocess
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError('Invalid launcher type: {}'.format(launcher))
def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError
def _init_dist_slurm(backend, port=29500, **kwargs):
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(
'scontrol show hostname {} | head -n1'.format(node_list))
os.environ['MASTER_PORT'] = str(port)
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
def get_dist_info():
if torch.__version__ < '1.0':
initialized = dist._initialized
else:
if dist.is_available():
initialized = dist.is_initialized()
else:
initialized = False
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper

View File

@ -1,4 +1,4 @@
from ..utils import master_only
from ..dist_utils import master_only
from .hook import Hook

View File

@ -9,7 +9,8 @@ from threading import Thread
import requests
from six.moves.queue import Empty, Queue
from ...utils import get_host_info, master_only
from ...dist_utils import master_only
from ...utils import get_host_info
from .base import LoggerHook

View File

@ -2,7 +2,7 @@ import os.path as osp
import torch
from ...utils import master_only
from ...dist_utils import master_only
from .base import LoggerHook

View File

@ -1,4 +1,4 @@
from ...utils import master_only
from ...dist_utils import master_only
from .base import LoggerHook
import numbers

View File

@ -7,11 +7,12 @@ import torch
import mmcv
from . import hooks
from .checkpoint import load_checkpoint, save_checkpoint
from .dist_utils import get_dist_info
from .hooks import (CheckpointHook, Hook, IterTimerHook, LrUpdaterHook,
OptimizerHook, lr_updater)
from .log_buffer import LogBuffer
from .priority import get_priority
from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict
from .utils import get_host_info, get_time_str, obj_from_dict
class Runner(object):

View File

@ -1,12 +1,8 @@
import functools
import sys
import time
from getpass import getuser
from socket import gethostname
import torch
import torch.distributed as dist
import mmcv
@ -14,34 +10,6 @@ def get_host_info():
return '{}@{}'.format(getuser(), gethostname())
def get_dist_info():
if torch.__version__ < '1.0':
initialized = dist._initialized
else:
if dist.is_available():
initialized = dist.is_initialized()
else:
initialized = False
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())

View File

@ -1 +1 @@
__version__ = '0.2.14'
__version__ = '0.2.15'