[Docs] Update docstring of structures (#840)

* Update docstring of `structures`

* update docs

* add `import torch` to `examples`
pull/718/merge
Xiangxu-0103 2022-12-21 20:07:18 +08:00 committed by GitHub
parent e1f61252d4
commit 591024e533
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 71 additions and 77 deletions

View File

@ -72,6 +72,7 @@ class BaseDataElement:
model predictions. Defaults to None.
Examples:
>>> import torch
>>> from mmengine.structures import BaseDataElement
>>> gt_instances = BaseDataElement()
>>> bboxes = torch.rand((5, 4))

View File

@ -22,7 +22,7 @@ class InstanceData(BaseDataElement):
should have the same length. This design refer to
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value
in data field can be base data structure such as `torch.tensor`, `numpy.ndarray`, `list`, `str`, `tuple`,
in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`,
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes.
Examples:
@ -34,7 +34,7 @@ class InstanceData(BaseDataElement):
... def __len__(self):
... return len(self.tmp)
... def __getitem__(self, item):
... if type(item) == int:
... if isinstance(item, int):
... if item >= len(self) or item < -len(self): # type:ignore
... raise IndexError(f'Index {item} out of range!')
... else:
@ -54,6 +54,7 @@ class InstanceData(BaseDataElement):
... return str(self.tmp)
>>> from mmengine.structures import InstanceData
>>> import numpy as np
>>> import torch
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
>>> instance_data = InstanceData(metainfo=img_meta)
>>> 'img_shape' in instance_data
@ -67,41 +68,39 @@ class InstanceData(BaseDataElement):
>>> print(instance_data)
<InstanceData(
META INFORMATION
pad_shape: (800, 1196, 3)
img_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3])
det_scores: tensor([0.8, 0.7000])
det_scores: tensor([0.8000, 0.7000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]])
[0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7fb492de6280>
) at 0x7fb492de6280>
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
>>> sorted_results.det_scores
tensor([0.7000, 0.8000])
>>> print(instance_data[instance_data.det_scores > 0.75])
<InstanceData(
META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2])
masks: [[11, 21, 31, 41]]
det_scores: tensor([0.8000])
bboxes: tensor([[0.9308, 0.4000, 0.6077, 0.5554]])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
polygons: [[1, 2, 3, 4]]
) at 0x7f64ecf0ec40>
>>> print(instance_data[instance_data.det_scores > 1])
<InstanceData(
META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([], dtype=torch.int64)
masks: []
det_scores: tensor([])
bboxes: tensor([], size=(0, 4))
polygons: [[]]
polygons: []
) at 0x7f660a6a7f70>
>>> print(instance_data.cat([instance_data, instance_data]))
<InstanceData(
@ -110,44 +109,39 @@ class InstanceData(BaseDataElement):
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3, 2, 3])
bboxes: tensor([[0.7404, 0.6332, 0.1684, 0.9961],
[0.2837, 0.8112, 0.5416, 0.2810],
[0.7404, 0.6332, 0.1684, 0.9961],
[0.2837, 0.8112, 0.5416, 0.2810]])
data:
polygons: [[1, 2, 3, 4], [5, 6, 7, 8],
[1, 2, 3, 4], [5, 6, 7, 8]]
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
masks: [[11, 21, 31, 41], [51, 61, 71, 81],
[11, 21, 31, 41], [51, 61, 71, 81]]
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263],
[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7f203542feb0>
"""
def __setattr__(self, name: str, value: Sized):
"""setattr is only used to set data.
the value must have the attribute of `__len__` and have the same length
of instancedata
The value must have the attribute of `__len__` and have the same length
of `InstanceData`.
"""
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(
f'{name} has been used as a '
f'private attribute, which is immutable. ')
raise AttributeError(f'{name} has been used as a '
'private attribute, which is immutable.')
else:
assert isinstance(value,
Sized), 'value must contain `_len__` attribute'
Sized), 'value must contain `__len__` attribute'
if len(self) > 0:
assert len(value) == len(self), f'the length of ' \
assert len(value) == len(self), 'The length of ' \
f'values {len(value)} is ' \
f'not consistent with' \
f' the length of this ' \
f':obj:`InstanceData` ' \
f'{len(self)} '
'not consistent with ' \
'the length of this ' \
':obj:`InstanceData` ' \
f'{len(self)}'
super().__setattr__(name, value)
__setitem__ = __setattr__
@ -155,12 +149,12 @@ class InstanceData(BaseDataElement):
def __getitem__(self, item: IndexType) -> 'InstanceData':
"""
Args:
item (str, obj:`slice`,
obj`torch.LongTensor`, obj:`torch.BoolTensor`):
get the corresponding values according to item.
item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
:obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
Get the corresponding values according to item.
Returns:
obj:`InstanceData`: Corresponding values.
:obj:`InstanceData`: Corresponding values.
"""
if isinstance(item, list):
item = np.array(item)
@ -178,7 +172,7 @@ class InstanceData(BaseDataElement):
if isinstance(item, str):
return getattr(self, item)
if type(item) == int:
if isinstance(item, int):
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!')
else:
@ -190,14 +184,14 @@ class InstanceData(BaseDataElement):
assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
assert len(item) == len(self), f'The shape of the' \
f' input(BoolTensor)) ' \
assert len(item) == len(self), 'The shape of the ' \
'input(BoolTensor) ' \
f'{len(item)} ' \
f' does not match the shape ' \
f'of the indexed tensor ' \
f'in results_filed ' \
'does not match the shape ' \
'of the indexed tensor ' \
'in results_field ' \
f'{len(self)} at ' \
f'first dimension. '
'first dimension.'
for k, v in self.items():
if isinstance(v, torch.Tensor):
@ -207,7 +201,7 @@ class InstanceData(BaseDataElement):
elif isinstance(
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')):
# convert to indexes from boolTensor
# convert to indexes from BoolTensor
if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)):
indexes = torch.nonzero(item).view(
@ -232,7 +226,7 @@ class InstanceData(BaseDataElement):
raise ValueError(
f'The type of `{k}` is `{type(v)}`, which has no '
'attribute of `cat`, so it does not '
f'support slice with `bool`')
'support slice with `bool`')
else:
# item is a slice
@ -252,7 +246,7 @@ class InstanceData(BaseDataElement):
of :obj:`InstanceData`.
Returns:
obj:`InstanceData`
:obj:`InstanceData`
"""
assert all(
isinstance(results, InstanceData) for results in instances_list)
@ -272,7 +266,7 @@ class InstanceData(BaseDataElement):
'cause the cat operation ' \
'to fail. Please make sure all ' \
'elements in `instances_list` ' \
'have the exact same key '
'have the exact same key.'
new_data = instances_list[0].__class__(
metainfo=instances_list[0].metainfo)
@ -297,7 +291,7 @@ class InstanceData(BaseDataElement):
return new_data # type:ignore
def __len__(self) -> int:
"""int: the length of InstanceData"""
"""int: The length of InstanceData."""
if len(self._data_fields) > 0:
return len(self.values()[0])
else:

View File

@ -16,7 +16,7 @@ class LabelData(BaseDataElement):
onehot (torch.Tensor, optional): The one-hot input. The format
of input must be one-hot.
Return:
Returns:
torch.Tensor: The converted results.
"""
assert isinstance(onehot, torch.Tensor)
@ -36,7 +36,7 @@ class LabelData(BaseDataElement):
of item must be label-format.
num_classes (int): The number of classes.
Return:
Returns:
torch.Tensor: The converted results.
"""
assert isinstance(label, torch.Tensor)

View File

@ -26,12 +26,12 @@ class PixelData(BaseDataElement):
>>> pixel_data = PixelData(metainfo=metainfo,
... image=image,
... featmap=featmap)
>>> print(pixel_data)
>>> (20, 40)
>>> print(pixel_data.shape)
(20, 40)
>>> # slice
>>> slice_data = pixel_data[10:20, 20:40]
>>> assert slice_data.shape == (10, 10)
>>> assert slice_data.shape == (10, 20)
>>> slice_data = pixel_data[10, 20]
>>> assert slice_data.shape == (1, 1)
@ -47,41 +47,40 @@ class PixelData(BaseDataElement):
"""Set attributes of ``PixelData``.
If the dimension of value is 2 and its shape meet the demand, it
will automatically expend its channel-dimension.
will automatically expand its channel-dimension.
Args:
name (str): The key to access the value, stored in `PixelData`.
value (Union[torch.Tensor, np.ndarray]): The value to store in.
The type of value must be `torch.Tensor` or `np.ndarray`,
The type of value must be `torch.Tensor` or `np.ndarray`,
and its shape must meet the requirements of `PixelData`.
"""
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(
f'{name} has been used as a '
f'private attribute, which is immutable. ')
raise AttributeError(f'{name} has been used as a '
'private attribute, which is immutable.')
else:
assert isinstance(value, (torch.Tensor, np.ndarray)), \
f'Can set {type(value)}, only support' \
f'Can not set {type(value)}, only support' \
f' {(torch.Tensor, np.ndarray)}'
if self.shape:
assert tuple(value.shape[-2:]) == self.shape, (
f'the height and width of '
'The height and width of '
f'values {tuple(value.shape[-2:])} is '
f'not consistent with'
f' the length of this '
f':obj:`PixelData` '
f'{self.shape} ')
'not consistent with '
'the shape of this '
':obj:`PixelData` '
f'{self.shape}')
assert value.ndim in [
2, 3
], f'The dim of value must be 2 or 3, but got {value.ndim}'
if value.ndim == 2:
value = value[None]
warnings.warn(f'The shape of value will convert from '
warnings.warn('The shape of value will convert from '
f'{value.shape[-2:]} to {value.shape}')
super().__setattr__(name, value)
@ -89,17 +88,17 @@ class PixelData(BaseDataElement):
def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData':
"""
Args:
item (Sequence[Union[int, slice]]): get the corresponding values
according to item.
item (Sequence[Union[int, slice]]): Get the corresponding values
according to item.
Returns:
obj:`PixelData`: Corresponding values.
:obj:`PixelData`: Corresponding values.
"""
new_data = self.__class__(metainfo=self.metainfo)
if isinstance(item, tuple):
assert len(item) == 2, 'Only support slice height and width'
assert len(item) == 2, 'Only support to slice height and width'
tmp_item: List[slice] = list()
for index, single_item in enumerate(item[::-1]):
if isinstance(single_item, int):

View File

@ -23,7 +23,7 @@ class TmpObject:
return len(self.tmp)
def __getitem__(self, item):
if type(item) == int:
if isinstance(item, int):
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!')
else:
@ -58,13 +58,13 @@ class TmpObjectWithoutCat:
return len(self.tmp)
def __getitem__(self, item):
if type(item) == int:
if isinstance(item, int):
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!')
else:
# keep the dimension
item = slice(item, None, len(self))
return TmpObject(self.tmp[item])
return TmpObjectWithoutCat(self.tmp[item])
def __repr__(self):
return str(self.tmp)
@ -131,18 +131,18 @@ class TestInstanceData(TestCase):
with pytest.raises(AssertionError):
instance_data[item]
# when input is a bool tensor, The shape of
# when input is a bool tensor, the shape of
# the input at index 0 should equal to
# the value length in instance_data_field
with pytest.raises(AssertionError):
instance_data[item.bool()]
# test Longtensor
# test LongTensor
long_tensor = torch.randint(5, (2, ))
long_index_instance_data = instance_data[long_tensor]
assert len(long_index_instance_data) == len(long_tensor)
# test bool tensor
# test BoolTensor
bool_tensor = torch.rand(5) > 0.5
bool_index_instance_data = instance_data[bool_tensor]
assert len(bool_index_instance_data) == bool_tensor.sum()
@ -155,7 +155,7 @@ class TestInstanceData(TestCase):
list_index_instance_data = instance_data[list_index]
assert len(list_index_instance_data) == len(list_index)
# text list bool
# test list bool
list_bool = [True, False, True, False, False]
list_bool_instance_data = instance_data[list_bool]
assert len(list_bool_instance_data) == 2