[Feature] Support downloading OpenMMLab projects on Ascend NPU (#228)

pull/232/head
Yinlei Sun 2023-10-23 10:58:34 +08:00 committed by GitHub
parent bf69e00b6d
commit 8021d1b0eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 16 deletions

View File

@ -17,7 +17,8 @@ from mim.utils import (
DEFAULT_MMCV_BASE_URL, DEFAULT_MMCV_BASE_URL,
PKG2PROJECT, PKG2PROJECT,
echo_warning, echo_warning,
get_torch_cuda_version, exit_with_error,
get_torch_device_version,
) )
@ -160,16 +161,24 @@ def get_mmcv_full_find_link(mmcv_base_url: str) -> str:
Returns: Returns:
str: The mmcv find links corresponding to the current torch version and str: The mmcv find links corresponding to the current torch version and
cuda version. CUDA/NPU version.
""" """
torch_v, cuda_v = get_torch_cuda_version() torch_v, device, device_v = get_torch_device_version()
major, minor, *_ = torch_v.split('.') major, minor, *_ = torch_v.split('.')
torch_v = '.'.join([major, minor, '0']) torch_v = '.'.join([major, minor, '0'])
if cuda_v.isdigit(): if device == 'cuda' and device_v.isdigit():
cuda_v = f'cu{cuda_v}' device_link = f'cu{device_v}'
elif device == 'ascend':
if not device_v.isdigit():
exit_with_error('Unable to install OpenMMLab projects via mim '
'on the current Ascend NPU, '
'please compile from source code to install.')
device_link = f'ascend{device_v}'
else:
device_link = 'cpu'
find_link = f'{mmcv_base_url}/mmcv/dist/{cuda_v}/torch{torch_v}/index.html' # noqa: E501 find_link = f'{mmcv_base_url}/mmcv/dist/{device_link}/torch{torch_v}/index.html' # noqa: E501
return find_link return find_link

View File

@ -32,7 +32,7 @@ from .utils import (
get_package_info_from_pypi, get_package_info_from_pypi,
get_package_version, get_package_version,
get_release_version, get_release_version,
get_torch_cuda_version, get_torch_device_version,
highlighted_error, highlighted_error,
is_installed, is_installed,
is_version_equal, is_version_equal,
@ -59,7 +59,7 @@ __all__ = [
'get_installed_version', 'get_installed_version',
'get_installed_path', 'get_installed_path',
'get_latest_version', 'get_latest_version',
'get_torch_cuda_version', 'get_torch_device_version',
'is_installed', 'is_installed',
'parse_url', 'parse_url',
'PKG2PROJECT', 'PKG2PROJECT',

View File

@ -23,6 +23,15 @@ from requests.models import Response
from .default import PKG2PROJECT from .default import PKG2PROJECT
from .progress_bars import rich_progress_bar from .progress_bars import rich_progress_bar
try:
import torch
import torch_npu
IS_NPU_AVAILABLE = hasattr(
torch, 'npu') and torch.npu.is_available() # type: ignore
except Exception:
IS_NPU_AVAILABLE = False
def parse_url(url: str) -> Tuple[str, str]: def parse_url(url: str) -> Tuple[str, str]:
"""Parse username and repo from url. """Parse username and repo from url.
@ -327,12 +336,37 @@ def get_installed_path(package: str) -> str:
return osp.join(pkg.location, package2module(package)) return osp.join(pkg.location, package2module(package))
def get_torch_cuda_version() -> Tuple[str, str]: def is_npu_available() -> bool:
"""Get PyTorch version and CUDA version if it is available. """Returns True if Ascend PyTorch and npu devices exist."""
return IS_NPU_AVAILABLE
def get_npu_version() -> str:
"""Returns the version of NPU when npu devices exist."""
if not is_npu_available():
return ''
ascend_home_path = os.environ.get('ASCEND_HOME_PATH', None)
if not ascend_home_path or not os.path.exists(ascend_home_path):
raise RuntimeError(
highlighted_error(
f'ASCEND_HOME_PATH:{ascend_home_path} does not exists when '
'installing OpenMMLab projects on Ascend NPU.'
"Please run 'source set_env.sh' in the CANN installation path."
))
npu_version = torch_npu.get_cann_version(ascend_home_path)
return npu_version
def get_torch_device_version() -> Tuple[str, str, str]:
"""Get PyTorch version and CUDA/NPU version if it is available.
Example: Example:
>>> get_torch_cuda_version() >>> get_torch_device_version()
'1.8.0', '102' '1.8.0', 'cpu', ''
>>> get_torch_device_version()
'1.8.0', 'cuda', '102'
>>> get_torch_device_version()
'1.11.0', 'ascend', '602'
""" """
try: try:
import torch import torch
@ -344,11 +378,17 @@ def get_torch_cuda_version() -> Tuple[str, str]:
torch_v = torch_v.split('+')[0] torch_v = torch_v.split('+')[0]
if torch.version.cuda is not None: if torch.version.cuda is not None:
device = 'cuda'
# torch.version.cuda like 10.2 -> 102 # torch.version.cuda like 10.2 -> 102
cuda_v = ''.join(torch.version.cuda.split('.')) device_v = ''.join(torch.version.cuda.split('.'))
elif is_npu_available():
device = 'ascend'
device_v = get_npu_version()
device_v = ''.join(device_v.split('.'))
else: else:
cuda_v = 'cpu' device = 'cpu'
return torch_v, cuda_v device_v = ''
return torch_v, device, device_v
def cast2lowercase(input: Union[list, tuple, str]) -> Any: def cast2lowercase(input: Union[list, tuple, str]) -> Any:

View File

@ -20,4 +20,4 @@ default_section = THIRDPARTY
include_trailing_comma = true include_trailing_comma = true
[codespell] [codespell]
ignore-words-list = te ignore-words-list = te, cann

View File

@ -4,6 +4,7 @@ from click.testing import CliRunner
from mim.commands.install import cli as install from mim.commands.install import cli as install
from mim.commands.uninstall import cli as uninstall from mim.commands.uninstall import cli as uninstall
from mim.utils import get_github_url, parse_home_page from mim.utils import get_github_url, parse_home_page
from mim.utils.utils import get_torch_device_version, is_npu_available
def setup_module(): def setup_module():
@ -39,6 +40,14 @@ def test_get_github_url():
'mmcls') == 'https://github.com/open-mmlab/mmclassification.git' 'mmcls') == 'https://github.com/open-mmlab/mmclassification.git'
def test_get_torch_device_version():
torch_v, device, device_v = get_torch_device_version()
assert torch_v.replace('.', '').isdigit()
if is_npu_available():
assert device == 'ascend'
assert device_v.replace('.', '').isdigit()
def teardown_module(): def teardown_module():
runner = CliRunner() runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes']) result = runner.invoke(uninstall, ['mmcv-full', '--yes'])