mirror of https://github.com/open-mmlab/mmcv.git
Add dist utils (#157)
* add dist utils * rename due to a typo * bump version to 0.2.15 * fix importspull/159/head v0.2.15
parent
2330c4f468
commit
ead4bc39bb
|
@ -8,8 +8,8 @@ from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
|
|||
save_checkpoint)
|
||||
from .parallel_test import parallel_test
|
||||
from .priority import Priority, get_priority
|
||||
from .utils import (get_host_info, get_dist_info, master_only, get_time_str,
|
||||
obj_from_dict)
|
||||
from .utils import get_host_info, get_time_str, obj_from_dict
|
||||
from .dist_utils import init_dist, get_dist_info, master_only
|
||||
|
||||
__all__ = [
|
||||
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
|
||||
|
@ -17,6 +17,6 @@ __all__ = [
|
|||
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
|
||||
'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu',
|
||||
'save_checkpoint', 'parallel_test', 'Priority', 'get_priority',
|
||||
'get_host_info', 'get_dist_info', 'master_only', 'get_time_str',
|
||||
'obj_from_dict'
|
||||
'get_host_info', 'get_time_str', 'obj_from_dict', 'init_dist',
|
||||
'get_dist_info', 'master_only'
|
||||
]
|
||||
|
|
|
@ -12,7 +12,7 @@ from terminaltables import AsciiTable
|
|||
from torch.utils import model_zoo
|
||||
|
||||
import mmcv
|
||||
from .utils import get_dist_info
|
||||
from .dist_utils import get_dist_info
|
||||
|
||||
open_mmlab_model_urls = {
|
||||
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
import functools
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
def init_dist(launcher, backend='nccl', **kwargs):
|
||||
if mp.get_start_method(allow_none=True) is None:
|
||||
mp.set_start_method('spawn')
|
||||
if launcher == 'pytorch':
|
||||
_init_dist_pytorch(backend, **kwargs)
|
||||
elif launcher == 'mpi':
|
||||
_init_dist_mpi(backend, **kwargs)
|
||||
elif launcher == 'slurm':
|
||||
_init_dist_slurm(backend, **kwargs)
|
||||
else:
|
||||
raise ValueError('Invalid launcher type: {}'.format(launcher))
|
||||
|
||||
|
||||
def _init_dist_pytorch(backend, **kwargs):
|
||||
# TODO: use local_rank instead of rank % num_gpus
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
|
||||
def _init_dist_mpi(backend, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _init_dist_slurm(backend, port=29500, **kwargs):
|
||||
proc_id = int(os.environ['SLURM_PROCID'])
|
||||
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||
node_list = os.environ['SLURM_NODELIST']
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(proc_id % num_gpus)
|
||||
addr = subprocess.getoutput(
|
||||
'scontrol show hostname {} | head -n1'.format(node_list))
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
os.environ['MASTER_ADDR'] = addr
|
||||
os.environ['WORLD_SIZE'] = str(ntasks)
|
||||
os.environ['RANK'] = str(proc_id)
|
||||
dist.init_process_group(backend=backend)
|
||||
|
||||
|
||||
def get_dist_info():
|
||||
if torch.__version__ < '1.0':
|
||||
initialized = dist._initialized
|
||||
else:
|
||||
if dist.is_available():
|
||||
initialized = dist.is_initialized()
|
||||
else:
|
||||
initialized = False
|
||||
if initialized:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def master_only(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
|
@ -1,4 +1,4 @@
|
|||
from ..utils import master_only
|
||||
from ..dist_utils import master_only
|
||||
from .hook import Hook
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,8 @@ from threading import Thread
|
|||
import requests
|
||||
from six.moves.queue import Empty, Queue
|
||||
|
||||
from ...utils import get_host_info, master_only
|
||||
from ...dist_utils import master_only
|
||||
from ...utils import get_host_info
|
||||
from .base import LoggerHook
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import os.path as osp
|
|||
|
||||
import torch
|
||||
|
||||
from ...utils import master_only
|
||||
from ...dist_utils import master_only
|
||||
from .base import LoggerHook
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from ...utils import master_only
|
||||
from ...dist_utils import master_only
|
||||
from .base import LoggerHook
|
||||
import numbers
|
||||
|
||||
|
|
|
@ -7,11 +7,12 @@ import torch
|
|||
import mmcv
|
||||
from . import hooks
|
||||
from .checkpoint import load_checkpoint, save_checkpoint
|
||||
from .dist_utils import get_dist_info
|
||||
from .hooks import (CheckpointHook, Hook, IterTimerHook, LrUpdaterHook,
|
||||
OptimizerHook, lr_updater)
|
||||
from .log_buffer import LogBuffer
|
||||
from .priority import get_priority
|
||||
from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict
|
||||
from .utils import get_host_info, get_time_str, obj_from_dict
|
||||
|
||||
|
||||
class Runner(object):
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
import functools
|
||||
import sys
|
||||
import time
|
||||
from getpass import getuser
|
||||
from socket import gethostname
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import mmcv
|
||||
|
||||
|
||||
|
@ -14,34 +10,6 @@ def get_host_info():
|
|||
return '{}@{}'.format(getuser(), gethostname())
|
||||
|
||||
|
||||
def get_dist_info():
|
||||
if torch.__version__ < '1.0':
|
||||
initialized = dist._initialized
|
||||
else:
|
||||
if dist.is_available():
|
||||
initialized = dist.is_initialized()
|
||||
else:
|
||||
initialized = False
|
||||
if initialized:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def master_only(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_time_str():
|
||||
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '0.2.14'
|
||||
__version__ = '0.2.15'
|
||||
|
|
Loading…
Reference in New Issue