mirror of https://github.com/open-mmlab/mmcv.git
share torch version (#343)
parent
b87e774f66
commit
c30e91db87
|
@ -3,6 +3,7 @@ import torch
|
|||
from torch.nn.parallel.distributed import (DistributedDataParallel,
|
||||
_find_tensors)
|
||||
|
||||
from mmcv.utils import TORCH_VERSION
|
||||
from .scatter_gather import scatter_kwargs
|
||||
|
||||
|
||||
|
@ -47,7 +48,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
|
|||
else:
|
||||
self.reducer.prepare_for_backward([])
|
||||
else:
|
||||
if torch.__version__ > '1.2':
|
||||
if TORCH_VERSION > '1.2':
|
||||
self.require_forward_param_sync = False
|
||||
return output
|
||||
|
||||
|
@ -79,6 +80,6 @@ class MMDistributedDataParallel(DistributedDataParallel):
|
|||
else:
|
||||
self.reducer.prepare_for_backward([])
|
||||
else:
|
||||
if torch.__version__ > '1.2':
|
||||
if TORCH_VERSION > '1.2':
|
||||
self.require_forward_param_sync = False
|
||||
return output
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
|||
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
||||
_unflatten_dense_tensors)
|
||||
|
||||
from mmcv.utils import TORCH_VERSION
|
||||
from .scatter_gather import scatter_kwargs
|
||||
|
||||
|
||||
|
@ -37,7 +38,7 @@ class MMDistributedDataParallel(nn.Module):
|
|||
self._dist_broadcast_coalesced(module_states,
|
||||
self.broadcast_bucket_size)
|
||||
if self.broadcast_buffers:
|
||||
if torch.__version__ < '1.0':
|
||||
if TORCH_VERSION < '1.0':
|
||||
buffers = [b.data for b in self.module._all_buffers()]
|
||||
else:
|
||||
buffers = [b.data for b in self.module.buffers()]
|
||||
|
|
|
@ -7,6 +7,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from mmcv.utils import TORCH_VERSION
|
||||
|
||||
|
||||
def init_dist(launcher, backend='nccl', **kwargs):
|
||||
if mp.get_start_method(allow_none=True) is None:
|
||||
|
@ -49,7 +51,7 @@ def _init_dist_slurm(backend, port=29500):
|
|||
|
||||
|
||||
def get_dist_info():
|
||||
if torch.__version__ < '1.0':
|
||||
if TORCH_VERSION < '1.0':
|
||||
initialized = dist._initialized
|
||||
else:
|
||||
if dist.is_available():
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
|
||||
from mmcv.utils import TORCH_VERSION
|
||||
from ...dist_utils import master_only
|
||||
from ..hook import HOOKS
|
||||
from .base import LoggerHook
|
||||
|
@ -22,7 +21,7 @@ class TensorboardLoggerHook(LoggerHook):
|
|||
|
||||
@master_only
|
||||
def before_run(self, runner):
|
||||
if torch.__version__ < '1.1' or torch.__version__ == 'parrots':
|
||||
if TORCH_VERSION < '1.1' or TORCH_VERSION == 'parrots':
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
except ImportError:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
from .config import Config, ConfigDict, DictAction
|
||||
from .env import TORCH_VERSION
|
||||
from .logging import get_logger, print_log
|
||||
from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of,
|
||||
is_str, is_tuple_of, iter_cast, list_cast,
|
||||
|
@ -29,5 +30,6 @@ __all__ = [
|
|||
'CUDA_HOME', 'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd',
|
||||
'_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin',
|
||||
'_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension',
|
||||
'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader'
|
||||
'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader',
|
||||
'TORCH_VERSION'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# This file holding some environment constant for sharing by other files
|
||||
import torch
|
||||
|
||||
TORCH_VERSION = torch.__version__
|
|
@ -2,9 +2,11 @@ from functools import partial
|
|||
|
||||
import torch
|
||||
|
||||
from .env import TORCH_VERSION
|
||||
|
||||
|
||||
def _get_cuda_home():
|
||||
if torch.__version__ == 'parrots':
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.utils.build_extension import CUDA_HOME
|
||||
else:
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
@ -12,7 +14,7 @@ def _get_cuda_home():
|
|||
|
||||
|
||||
def get_build_config():
|
||||
if torch.__version__ == 'parrots':
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.config import get_build_info
|
||||
return get_build_info()
|
||||
else:
|
||||
|
@ -20,7 +22,7 @@ def get_build_config():
|
|||
|
||||
|
||||
def _get_conv():
|
||||
if torch.__version__ == 'parrots':
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
|
||||
else:
|
||||
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
|
||||
|
@ -28,7 +30,7 @@ def _get_conv():
|
|||
|
||||
|
||||
def _get_dataloader():
|
||||
if torch.__version__ == 'parrots':
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from torch.utils.data import DataLoader, PoolDataLoader
|
||||
else:
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -37,7 +39,7 @@ def _get_dataloader():
|
|||
|
||||
|
||||
def _get_extension():
|
||||
if torch.__version__ == 'parrots':
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.utils.build_extension import BuildExtension, Extension
|
||||
CppExtension = partial(Extension, cuda=False)
|
||||
CUDAExtension = partial(Extension, cuda=True)
|
||||
|
@ -48,7 +50,7 @@ def _get_extension():
|
|||
|
||||
|
||||
def _get_pool():
|
||||
if torch.__version__ == 'parrots':
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
|
||||
_AdaptiveMaxPoolNd, _AvgPoolNd,
|
||||
_MaxPoolNd)
|
||||
|
@ -60,7 +62,7 @@ def _get_pool():
|
|||
|
||||
|
||||
def _get_norm():
|
||||
if torch.__version__ == 'parrots':
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
|
||||
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
|
||||
else:
|
||||
|
@ -81,11 +83,11 @@ _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
|
|||
class SyncBatchNorm(SyncBatchNorm_):
|
||||
|
||||
def _specify_ddp_gpu_num(self, gpu_size):
|
||||
if torch.__version__ != 'parrots':
|
||||
if TORCH_VERSION != 'parrots':
|
||||
super()._specify_ddp_gpu_num(gpu_size)
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if torch.__version__ == 'parrots':
|
||||
if TORCH_VERSION == 'parrots':
|
||||
if input.dim() < 2:
|
||||
raise ValueError(
|
||||
f'expected at least 2D input (got {input.dim()}D input)')
|
||||
|
|
Loading…
Reference in New Issue