From c30e91db87bed113533dd6d0249b1f8cf09eb4cc Mon Sep 17 00:00:00 2001 From: Cao Yuhang Date: Sun, 14 Jun 2020 21:09:48 +0800 Subject: [PATCH] share torch version (#343) --- mmcv/parallel/distributed.py | 5 +++-- mmcv/parallel/distributed_deprecated.py | 3 ++- mmcv/runner/dist_utils.py | 4 +++- mmcv/runner/hooks/logger/tensorboard.py | 5 ++--- mmcv/utils/__init__.py | 4 +++- mmcv/utils/env.py | 4 ++++ mmcv/utils/parrots_wrapper.py | 20 +++++++++++--------- 7 files changed, 28 insertions(+), 17 deletions(-) create mode 100644 mmcv/utils/env.py diff --git a/mmcv/parallel/distributed.py b/mmcv/parallel/distributed.py index 54da245d0..07771a38a 100644 --- a/mmcv/parallel/distributed.py +++ b/mmcv/parallel/distributed.py @@ -3,6 +3,7 @@ import torch from torch.nn.parallel.distributed import (DistributedDataParallel, _find_tensors) +from mmcv.utils import TORCH_VERSION from .scatter_gather import scatter_kwargs @@ -47,7 +48,7 @@ class MMDistributedDataParallel(DistributedDataParallel): else: self.reducer.prepare_for_backward([]) else: - if torch.__version__ > '1.2': + if TORCH_VERSION > '1.2': self.require_forward_param_sync = False return output @@ -79,6 +80,6 @@ class MMDistributedDataParallel(DistributedDataParallel): else: self.reducer.prepare_for_backward([]) else: - if torch.__version__ > '1.2': + if TORCH_VERSION > '1.2': self.require_forward_param_sync = False return output diff --git a/mmcv/parallel/distributed_deprecated.py b/mmcv/parallel/distributed_deprecated.py index c1f6fd8d4..591e6a361 100644 --- a/mmcv/parallel/distributed_deprecated.py +++ b/mmcv/parallel/distributed_deprecated.py @@ -5,6 +5,7 @@ import torch.nn as nn from torch._utils import (_flatten_dense_tensors, _take_tensors, _unflatten_dense_tensors) +from mmcv.utils import TORCH_VERSION from .scatter_gather import scatter_kwargs @@ -37,7 +38,7 @@ class MMDistributedDataParallel(nn.Module): self._dist_broadcast_coalesced(module_states, self.broadcast_bucket_size) if self.broadcast_buffers: - if torch.__version__ < '1.0': + if TORCH_VERSION < '1.0': buffers = [b.data for b in self.module._all_buffers()] else: buffers = [b.data for b in self.module.buffers()] diff --git a/mmcv/runner/dist_utils.py b/mmcv/runner/dist_utils.py index 67572fc95..1aa9ce1c9 100644 --- a/mmcv/runner/dist_utils.py +++ b/mmcv/runner/dist_utils.py @@ -7,6 +7,8 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +from mmcv.utils import TORCH_VERSION + def init_dist(launcher, backend='nccl', **kwargs): if mp.get_start_method(allow_none=True) is None: @@ -49,7 +51,7 @@ def _init_dist_slurm(backend, port=29500): def get_dist_info(): - if torch.__version__ < '1.0': + if TORCH_VERSION < '1.0': initialized = dist._initialized else: if dist.is_available(): diff --git a/mmcv/runner/hooks/logger/tensorboard.py b/mmcv/runner/hooks/logger/tensorboard.py index df4478a87..fd9ac198e 100644 --- a/mmcv/runner/hooks/logger/tensorboard.py +++ b/mmcv/runner/hooks/logger/tensorboard.py @@ -1,8 +1,7 @@ # Copyright (c) Open-MMLab. All rights reserved. import os.path as osp -import torch - +from mmcv.utils import TORCH_VERSION from ...dist_utils import master_only from ..hook import HOOKS from .base import LoggerHook @@ -22,7 +21,7 @@ class TensorboardLoggerHook(LoggerHook): @master_only def before_run(self, runner): - if torch.__version__ < '1.1' or torch.__version__ == 'parrots': + if TORCH_VERSION < '1.1' or TORCH_VERSION == 'parrots': try: from tensorboardX import SummaryWriter except ImportError: diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 870d2a740..96cc0be16 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Open-MMLab. All rights reserved. from .config import Config, ConfigDict, DictAction +from .env import TORCH_VERSION from .logging import get_logger, print_log from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of, is_str, is_tuple_of, iter_cast, list_cast, @@ -29,5 +30,6 @@ __all__ = [ 'CUDA_HOME', 'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension', - 'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader' + 'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader', + 'TORCH_VERSION' ] diff --git a/mmcv/utils/env.py b/mmcv/utils/env.py new file mode 100644 index 000000000..2f31f7194 --- /dev/null +++ b/mmcv/utils/env.py @@ -0,0 +1,4 @@ +# This file holding some environment constant for sharing by other files +import torch + +TORCH_VERSION = torch.__version__ diff --git a/mmcv/utils/parrots_wrapper.py b/mmcv/utils/parrots_wrapper.py index bc35f3415..e83bbd6e2 100644 --- a/mmcv/utils/parrots_wrapper.py +++ b/mmcv/utils/parrots_wrapper.py @@ -2,9 +2,11 @@ from functools import partial import torch +from .env import TORCH_VERSION + def _get_cuda_home(): - if torch.__version__ == 'parrots': + if TORCH_VERSION == 'parrots': from parrots.utils.build_extension import CUDA_HOME else: from torch.utils.cpp_extension import CUDA_HOME @@ -12,7 +14,7 @@ def _get_cuda_home(): def get_build_config(): - if torch.__version__ == 'parrots': + if TORCH_VERSION == 'parrots': from parrots.config import get_build_info return get_build_info() else: @@ -20,7 +22,7 @@ def get_build_config(): def _get_conv(): - if torch.__version__ == 'parrots': + if TORCH_VERSION == 'parrots': from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin else: from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin @@ -28,7 +30,7 @@ def _get_conv(): def _get_dataloader(): - if torch.__version__ == 'parrots': + if TORCH_VERSION == 'parrots': from torch.utils.data import DataLoader, PoolDataLoader else: from torch.utils.data import DataLoader @@ -37,7 +39,7 @@ def _get_dataloader(): def _get_extension(): - if torch.__version__ == 'parrots': + if TORCH_VERSION == 'parrots': from parrots.utils.build_extension import BuildExtension, Extension CppExtension = partial(Extension, cuda=False) CUDAExtension = partial(Extension, cuda=True) @@ -48,7 +50,7 @@ def _get_extension(): def _get_pool(): - if torch.__version__ == 'parrots': + if TORCH_VERSION == 'parrots': from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd) @@ -60,7 +62,7 @@ def _get_pool(): def _get_norm(): - if torch.__version__ == 'parrots': + if TORCH_VERSION == 'parrots': from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm SyncBatchNorm_ = torch.nn.SyncBatchNorm2d else: @@ -81,11 +83,11 @@ _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() class SyncBatchNorm(SyncBatchNorm_): def _specify_ddp_gpu_num(self, gpu_size): - if torch.__version__ != 'parrots': + if TORCH_VERSION != 'parrots': super()._specify_ddp_gpu_num(gpu_size) def _check_input_dim(self, input): - if torch.__version__ == 'parrots': + if TORCH_VERSION == 'parrots': if input.dim() < 2: raise ValueError( f'expected at least 2D input (got {input.dim()}D input)')