[Enhancement] Support BoolTensor and LongTensor on Ascend NPU (#1011)

This commit is contained in:
Yinlei Sun 2023-04-10 16:31:31 +08:00 committed by GitHub
parent 8bf1ecad38
commit 1c67f9eb22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,16 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from collections.abc import Sized
from typing import List, Union
from typing import Any, List, Union
import numpy as np
import torch
from mmengine.device import get_device
from .base_data_element import BaseDataElement
IndexType = Union[str, slice, int, list, torch.LongTensor,
torch.cuda.LongTensor, torch.BoolTensor,
torch.cuda.BoolTensor, np.ndarray]
BoolTypeTensor: Union[Any]
LongTypeTensor: Union[Any]
if get_device() == 'npu':
BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor]
else:
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor,
BoolTypeTensor, np.ndarray]
# Modified from
@ -156,6 +166,7 @@ class InstanceData(BaseDataElement):
Returns:
:obj:`InstanceData`: Corresponding values.
"""
assert isinstance(item, IndexType.__args__)
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
@ -165,9 +176,6 @@ class InstanceData(BaseDataElement):
# More details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item)
assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor))
if isinstance(item, str):
return getattr(self, item)
@ -183,7 +191,7 @@ class InstanceData(BaseDataElement):
if isinstance(item, torch.Tensor):
assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
if isinstance(item, BoolTypeTensor.__args__):
assert len(item) == len(self), 'The shape of the ' \
'input(BoolTensor) ' \
f'{len(item)} ' \
@ -202,8 +210,7 @@ class InstanceData(BaseDataElement):
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')):
# convert to indexes from BoolTensor
if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)):
if isinstance(item, BoolTypeTensor.__args__):
indexes = torch.nonzero(item).view(
-1).cpu().numpy().tolist()
else: