mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Add support for Ascend device (#847)
* add npu device support * add comment for torch.npu.set_compile_mode
This commit is contained in:
parent
925ac870e2
commit
79067e4628
@ -36,6 +36,10 @@ def is_npu_available() -> bool:
|
|||||||
"""Returns True if Ascend PyTorch and npu devices exist."""
|
"""Returns True if Ascend PyTorch and npu devices exist."""
|
||||||
try:
|
try:
|
||||||
import torch_npu # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
|
# Enable operator support for dynamic shape and
|
||||||
|
# binary operator support on the NPU.
|
||||||
|
torch.npu.set_compile_mode(jit_compile=False)
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
return hasattr(torch, 'npu') and torch.npu.is_available()
|
return hasattr(torch, 'npu') and torch.npu.is_available()
|
||||||
|
@ -184,6 +184,18 @@ class BaseModel(BaseModule):
|
|||||||
Returns:
|
Returns:
|
||||||
nn.Module: The model itself.
|
nn.Module: The model itself.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Since Torch has not officially merged
|
||||||
|
# the npu-related fields, using the _parse_to function
|
||||||
|
# directly will cause the NPU to not be found.
|
||||||
|
# Here, the input parameters are processed to avoid errors.
|
||||||
|
if args and isinstance(args[0], str) and 'npu' in args[0]:
|
||||||
|
args = tuple(
|
||||||
|
[list(args)[0].replace('npu', torch.npu.native_device)])
|
||||||
|
if kwargs and 'npu' in str(kwargs.get('device', '')):
|
||||||
|
kwargs['device'] = kwargs['device'].replace(
|
||||||
|
'npu', torch.npu.native_device)
|
||||||
|
|
||||||
device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
||||||
if device is not None:
|
if device is not None:
|
||||||
self._set_device(torch.device(device))
|
self._set_device(torch.device(device))
|
||||||
|
@ -87,6 +87,18 @@ class BaseDataPreprocessor(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
nn.Module: The model itself.
|
nn.Module: The model itself.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Since Torch has not officially merged
|
||||||
|
# the npu-related fields, using the _parse_to function
|
||||||
|
# directly will cause the NPU to not be found.
|
||||||
|
# Here, the input parameters are processed to avoid errors.
|
||||||
|
if args and isinstance(args[0], str) and 'npu' in args[0]:
|
||||||
|
args = tuple(
|
||||||
|
[list(args)[0].replace('npu', torch.npu.native_device)])
|
||||||
|
if kwargs and 'npu' in str(kwargs.get('device', '')):
|
||||||
|
kwargs['device'] = kwargs['device'].replace(
|
||||||
|
'npu', torch.npu.native_device)
|
||||||
|
|
||||||
device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
||||||
if device is not None:
|
if device is not None:
|
||||||
self._device = torch.device(device)
|
self._device = torch.device(device)
|
||||||
@ -101,6 +113,15 @@ class BaseDataPreprocessor(nn.Module):
|
|||||||
self._device = torch.device(torch.cuda.current_device())
|
self._device = torch.device(torch.cuda.current_device())
|
||||||
return super().cuda()
|
return super().cuda()
|
||||||
|
|
||||||
|
def npu(self, *args, **kwargs) -> nn.Module:
|
||||||
|
"""Overrides this method to set the :attr:`device`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The model itself.
|
||||||
|
"""
|
||||||
|
self._device = torch.device(torch.npu.current_device())
|
||||||
|
return super().npu()
|
||||||
|
|
||||||
def cpu(self, *args, **kwargs) -> nn.Module:
|
def cpu(self, *args, **kwargs) -> nn.Module:
|
||||||
"""Overrides this method to set the :attr:`device`
|
"""Overrides this method to set the :attr:`device`
|
||||||
|
|
||||||
|
@ -507,6 +507,17 @@ class BaseDataElement:
|
|||||||
new_data.set_data(data)
|
new_data.set_data(data)
|
||||||
return new_data
|
return new_data
|
||||||
|
|
||||||
|
# Tensor-like methods
|
||||||
|
def npu(self) -> 'BaseDataElement':
|
||||||
|
"""Convert all tensors to NPU in data."""
|
||||||
|
new_data = self.new()
|
||||||
|
for k, v in self.items():
|
||||||
|
if isinstance(v, (torch.Tensor, BaseDataElement)):
|
||||||
|
v = v.npu()
|
||||||
|
data = {k: v}
|
||||||
|
new_data.set_data(data)
|
||||||
|
return new_data
|
||||||
|
|
||||||
# Tensor-like methods
|
# Tensor-like methods
|
||||||
def detach(self) -> 'BaseDataElement':
|
def detach(self) -> 'BaseDataElement':
|
||||||
"""Detach all tensors in data."""
|
"""Detach all tensors in data."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user