[Feature] Support engine with NPU backend. (#572)
* init npu * Update mmengine/optim/optimizer/amp_optimizer_wrapper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/dist/dist.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * change to is_hccl_backend * Update mmengine/optim/optimizer/amp_optimizer_wrapper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * add comment with AmpOptimWrapper * Update mmengine/runner/amp.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/runner/amp.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * add npu fn in base_model * Update mmengine/optim/optimizer/amp_optimizer_wrapper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * clean lint * Update mmengine/optim/optimizer/amp_optimizer_wrapper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/model/base_model/base_model.py Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * add is_npu_available * try to fix * Add comments * Refine grammar Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>pull/640/head
parent
d270516fe8
commit
601db12d38
|
@ -13,5 +13,6 @@ mmengine.device
|
|||
get_device
|
||||
get_max_cuda_memory
|
||||
is_cuda_available
|
||||
is_npu_available
|
||||
is_mlu_available
|
||||
is_mps_available
|
||||
|
|
|
@ -13,5 +13,6 @@ mmengine.device
|
|||
get_device
|
||||
get_max_cuda_memory
|
||||
is_cuda_available
|
||||
is_npu_available
|
||||
is_mlu_available
|
||||
is_mps_available
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
||||
is_mlu_available, is_mps_available)
|
||||
is_mlu_available, is_mps_available, is_npu_available)
|
||||
|
||||
__all__ = [
|
||||
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
||||
'is_mlu_available', 'is_mps_available'
|
||||
'is_mlu_available', 'is_mps_available', 'is_npu_available'
|
||||
]
|
||||
|
|
|
@ -32,6 +32,15 @@ def is_cuda_available() -> bool:
|
|||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
def is_npu_available() -> bool:
|
||||
"""Returns True if Ascend PyTorch and npu devices exist."""
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
except Exception:
|
||||
return False
|
||||
return hasattr(torch, 'npu') and torch.npu.is_available()
|
||||
|
||||
|
||||
def is_mlu_available() -> bool:
|
||||
"""Returns True if Cambricon PyTorch and mlu devices exist."""
|
||||
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
||||
|
@ -49,9 +58,11 @@ def get_device() -> str:
|
|||
"""Returns the currently existing device type.
|
||||
|
||||
Returns:
|
||||
str: cuda | mlu | mps | cpu.
|
||||
str: cuda | npu | mlu | mps | cpu.
|
||||
"""
|
||||
if is_cuda_available():
|
||||
if is_npu_available():
|
||||
return 'npu'
|
||||
elif is_cuda_available():
|
||||
return 'cuda'
|
||||
elif is_mlu_available():
|
||||
return 'mlu'
|
||||
|
|
|
@ -20,6 +20,7 @@ from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
|
|||
get_comm_device, cast_data_device)
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
from mmengine.device import is_npu_available
|
||||
|
||||
|
||||
def _get_reduce_op(name: str) -> torch_dist.ReduceOp:
|
||||
|
@ -411,7 +412,11 @@ def _broadcast_object_list(object_list: List[Any],
|
|||
group_backend = get_backend(group)
|
||||
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
|
||||
current_device = torch.device('cpu')
|
||||
if is_nccl_backend:
|
||||
is_hccl_backend = group_backend == 'hccl'
|
||||
if is_hccl_backend:
|
||||
current_device = torch.npu.current_device()
|
||||
object_sizes_tensor = object_sizes_tensor.to(current_device)
|
||||
elif is_nccl_backend:
|
||||
# See note about using torch.cuda.current_device() here in
|
||||
# docstring. We cannot simply use my_rank since rank == device is
|
||||
# not necessarily true.
|
||||
|
@ -430,7 +435,7 @@ def _broadcast_object_list(object_list: List[Any],
|
|||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
if is_nccl_backend:
|
||||
if is_nccl_backend or is_hccl_backend:
|
||||
object_tensor = object_tensor.to(current_device)
|
||||
torch_dist.broadcast(object_tensor, src=src, group=group)
|
||||
# Deserialize objects using their stored sizes.
|
||||
|
@ -504,7 +509,8 @@ def broadcast_object_list(data: List[Any],
|
|||
if group is None:
|
||||
group = get_default_group()
|
||||
|
||||
if digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
|
||||
if digit_version(TORCH_VERSION) >= digit_version(
|
||||
'1.8.0') and not is_npu_available():
|
||||
torch_dist.broadcast_object_list(data, src, group)
|
||||
else:
|
||||
_broadcast_object_list(data, src, group)
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch.multiprocessing as mp
|
|||
from torch import Tensor
|
||||
from torch import distributed as torch_dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from mmengine.device import is_mlu_available
|
||||
from mmengine.device import is_mlu_available, is_npu_available
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
|
||||
|
@ -80,6 +80,14 @@ def _init_dist_pytorch(backend, **kwargs) -> None:
|
|||
rank=rank,
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
**kwargs)
|
||||
elif is_npu_available():
|
||||
import torch_npu # noqa: F401
|
||||
torch.npu.set_device(rank)
|
||||
torch_dist.init_process_group(
|
||||
backend='hccl',
|
||||
rank=rank,
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
**kwargs)
|
||||
else:
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
|
@ -437,7 +445,10 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
|
|||
torch.device: The device of backend.
|
||||
"""
|
||||
backend = get_backend(group)
|
||||
if backend == torch_dist.Backend.NCCL:
|
||||
if backend == 'hccl':
|
||||
import torch_npu # noqa: F401
|
||||
return torch.device('npu', torch.npu.current_device())
|
||||
elif backend == torch_dist.Backend.NCCL:
|
||||
return torch.device('cuda', torch.cuda.current_device())
|
||||
elif backend == 'cncl':
|
||||
import torch_mlu # noqa: F401
|
||||
|
|
|
@ -210,6 +210,25 @@ class BaseModel(BaseModule):
|
|||
self._set_device(torch.device(device))
|
||||
return super().cuda(device)
|
||||
|
||||
def npu(
|
||||
self,
|
||||
device: Union[int, str, torch.device, None] = None,
|
||||
) -> nn.Module:
|
||||
"""Overrides this method to call :meth:`BaseDataPreprocessor.npu`
|
||||
additionally.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
|
||||
Note:
|
||||
This generation of NPU(Ascend910) does not support
|
||||
the use of multiple cards in a single process,
|
||||
so the index here needs to be consistent with the default device
|
||||
"""
|
||||
device = torch.npu.current_device()
|
||||
self._set_device(device)
|
||||
return super().npu()
|
||||
|
||||
def cpu(self, *args, **kwargs) -> nn.Module:
|
||||
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
|
||||
additionally.
|
||||
|
|
|
@ -3,13 +3,18 @@ from contextlib import contextmanager
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from mmengine.device import is_cuda_available, is_npu_available
|
||||
from mmengine.registry import OPTIM_WRAPPERS
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
from .optimizer_wrapper import OptimWrapper
|
||||
|
||||
if is_npu_available():
|
||||
from torch.npu.amp import GradScaler
|
||||
else:
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
|
||||
@OPTIM_WRAPPERS.register_module()
|
||||
class AmpOptimWrapper(OptimWrapper):
|
||||
|
@ -44,8 +49,8 @@ class AmpOptimWrapper(OptimWrapper):
|
|||
def __init__(self, loss_scale='dynamic', **kwargs):
|
||||
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
|
||||
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
|
||||
assert torch.cuda.is_available(), (
|
||||
'``AmpOptimizerWrapper`` is only available training on gpu')
|
||||
assert is_cuda_available() or is_npu_available(), (
|
||||
'``AmpOptimizerWrapper`` is only available training on gpu or npu')
|
||||
super().__init__(**kwargs)
|
||||
self._scale_update_param = None
|
||||
if loss_scale == 'dynamic':
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.device import is_npu_available
|
||||
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
|
||||
from .optimizer_wrapper import OptimWrapper
|
||||
|
||||
|
@ -53,6 +54,13 @@ def build_optim_wrapper(model: nn.Module,
|
|||
constructor_type = optim_wrapper_cfg.pop('constructor',
|
||||
'DefaultOptimWrapperConstructor')
|
||||
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
|
||||
|
||||
# Since the current generation of NPU(Ascend 910) only supports
|
||||
# mixed precision training, here we turn on mixed precision by default
|
||||
# on the NPU to make the training normal
|
||||
if is_npu_available():
|
||||
optim_wrapper_cfg['type'] = 'AmpOptimWrapper'
|
||||
|
||||
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
|
||||
dict(
|
||||
type=constructor_type,
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
|
||||
from mmengine.device import get_device
|
||||
from mmengine.device import get_device, is_cuda_available, is_npu_available
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
@ -86,7 +86,10 @@ def autocast(device_type: Optional[str] = None,
|
|||
logger='current',
|
||||
level=logging.WARNING)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if is_npu_available():
|
||||
with torch.npu.amp.autocast(enabled=enabled):
|
||||
yield
|
||||
elif is_cuda_available():
|
||||
with torch.cuda.amp.autocast(enabled=enabled):
|
||||
yield
|
||||
else:
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
|
||||
is_mps_available)
|
||||
is_mps_available, is_npu_available)
|
||||
|
||||
|
||||
def test_get_device():
|
||||
device = get_device()
|
||||
if is_cuda_available():
|
||||
if is_npu_available():
|
||||
assert device == 'npu'
|
||||
elif is_cuda_available():
|
||||
assert device == 'cuda'
|
||||
elif is_mlu_available():
|
||||
assert device == 'mlu'
|
||||
|
|
Loading…
Reference in New Issue