mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Speed up mmcv import (#1249)
* [Enhancement] Speed import mmcv * fix missing parse_version * fix circle dependency * renamepull/1254/head
parent
94a677de3f
commit
9fa5de8b9b
|
@ -40,10 +40,10 @@ else:
|
|||
from .logging import get_logger, print_log
|
||||
from .parrots_jit import jit, skip_no_elena
|
||||
from .parrots_wrapper import (
|
||||
CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension,
|
||||
DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd,
|
||||
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
|
||||
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
|
||||
TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader,
|
||||
PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
|
||||
_AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
|
||||
_MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
|
||||
from .registry import Registry, build_from_cfg
|
||||
from .trace import is_jit_tracing
|
||||
__all__ = [
|
||||
|
@ -54,15 +54,16 @@ else:
|
|||
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
|
||||
'symlink', 'scandir', 'ProgressBar', 'track_progress',
|
||||
'track_iter_progress', 'track_parallel_progress', 'Registry',
|
||||
'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'CUDA_HOME',
|
||||
'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd',
|
||||
'_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin',
|
||||
'_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension',
|
||||
'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader',
|
||||
'TORCH_VERSION', 'deprecated_api_warning', 'digit_version',
|
||||
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
|
||||
'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
|
||||
'_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
|
||||
'_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
|
||||
'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
|
||||
'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
|
||||
'deprecated_api_warning', 'digit_version', 'get_git_hash',
|
||||
'import_modules_from_strings', 'jit', 'skip_no_elena',
|
||||
'assert_dict_contains_subset', 'assert_attrs_equal',
|
||||
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
|
||||
'assert_params_all_zeros', 'check_python_script',
|
||||
'is_method_overridden', 'is_jit_tracing'
|
||||
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
|
||||
'_get_cuda_home'
|
||||
]
|
||||
|
|
|
@ -49,7 +49,8 @@ def collect_env():
|
|||
for name, device_ids in devices.items():
|
||||
env_info['GPU ' + ','.join(device_ids)] = name
|
||||
|
||||
from mmcv.utils.parrots_wrapper import CUDA_HOME
|
||||
from mmcv.utils.parrots_wrapper import _get_cuda_home
|
||||
CUDA_HOME = _get_cuda_home()
|
||||
env_info['CUDA_HOME'] = CUDA_HOME
|
||||
|
||||
if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
|
||||
|
|
|
@ -3,23 +3,26 @@ from functools import partial
|
|||
|
||||
import torch
|
||||
|
||||
from mmcv.utils import digit_version
|
||||
|
||||
TORCH_VERSION = torch.__version__
|
||||
|
||||
is_rocm_pytorch = False
|
||||
if (TORCH_VERSION != 'parrots'
|
||||
and digit_version(TORCH_VERSION) >= digit_version('1.5')):
|
||||
from torch.utils.cpp_extension import ROCM_HOME
|
||||
is_rocm_pytorch = True if ((torch.version.hip is not None) and
|
||||
|
||||
def is_rocm_pytorch() -> bool:
|
||||
is_rocm = False
|
||||
if TORCH_VERSION != 'parrots':
|
||||
try:
|
||||
from torch.utils.cpp_extension import ROCM_HOME
|
||||
is_rocm = True if ((torch.version.hip is not None) and
|
||||
(ROCM_HOME is not None)) else False
|
||||
except ImportError:
|
||||
pass
|
||||
return is_rocm
|
||||
|
||||
|
||||
def _get_cuda_home():
|
||||
if TORCH_VERSION == 'parrots':
|
||||
from parrots.utils.build_extension import CUDA_HOME
|
||||
else:
|
||||
if is_rocm_pytorch:
|
||||
if is_rocm_pytorch():
|
||||
from torch.utils.cpp_extension import ROCM_HOME
|
||||
CUDA_HOME = ROCM_HOME
|
||||
else:
|
||||
|
@ -86,7 +89,6 @@ def _get_norm():
|
|||
return _BatchNorm, _InstanceNorm, SyncBatchNorm_
|
||||
|
||||
|
||||
CUDA_HOME = _get_cuda_home()
|
||||
_ConvNd, _ConvTransposeMixin = _get_conv()
|
||||
DataLoader, PoolDataLoader = _get_dataloader()
|
||||
BuildExtension, CppExtension, CUDAExtension = _get_extension()
|
||||
|
|
Loading…
Reference in New Issue