mirror of https://github.com/open-mmlab/mim.git
[Feature] Support downloading OpenMMLab projects on Ascend NPU (#228)
parent
bf69e00b6d
commit
8021d1b0eb
|
@ -17,7 +17,8 @@ from mim.utils import (
|
|||
DEFAULT_MMCV_BASE_URL,
|
||||
PKG2PROJECT,
|
||||
echo_warning,
|
||||
get_torch_cuda_version,
|
||||
exit_with_error,
|
||||
get_torch_device_version,
|
||||
)
|
||||
|
||||
|
||||
|
@ -160,16 +161,24 @@ def get_mmcv_full_find_link(mmcv_base_url: str) -> str:
|
|||
|
||||
Returns:
|
||||
str: The mmcv find links corresponding to the current torch version and
|
||||
cuda version.
|
||||
CUDA/NPU version.
|
||||
"""
|
||||
torch_v, cuda_v = get_torch_cuda_version()
|
||||
torch_v, device, device_v = get_torch_device_version()
|
||||
major, minor, *_ = torch_v.split('.')
|
||||
torch_v = '.'.join([major, minor, '0'])
|
||||
|
||||
if cuda_v.isdigit():
|
||||
cuda_v = f'cu{cuda_v}'
|
||||
if device == 'cuda' and device_v.isdigit():
|
||||
device_link = f'cu{device_v}'
|
||||
elif device == 'ascend':
|
||||
if not device_v.isdigit():
|
||||
exit_with_error('Unable to install OpenMMLab projects via mim '
|
||||
'on the current Ascend NPU, '
|
||||
'please compile from source code to install.')
|
||||
device_link = f'ascend{device_v}'
|
||||
else:
|
||||
device_link = 'cpu'
|
||||
|
||||
find_link = f'{mmcv_base_url}/mmcv/dist/{cuda_v}/torch{torch_v}/index.html' # noqa: E501
|
||||
find_link = f'{mmcv_base_url}/mmcv/dist/{device_link}/torch{torch_v}/index.html' # noqa: E501
|
||||
return find_link
|
||||
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ from .utils import (
|
|||
get_package_info_from_pypi,
|
||||
get_package_version,
|
||||
get_release_version,
|
||||
get_torch_cuda_version,
|
||||
get_torch_device_version,
|
||||
highlighted_error,
|
||||
is_installed,
|
||||
is_version_equal,
|
||||
|
@ -59,7 +59,7 @@ __all__ = [
|
|||
'get_installed_version',
|
||||
'get_installed_path',
|
||||
'get_latest_version',
|
||||
'get_torch_cuda_version',
|
||||
'get_torch_device_version',
|
||||
'is_installed',
|
||||
'parse_url',
|
||||
'PKG2PROJECT',
|
||||
|
|
|
@ -23,6 +23,15 @@ from requests.models import Response
|
|||
from .default import PKG2PROJECT
|
||||
from .progress_bars import rich_progress_bar
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
IS_NPU_AVAILABLE = hasattr(
|
||||
torch, 'npu') and torch.npu.is_available() # type: ignore
|
||||
except Exception:
|
||||
IS_NPU_AVAILABLE = False
|
||||
|
||||
|
||||
def parse_url(url: str) -> Tuple[str, str]:
|
||||
"""Parse username and repo from url.
|
||||
|
@ -327,12 +336,37 @@ def get_installed_path(package: str) -> str:
|
|||
return osp.join(pkg.location, package2module(package))
|
||||
|
||||
|
||||
def get_torch_cuda_version() -> Tuple[str, str]:
|
||||
"""Get PyTorch version and CUDA version if it is available.
|
||||
def is_npu_available() -> bool:
|
||||
"""Returns True if Ascend PyTorch and npu devices exist."""
|
||||
return IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
def get_npu_version() -> str:
|
||||
"""Returns the version of NPU when npu devices exist."""
|
||||
if not is_npu_available():
|
||||
return ''
|
||||
ascend_home_path = os.environ.get('ASCEND_HOME_PATH', None)
|
||||
if not ascend_home_path or not os.path.exists(ascend_home_path):
|
||||
raise RuntimeError(
|
||||
highlighted_error(
|
||||
f'ASCEND_HOME_PATH:{ascend_home_path} does not exists when '
|
||||
'installing OpenMMLab projects on Ascend NPU.'
|
||||
"Please run 'source set_env.sh' in the CANN installation path."
|
||||
))
|
||||
npu_version = torch_npu.get_cann_version(ascend_home_path)
|
||||
return npu_version
|
||||
|
||||
|
||||
def get_torch_device_version() -> Tuple[str, str, str]:
|
||||
"""Get PyTorch version and CUDA/NPU version if it is available.
|
||||
|
||||
Example:
|
||||
>>> get_torch_cuda_version()
|
||||
'1.8.0', '102'
|
||||
>>> get_torch_device_version()
|
||||
'1.8.0', 'cpu', ''
|
||||
>>> get_torch_device_version()
|
||||
'1.8.0', 'cuda', '102'
|
||||
>>> get_torch_device_version()
|
||||
'1.11.0', 'ascend', '602'
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
|
@ -344,11 +378,17 @@ def get_torch_cuda_version() -> Tuple[str, str]:
|
|||
torch_v = torch_v.split('+')[0]
|
||||
|
||||
if torch.version.cuda is not None:
|
||||
device = 'cuda'
|
||||
# torch.version.cuda like 10.2 -> 102
|
||||
cuda_v = ''.join(torch.version.cuda.split('.'))
|
||||
device_v = ''.join(torch.version.cuda.split('.'))
|
||||
elif is_npu_available():
|
||||
device = 'ascend'
|
||||
device_v = get_npu_version()
|
||||
device_v = ''.join(device_v.split('.'))
|
||||
else:
|
||||
cuda_v = 'cpu'
|
||||
return torch_v, cuda_v
|
||||
device = 'cpu'
|
||||
device_v = ''
|
||||
return torch_v, device, device_v
|
||||
|
||||
|
||||
def cast2lowercase(input: Union[list, tuple, str]) -> Any:
|
||||
|
|
|
@ -20,4 +20,4 @@ default_section = THIRDPARTY
|
|||
include_trailing_comma = true
|
||||
|
||||
[codespell]
|
||||
ignore-words-list = te
|
||||
ignore-words-list = te, cann
|
||||
|
|
|
@ -4,6 +4,7 @@ from click.testing import CliRunner
|
|||
from mim.commands.install import cli as install
|
||||
from mim.commands.uninstall import cli as uninstall
|
||||
from mim.utils import get_github_url, parse_home_page
|
||||
from mim.utils.utils import get_torch_device_version, is_npu_available
|
||||
|
||||
|
||||
def setup_module():
|
||||
|
@ -39,6 +40,14 @@ def test_get_github_url():
|
|||
'mmcls') == 'https://github.com/open-mmlab/mmclassification.git'
|
||||
|
||||
|
||||
def test_get_torch_device_version():
|
||||
torch_v, device, device_v = get_torch_device_version()
|
||||
assert torch_v.replace('.', '').isdigit()
|
||||
if is_npu_available():
|
||||
assert device == 'ascend'
|
||||
assert device_v.replace('.', '').isdigit()
|
||||
|
||||
|
||||
def teardown_module():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
|
||||
|
|
Loading…
Reference in New Issue