[Docs] Update docstring of structures (#840)
* Update docstring of `structures` * update docs * add `import torch` to `examples`pull/718/merge
parent
e1f61252d4
commit
591024e533
|
@ -72,6 +72,7 @@ class BaseDataElement:
|
|||
model predictions. Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmengine.structures import BaseDataElement
|
||||
>>> gt_instances = BaseDataElement()
|
||||
>>> bboxes = torch.rand((5, 4))
|
||||
|
|
|
@ -22,7 +22,7 @@ class InstanceData(BaseDataElement):
|
|||
should have the same length. This design refer to
|
||||
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
|
||||
InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value
|
||||
in data field can be base data structure such as `torch.tensor`, `numpy.ndarray`, `list`, `str`, `tuple`,
|
||||
in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`,
|
||||
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes.
|
||||
|
||||
Examples:
|
||||
|
@ -34,7 +34,7 @@ class InstanceData(BaseDataElement):
|
|||
... def __len__(self):
|
||||
... return len(self.tmp)
|
||||
... def __getitem__(self, item):
|
||||
... if type(item) == int:
|
||||
... if isinstance(item, int):
|
||||
... if item >= len(self) or item < -len(self): # type:ignore
|
||||
... raise IndexError(f'Index {item} out of range!')
|
||||
... else:
|
||||
|
@ -54,6 +54,7 @@ class InstanceData(BaseDataElement):
|
|||
... return str(self.tmp)
|
||||
>>> from mmengine.structures import InstanceData
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
|
||||
>>> instance_data = InstanceData(metainfo=img_meta)
|
||||
>>> 'img_shape' in instance_data
|
||||
|
@ -67,41 +68,39 @@ class InstanceData(BaseDataElement):
|
|||
>>> print(instance_data)
|
||||
<InstanceData(
|
||||
META INFORMATION
|
||||
pad_shape: (800, 1196, 3)
|
||||
img_shape: (800, 1216, 3)
|
||||
img_shape: (800, 1196, 3)
|
||||
pad_shape: (800, 1216, 3)
|
||||
DATA FIELDS
|
||||
det_labels: tensor([2, 3])
|
||||
det_scores: tensor([0.8, 0.7000])
|
||||
det_scores: tensor([0.8000, 0.7000])
|
||||
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
|
||||
[0.8101, 0.3105, 0.5123, 0.6263]])
|
||||
[0.8101, 0.3105, 0.5123, 0.6263]])
|
||||
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
) at 0x7fb492de6280>
|
||||
) at 0x7fb492de6280>
|
||||
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
|
||||
>>> sorted_results.det_scores
|
||||
tensor([0.7000, 0.8000])
|
||||
>>> print(instance_data[instance_data.det_scores > 0.75])
|
||||
<InstanceData(
|
||||
META INFORMATION
|
||||
pad_shape: (800, 1216, 3)
|
||||
img_shape: (800, 1196, 3)
|
||||
pad_shape: (800, 1216, 3)
|
||||
DATA FIELDS
|
||||
det_labels: tensor([2])
|
||||
masks: [[11, 21, 31, 41]]
|
||||
det_scores: tensor([0.8000])
|
||||
bboxes: tensor([[0.9308, 0.4000, 0.6077, 0.5554]])
|
||||
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
|
||||
polygons: [[1, 2, 3, 4]]
|
||||
) at 0x7f64ecf0ec40>
|
||||
>>> print(instance_data[instance_data.det_scores > 1])
|
||||
<InstanceData(
|
||||
META INFORMATION
|
||||
pad_shape: (800, 1216, 3)
|
||||
img_shape: (800, 1196, 3)
|
||||
pad_shape: (800, 1216, 3)
|
||||
DATA FIELDS
|
||||
det_labels: tensor([], dtype=torch.int64)
|
||||
masks: []
|
||||
det_scores: tensor([])
|
||||
bboxes: tensor([], size=(0, 4))
|
||||
polygons: [[]]
|
||||
polygons: []
|
||||
) at 0x7f660a6a7f70>
|
||||
>>> print(instance_data.cat([instance_data, instance_data]))
|
||||
<InstanceData(
|
||||
|
@ -110,44 +109,39 @@ class InstanceData(BaseDataElement):
|
|||
pad_shape: (800, 1216, 3)
|
||||
DATA FIELDS
|
||||
det_labels: tensor([2, 3, 2, 3])
|
||||
bboxes: tensor([[0.7404, 0.6332, 0.1684, 0.9961],
|
||||
[0.2837, 0.8112, 0.5416, 0.2810],
|
||||
[0.7404, 0.6332, 0.1684, 0.9961],
|
||||
[0.2837, 0.8112, 0.5416, 0.2810]])
|
||||
data:
|
||||
polygons: [[1, 2, 3, 4], [5, 6, 7, 8],
|
||||
[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
|
||||
masks: [[11, 21, 31, 41], [51, 61, 71, 81],
|
||||
[11, 21, 31, 41], [51, 61, 71, 81]]
|
||||
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
|
||||
[0.8101, 0.3105, 0.5123, 0.6263],
|
||||
[0.4997, 0.7707, 0.0595, 0.4188],
|
||||
[0.8101, 0.3105, 0.5123, 0.6263]])
|
||||
polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
) at 0x7f203542feb0>
|
||||
"""
|
||||
|
||||
def __setattr__(self, name: str, value: Sized):
|
||||
"""setattr is only used to set data.
|
||||
|
||||
the value must have the attribute of `__len__` and have the same length
|
||||
of instancedata
|
||||
The value must have the attribute of `__len__` and have the same length
|
||||
of `InstanceData`.
|
||||
"""
|
||||
if name in ('_metainfo_fields', '_data_fields'):
|
||||
if not hasattr(self, name):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
raise AttributeError(
|
||||
f'{name} has been used as a '
|
||||
f'private attribute, which is immutable. ')
|
||||
raise AttributeError(f'{name} has been used as a '
|
||||
'private attribute, which is immutable.')
|
||||
|
||||
else:
|
||||
assert isinstance(value,
|
||||
Sized), 'value must contain `_len__` attribute'
|
||||
Sized), 'value must contain `__len__` attribute'
|
||||
|
||||
if len(self) > 0:
|
||||
assert len(value) == len(self), f'the length of ' \
|
||||
assert len(value) == len(self), 'The length of ' \
|
||||
f'values {len(value)} is ' \
|
||||
f'not consistent with' \
|
||||
f' the length of this ' \
|
||||
f':obj:`InstanceData` ' \
|
||||
f'{len(self)} '
|
||||
'not consistent with ' \
|
||||
'the length of this ' \
|
||||
':obj:`InstanceData` ' \
|
||||
f'{len(self)}'
|
||||
super().__setattr__(name, value)
|
||||
|
||||
__setitem__ = __setattr__
|
||||
|
@ -155,12 +149,12 @@ class InstanceData(BaseDataElement):
|
|||
def __getitem__(self, item: IndexType) -> 'InstanceData':
|
||||
"""
|
||||
Args:
|
||||
item (str, obj:`slice`,
|
||||
obj`torch.LongTensor`, obj:`torch.BoolTensor`):
|
||||
get the corresponding values according to item.
|
||||
item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
|
||||
:obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
|
||||
Get the corresponding values according to item.
|
||||
|
||||
Returns:
|
||||
obj:`InstanceData`: Corresponding values.
|
||||
:obj:`InstanceData`: Corresponding values.
|
||||
"""
|
||||
if isinstance(item, list):
|
||||
item = np.array(item)
|
||||
|
@ -178,7 +172,7 @@ class InstanceData(BaseDataElement):
|
|||
if isinstance(item, str):
|
||||
return getattr(self, item)
|
||||
|
||||
if type(item) == int:
|
||||
if isinstance(item, int):
|
||||
if item >= len(self) or item < -len(self): # type:ignore
|
||||
raise IndexError(f'Index {item} out of range!')
|
||||
else:
|
||||
|
@ -190,14 +184,14 @@ class InstanceData(BaseDataElement):
|
|||
assert item.dim() == 1, 'Only support to get the' \
|
||||
' values along the first dimension.'
|
||||
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
|
||||
assert len(item) == len(self), f'The shape of the' \
|
||||
f' input(BoolTensor)) ' \
|
||||
assert len(item) == len(self), 'The shape of the ' \
|
||||
'input(BoolTensor) ' \
|
||||
f'{len(item)} ' \
|
||||
f' does not match the shape ' \
|
||||
f'of the indexed tensor ' \
|
||||
f'in results_filed ' \
|
||||
'does not match the shape ' \
|
||||
'of the indexed tensor ' \
|
||||
'in results_field ' \
|
||||
f'{len(self)} at ' \
|
||||
f'first dimension. '
|
||||
'first dimension.'
|
||||
|
||||
for k, v in self.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
|
@ -207,7 +201,7 @@ class InstanceData(BaseDataElement):
|
|||
elif isinstance(
|
||||
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
|
||||
and hasattr(v, 'cat')):
|
||||
# convert to indexes from boolTensor
|
||||
# convert to indexes from BoolTensor
|
||||
if isinstance(item,
|
||||
(torch.BoolTensor, torch.cuda.BoolTensor)):
|
||||
indexes = torch.nonzero(item).view(
|
||||
|
@ -232,7 +226,7 @@ class InstanceData(BaseDataElement):
|
|||
raise ValueError(
|
||||
f'The type of `{k}` is `{type(v)}`, which has no '
|
||||
'attribute of `cat`, so it does not '
|
||||
f'support slice with `bool`')
|
||||
'support slice with `bool`')
|
||||
|
||||
else:
|
||||
# item is a slice
|
||||
|
@ -252,7 +246,7 @@ class InstanceData(BaseDataElement):
|
|||
of :obj:`InstanceData`.
|
||||
|
||||
Returns:
|
||||
obj:`InstanceData`
|
||||
:obj:`InstanceData`
|
||||
"""
|
||||
assert all(
|
||||
isinstance(results, InstanceData) for results in instances_list)
|
||||
|
@ -272,7 +266,7 @@ class InstanceData(BaseDataElement):
|
|||
'cause the cat operation ' \
|
||||
'to fail. Please make sure all ' \
|
||||
'elements in `instances_list` ' \
|
||||
'have the exact same key '
|
||||
'have the exact same key.'
|
||||
|
||||
new_data = instances_list[0].__class__(
|
||||
metainfo=instances_list[0].metainfo)
|
||||
|
@ -297,7 +291,7 @@ class InstanceData(BaseDataElement):
|
|||
return new_data # type:ignore
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""int: the length of InstanceData"""
|
||||
"""int: The length of InstanceData."""
|
||||
if len(self._data_fields) > 0:
|
||||
return len(self.values()[0])
|
||||
else:
|
||||
|
|
|
@ -16,7 +16,7 @@ class LabelData(BaseDataElement):
|
|||
onehot (torch.Tensor, optional): The one-hot input. The format
|
||||
of input must be one-hot.
|
||||
|
||||
Return:
|
||||
Returns:
|
||||
torch.Tensor: The converted results.
|
||||
"""
|
||||
assert isinstance(onehot, torch.Tensor)
|
||||
|
@ -36,7 +36,7 @@ class LabelData(BaseDataElement):
|
|||
of item must be label-format.
|
||||
num_classes (int): The number of classes.
|
||||
|
||||
Return:
|
||||
Returns:
|
||||
torch.Tensor: The converted results.
|
||||
"""
|
||||
assert isinstance(label, torch.Tensor)
|
||||
|
|
|
@ -26,12 +26,12 @@ class PixelData(BaseDataElement):
|
|||
>>> pixel_data = PixelData(metainfo=metainfo,
|
||||
... image=image,
|
||||
... featmap=featmap)
|
||||
>>> print(pixel_data)
|
||||
>>> (20, 40)
|
||||
>>> print(pixel_data.shape)
|
||||
(20, 40)
|
||||
|
||||
>>> # slice
|
||||
>>> slice_data = pixel_data[10:20, 20:40]
|
||||
>>> assert slice_data.shape == (10, 10)
|
||||
>>> assert slice_data.shape == (10, 20)
|
||||
>>> slice_data = pixel_data[10, 20]
|
||||
>>> assert slice_data.shape == (1, 1)
|
||||
|
||||
|
@ -47,41 +47,40 @@ class PixelData(BaseDataElement):
|
|||
"""Set attributes of ``PixelData``.
|
||||
|
||||
If the dimension of value is 2 and its shape meet the demand, it
|
||||
will automatically expend its channel-dimension.
|
||||
will automatically expand its channel-dimension.
|
||||
|
||||
Args:
|
||||
name (str): The key to access the value, stored in `PixelData`.
|
||||
value (Union[torch.Tensor, np.ndarray]): The value to store in.
|
||||
The type of value must be `torch.Tensor` or `np.ndarray`,
|
||||
The type of value must be `torch.Tensor` or `np.ndarray`,
|
||||
and its shape must meet the requirements of `PixelData`.
|
||||
"""
|
||||
if name in ('_metainfo_fields', '_data_fields'):
|
||||
if not hasattr(self, name):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
raise AttributeError(
|
||||
f'{name} has been used as a '
|
||||
f'private attribute, which is immutable. ')
|
||||
raise AttributeError(f'{name} has been used as a '
|
||||
'private attribute, which is immutable.')
|
||||
|
||||
else:
|
||||
assert isinstance(value, (torch.Tensor, np.ndarray)), \
|
||||
f'Can set {type(value)}, only support' \
|
||||
f'Can not set {type(value)}, only support' \
|
||||
f' {(torch.Tensor, np.ndarray)}'
|
||||
|
||||
if self.shape:
|
||||
assert tuple(value.shape[-2:]) == self.shape, (
|
||||
f'the height and width of '
|
||||
'The height and width of '
|
||||
f'values {tuple(value.shape[-2:])} is '
|
||||
f'not consistent with'
|
||||
f' the length of this '
|
||||
f':obj:`PixelData` '
|
||||
f'{self.shape} ')
|
||||
'not consistent with '
|
||||
'the shape of this '
|
||||
':obj:`PixelData` '
|
||||
f'{self.shape}')
|
||||
assert value.ndim in [
|
||||
2, 3
|
||||
], f'The dim of value must be 2 or 3, but got {value.ndim}'
|
||||
if value.ndim == 2:
|
||||
value = value[None]
|
||||
warnings.warn(f'The shape of value will convert from '
|
||||
warnings.warn('The shape of value will convert from '
|
||||
f'{value.shape[-2:]} to {value.shape}')
|
||||
super().__setattr__(name, value)
|
||||
|
||||
|
@ -89,17 +88,17 @@ class PixelData(BaseDataElement):
|
|||
def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData':
|
||||
"""
|
||||
Args:
|
||||
item (Sequence[Union[int, slice]]): get the corresponding values
|
||||
according to item.
|
||||
item (Sequence[Union[int, slice]]): Get the corresponding values
|
||||
according to item.
|
||||
|
||||
Returns:
|
||||
obj:`PixelData`: Corresponding values.
|
||||
:obj:`PixelData`: Corresponding values.
|
||||
"""
|
||||
|
||||
new_data = self.__class__(metainfo=self.metainfo)
|
||||
if isinstance(item, tuple):
|
||||
|
||||
assert len(item) == 2, 'Only support slice height and width'
|
||||
assert len(item) == 2, 'Only support to slice height and width'
|
||||
tmp_item: List[slice] = list()
|
||||
for index, single_item in enumerate(item[::-1]):
|
||||
if isinstance(single_item, int):
|
||||
|
|
|
@ -23,7 +23,7 @@ class TmpObject:
|
|||
return len(self.tmp)
|
||||
|
||||
def __getitem__(self, item):
|
||||
if type(item) == int:
|
||||
if isinstance(item, int):
|
||||
if item >= len(self) or item < -len(self): # type:ignore
|
||||
raise IndexError(f'Index {item} out of range!')
|
||||
else:
|
||||
|
@ -58,13 +58,13 @@ class TmpObjectWithoutCat:
|
|||
return len(self.tmp)
|
||||
|
||||
def __getitem__(self, item):
|
||||
if type(item) == int:
|
||||
if isinstance(item, int):
|
||||
if item >= len(self) or item < -len(self): # type:ignore
|
||||
raise IndexError(f'Index {item} out of range!')
|
||||
else:
|
||||
# keep the dimension
|
||||
item = slice(item, None, len(self))
|
||||
return TmpObject(self.tmp[item])
|
||||
return TmpObjectWithoutCat(self.tmp[item])
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tmp)
|
||||
|
@ -131,18 +131,18 @@ class TestInstanceData(TestCase):
|
|||
with pytest.raises(AssertionError):
|
||||
instance_data[item]
|
||||
|
||||
# when input is a bool tensor, The shape of
|
||||
# when input is a bool tensor, the shape of
|
||||
# the input at index 0 should equal to
|
||||
# the value length in instance_data_field
|
||||
with pytest.raises(AssertionError):
|
||||
instance_data[item.bool()]
|
||||
|
||||
# test Longtensor
|
||||
# test LongTensor
|
||||
long_tensor = torch.randint(5, (2, ))
|
||||
long_index_instance_data = instance_data[long_tensor]
|
||||
assert len(long_index_instance_data) == len(long_tensor)
|
||||
|
||||
# test bool tensor
|
||||
# test BoolTensor
|
||||
bool_tensor = torch.rand(5) > 0.5
|
||||
bool_index_instance_data = instance_data[bool_tensor]
|
||||
assert len(bool_index_instance_data) == bool_tensor.sum()
|
||||
|
@ -155,7 +155,7 @@ class TestInstanceData(TestCase):
|
|||
list_index_instance_data = instance_data[list_index]
|
||||
assert len(list_index_instance_data) == len(list_index)
|
||||
|
||||
# text list bool
|
||||
# test list bool
|
||||
list_bool = [True, False, True, False, False]
|
||||
list_bool_instance_data = instance_data[list_bool]
|
||||
assert len(list_bool_instance_data) == 2
|
||||
|
|
Loading…
Reference in New Issue