From ead4bc39bb9c9527a0440e30e59070901b7c121d Mon Sep 17 00:00:00 2001 From: Kai Chen Date: Thu, 12 Dec 2019 22:18:40 +0800 Subject: [PATCH] Add dist utils (#157) * add dist utils * rename due to a typo * bump version to 0.2.15 * fix imports --- mmcv/runner/__init__.py | 8 +-- mmcv/runner/checkpoint.py | 2 +- mmcv/runner/dist_utils.py | 75 +++++++++++++++++++++++++ mmcv/runner/hooks/checkpoint.py | 2 +- mmcv/runner/hooks/logger/pavi.py | 3 +- mmcv/runner/hooks/logger/tensorboard.py | 2 +- mmcv/runner/hooks/logger/wandb.py | 2 +- mmcv/runner/runner.py | 3 +- mmcv/runner/utils.py | 32 ----------- mmcv/version.py | 2 +- 10 files changed, 88 insertions(+), 43 deletions(-) create mode 100644 mmcv/runner/dist_utils.py diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py index d95ededaf..8d1215e24 100644 --- a/mmcv/runner/__init__.py +++ b/mmcv/runner/__init__.py @@ -8,8 +8,8 @@ from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu, save_checkpoint) from .parallel_test import parallel_test from .priority import Priority, get_priority -from .utils import (get_host_info, get_dist_info, master_only, get_time_str, - obj_from_dict) +from .utils import get_host_info, get_time_str, obj_from_dict +from .dist_utils import init_dist, get_dist_info, master_only __all__ = [ 'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook', @@ -17,6 +17,6 @@ __all__ = [ 'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'parallel_test', 'Priority', 'get_priority', - 'get_host_info', 'get_dist_info', 'master_only', 'get_time_str', - 'obj_from_dict' + 'get_host_info', 'get_time_str', 'obj_from_dict', 'init_dist', + 'get_dist_info', 'master_only' ] diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index ccf4bda38..4bc22ef78 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -12,7 +12,7 @@ from terminaltables import AsciiTable from torch.utils import model_zoo import mmcv -from .utils import get_dist_info +from .dist_utils import get_dist_info open_mmlab_model_urls = { 'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501 diff --git a/mmcv/runner/dist_utils.py b/mmcv/runner/dist_utils.py new file mode 100644 index 000000000..cb6558548 --- /dev/null +++ b/mmcv/runner/dist_utils.py @@ -0,0 +1,75 @@ +import functools +import os +import subprocess + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'mpi': + _init_dist_mpi(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError('Invalid launcher type: {}'.format(launcher)) + + +def _init_dist_pytorch(backend, **kwargs): + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_mpi(backend, **kwargs): + raise NotImplementedError + + +def _init_dist_slurm(backend, port=29500, **kwargs): + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + 'scontrol show hostname {} | head -n1'.format(node_list)) + os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if torch.__version__ < '1.0': + initialized = dist._initialized + else: + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/mmcv/runner/hooks/checkpoint.py b/mmcv/runner/hooks/checkpoint.py index c27fc00c0..96d85f528 100644 --- a/mmcv/runner/hooks/checkpoint.py +++ b/mmcv/runner/hooks/checkpoint.py @@ -1,4 +1,4 @@ -from ..utils import master_only +from ..dist_utils import master_only from .hook import Hook diff --git a/mmcv/runner/hooks/logger/pavi.py b/mmcv/runner/hooks/logger/pavi.py index 670b10326..1a2321f4e 100644 --- a/mmcv/runner/hooks/logger/pavi.py +++ b/mmcv/runner/hooks/logger/pavi.py @@ -9,7 +9,8 @@ from threading import Thread import requests from six.moves.queue import Empty, Queue -from ...utils import get_host_info, master_only +from ...dist_utils import master_only +from ...utils import get_host_info from .base import LoggerHook diff --git a/mmcv/runner/hooks/logger/tensorboard.py b/mmcv/runner/hooks/logger/tensorboard.py index 7ed95f860..872908897 100644 --- a/mmcv/runner/hooks/logger/tensorboard.py +++ b/mmcv/runner/hooks/logger/tensorboard.py @@ -2,7 +2,7 @@ import os.path as osp import torch -from ...utils import master_only +from ...dist_utils import master_only from .base import LoggerHook diff --git a/mmcv/runner/hooks/logger/wandb.py b/mmcv/runner/hooks/logger/wandb.py index 176a47783..ed71ffa73 100644 --- a/mmcv/runner/hooks/logger/wandb.py +++ b/mmcv/runner/hooks/logger/wandb.py @@ -1,4 +1,4 @@ -from ...utils import master_only +from ...dist_utils import master_only from .base import LoggerHook import numbers diff --git a/mmcv/runner/runner.py b/mmcv/runner/runner.py index 74ad5981a..5278a826e 100644 --- a/mmcv/runner/runner.py +++ b/mmcv/runner/runner.py @@ -7,11 +7,12 @@ import torch import mmcv from . import hooks from .checkpoint import load_checkpoint, save_checkpoint +from .dist_utils import get_dist_info from .hooks import (CheckpointHook, Hook, IterTimerHook, LrUpdaterHook, OptimizerHook, lr_updater) from .log_buffer import LogBuffer from .priority import get_priority -from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict +from .utils import get_host_info, get_time_str, obj_from_dict class Runner(object): diff --git a/mmcv/runner/utils.py b/mmcv/runner/utils.py index 771127382..6d0d987d0 100644 --- a/mmcv/runner/utils.py +++ b/mmcv/runner/utils.py @@ -1,12 +1,8 @@ -import functools import sys import time from getpass import getuser from socket import gethostname -import torch -import torch.distributed as dist - import mmcv @@ -14,34 +10,6 @@ def get_host_info(): return '{}@{}'.format(getuser(), gethostname()) -def get_dist_info(): - if torch.__version__ < '1.0': - initialized = dist._initialized - else: - if dist.is_available(): - initialized = dist.is_initialized() - else: - initialized = False - if initialized: - rank = dist.get_rank() - world_size = dist.get_world_size() - else: - rank = 0 - world_size = 1 - return rank, world_size - - -def master_only(func): - - @functools.wraps(func) - def wrapper(*args, **kwargs): - rank, _ = get_dist_info() - if rank == 0: - return func(*args, **kwargs) - - return wrapper - - def get_time_str(): return time.strftime('%Y%m%d_%H%M%S', time.localtime()) diff --git a/mmcv/version.py b/mmcv/version.py index 10d10d203..a2570ebf4 100644 --- a/mmcv/version.py +++ b/mmcv/version.py @@ -1 +1 @@ -__version__ = '0.2.14' +__version__ = '0.2.15'