[Enchance] cat empty instancedata, support torch.bool for more type (#209)

* refactor instancedata

* fix docs

* fix comment
pull/207/head^2
liukuikun 2022-05-06 14:00:51 +08:00 committed by GitHub
parent 16058fdb18
commit 5c5c03e648
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 279 additions and 58 deletions

View File

@ -7,10 +7,11 @@ import torch
class BaseDataElement:
"""A base data structure interface of OpenMMlab.
"""A base data interface that supports Tensor-like and dict-like
operations.
Data elements refer to predicted results or ground truth labels on a
task, such as predicted bboxes, instance masks, semantic
A typical data elements refer to predicted results or ground truth labels
on a task, such as predicted bboxes, instance masks, semantic
segmentation masks, etc. Because groundtruth labels and predicted results
often have similar properties (for example, the predicted bboxes and the
groundtruth bboxes), MMEngine uses the same abstract data interface to
@ -23,7 +24,23 @@ class BaseDataElement:
``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and
``LabelData`` inheriting from ``BaseDataElement`` to represent different
types of ground truth labels or predictions.
They are used as interfaces between different commopenets.
Another common data element is sample data. A sample data consists of input
data (such as an image) and its annotations and predictions. In general,
an image can have multiple types of annotations and/or predictions at the
same time (for example, both pixel-level semantic segmentation annotations
and instance-level detection bboxes annotations). All labels and
predictions of a training sample are often passed between Dataset, Model,
Visualizer, and Evaluator components. In order to simplify the interface
between components, we can treat them as a large data element and
encapsulate them. Such data elements are generally called XXDataSample in
the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement`
allows `BaseDataElement` as its attribute. Such a class generally
encapsulates all the data of a sample in the algorithm library, and its
attributes generally are various types of data elements. For example,
MMDetection is assigned by the BaseDataElement to encapsulate all the data
elements of the sample labeling and prediction of a sample in the
algorithm library.
The attributes in ``BaseDataElement`` are divided into two parts,
the ``metainfo`` and the ``data`` respectively.
@ -70,8 +87,8 @@ class BaseDataElement:
>>> # new
>>> gt_instances1 = gt_instance.new(
... metainfo=dict(img_id=1, img_shape=(640, 640)),
... data=dict(bboxes=torch.rand((5, 4)),
... scores=torch.rand((5,))))
... bboxes=torch.rand((5, 4)),
... scores=torch.rand((5,)))
>>> gt_instances2 = gt_instances1.new()
>>> # add and process property
@ -241,8 +258,9 @@ class BaseDataElement:
self.set_data(dict(instance.items()))
def new(self,
metainfo: dict = None,
data: dict = None) -> 'BaseDataElement':
*,
metainfo: Optional[dict] = None,
**kwargs) -> 'BaseDataElement':
"""Return a new data element with same type. If ``metainfo`` and
``data`` are None, the new data element will have same metainfo and
data. If metainfo or data is not None, the new result will overwrite it
@ -252,8 +270,9 @@ class BaseDataElement:
metainfo (dict, optional): A dict contains the meta information
of image, such as ``img_shape``, ``scale_factor``, etc.
Defaults to None.
data (dict, optional): A dict contains annotations of image or
model predictions. Defaults to None.
kwargs (dict): A dict contains annotations of image or
model predictions.
Returns:
BaseDataElement: a new data element with same type.
"""
@ -263,8 +282,8 @@ class BaseDataElement:
new_data.set_metainfo(metainfo)
else:
new_data.set_metainfo(dict(self.metainfo_items()))
if data is not None:
new_data.set_data(data)
if kwargs:
new_data.set_data(kwargs)
else:
new_data.set_data(dict(self.items()))
return new_data
@ -388,7 +407,6 @@ class BaseDataElement:
self._data_fields.remove(item)
# dict-like methods
__setitem__ = __setattr__
__delitem__ = __delattr__
def get(self, key, default=None) -> Any:
@ -519,6 +537,7 @@ class BaseDataElement:
}
def __repr__(self) -> str:
"""represent the object."""
def _addindent(s_: str, num_spaces: int) -> str:
"""This func is modified from `pytorch` https://github.com/pytorch/

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from collections.abc import Sized
from typing import List, Union
import numpy as np
@ -7,8 +8,9 @@ import torch
from .base_data_element import BaseDataElement
IndexType = Union[str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor, np.long, np.bool]
IndexType = Union[str, slice, int, list, torch.LongTensor,
torch.cuda.LongTensor, torch.BoolTensor,
torch.cuda.BoolTensor, np.ndarray]
# Modified from
@ -19,8 +21,37 @@ class InstanceData(BaseDataElement):
Subclass of :class:`BaseDataElement`. All value in `data_fields`
should have the same length. This design refer to
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value
in data field can be base data structure such as `torch.tensor`, `numpy.dnarray`, `list`, `str`, `tuple`,
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes.
Examples:
>>> # custom data structure
>>> class TmpObject:
... def __init__(self, tmp) -> None:
... assert isinstance(tmp, list)
... self.tmp = tmp
... def __len__(self):
... return len(self.tmp)
... def __getitem__(self, item):
... if type(item) == int:
... if item >= len(self) or item < -len(self): # type:ignore
... raise IndexError(f'Index {item} out of range!')
... else:
... # keep the dimension
... item = slice(item, None, len(self))
... return TmpObject(self.tmp[item])
... @staticmethod
... def cat(tmp_objs):
... assert all(isinstance(results, TmpObject) for results in tmp_objs)
... if len(tmp_objs) == 1:
... return tmp_objs[0]
... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
... tmp_list = list(itertools.chain(*tmp_list))
... new_data = TmpObject(tmp_list)
... return new_data
... def __repr__(self):
... return str(self.tmp)
>>> from mmengine.data import InstanceData
>>> import numpy as np
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
@ -30,44 +61,69 @@ class InstanceData(BaseDataElement):
>>> instance_data.det_labels = torch.LongTensor([2, 3])
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
>>> instance_data.bboxes = torch.rand((2, 4))
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
>>> len(instance_data)
4
2
>>> print(instance_data)
<InstanceData(
META INFORMATION
pad_shape: (800, 1196, 3)
img_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3])
det_scores: tensor([0.8, 0.7000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7fb492de6280>
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
>>> sorted_results.det_scores
tensor([0.7000, 0.8000])
>>> print(instance_data[instance_data.det_scores > 0.75])
<InstanceData(
META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
DATA FIELDS
det_labels: tensor([0])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
det_labels: tensor([2])
masks: [[11, 21, 31, 41]]
det_scores: tensor([0.8000])
) at 0x7fb5cf6e2790>
>>> instance_data[instance_data.det_scores > 0.75].det_labels
tensor([0])
>>> instance_data[instance_data.det_scores > 0.75].det_scores
tensor([0.8000])
bboxes: tensor([[0.9308, 0.4000, 0.6077, 0.5554]])
polygons: [[1, 2, 3, 4]]
) at 0x7f64ecf0ec40>
>>> print(instance_data[instance_data.det_scores > 1])
<InstanceData(
META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
DATA FIELDS
det_labels: tensor([], dtype=torch.int64)
masks: []
det_scores: tensor([])
bboxes: tensor([], size=(0, 4))
polygons: [[]]
) at 0x7f660a6a7f70>
>>> print(instance_data.cat([instance_data, instance_data]))
<InstanceData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3, 2, 3])
bboxes: tensor([[0.7404, 0.6332, 0.1684, 0.9961],
[0.2837, 0.8112, 0.5416, 0.2810],
[0.7404, 0.6332, 0.1684, 0.9961],
[0.2837, 0.8112, 0.5416, 0.2810]])
data:
polygons: [[1, 2, 3, 4], [5, 6, 7, 8],
[1, 2, 3, 4], [5, 6, 7, 8]]
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
masks: [[11, 21, 31, 41], [51, 61, 71, 81],
[11, 21, 31, 41], [51, 61, 71, 81]]
) at 0x7f203542feb0>
"""
def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray,
list]):
def __setattr__(self, name: str, value: Sized):
"""setattr is only used to set data.
the value must have the attribute of `__len__` and have the same length
@ -82,9 +138,8 @@ class InstanceData(BaseDataElement):
f'private attribute, which is immutable. ')
else:
assert isinstance(value, (torch.Tensor, np.ndarray, list)), \
f'Can set {type(value)}, only support' \
f' {(torch.Tensor, np.ndarray, list)}'
assert isinstance(value,
Sized), 'value must contain `_len__` attribute'
if len(self) > 0:
assert len(value) == len(self), f'the length of ' \
@ -95,6 +150,8 @@ class InstanceData(BaseDataElement):
f'{len(self)} '
super().__setattr__(name, value)
__setitem__ = __setattr__
def __getitem__(self, item: IndexType) -> 'InstanceData':
"""
Args:
@ -105,11 +162,13 @@ class InstanceData(BaseDataElement):
Returns:
obj:`InstanceData`: Corresponding values.
"""
assert len(self) > 0, ' This is a empty instance'
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
item = torch.from_numpy(item)
assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor, np.bool, np.long))
torch.BoolTensor, torch.cuda.BoolTensor))
if isinstance(item, str):
return getattr(self, item)
@ -121,7 +180,7 @@ class InstanceData(BaseDataElement):
# keep the dimension
item = slice(item, None, len(self))
new_data = self.new(data={})
new_data = self.__class__(metainfo=self.metainfo)
if isinstance(item, torch.Tensor):
assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.'
@ -140,17 +199,36 @@ class InstanceData(BaseDataElement):
new_data[k] = v[item]
elif isinstance(v, np.ndarray):
new_data[k] = v[item.cpu().numpy()]
elif isinstance(v, list):
r_list = []
elif isinstance(
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')):
# convert to indexes from boolTensor
if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)):
indexes = torch.nonzero(item).view(-1)
indexes = torch.nonzero(item).view(
-1).cpu().numpy().tolist()
else:
indexes = item
for index in indexes:
r_list.append(v[index])
new_data[k] = r_list
indexes = item.cpu().numpy().tolist()
slice_list = []
if indexes:
for index in indexes:
slice_list.append(slice(index, None, len(v)))
else:
slice_list.append(slice(None, 0, None))
r_list = [v[s] for s in slice_list]
if isinstance(v, (str, list, tuple)):
new_value = r_list[0]
for r in r_list[1:]:
new_value = new_value + r
else:
new_value = v.cat(r_list)
new_data[k] = new_value
else:
raise ValueError(
f'The type of `{k}` is `{type(v)}`, which has no '
'attribute of `cat`, so it does not '
f'support slice with `bool`')
else:
# item is a slice
for k, v in self.items():
@ -191,24 +269,30 @@ class InstanceData(BaseDataElement):
'elements in `instances_list` ' \
'have the exact same key '
new_data = instances_list[0].new(data={})
new_data = instances_list[0].__class__(
metainfo=instances_list[0].metainfo)
for k in instances_list[0].keys():
values = [results[k] for results in instances_list]
v0 = values[0]
if isinstance(v0, torch.Tensor):
values = torch.cat(values, dim=0)
new_values = torch.cat(values, dim=0)
elif isinstance(v0, np.ndarray):
values = np.concatenate(values, axis=0)
elif isinstance(v0, list):
values = list(itertools.chain(*values))
new_values = np.concatenate(values, axis=0)
elif isinstance(v0, (str, list, tuple)):
new_values = v0[:]
for v in values[1:]:
new_values += v
elif hasattr(v0, 'cat'):
new_values = v0.cat(values)
else:
raise ValueError(
f'Can not concat the {k} which is a {type(v0)}')
new_data[k] = values
f'The type of `{k}` is `{type(v0)}` which has no '
'attribute of `cat`')
new_data[k] = new_values
return new_data # type:ignore
def __len__(self) -> int:
"""The length of instance data."""
"""int: the length of InstanceData"""
if len(self._data_fields) > 0:
return len(self.values()[0])
else:

View File

@ -112,7 +112,7 @@ class TestBaseDataElement(TestCase):
# test new() with arguments
metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, data=data)
new_instances = instances.new(metainfo=metainfo, **data)
assert type(new_instances) == type(instances)
assert id(new_instances.gt_instances) != id(instances.gt_instances)
_, new_data = self.setup_data()

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import random
from unittest import TestCase
@ -9,6 +10,66 @@ import torch
from mmengine.data import BaseDataElement, InstanceData
class TmpObject:
def __init__(self, tmp) -> None:
assert isinstance(tmp, list)
if len(tmp) > 0:
for t in tmp:
assert isinstance(t, list)
self.tmp = tmp
def __len__(self):
return len(self.tmp)
def __getitem__(self, item):
if type(item) == int:
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!')
else:
# keep the dimension
item = slice(item, None, len(self))
return TmpObject(self.tmp[item])
@staticmethod
def cat(tmp_objs):
assert all(isinstance(results, TmpObject) for results in tmp_objs)
if len(tmp_objs) == 1:
return tmp_objs[0]
tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
tmp_list = list(itertools.chain(*tmp_list))
new_data = TmpObject(tmp_list)
return new_data
def __repr__(self):
return str(self.tmp)
class TmpObjectWithoutCat:
def __init__(self, tmp) -> None:
assert isinstance(tmp, list)
if len(tmp) > 0:
for t in tmp:
assert isinstance(t, list)
self.tmp = tmp
def __len__(self):
return len(self.tmp)
def __getitem__(self, item):
if type(item) == int:
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!')
else:
# keep the dimension
item = slice(item, None, len(self))
return TmpObject(self.tmp[item])
def __repr__(self):
return str(self.tmp)
class TestInstanceData(TestCase):
def setup_data(self):
@ -18,10 +79,18 @@ class TestInstanceData(TestCase):
instances_infos = [1] * 5
bboxes = torch.rand((5, 4))
labels = np.random.rand(5)
kps = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
ids = (1, 2, 3, 4, 5)
name_ids = '12345'
polygons = TmpObject(np.arange(25).reshape((5, -1)).tolist())
instance_data = InstanceData(
metainfo=metainfo,
bboxes=bboxes,
labels=labels,
polygons=polygons,
kps=kps,
ids=ids,
name_ids=name_ids,
instances_infos=instances_infos)
return instance_data
@ -34,10 +103,6 @@ class TestInstanceData(TestCase):
with self.assertRaises(AttributeError):
instance_data._data_fields = 1
# value only supports (torch.Tensor, np.ndarray, list)
with self.assertRaises(AssertionError):
instance_data.v = 'value'
# The data length in InstanceData must be the same
with self.assertRaises(AssertionError):
instance_data.keypoints = torch.rand((17, 2))
@ -48,14 +113,15 @@ class TestInstanceData(TestCase):
def test_getitem(self):
instance_data = InstanceData()
# length must be greater than 0
with self.assertRaises(AssertionError):
with self.assertRaises(IndexError):
instance_data[1]
instance_data = self.setup_data()
assert len(instance_data) == 5
slice_instance_data = instance_data[:2]
assert len(slice_instance_data) == 2
slice_instance_data = instance_data[1]
assert len(slice_instance_data) == 1
# assert the index should in 0 ~ len(instance_data) -1
with pytest.raises(IndexError):
instance_data[5]
@ -80,6 +146,40 @@ class TestInstanceData(TestCase):
bool_tensor = torch.rand(5) > 0.5
bool_index_instance_data = instance_data[bool_tensor]
assert len(bool_index_instance_data) == bool_tensor.sum()
bool_tensor = torch.rand(5) > 1
empty_instance_data = instance_data[bool_tensor]
assert len(empty_instance_data) == bool_tensor.sum()
# test list index
list_index = [1, 2]
list_index_instance_data = instance_data[list_index]
assert len(list_index_instance_data) == len(list_index)
# text list bool
list_bool = [True, False, True, False, False]
list_bool_instance_data = instance_data[list_bool]
assert len(list_bool_instance_data) == 2
# test numpy
long_numpy = np.random.randint(5, size=2)
long_numpy_instance_data = instance_data[long_numpy]
assert len(long_numpy_instance_data) == len(long_numpy)
bool_numpy = np.random.rand(5) > 0.5
bool_numpy_instance_data = instance_data[bool_numpy]
assert len(bool_numpy_instance_data) == bool_numpy.sum()
# without cat
instance_data.polygons = TmpObjectWithoutCat(
np.arange(25).reshape((5, -1)).tolist())
bool_numpy = np.random.rand(5) > 0.5
with pytest.raises(
ValueError,
match=('The type of `polygons` is '
f'`{type(instance_data.polygons)}`, '
'which has no attribute of `cat`, so it does not '
f'support slice with `bool`')):
bool_numpy_instance_data = instance_data[bool_numpy]
def test_cat(self):
instance_data_1 = self.setup_data()
@ -97,6 +197,24 @@ class TestInstanceData(TestCase):
# Input List length must be greater than 0
with self.assertRaises(AssertionError):
InstanceData.cat([])
instance_data_2 = instance_data_1.clone()
instance_data_2 = instance_data_2[torch.zeros(5) > 0.5]
cat_instance_data = InstanceData.cat(
[instance_data_1, instance_data_2])
cat_instance_data = InstanceData.cat([instance_data_1])
assert len(cat_instance_data) == 5
# test custom data cat
instance_data_1.polygons = TmpObjectWithoutCat(
np.arange(25).reshape((5, -1)).tolist())
instance_data_2 = instance_data_1.clone()
with pytest.raises(
ValueError,
match=('The type of `polygons` is '
f'`{type(instance_data_1.polygons)}` '
'which has no attribute of `cat`')):
cat_instance_data = InstanceData.cat(
[instance_data_1, instance_data_2])
def test_len(self):
instance_data = self.setup_data()