[Feature] Support MLU backend (#1075)
* support mlu device * support mlu device * fix lint error * fix lint error builder.py * fix lint error in amp.py * fix lint errors * fix data type in instance_data.pypull/1085/head
parent
f22002ec08
commit
60b4c199fc
|
@ -414,9 +414,13 @@ def _broadcast_object_list(object_list: List[Any],
|
|||
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
|
||||
current_device = torch.device('cpu')
|
||||
is_hccl_backend = group_backend == 'hccl'
|
||||
is_cncl_backend = group_backend == 'cncl'
|
||||
if is_hccl_backend:
|
||||
current_device = torch.npu.current_device()
|
||||
object_sizes_tensor = object_sizes_tensor.to(current_device)
|
||||
elif is_cncl_backend:
|
||||
current_device = torch.device('mlu', torch.mlu.current_device())
|
||||
object_sizes_tensor = object_sizes_tensor.to(current_device)
|
||||
elif is_nccl_backend:
|
||||
# See note about using torch.cuda.current_device() here in
|
||||
# docstring. We cannot simply use my_rank since rank == device is
|
||||
|
@ -436,7 +440,7 @@ def _broadcast_object_list(object_list: List[Any],
|
|||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
if is_nccl_backend or is_hccl_backend:
|
||||
if is_nccl_backend or is_hccl_backend or is_cncl_backend:
|
||||
object_tensor = object_tensor.to(current_device)
|
||||
torch_dist.broadcast(object_tensor, src=src, group=group)
|
||||
# Deserialize objects using their stored sizes.
|
||||
|
|
|
@ -216,6 +216,20 @@ class BaseModel(BaseModule):
|
|||
self._set_device(torch.device(device))
|
||||
return super().cuda(device)
|
||||
|
||||
def mlu(
|
||||
self,
|
||||
device: Union[int, str, torch.device, None] = None,
|
||||
) -> nn.Module:
|
||||
"""Overrides this method to call :meth:`BaseDataPreprocessor.mlu`
|
||||
additionally.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
device = torch.device('mlu', torch.mlu.current_device())
|
||||
self._set_device(device)
|
||||
return super().mlu()
|
||||
|
||||
def npu(
|
||||
self,
|
||||
device: Union[int, str, torch.device, None] = None,
|
||||
|
|
|
@ -122,6 +122,15 @@ class BaseDataPreprocessor(nn.Module):
|
|||
self._device = torch.device(torch.npu.current_device())
|
||||
return super().npu()
|
||||
|
||||
def mlu(self, *args, **kwargs) -> nn.Module:
|
||||
"""Overrides this method to set the :attr:`device`
|
||||
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
self._device = torch.device(torch.mlu.current_device())
|
||||
return super().mlu()
|
||||
|
||||
def cpu(self, *args, **kwargs) -> nn.Module:
|
||||
"""Overrides this method to set the :attr:`device`
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@ from typing import Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.device import is_cuda_available, is_npu_available
|
||||
from mmengine.device import (is_cuda_available, is_mlu_available,
|
||||
is_npu_available)
|
||||
from mmengine.registry import OPTIM_WRAPPERS
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
@ -13,6 +14,8 @@ from .optimizer_wrapper import OptimWrapper
|
|||
|
||||
if is_npu_available():
|
||||
from torch.npu.amp import GradScaler
|
||||
elif is_mlu_available():
|
||||
from torch.mlu.amp import GradScaler
|
||||
else:
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
|
@ -65,8 +68,9 @@ class AmpOptimWrapper(OptimWrapper):
|
|||
**kwargs):
|
||||
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
|
||||
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
|
||||
assert is_cuda_available() or is_npu_available(), (
|
||||
'``AmpOptimizerWrapper`` is only available training on gpu or npu')
|
||||
assert is_cuda_available() or is_npu_available() or is_mlu_available(
|
||||
), ('``AmpOptimizerWrapper`` is only available training '
|
||||
'on gpu, npu or mlu')
|
||||
super().__init__(**kwargs)
|
||||
self._scale_update_param = None
|
||||
if loss_scale == 'dynamic':
|
||||
|
|
|
@ -5,7 +5,8 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
|
||||
from mmengine.device import get_device, is_cuda_available, is_npu_available
|
||||
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
|
||||
is_npu_available)
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
@ -75,9 +76,11 @@ def autocast(device_type: Optional[str] = None,
|
|||
digit_version('1.10.0')):
|
||||
# If pytorch version is between 1.5.0 and 1.10.0, the default value of
|
||||
# dtype for `torch.cuda.amp.autocast` is torch.float16.
|
||||
assert device_type == 'cuda' or device_type is None, (
|
||||
'Pytorch version under 1.10.0 only supports running automatic '
|
||||
'mixed training with cuda')
|
||||
assert (
|
||||
device_type == 'cuda' or device_type == 'mlu'
|
||||
or device_type is None), (
|
||||
'Pytorch version under 1.10.0 only supports running automatic '
|
||||
'mixed training with cuda or mlu')
|
||||
if dtype is not None or cache_enabled is not None:
|
||||
print_log(
|
||||
f'{dtype} and {device_type} will not work for '
|
||||
|
@ -89,6 +92,9 @@ def autocast(device_type: Optional[str] = None,
|
|||
if is_npu_available():
|
||||
with torch.npu.amp.autocast(enabled=enabled):
|
||||
yield
|
||||
elif is_mlu_available():
|
||||
with torch.mlu.amp.autocast(enabled=enabled):
|
||||
yield
|
||||
elif is_cuda_available():
|
||||
with torch.cuda.amp.autocast(enabled=enabled):
|
||||
yield
|
||||
|
|
|
@ -521,6 +521,16 @@ class BaseDataElement:
|
|||
new_data.set_data(data)
|
||||
return new_data
|
||||
|
||||
def mlu(self) -> 'BaseDataElement':
|
||||
"""Convert all tensors to MLU in data."""
|
||||
new_data = self.new()
|
||||
for k, v in self.items():
|
||||
if isinstance(v, (torch.Tensor, BaseDataElement)):
|
||||
v = v.mlu()
|
||||
data = {k: v}
|
||||
new_data.set_data(data)
|
||||
return new_data
|
||||
|
||||
# Tensor-like methods
|
||||
def detach(self) -> 'BaseDataElement':
|
||||
"""Detach all tensors in data."""
|
||||
|
|
|
@ -15,6 +15,9 @@ LongTypeTensor: Union[Any]
|
|||
if get_device() == 'npu':
|
||||
BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor]
|
||||
LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor]
|
||||
elif get_device() == 'mlu':
|
||||
BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor]
|
||||
LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor]
|
||||
else:
|
||||
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
|
||||
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
import mmengine
|
||||
from mmengine.device import get_device
|
||||
from mmengine.device import get_device, is_mlu_available
|
||||
from mmengine.runner import autocast
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
@ -14,7 +14,22 @@ from mmengine.utils.dl_utils import TORCH_VERSION
|
|||
class TestAmp(unittest.TestCase):
|
||||
|
||||
def test_autocast(self):
|
||||
if not torch.cuda.is_available():
|
||||
if is_mlu_available():
|
||||
device = 'mlu'
|
||||
with autocast(device_type=device):
|
||||
# torch.autocast support mlu mode.
|
||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
|
||||
with autocast(enabled=False, device_type=device):
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
# Test with fp32_enabled
|
||||
with autocast(enabled=False, device_type=device):
|
||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
elif not torch.cuda.is_available():
|
||||
if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
|
||||
# `torch.cuda.amp.autocast` is only support in gpu mode, if
|
||||
# cuda is not available, it will return an empty context and
|
||||
|
|
Loading…
Reference in New Issue