From 9fa5de8b9b393340c5a1fe5166725dfcb229dfe0 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Tue, 10 Aug 2021 15:08:00 +0800 Subject: [PATCH] [Enhancement] Speed up mmcv import (#1249) * [Enhancement] Speed import mmcv * fix missing parse_version * fix circle dependency * rename --- mmcv/utils/__init__.py | 25 +++++++++++++------------ mmcv/utils/env.py | 3 ++- mmcv/utils/parrots_wrapper.py | 20 +++++++++++--------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 6e5ca6c11..baf8109f0 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -40,10 +40,10 @@ else: from .logging import get_logger, print_log from .parrots_jit import jit, skip_no_elena from .parrots_wrapper import ( - CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, - DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, - _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd, - _ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config) + TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader, + PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, + _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm, + _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home) from .registry import Registry, build_from_cfg from .trace import is_jit_tracing __all__ = [ @@ -54,15 +54,16 @@ else: 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar', 'track_progress', 'track_iter_progress', 'track_parallel_progress', 'Registry', - 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'CUDA_HOME', - 'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', - '_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin', - '_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension', - 'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader', - 'TORCH_VERSION', 'deprecated_api_warning', 'digit_version', - 'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena', + 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm', + '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm', + '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd', + 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension', + 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION', + 'deprecated_api_warning', 'digit_version', 'get_git_hash', + 'import_modules_from_strings', 'jit', 'skip_no_elena', 'assert_dict_contains_subset', 'assert_attrs_equal', 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', - 'is_method_overridden', 'is_jit_tracing' + 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', + '_get_cuda_home' ] diff --git a/mmcv/utils/env.py b/mmcv/utils/env.py index e3ef5e873..e46a1094f 100644 --- a/mmcv/utils/env.py +++ b/mmcv/utils/env.py @@ -49,7 +49,8 @@ def collect_env(): for name, device_ids in devices.items(): env_info['GPU ' + ','.join(device_ids)] = name - from mmcv.utils.parrots_wrapper import CUDA_HOME + from mmcv.utils.parrots_wrapper import _get_cuda_home + CUDA_HOME = _get_cuda_home() env_info['CUDA_HOME'] = CUDA_HOME if CUDA_HOME is not None and osp.isdir(CUDA_HOME): diff --git a/mmcv/utils/parrots_wrapper.py b/mmcv/utils/parrots_wrapper.py index 6a23f530a..93c97640d 100644 --- a/mmcv/utils/parrots_wrapper.py +++ b/mmcv/utils/parrots_wrapper.py @@ -3,23 +3,26 @@ from functools import partial import torch -from mmcv.utils import digit_version - TORCH_VERSION = torch.__version__ -is_rocm_pytorch = False -if (TORCH_VERSION != 'parrots' - and digit_version(TORCH_VERSION) >= digit_version('1.5')): - from torch.utils.cpp_extension import ROCM_HOME - is_rocm_pytorch = True if ((torch.version.hip is not None) and + +def is_rocm_pytorch() -> bool: + is_rocm = False + if TORCH_VERSION != 'parrots': + try: + from torch.utils.cpp_extension import ROCM_HOME + is_rocm = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False + except ImportError: + pass + return is_rocm def _get_cuda_home(): if TORCH_VERSION == 'parrots': from parrots.utils.build_extension import CUDA_HOME else: - if is_rocm_pytorch: + if is_rocm_pytorch(): from torch.utils.cpp_extension import ROCM_HOME CUDA_HOME = ROCM_HOME else: @@ -86,7 +89,6 @@ def _get_norm(): return _BatchNorm, _InstanceNorm, SyncBatchNorm_ -CUDA_HOME = _get_cuda_home() _ConvNd, _ConvTransposeMixin = _get_conv() DataLoader, PoolDataLoader = _get_dataloader() BuildExtension, CppExtension, CUDAExtension = _get_extension()