132 lines
4.9 KiB
Python
132 lines
4.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import warnings
|
|
from typing import List, Sequence, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from .base_data_element import BaseDataElement
|
|
|
|
|
|
class PixelData(BaseDataElement):
|
|
"""Data structure for pixel-level annnotations or predictions.
|
|
|
|
All data items in ``data_fields`` of ``PixelData`` meet the following
|
|
requirements:
|
|
|
|
- They all have 3 dimensions in orders of channel, height, and width.
|
|
- They should have the same height and width.
|
|
|
|
Examples:
|
|
>>> metainfo = dict(
|
|
... img_id=random.randint(0, 100),
|
|
... img_shape=(random.randint(400, 600), random.randint(400, 600)))
|
|
>>> image = np.random.randint(0, 255, (4, 20, 40))
|
|
>>> featmap = torch.randint(0, 255, (10, 20, 40))
|
|
>>> pixel_data = PixelData(metainfo=metainfo,
|
|
... image=image,
|
|
... featmap=featmap)
|
|
>>> print(pixel_data)
|
|
>>> (20, 40)
|
|
|
|
>>> # slice
|
|
>>> slice_data = pixel_data[10:20, 20:40]
|
|
>>> assert slice_data.shape == (10, 10)
|
|
>>> slice_data = pixel_data[10, 20]
|
|
>>> assert slice_data.shape == (1, 1)
|
|
|
|
>>> # set
|
|
>>> pixel_data.map3 = torch.randint(0, 255, (20, 40))
|
|
>>> assert tuple(pixel_data.map3.shape) == (1, 20, 40)
|
|
>>> with self.assertRaises(AssertionError):
|
|
... # The dimension must be 3 or 2
|
|
... pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40))
|
|
"""
|
|
|
|
def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray]):
|
|
"""Set attributes of ``PixelData``.
|
|
|
|
If the dimension of value is 2 and its shape meet the demand, it
|
|
will automatically expend its channel-dimension.
|
|
|
|
Args:
|
|
name (str): The key to access the value, stored in `PixelData`.
|
|
value (Union[torch.Tensor, np.ndarray]): The value to store in.
|
|
The type of value must be `torch.Tensor` or `np.ndarray`,
|
|
and its shape must meet the requirements of `PixelData`.
|
|
"""
|
|
if name in ('_metainfo_fields', '_data_fields'):
|
|
if not hasattr(self, name):
|
|
super().__setattr__(name, value)
|
|
else:
|
|
raise AttributeError(
|
|
f'{name} has been used as a '
|
|
f'private attribute, which is immutable. ')
|
|
|
|
else:
|
|
assert isinstance(value, (torch.Tensor, np.ndarray)), \
|
|
f'Can set {type(value)}, only support' \
|
|
f' {(torch.Tensor, np.ndarray)}'
|
|
|
|
if self.shape:
|
|
assert tuple(value.shape[-2:]) == self.shape, (
|
|
f'the height and width of '
|
|
f'values {tuple(value.shape[-2:])} is '
|
|
f'not consistent with'
|
|
f' the length of this '
|
|
f':obj:`PixelData` '
|
|
f'{self.shape} ')
|
|
assert value.ndim in [
|
|
2, 3
|
|
], f'The dim of value must be 2 or 3, but got {value.ndim}'
|
|
if value.ndim == 2:
|
|
value = value[None]
|
|
warnings.warn(f'The shape of value will convert from '
|
|
f'{value.shape[-2:]} to {value.shape}')
|
|
super().__setattr__(name, value)
|
|
|
|
# TODO torch.Long/bool
|
|
def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData':
|
|
"""
|
|
Args:
|
|
item (Sequence[Union[int, slice]]): get the corresponding values
|
|
according to item.
|
|
|
|
Returns:
|
|
obj:`PixelData`: Corresponding values.
|
|
"""
|
|
|
|
new_data = self.__class__(metainfo=self.metainfo)
|
|
if isinstance(item, tuple):
|
|
|
|
assert len(item) == 2, 'Only support slice height and width'
|
|
tmp_item: List[slice] = list()
|
|
for index, single_item in enumerate(item[::-1]):
|
|
if isinstance(single_item, int):
|
|
tmp_item.insert(
|
|
0, slice(single_item, None, self.shape[-index - 1]))
|
|
elif isinstance(single_item, slice):
|
|
tmp_item.insert(0, single_item)
|
|
else:
|
|
raise TypeError(
|
|
'The type of element in input must be int or slice, '
|
|
f'but got {type(single_item)}')
|
|
tmp_item.insert(0, slice(None, None, None))
|
|
item = tuple(tmp_item)
|
|
for k, v in self.items():
|
|
setattr(new_data, k, v[item])
|
|
else:
|
|
raise TypeError(
|
|
f'Unsupported type {type(item)} for slicing PixelData')
|
|
return new_data
|
|
|
|
@property
|
|
def shape(self):
|
|
"""The shape of pixel data."""
|
|
if len(self._data_fields) > 0:
|
|
return tuple(self.values()[0].shape[-2:])
|
|
else:
|
|
return None
|
|
|
|
# TODO padding, resize
|