Add dist utils (#157)

* add dist utils

* rename due to a typo

* bump version to 0.2.15

* fix imports
pull/159/head v0.2.15
Kai Chen 2019-12-12 22:18:40 +08:00 committed by GitHub
parent 2330c4f468
commit ead4bc39bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 88 additions and 43 deletions

View File

@ -8,8 +8,8 @@ from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
save_checkpoint) save_checkpoint)
from .parallel_test import parallel_test from .parallel_test import parallel_test
from .priority import Priority, get_priority from .priority import Priority, get_priority
from .utils import (get_host_info, get_dist_info, master_only, get_time_str, from .utils import get_host_info, get_time_str, obj_from_dict
obj_from_dict) from .dist_utils import init_dist, get_dist_info, master_only
__all__ = [ __all__ = [
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook', 'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
@ -17,6 +17,6 @@ __all__ = [
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook', 'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu',
'save_checkpoint', 'parallel_test', 'Priority', 'get_priority', 'save_checkpoint', 'parallel_test', 'Priority', 'get_priority',
'get_host_info', 'get_dist_info', 'master_only', 'get_time_str', 'get_host_info', 'get_time_str', 'obj_from_dict', 'init_dist',
'obj_from_dict' 'get_dist_info', 'master_only'
] ]

View File

@ -12,7 +12,7 @@ from terminaltables import AsciiTable
from torch.utils import model_zoo from torch.utils import model_zoo
import mmcv import mmcv
from .utils import get_dist_info from .dist_utils import get_dist_info
open_mmlab_model_urls = { open_mmlab_model_urls = {
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501 'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501

View File

@ -0,0 +1,75 @@
import functools
import os
import subprocess
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError('Invalid launcher type: {}'.format(launcher))
def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError
def _init_dist_slurm(backend, port=29500, **kwargs):
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(
'scontrol show hostname {} | head -n1'.format(node_list))
os.environ['MASTER_PORT'] = str(port)
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
def get_dist_info():
if torch.__version__ < '1.0':
initialized = dist._initialized
else:
if dist.is_available():
initialized = dist.is_initialized()
else:
initialized = False
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper

View File

@ -1,4 +1,4 @@
from ..utils import master_only from ..dist_utils import master_only
from .hook import Hook from .hook import Hook

View File

@ -9,7 +9,8 @@ from threading import Thread
import requests import requests
from six.moves.queue import Empty, Queue from six.moves.queue import Empty, Queue
from ...utils import get_host_info, master_only from ...dist_utils import master_only
from ...utils import get_host_info
from .base import LoggerHook from .base import LoggerHook

View File

@ -2,7 +2,7 @@ import os.path as osp
import torch import torch
from ...utils import master_only from ...dist_utils import master_only
from .base import LoggerHook from .base import LoggerHook

View File

@ -1,4 +1,4 @@
from ...utils import master_only from ...dist_utils import master_only
from .base import LoggerHook from .base import LoggerHook
import numbers import numbers

View File

@ -7,11 +7,12 @@ import torch
import mmcv import mmcv
from . import hooks from . import hooks
from .checkpoint import load_checkpoint, save_checkpoint from .checkpoint import load_checkpoint, save_checkpoint
from .dist_utils import get_dist_info
from .hooks import (CheckpointHook, Hook, IterTimerHook, LrUpdaterHook, from .hooks import (CheckpointHook, Hook, IterTimerHook, LrUpdaterHook,
OptimizerHook, lr_updater) OptimizerHook, lr_updater)
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .priority import get_priority from .priority import get_priority
from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict from .utils import get_host_info, get_time_str, obj_from_dict
class Runner(object): class Runner(object):

View File

@ -1,12 +1,8 @@
import functools
import sys import sys
import time import time
from getpass import getuser from getpass import getuser
from socket import gethostname from socket import gethostname
import torch
import torch.distributed as dist
import mmcv import mmcv
@ -14,34 +10,6 @@ def get_host_info():
return '{}@{}'.format(getuser(), gethostname()) return '{}@{}'.format(getuser(), gethostname())
def get_dist_info():
if torch.__version__ < '1.0':
initialized = dist._initialized
else:
if dist.is_available():
initialized = dist.is_initialized()
else:
initialized = False
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
def get_time_str(): def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime()) return time.strftime('%Y%m%d_%H%M%S', time.localtime())

View File

@ -1 +1 @@
__version__ = '0.2.14' __version__ = '0.2.15'