[Enhancement] Speed up mmcv import (#1249)

* [Enhancement] Speed import mmcv

* fix missing parse_version

* fix circle dependency

* rename
pull/1254/head
Zaida Zhou 2021-08-10 15:08:00 +08:00 committed by GitHub
parent 94a677de3f
commit 9fa5de8b9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 22 deletions

View File

@ -40,10 +40,10 @@ else:
from .logging import get_logger, print_log
from .parrots_jit import jit, skip_no_elena
from .parrots_wrapper import (
CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension,
DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader,
PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
_AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
_MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
from .registry import Registry, build_from_cfg
from .trace import is_jit_tracing
__all__ = [
@ -54,15 +54,16 @@ else:
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
'symlink', 'scandir', 'ProgressBar', 'track_progress',
'track_iter_progress', 'track_parallel_progress', 'Registry',
'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'CUDA_HOME',
'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd',
'_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin',
'_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension',
'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader',
'TORCH_VERSION', 'deprecated_api_warning', 'digit_version',
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
'_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
'_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
'deprecated_api_warning', 'digit_version', 'get_git_hash',
'import_modules_from_strings', 'jit', 'skip_no_elena',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden', 'is_jit_tracing'
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home'
]

View File

@ -49,7 +49,8 @@ def collect_env():
for name, device_ids in devices.items():
env_info['GPU ' + ','.join(device_ids)] = name
from mmcv.utils.parrots_wrapper import CUDA_HOME
from mmcv.utils.parrots_wrapper import _get_cuda_home
CUDA_HOME = _get_cuda_home()
env_info['CUDA_HOME'] = CUDA_HOME
if CUDA_HOME is not None and osp.isdir(CUDA_HOME):

View File

@ -3,23 +3,26 @@ from functools import partial
import torch
from mmcv.utils import digit_version
TORCH_VERSION = torch.__version__
is_rocm_pytorch = False
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.5')):
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
def is_rocm_pytorch() -> bool:
is_rocm = False
if TORCH_VERSION != 'parrots':
try:
from torch.utils.cpp_extension import ROCM_HOME
is_rocm = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
except ImportError:
pass
return is_rocm
def _get_cuda_home():
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import CUDA_HOME
else:
if is_rocm_pytorch:
if is_rocm_pytorch():
from torch.utils.cpp_extension import ROCM_HOME
CUDA_HOME = ROCM_HOME
else:
@ -86,7 +89,6 @@ def _get_norm():
return _BatchNorm, _InstanceNorm, SyncBatchNorm_
CUDA_HOME = _get_cuda_home()
_ConvNd, _ConvTransposeMixin = _get_conv()
DataLoader, PoolDataLoader = _get_dataloader()
BuildExtension, CppExtension, CUDAExtension = _get_extension()