mirror of https://github.com/open-mmlab/mmcv.git
Add dist utils (#157)
* add dist utils * rename due to a typo * bump version to 0.2.15 * fix importspull/159/head v0.2.15
parent
2330c4f468
commit
ead4bc39bb
|
@ -8,8 +8,8 @@ from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
|
||||||
save_checkpoint)
|
save_checkpoint)
|
||||||
from .parallel_test import parallel_test
|
from .parallel_test import parallel_test
|
||||||
from .priority import Priority, get_priority
|
from .priority import Priority, get_priority
|
||||||
from .utils import (get_host_info, get_dist_info, master_only, get_time_str,
|
from .utils import get_host_info, get_time_str, obj_from_dict
|
||||||
obj_from_dict)
|
from .dist_utils import init_dist, get_dist_info, master_only
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
|
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
|
||||||
|
@ -17,6 +17,6 @@ __all__ = [
|
||||||
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
|
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
|
||||||
'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu',
|
'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu',
|
||||||
'save_checkpoint', 'parallel_test', 'Priority', 'get_priority',
|
'save_checkpoint', 'parallel_test', 'Priority', 'get_priority',
|
||||||
'get_host_info', 'get_dist_info', 'master_only', 'get_time_str',
|
'get_host_info', 'get_time_str', 'obj_from_dict', 'init_dist',
|
||||||
'obj_from_dict'
|
'get_dist_info', 'master_only'
|
||||||
]
|
]
|
||||||
|
|
|
@ -12,7 +12,7 @@ from terminaltables import AsciiTable
|
||||||
from torch.utils import model_zoo
|
from torch.utils import model_zoo
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
from .utils import get_dist_info
|
from .dist_utils import get_dist_info
|
||||||
|
|
||||||
open_mmlab_model_urls = {
|
open_mmlab_model_urls = {
|
||||||
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501
|
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
|
||||||
|
def init_dist(launcher, backend='nccl', **kwargs):
|
||||||
|
if mp.get_start_method(allow_none=True) is None:
|
||||||
|
mp.set_start_method('spawn')
|
||||||
|
if launcher == 'pytorch':
|
||||||
|
_init_dist_pytorch(backend, **kwargs)
|
||||||
|
elif launcher == 'mpi':
|
||||||
|
_init_dist_mpi(backend, **kwargs)
|
||||||
|
elif launcher == 'slurm':
|
||||||
|
_init_dist_slurm(backend, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid launcher type: {}'.format(launcher))
|
||||||
|
|
||||||
|
|
||||||
|
def _init_dist_pytorch(backend, **kwargs):
|
||||||
|
# TODO: use local_rank instead of rank % num_gpus
|
||||||
|
rank = int(os.environ['RANK'])
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(rank % num_gpus)
|
||||||
|
dist.init_process_group(backend=backend, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_dist_mpi(backend, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def _init_dist_slurm(backend, port=29500, **kwargs):
|
||||||
|
proc_id = int(os.environ['SLURM_PROCID'])
|
||||||
|
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||||
|
node_list = os.environ['SLURM_NODELIST']
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(proc_id % num_gpus)
|
||||||
|
addr = subprocess.getoutput(
|
||||||
|
'scontrol show hostname {} | head -n1'.format(node_list))
|
||||||
|
os.environ['MASTER_PORT'] = str(port)
|
||||||
|
os.environ['MASTER_ADDR'] = addr
|
||||||
|
os.environ['WORLD_SIZE'] = str(ntasks)
|
||||||
|
os.environ['RANK'] = str(proc_id)
|
||||||
|
dist.init_process_group(backend=backend)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dist_info():
|
||||||
|
if torch.__version__ < '1.0':
|
||||||
|
initialized = dist._initialized
|
||||||
|
else:
|
||||||
|
if dist.is_available():
|
||||||
|
initialized = dist.is_initialized()
|
||||||
|
else:
|
||||||
|
initialized = False
|
||||||
|
if initialized:
|
||||||
|
rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
else:
|
||||||
|
rank = 0
|
||||||
|
world_size = 1
|
||||||
|
return rank, world_size
|
||||||
|
|
||||||
|
|
||||||
|
def master_only(func):
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
rank, _ = get_dist_info()
|
||||||
|
if rank == 0:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
|
@ -1,4 +1,4 @@
|
||||||
from ..utils import master_only
|
from ..dist_utils import master_only
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,8 @@ from threading import Thread
|
||||||
import requests
|
import requests
|
||||||
from six.moves.queue import Empty, Queue
|
from six.moves.queue import Empty, Queue
|
||||||
|
|
||||||
from ...utils import get_host_info, master_only
|
from ...dist_utils import master_only
|
||||||
|
from ...utils import get_host_info
|
||||||
from .base import LoggerHook
|
from .base import LoggerHook
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os.path as osp
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...utils import master_only
|
from ...dist_utils import master_only
|
||||||
from .base import LoggerHook
|
from .base import LoggerHook
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from ...utils import master_only
|
from ...dist_utils import master_only
|
||||||
from .base import LoggerHook
|
from .base import LoggerHook
|
||||||
import numbers
|
import numbers
|
||||||
|
|
||||||
|
|
|
@ -7,11 +7,12 @@ import torch
|
||||||
import mmcv
|
import mmcv
|
||||||
from . import hooks
|
from . import hooks
|
||||||
from .checkpoint import load_checkpoint, save_checkpoint
|
from .checkpoint import load_checkpoint, save_checkpoint
|
||||||
|
from .dist_utils import get_dist_info
|
||||||
from .hooks import (CheckpointHook, Hook, IterTimerHook, LrUpdaterHook,
|
from .hooks import (CheckpointHook, Hook, IterTimerHook, LrUpdaterHook,
|
||||||
OptimizerHook, lr_updater)
|
OptimizerHook, lr_updater)
|
||||||
from .log_buffer import LogBuffer
|
from .log_buffer import LogBuffer
|
||||||
from .priority import get_priority
|
from .priority import get_priority
|
||||||
from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict
|
from .utils import get_host_info, get_time_str, obj_from_dict
|
||||||
|
|
||||||
|
|
||||||
class Runner(object):
|
class Runner(object):
|
||||||
|
|
|
@ -1,12 +1,8 @@
|
||||||
import functools
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from getpass import getuser
|
from getpass import getuser
|
||||||
from socket import gethostname
|
from socket import gethostname
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,34 +10,6 @@ def get_host_info():
|
||||||
return '{}@{}'.format(getuser(), gethostname())
|
return '{}@{}'.format(getuser(), gethostname())
|
||||||
|
|
||||||
|
|
||||||
def get_dist_info():
|
|
||||||
if torch.__version__ < '1.0':
|
|
||||||
initialized = dist._initialized
|
|
||||||
else:
|
|
||||||
if dist.is_available():
|
|
||||||
initialized = dist.is_initialized()
|
|
||||||
else:
|
|
||||||
initialized = False
|
|
||||||
if initialized:
|
|
||||||
rank = dist.get_rank()
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
else:
|
|
||||||
rank = 0
|
|
||||||
world_size = 1
|
|
||||||
return rank, world_size
|
|
||||||
|
|
||||||
|
|
||||||
def master_only(func):
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
rank, _ = get_dist_info()
|
|
||||||
if rank == 0:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def get_time_str():
|
def get_time_str():
|
||||||
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||||
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
__version__ = '0.2.14'
|
__version__ = '0.2.15'
|
||||||
|
|
Loading…
Reference in New Issue