[Feature] Support downloading OpenMMLab projects on Ascend NPU ()

pull/232/head
Yinlei Sun 2023-10-23 10:58:34 +08:00 committed by GitHub
parent bf69e00b6d
commit 8021d1b0eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 16 deletions

View File

@ -17,7 +17,8 @@ from mim.utils import (
DEFAULT_MMCV_BASE_URL,
PKG2PROJECT,
echo_warning,
get_torch_cuda_version,
exit_with_error,
get_torch_device_version,
)
@ -160,16 +161,24 @@ def get_mmcv_full_find_link(mmcv_base_url: str) -> str:
Returns:
str: The mmcv find links corresponding to the current torch version and
cuda version.
CUDA/NPU version.
"""
torch_v, cuda_v = get_torch_cuda_version()
torch_v, device, device_v = get_torch_device_version()
major, minor, *_ = torch_v.split('.')
torch_v = '.'.join([major, minor, '0'])
if cuda_v.isdigit():
cuda_v = f'cu{cuda_v}'
if device == 'cuda' and device_v.isdigit():
device_link = f'cu{device_v}'
elif device == 'ascend':
if not device_v.isdigit():
exit_with_error('Unable to install OpenMMLab projects via mim '
'on the current Ascend NPU, '
'please compile from source code to install.')
device_link = f'ascend{device_v}'
else:
device_link = 'cpu'
find_link = f'{mmcv_base_url}/mmcv/dist/{cuda_v}/torch{torch_v}/index.html' # noqa: E501
find_link = f'{mmcv_base_url}/mmcv/dist/{device_link}/torch{torch_v}/index.html' # noqa: E501
return find_link

View File

@ -32,7 +32,7 @@ from .utils import (
get_package_info_from_pypi,
get_package_version,
get_release_version,
get_torch_cuda_version,
get_torch_device_version,
highlighted_error,
is_installed,
is_version_equal,
@ -59,7 +59,7 @@ __all__ = [
'get_installed_version',
'get_installed_path',
'get_latest_version',
'get_torch_cuda_version',
'get_torch_device_version',
'is_installed',
'parse_url',
'PKG2PROJECT',

View File

@ -23,6 +23,15 @@ from requests.models import Response
from .default import PKG2PROJECT
from .progress_bars import rich_progress_bar
try:
import torch
import torch_npu
IS_NPU_AVAILABLE = hasattr(
torch, 'npu') and torch.npu.is_available() # type: ignore
except Exception:
IS_NPU_AVAILABLE = False
def parse_url(url: str) -> Tuple[str, str]:
"""Parse username and repo from url.
@ -327,12 +336,37 @@ def get_installed_path(package: str) -> str:
return osp.join(pkg.location, package2module(package))
def get_torch_cuda_version() -> Tuple[str, str]:
"""Get PyTorch version and CUDA version if it is available.
def is_npu_available() -> bool:
"""Returns True if Ascend PyTorch and npu devices exist."""
return IS_NPU_AVAILABLE
def get_npu_version() -> str:
"""Returns the version of NPU when npu devices exist."""
if not is_npu_available():
return ''
ascend_home_path = os.environ.get('ASCEND_HOME_PATH', None)
if not ascend_home_path or not os.path.exists(ascend_home_path):
raise RuntimeError(
highlighted_error(
f'ASCEND_HOME_PATH:{ascend_home_path} does not exists when '
'installing OpenMMLab projects on Ascend NPU.'
"Please run 'source set_env.sh' in the CANN installation path."
))
npu_version = torch_npu.get_cann_version(ascend_home_path)
return npu_version
def get_torch_device_version() -> Tuple[str, str, str]:
"""Get PyTorch version and CUDA/NPU version if it is available.
Example:
>>> get_torch_cuda_version()
'1.8.0', '102'
>>> get_torch_device_version()
'1.8.0', 'cpu', ''
>>> get_torch_device_version()
'1.8.0', 'cuda', '102'
>>> get_torch_device_version()
'1.11.0', 'ascend', '602'
"""
try:
import torch
@ -344,11 +378,17 @@ def get_torch_cuda_version() -> Tuple[str, str]:
torch_v = torch_v.split('+')[0]
if torch.version.cuda is not None:
device = 'cuda'
# torch.version.cuda like 10.2 -> 102
cuda_v = ''.join(torch.version.cuda.split('.'))
device_v = ''.join(torch.version.cuda.split('.'))
elif is_npu_available():
device = 'ascend'
device_v = get_npu_version()
device_v = ''.join(device_v.split('.'))
else:
cuda_v = 'cpu'
return torch_v, cuda_v
device = 'cpu'
device_v = ''
return torch_v, device, device_v
def cast2lowercase(input: Union[list, tuple, str]) -> Any:

View File

@ -20,4 +20,4 @@ default_section = THIRDPARTY
include_trailing_comma = true
[codespell]
ignore-words-list = te
ignore-words-list = te, cann

View File

@ -4,6 +4,7 @@ from click.testing import CliRunner
from mim.commands.install import cli as install
from mim.commands.uninstall import cli as uninstall
from mim.utils import get_github_url, parse_home_page
from mim.utils.utils import get_torch_device_version, is_npu_available
def setup_module():
@ -39,6 +40,14 @@ def test_get_github_url():
'mmcls') == 'https://github.com/open-mmlab/mmclassification.git'
def test_get_torch_device_version():
torch_v, device, device_v = get_torch_device_version()
assert torch_v.replace('.', '').isdigit()
if is_npu_available():
assert device == 'ascend'
assert device_v.replace('.', '').isdigit()
def teardown_module():
runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])