mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhancement] Support BoolTensor and LongTensor on Ascend NPU (#1011)
This commit is contained in:
parent
8bf1ecad38
commit
1c67f9eb22
@ -1,16 +1,26 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import itertools
|
import itertools
|
||||||
from collections.abc import Sized
|
from collections.abc import Sized
|
||||||
from typing import List, Union
|
from typing import Any, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from mmengine.device import get_device
|
||||||
from .base_data_element import BaseDataElement
|
from .base_data_element import BaseDataElement
|
||||||
|
|
||||||
IndexType = Union[str, slice, int, list, torch.LongTensor,
|
BoolTypeTensor: Union[Any]
|
||||||
torch.cuda.LongTensor, torch.BoolTensor,
|
LongTypeTensor: Union[Any]
|
||||||
torch.cuda.BoolTensor, np.ndarray]
|
|
||||||
|
if get_device() == 'npu':
|
||||||
|
BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor]
|
||||||
|
LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor]
|
||||||
|
else:
|
||||||
|
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
|
||||||
|
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
|
||||||
|
|
||||||
|
IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor,
|
||||||
|
BoolTypeTensor, np.ndarray]
|
||||||
|
|
||||||
|
|
||||||
# Modified from
|
# Modified from
|
||||||
@ -156,6 +166,7 @@ class InstanceData(BaseDataElement):
|
|||||||
Returns:
|
Returns:
|
||||||
:obj:`InstanceData`: Corresponding values.
|
:obj:`InstanceData`: Corresponding values.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(item, IndexType.__args__)
|
||||||
if isinstance(item, list):
|
if isinstance(item, list):
|
||||||
item = np.array(item)
|
item = np.array(item)
|
||||||
if isinstance(item, np.ndarray):
|
if isinstance(item, np.ndarray):
|
||||||
@ -165,9 +176,6 @@ class InstanceData(BaseDataElement):
|
|||||||
# More details in https://github.com/numpy/numpy/issues/9464
|
# More details in https://github.com/numpy/numpy/issues/9464
|
||||||
item = item.astype(np.int64) if item.dtype == np.int32 else item
|
item = item.astype(np.int64) if item.dtype == np.int32 else item
|
||||||
item = torch.from_numpy(item)
|
item = torch.from_numpy(item)
|
||||||
assert isinstance(
|
|
||||||
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
|
|
||||||
torch.BoolTensor, torch.cuda.BoolTensor))
|
|
||||||
|
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
return getattr(self, item)
|
return getattr(self, item)
|
||||||
@ -183,7 +191,7 @@ class InstanceData(BaseDataElement):
|
|||||||
if isinstance(item, torch.Tensor):
|
if isinstance(item, torch.Tensor):
|
||||||
assert item.dim() == 1, 'Only support to get the' \
|
assert item.dim() == 1, 'Only support to get the' \
|
||||||
' values along the first dimension.'
|
' values along the first dimension.'
|
||||||
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
|
if isinstance(item, BoolTypeTensor.__args__):
|
||||||
assert len(item) == len(self), 'The shape of the ' \
|
assert len(item) == len(self), 'The shape of the ' \
|
||||||
'input(BoolTensor) ' \
|
'input(BoolTensor) ' \
|
||||||
f'{len(item)} ' \
|
f'{len(item)} ' \
|
||||||
@ -202,8 +210,7 @@ class InstanceData(BaseDataElement):
|
|||||||
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
|
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
|
||||||
and hasattr(v, 'cat')):
|
and hasattr(v, 'cat')):
|
||||||
# convert to indexes from BoolTensor
|
# convert to indexes from BoolTensor
|
||||||
if isinstance(item,
|
if isinstance(item, BoolTypeTensor.__args__):
|
||||||
(torch.BoolTensor, torch.cuda.BoolTensor)):
|
|
||||||
indexes = torch.nonzero(item).view(
|
indexes = torch.nonzero(item).view(
|
||||||
-1).cpu().numpy().tolist()
|
-1).cpu().numpy().tolist()
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user