mirror of https://github.com/open-mmlab/mim.git
[Feature] Support downloading OpenMMLab projects on Ascend NPU (#228)
parent
bf69e00b6d
commit
8021d1b0eb
|
@ -17,7 +17,8 @@ from mim.utils import (
|
||||||
DEFAULT_MMCV_BASE_URL,
|
DEFAULT_MMCV_BASE_URL,
|
||||||
PKG2PROJECT,
|
PKG2PROJECT,
|
||||||
echo_warning,
|
echo_warning,
|
||||||
get_torch_cuda_version,
|
exit_with_error,
|
||||||
|
get_torch_device_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,16 +161,24 @@ def get_mmcv_full_find_link(mmcv_base_url: str) -> str:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The mmcv find links corresponding to the current torch version and
|
str: The mmcv find links corresponding to the current torch version and
|
||||||
cuda version.
|
CUDA/NPU version.
|
||||||
"""
|
"""
|
||||||
torch_v, cuda_v = get_torch_cuda_version()
|
torch_v, device, device_v = get_torch_device_version()
|
||||||
major, minor, *_ = torch_v.split('.')
|
major, minor, *_ = torch_v.split('.')
|
||||||
torch_v = '.'.join([major, minor, '0'])
|
torch_v = '.'.join([major, minor, '0'])
|
||||||
|
|
||||||
if cuda_v.isdigit():
|
if device == 'cuda' and device_v.isdigit():
|
||||||
cuda_v = f'cu{cuda_v}'
|
device_link = f'cu{device_v}'
|
||||||
|
elif device == 'ascend':
|
||||||
|
if not device_v.isdigit():
|
||||||
|
exit_with_error('Unable to install OpenMMLab projects via mim '
|
||||||
|
'on the current Ascend NPU, '
|
||||||
|
'please compile from source code to install.')
|
||||||
|
device_link = f'ascend{device_v}'
|
||||||
|
else:
|
||||||
|
device_link = 'cpu'
|
||||||
|
|
||||||
find_link = f'{mmcv_base_url}/mmcv/dist/{cuda_v}/torch{torch_v}/index.html' # noqa: E501
|
find_link = f'{mmcv_base_url}/mmcv/dist/{device_link}/torch{torch_v}/index.html' # noqa: E501
|
||||||
return find_link
|
return find_link
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ from .utils import (
|
||||||
get_package_info_from_pypi,
|
get_package_info_from_pypi,
|
||||||
get_package_version,
|
get_package_version,
|
||||||
get_release_version,
|
get_release_version,
|
||||||
get_torch_cuda_version,
|
get_torch_device_version,
|
||||||
highlighted_error,
|
highlighted_error,
|
||||||
is_installed,
|
is_installed,
|
||||||
is_version_equal,
|
is_version_equal,
|
||||||
|
@ -59,7 +59,7 @@ __all__ = [
|
||||||
'get_installed_version',
|
'get_installed_version',
|
||||||
'get_installed_path',
|
'get_installed_path',
|
||||||
'get_latest_version',
|
'get_latest_version',
|
||||||
'get_torch_cuda_version',
|
'get_torch_device_version',
|
||||||
'is_installed',
|
'is_installed',
|
||||||
'parse_url',
|
'parse_url',
|
||||||
'PKG2PROJECT',
|
'PKG2PROJECT',
|
||||||
|
|
|
@ -23,6 +23,15 @@ from requests.models import Response
|
||||||
from .default import PKG2PROJECT
|
from .default import PKG2PROJECT
|
||||||
from .progress_bars import rich_progress_bar
|
from .progress_bars import rich_progress_bar
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
IS_NPU_AVAILABLE = hasattr(
|
||||||
|
torch, 'npu') and torch.npu.is_available() # type: ignore
|
||||||
|
except Exception:
|
||||||
|
IS_NPU_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
def parse_url(url: str) -> Tuple[str, str]:
|
def parse_url(url: str) -> Tuple[str, str]:
|
||||||
"""Parse username and repo from url.
|
"""Parse username and repo from url.
|
||||||
|
@ -327,12 +336,37 @@ def get_installed_path(package: str) -> str:
|
||||||
return osp.join(pkg.location, package2module(package))
|
return osp.join(pkg.location, package2module(package))
|
||||||
|
|
||||||
|
|
||||||
def get_torch_cuda_version() -> Tuple[str, str]:
|
def is_npu_available() -> bool:
|
||||||
"""Get PyTorch version and CUDA version if it is available.
|
"""Returns True if Ascend PyTorch and npu devices exist."""
|
||||||
|
return IS_NPU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
def get_npu_version() -> str:
|
||||||
|
"""Returns the version of NPU when npu devices exist."""
|
||||||
|
if not is_npu_available():
|
||||||
|
return ''
|
||||||
|
ascend_home_path = os.environ.get('ASCEND_HOME_PATH', None)
|
||||||
|
if not ascend_home_path or not os.path.exists(ascend_home_path):
|
||||||
|
raise RuntimeError(
|
||||||
|
highlighted_error(
|
||||||
|
f'ASCEND_HOME_PATH:{ascend_home_path} does not exists when '
|
||||||
|
'installing OpenMMLab projects on Ascend NPU.'
|
||||||
|
"Please run 'source set_env.sh' in the CANN installation path."
|
||||||
|
))
|
||||||
|
npu_version = torch_npu.get_cann_version(ascend_home_path)
|
||||||
|
return npu_version
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_device_version() -> Tuple[str, str, str]:
|
||||||
|
"""Get PyTorch version and CUDA/NPU version if it is available.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> get_torch_cuda_version()
|
>>> get_torch_device_version()
|
||||||
'1.8.0', '102'
|
'1.8.0', 'cpu', ''
|
||||||
|
>>> get_torch_device_version()
|
||||||
|
'1.8.0', 'cuda', '102'
|
||||||
|
>>> get_torch_device_version()
|
||||||
|
'1.11.0', 'ascend', '602'
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
@ -344,11 +378,17 @@ def get_torch_cuda_version() -> Tuple[str, str]:
|
||||||
torch_v = torch_v.split('+')[0]
|
torch_v = torch_v.split('+')[0]
|
||||||
|
|
||||||
if torch.version.cuda is not None:
|
if torch.version.cuda is not None:
|
||||||
|
device = 'cuda'
|
||||||
# torch.version.cuda like 10.2 -> 102
|
# torch.version.cuda like 10.2 -> 102
|
||||||
cuda_v = ''.join(torch.version.cuda.split('.'))
|
device_v = ''.join(torch.version.cuda.split('.'))
|
||||||
|
elif is_npu_available():
|
||||||
|
device = 'ascend'
|
||||||
|
device_v = get_npu_version()
|
||||||
|
device_v = ''.join(device_v.split('.'))
|
||||||
else:
|
else:
|
||||||
cuda_v = 'cpu'
|
device = 'cpu'
|
||||||
return torch_v, cuda_v
|
device_v = ''
|
||||||
|
return torch_v, device, device_v
|
||||||
|
|
||||||
|
|
||||||
def cast2lowercase(input: Union[list, tuple, str]) -> Any:
|
def cast2lowercase(input: Union[list, tuple, str]) -> Any:
|
||||||
|
|
|
@ -20,4 +20,4 @@ default_section = THIRDPARTY
|
||||||
include_trailing_comma = true
|
include_trailing_comma = true
|
||||||
|
|
||||||
[codespell]
|
[codespell]
|
||||||
ignore-words-list = te
|
ignore-words-list = te, cann
|
||||||
|
|
|
@ -4,6 +4,7 @@ from click.testing import CliRunner
|
||||||
from mim.commands.install import cli as install
|
from mim.commands.install import cli as install
|
||||||
from mim.commands.uninstall import cli as uninstall
|
from mim.commands.uninstall import cli as uninstall
|
||||||
from mim.utils import get_github_url, parse_home_page
|
from mim.utils import get_github_url, parse_home_page
|
||||||
|
from mim.utils.utils import get_torch_device_version, is_npu_available
|
||||||
|
|
||||||
|
|
||||||
def setup_module():
|
def setup_module():
|
||||||
|
@ -39,6 +40,14 @@ def test_get_github_url():
|
||||||
'mmcls') == 'https://github.com/open-mmlab/mmclassification.git'
|
'mmcls') == 'https://github.com/open-mmlab/mmclassification.git'
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_torch_device_version():
|
||||||
|
torch_v, device, device_v = get_torch_device_version()
|
||||||
|
assert torch_v.replace('.', '').isdigit()
|
||||||
|
if is_npu_available():
|
||||||
|
assert device == 'ascend'
|
||||||
|
assert device_v.replace('.', '').isdigit()
|
||||||
|
|
||||||
|
|
||||||
def teardown_module():
|
def teardown_module():
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
|
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
|
||||||
|
|
Loading…
Reference in New Issue