[Enchance] cat empty instancedata, support torch.bool for more type (#209)
* refactor instancedata * fix docs * fix commentpull/207/head^2
parent
16058fdb18
commit
5c5c03e648
|
@ -7,10 +7,11 @@ import torch
|
|||
|
||||
|
||||
class BaseDataElement:
|
||||
"""A base data structure interface of OpenMMlab.
|
||||
"""A base data interface that supports Tensor-like and dict-like
|
||||
operations.
|
||||
|
||||
Data elements refer to predicted results or ground truth labels on a
|
||||
task, such as predicted bboxes, instance masks, semantic
|
||||
A typical data elements refer to predicted results or ground truth labels
|
||||
on a task, such as predicted bboxes, instance masks, semantic
|
||||
segmentation masks, etc. Because groundtruth labels and predicted results
|
||||
often have similar properties (for example, the predicted bboxes and the
|
||||
groundtruth bboxes), MMEngine uses the same abstract data interface to
|
||||
|
@ -23,7 +24,23 @@ class BaseDataElement:
|
|||
``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and
|
||||
``LabelData`` inheriting from ``BaseDataElement`` to represent different
|
||||
types of ground truth labels or predictions.
|
||||
They are used as interfaces between different commopenets.
|
||||
|
||||
Another common data element is sample data. A sample data consists of input
|
||||
data (such as an image) and its annotations and predictions. In general,
|
||||
an image can have multiple types of annotations and/or predictions at the
|
||||
same time (for example, both pixel-level semantic segmentation annotations
|
||||
and instance-level detection bboxes annotations). All labels and
|
||||
predictions of a training sample are often passed between Dataset, Model,
|
||||
Visualizer, and Evaluator components. In order to simplify the interface
|
||||
between components, we can treat them as a large data element and
|
||||
encapsulate them. Such data elements are generally called XXDataSample in
|
||||
the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement`
|
||||
allows `BaseDataElement` as its attribute. Such a class generally
|
||||
encapsulates all the data of a sample in the algorithm library, and its
|
||||
attributes generally are various types of data elements. For example,
|
||||
MMDetection is assigned by the BaseDataElement to encapsulate all the data
|
||||
elements of the sample labeling and prediction of a sample in the
|
||||
algorithm library.
|
||||
|
||||
The attributes in ``BaseDataElement`` are divided into two parts,
|
||||
the ``metainfo`` and the ``data`` respectively.
|
||||
|
@ -70,8 +87,8 @@ class BaseDataElement:
|
|||
>>> # new
|
||||
>>> gt_instances1 = gt_instance.new(
|
||||
... metainfo=dict(img_id=1, img_shape=(640, 640)),
|
||||
... data=dict(bboxes=torch.rand((5, 4)),
|
||||
... scores=torch.rand((5,))))
|
||||
... bboxes=torch.rand((5, 4)),
|
||||
... scores=torch.rand((5,)))
|
||||
>>> gt_instances2 = gt_instances1.new()
|
||||
|
||||
>>> # add and process property
|
||||
|
@ -241,8 +258,9 @@ class BaseDataElement:
|
|||
self.set_data(dict(instance.items()))
|
||||
|
||||
def new(self,
|
||||
metainfo: dict = None,
|
||||
data: dict = None) -> 'BaseDataElement':
|
||||
*,
|
||||
metainfo: Optional[dict] = None,
|
||||
**kwargs) -> 'BaseDataElement':
|
||||
"""Return a new data element with same type. If ``metainfo`` and
|
||||
``data`` are None, the new data element will have same metainfo and
|
||||
data. If metainfo or data is not None, the new result will overwrite it
|
||||
|
@ -252,8 +270,9 @@ class BaseDataElement:
|
|||
metainfo (dict, optional): A dict contains the meta information
|
||||
of image, such as ``img_shape``, ``scale_factor``, etc.
|
||||
Defaults to None.
|
||||
data (dict, optional): A dict contains annotations of image or
|
||||
model predictions. Defaults to None.
|
||||
kwargs (dict): A dict contains annotations of image or
|
||||
model predictions.
|
||||
|
||||
Returns:
|
||||
BaseDataElement: a new data element with same type.
|
||||
"""
|
||||
|
@ -263,8 +282,8 @@ class BaseDataElement:
|
|||
new_data.set_metainfo(metainfo)
|
||||
else:
|
||||
new_data.set_metainfo(dict(self.metainfo_items()))
|
||||
if data is not None:
|
||||
new_data.set_data(data)
|
||||
if kwargs:
|
||||
new_data.set_data(kwargs)
|
||||
else:
|
||||
new_data.set_data(dict(self.items()))
|
||||
return new_data
|
||||
|
@ -388,7 +407,6 @@ class BaseDataElement:
|
|||
self._data_fields.remove(item)
|
||||
|
||||
# dict-like methods
|
||||
__setitem__ = __setattr__
|
||||
__delitem__ = __delattr__
|
||||
|
||||
def get(self, key, default=None) -> Any:
|
||||
|
@ -519,6 +537,7 @@ class BaseDataElement:
|
|||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""represent the object."""
|
||||
|
||||
def _addindent(s_: str, num_spaces: int) -> str:
|
||||
"""This func is modified from `pytorch` https://github.com/pytorch/
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import itertools
|
||||
from collections.abc import Sized
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -7,8 +8,9 @@ import torch
|
|||
|
||||
from .base_data_element import BaseDataElement
|
||||
|
||||
IndexType = Union[str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
|
||||
torch.BoolTensor, torch.cuda.BoolTensor, np.long, np.bool]
|
||||
IndexType = Union[str, slice, int, list, torch.LongTensor,
|
||||
torch.cuda.LongTensor, torch.BoolTensor,
|
||||
torch.cuda.BoolTensor, np.ndarray]
|
||||
|
||||
|
||||
# Modified from
|
||||
|
@ -19,8 +21,37 @@ class InstanceData(BaseDataElement):
|
|||
Subclass of :class:`BaseDataElement`. All value in `data_fields`
|
||||
should have the same length. This design refer to
|
||||
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
|
||||
InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value
|
||||
in data field can be base data structure such as `torch.tensor`, `numpy.dnarray`, `list`, `str`, `tuple`,
|
||||
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes.
|
||||
|
||||
Examples:
|
||||
>>> # custom data structure
|
||||
>>> class TmpObject:
|
||||
... def __init__(self, tmp) -> None:
|
||||
... assert isinstance(tmp, list)
|
||||
... self.tmp = tmp
|
||||
... def __len__(self):
|
||||
... return len(self.tmp)
|
||||
... def __getitem__(self, item):
|
||||
... if type(item) == int:
|
||||
... if item >= len(self) or item < -len(self): # type:ignore
|
||||
... raise IndexError(f'Index {item} out of range!')
|
||||
... else:
|
||||
... # keep the dimension
|
||||
... item = slice(item, None, len(self))
|
||||
... return TmpObject(self.tmp[item])
|
||||
... @staticmethod
|
||||
... def cat(tmp_objs):
|
||||
... assert all(isinstance(results, TmpObject) for results in tmp_objs)
|
||||
... if len(tmp_objs) == 1:
|
||||
... return tmp_objs[0]
|
||||
... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
|
||||
... tmp_list = list(itertools.chain(*tmp_list))
|
||||
... new_data = TmpObject(tmp_list)
|
||||
... return new_data
|
||||
... def __repr__(self):
|
||||
... return str(self.tmp)
|
||||
>>> from mmengine.data import InstanceData
|
||||
>>> import numpy as np
|
||||
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
|
||||
|
@ -30,44 +61,69 @@ class InstanceData(BaseDataElement):
|
|||
>>> instance_data.det_labels = torch.LongTensor([2, 3])
|
||||
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
|
||||
>>> instance_data.bboxes = torch.rand((2, 4))
|
||||
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
|
||||
>>> len(instance_data)
|
||||
4
|
||||
2
|
||||
>>> print(instance_data)
|
||||
<InstanceData(
|
||||
|
||||
META INFORMATION
|
||||
pad_shape: (800, 1196, 3)
|
||||
img_shape: (800, 1216, 3)
|
||||
|
||||
DATA FIELDS
|
||||
det_labels: tensor([2, 3])
|
||||
det_scores: tensor([0.8, 0.7000])
|
||||
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
|
||||
[0.8101, 0.3105, 0.5123, 0.6263]])
|
||||
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
) at 0x7fb492de6280>
|
||||
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
|
||||
>>> sorted_results.det_scores
|
||||
tensor([0.7000, 0.8000])
|
||||
>>> print(instance_data[instance_data.det_scores > 0.75])
|
||||
<InstanceData(
|
||||
|
||||
META INFORMATION
|
||||
pad_shape: (800, 1216, 3)
|
||||
img_shape: (800, 1196, 3)
|
||||
|
||||
DATA FIELDS
|
||||
det_labels: tensor([0])
|
||||
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
|
||||
det_labels: tensor([2])
|
||||
masks: [[11, 21, 31, 41]]
|
||||
det_scores: tensor([0.8000])
|
||||
) at 0x7fb5cf6e2790>
|
||||
>>> instance_data[instance_data.det_scores > 0.75].det_labels
|
||||
tensor([0])
|
||||
>>> instance_data[instance_data.det_scores > 0.75].det_scores
|
||||
tensor([0.8000])
|
||||
bboxes: tensor([[0.9308, 0.4000, 0.6077, 0.5554]])
|
||||
polygons: [[1, 2, 3, 4]]
|
||||
) at 0x7f64ecf0ec40>
|
||||
>>> print(instance_data[instance_data.det_scores > 1])
|
||||
<InstanceData(
|
||||
META INFORMATION
|
||||
pad_shape: (800, 1216, 3)
|
||||
img_shape: (800, 1196, 3)
|
||||
DATA FIELDS
|
||||
det_labels: tensor([], dtype=torch.int64)
|
||||
masks: []
|
||||
det_scores: tensor([])
|
||||
bboxes: tensor([], size=(0, 4))
|
||||
polygons: [[]]
|
||||
) at 0x7f660a6a7f70>
|
||||
>>> print(instance_data.cat([instance_data, instance_data]))
|
||||
<InstanceData(
|
||||
META INFORMATION
|
||||
img_shape: (800, 1196, 3)
|
||||
pad_shape: (800, 1216, 3)
|
||||
DATA FIELDS
|
||||
det_labels: tensor([2, 3, 2, 3])
|
||||
bboxes: tensor([[0.7404, 0.6332, 0.1684, 0.9961],
|
||||
[0.2837, 0.8112, 0.5416, 0.2810],
|
||||
[0.7404, 0.6332, 0.1684, 0.9961],
|
||||
[0.2837, 0.8112, 0.5416, 0.2810]])
|
||||
data:
|
||||
polygons: [[1, 2, 3, 4], [5, 6, 7, 8],
|
||||
[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
|
||||
masks: [[11, 21, 31, 41], [51, 61, 71, 81],
|
||||
[11, 21, 31, 41], [51, 61, 71, 81]]
|
||||
) at 0x7f203542feb0>
|
||||
"""
|
||||
|
||||
def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray,
|
||||
list]):
|
||||
def __setattr__(self, name: str, value: Sized):
|
||||
"""setattr is only used to set data.
|
||||
|
||||
the value must have the attribute of `__len__` and have the same length
|
||||
|
@ -82,9 +138,8 @@ class InstanceData(BaseDataElement):
|
|||
f'private attribute, which is immutable. ')
|
||||
|
||||
else:
|
||||
assert isinstance(value, (torch.Tensor, np.ndarray, list)), \
|
||||
f'Can set {type(value)}, only support' \
|
||||
f' {(torch.Tensor, np.ndarray, list)}'
|
||||
assert isinstance(value,
|
||||
Sized), 'value must contain `_len__` attribute'
|
||||
|
||||
if len(self) > 0:
|
||||
assert len(value) == len(self), f'the length of ' \
|
||||
|
@ -95,6 +150,8 @@ class InstanceData(BaseDataElement):
|
|||
f'{len(self)} '
|
||||
super().__setattr__(name, value)
|
||||
|
||||
__setitem__ = __setattr__
|
||||
|
||||
def __getitem__(self, item: IndexType) -> 'InstanceData':
|
||||
"""
|
||||
Args:
|
||||
|
@ -105,11 +162,13 @@ class InstanceData(BaseDataElement):
|
|||
Returns:
|
||||
obj:`InstanceData`: Corresponding values.
|
||||
"""
|
||||
assert len(self) > 0, ' This is a empty instance'
|
||||
|
||||
if isinstance(item, list):
|
||||
item = np.array(item)
|
||||
if isinstance(item, np.ndarray):
|
||||
item = torch.from_numpy(item)
|
||||
assert isinstance(
|
||||
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
|
||||
torch.BoolTensor, torch.cuda.BoolTensor, np.bool, np.long))
|
||||
torch.BoolTensor, torch.cuda.BoolTensor))
|
||||
|
||||
if isinstance(item, str):
|
||||
return getattr(self, item)
|
||||
|
@ -121,7 +180,7 @@ class InstanceData(BaseDataElement):
|
|||
# keep the dimension
|
||||
item = slice(item, None, len(self))
|
||||
|
||||
new_data = self.new(data={})
|
||||
new_data = self.__class__(metainfo=self.metainfo)
|
||||
if isinstance(item, torch.Tensor):
|
||||
assert item.dim() == 1, 'Only support to get the' \
|
||||
' values along the first dimension.'
|
||||
|
@ -140,17 +199,36 @@ class InstanceData(BaseDataElement):
|
|||
new_data[k] = v[item]
|
||||
elif isinstance(v, np.ndarray):
|
||||
new_data[k] = v[item.cpu().numpy()]
|
||||
elif isinstance(v, list):
|
||||
r_list = []
|
||||
elif isinstance(
|
||||
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
|
||||
and hasattr(v, 'cat')):
|
||||
# convert to indexes from boolTensor
|
||||
if isinstance(item,
|
||||
(torch.BoolTensor, torch.cuda.BoolTensor)):
|
||||
indexes = torch.nonzero(item).view(-1)
|
||||
indexes = torch.nonzero(item).view(
|
||||
-1).cpu().numpy().tolist()
|
||||
else:
|
||||
indexes = item
|
||||
for index in indexes:
|
||||
r_list.append(v[index])
|
||||
new_data[k] = r_list
|
||||
indexes = item.cpu().numpy().tolist()
|
||||
slice_list = []
|
||||
if indexes:
|
||||
for index in indexes:
|
||||
slice_list.append(slice(index, None, len(v)))
|
||||
else:
|
||||
slice_list.append(slice(None, 0, None))
|
||||
r_list = [v[s] for s in slice_list]
|
||||
if isinstance(v, (str, list, tuple)):
|
||||
new_value = r_list[0]
|
||||
for r in r_list[1:]:
|
||||
new_value = new_value + r
|
||||
else:
|
||||
new_value = v.cat(r_list)
|
||||
new_data[k] = new_value
|
||||
else:
|
||||
raise ValueError(
|
||||
f'The type of `{k}` is `{type(v)}`, which has no '
|
||||
'attribute of `cat`, so it does not '
|
||||
f'support slice with `bool`')
|
||||
|
||||
else:
|
||||
# item is a slice
|
||||
for k, v in self.items():
|
||||
|
@ -191,24 +269,30 @@ class InstanceData(BaseDataElement):
|
|||
'elements in `instances_list` ' \
|
||||
'have the exact same key '
|
||||
|
||||
new_data = instances_list[0].new(data={})
|
||||
new_data = instances_list[0].__class__(
|
||||
metainfo=instances_list[0].metainfo)
|
||||
for k in instances_list[0].keys():
|
||||
values = [results[k] for results in instances_list]
|
||||
v0 = values[0]
|
||||
if isinstance(v0, torch.Tensor):
|
||||
values = torch.cat(values, dim=0)
|
||||
new_values = torch.cat(values, dim=0)
|
||||
elif isinstance(v0, np.ndarray):
|
||||
values = np.concatenate(values, axis=0)
|
||||
elif isinstance(v0, list):
|
||||
values = list(itertools.chain(*values))
|
||||
new_values = np.concatenate(values, axis=0)
|
||||
elif isinstance(v0, (str, list, tuple)):
|
||||
new_values = v0[:]
|
||||
for v in values[1:]:
|
||||
new_values += v
|
||||
elif hasattr(v0, 'cat'):
|
||||
new_values = v0.cat(values)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Can not concat the {k} which is a {type(v0)}')
|
||||
new_data[k] = values
|
||||
f'The type of `{k}` is `{type(v0)}` which has no '
|
||||
'attribute of `cat`')
|
||||
new_data[k] = new_values
|
||||
return new_data # type:ignore
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The length of instance data."""
|
||||
"""int: the length of InstanceData"""
|
||||
if len(self._data_fields) > 0:
|
||||
return len(self.values()[0])
|
||||
else:
|
||||
|
|
|
@ -112,7 +112,7 @@ class TestBaseDataElement(TestCase):
|
|||
|
||||
# test new() with arguments
|
||||
metainfo, data = self.setup_data()
|
||||
new_instances = instances.new(metainfo=metainfo, data=data)
|
||||
new_instances = instances.new(metainfo=metainfo, **data)
|
||||
assert type(new_instances) == type(instances)
|
||||
assert id(new_instances.gt_instances) != id(instances.gt_instances)
|
||||
_, new_data = self.setup_data()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import itertools
|
||||
import random
|
||||
from unittest import TestCase
|
||||
|
||||
|
@ -9,6 +10,66 @@ import torch
|
|||
from mmengine.data import BaseDataElement, InstanceData
|
||||
|
||||
|
||||
class TmpObject:
|
||||
|
||||
def __init__(self, tmp) -> None:
|
||||
assert isinstance(tmp, list)
|
||||
if len(tmp) > 0:
|
||||
for t in tmp:
|
||||
assert isinstance(t, list)
|
||||
self.tmp = tmp
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tmp)
|
||||
|
||||
def __getitem__(self, item):
|
||||
if type(item) == int:
|
||||
if item >= len(self) or item < -len(self): # type:ignore
|
||||
raise IndexError(f'Index {item} out of range!')
|
||||
else:
|
||||
# keep the dimension
|
||||
item = slice(item, None, len(self))
|
||||
return TmpObject(self.tmp[item])
|
||||
|
||||
@staticmethod
|
||||
def cat(tmp_objs):
|
||||
assert all(isinstance(results, TmpObject) for results in tmp_objs)
|
||||
if len(tmp_objs) == 1:
|
||||
return tmp_objs[0]
|
||||
tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
|
||||
tmp_list = list(itertools.chain(*tmp_list))
|
||||
new_data = TmpObject(tmp_list)
|
||||
return new_data
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tmp)
|
||||
|
||||
|
||||
class TmpObjectWithoutCat:
|
||||
|
||||
def __init__(self, tmp) -> None:
|
||||
assert isinstance(tmp, list)
|
||||
if len(tmp) > 0:
|
||||
for t in tmp:
|
||||
assert isinstance(t, list)
|
||||
self.tmp = tmp
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tmp)
|
||||
|
||||
def __getitem__(self, item):
|
||||
if type(item) == int:
|
||||
if item >= len(self) or item < -len(self): # type:ignore
|
||||
raise IndexError(f'Index {item} out of range!')
|
||||
else:
|
||||
# keep the dimension
|
||||
item = slice(item, None, len(self))
|
||||
return TmpObject(self.tmp[item])
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tmp)
|
||||
|
||||
|
||||
class TestInstanceData(TestCase):
|
||||
|
||||
def setup_data(self):
|
||||
|
@ -18,10 +79,18 @@ class TestInstanceData(TestCase):
|
|||
instances_infos = [1] * 5
|
||||
bboxes = torch.rand((5, 4))
|
||||
labels = np.random.rand(5)
|
||||
kps = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
|
||||
ids = (1, 2, 3, 4, 5)
|
||||
name_ids = '12345'
|
||||
polygons = TmpObject(np.arange(25).reshape((5, -1)).tolist())
|
||||
instance_data = InstanceData(
|
||||
metainfo=metainfo,
|
||||
bboxes=bboxes,
|
||||
labels=labels,
|
||||
polygons=polygons,
|
||||
kps=kps,
|
||||
ids=ids,
|
||||
name_ids=name_ids,
|
||||
instances_infos=instances_infos)
|
||||
return instance_data
|
||||
|
||||
|
@ -34,10 +103,6 @@ class TestInstanceData(TestCase):
|
|||
with self.assertRaises(AttributeError):
|
||||
instance_data._data_fields = 1
|
||||
|
||||
# value only supports (torch.Tensor, np.ndarray, list)
|
||||
with self.assertRaises(AssertionError):
|
||||
instance_data.v = 'value'
|
||||
|
||||
# The data length in InstanceData must be the same
|
||||
with self.assertRaises(AssertionError):
|
||||
instance_data.keypoints = torch.rand((17, 2))
|
||||
|
@ -48,14 +113,15 @@ class TestInstanceData(TestCase):
|
|||
def test_getitem(self):
|
||||
instance_data = InstanceData()
|
||||
# length must be greater than 0
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(IndexError):
|
||||
instance_data[1]
|
||||
|
||||
instance_data = self.setup_data()
|
||||
assert len(instance_data) == 5
|
||||
slice_instance_data = instance_data[:2]
|
||||
assert len(slice_instance_data) == 2
|
||||
|
||||
slice_instance_data = instance_data[1]
|
||||
assert len(slice_instance_data) == 1
|
||||
# assert the index should in 0 ~ len(instance_data) -1
|
||||
with pytest.raises(IndexError):
|
||||
instance_data[5]
|
||||
|
@ -80,6 +146,40 @@ class TestInstanceData(TestCase):
|
|||
bool_tensor = torch.rand(5) > 0.5
|
||||
bool_index_instance_data = instance_data[bool_tensor]
|
||||
assert len(bool_index_instance_data) == bool_tensor.sum()
|
||||
bool_tensor = torch.rand(5) > 1
|
||||
empty_instance_data = instance_data[bool_tensor]
|
||||
assert len(empty_instance_data) == bool_tensor.sum()
|
||||
|
||||
# test list index
|
||||
list_index = [1, 2]
|
||||
list_index_instance_data = instance_data[list_index]
|
||||
assert len(list_index_instance_data) == len(list_index)
|
||||
|
||||
# text list bool
|
||||
list_bool = [True, False, True, False, False]
|
||||
list_bool_instance_data = instance_data[list_bool]
|
||||
assert len(list_bool_instance_data) == 2
|
||||
|
||||
# test numpy
|
||||
long_numpy = np.random.randint(5, size=2)
|
||||
long_numpy_instance_data = instance_data[long_numpy]
|
||||
assert len(long_numpy_instance_data) == len(long_numpy)
|
||||
|
||||
bool_numpy = np.random.rand(5) > 0.5
|
||||
bool_numpy_instance_data = instance_data[bool_numpy]
|
||||
assert len(bool_numpy_instance_data) == bool_numpy.sum()
|
||||
|
||||
# without cat
|
||||
instance_data.polygons = TmpObjectWithoutCat(
|
||||
np.arange(25).reshape((5, -1)).tolist())
|
||||
bool_numpy = np.random.rand(5) > 0.5
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=('The type of `polygons` is '
|
||||
f'`{type(instance_data.polygons)}`, '
|
||||
'which has no attribute of `cat`, so it does not '
|
||||
f'support slice with `bool`')):
|
||||
bool_numpy_instance_data = instance_data[bool_numpy]
|
||||
|
||||
def test_cat(self):
|
||||
instance_data_1 = self.setup_data()
|
||||
|
@ -97,6 +197,24 @@ class TestInstanceData(TestCase):
|
|||
# Input List length must be greater than 0
|
||||
with self.assertRaises(AssertionError):
|
||||
InstanceData.cat([])
|
||||
instance_data_2 = instance_data_1.clone()
|
||||
instance_data_2 = instance_data_2[torch.zeros(5) > 0.5]
|
||||
cat_instance_data = InstanceData.cat(
|
||||
[instance_data_1, instance_data_2])
|
||||
cat_instance_data = InstanceData.cat([instance_data_1])
|
||||
assert len(cat_instance_data) == 5
|
||||
|
||||
# test custom data cat
|
||||
instance_data_1.polygons = TmpObjectWithoutCat(
|
||||
np.arange(25).reshape((5, -1)).tolist())
|
||||
instance_data_2 = instance_data_1.clone()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=('The type of `polygons` is '
|
||||
f'`{type(instance_data_1.polygons)}` '
|
||||
'which has no attribute of `cat`')):
|
||||
cat_instance_data = InstanceData.cat(
|
||||
[instance_data_1, instance_data_2])
|
||||
|
||||
def test_len(self):
|
||||
instance_data = self.setup_data()
|
||||
|
|
Loading…
Reference in New Issue