7.7 KiB
Structures
The same as those in other OpenMMLab repositories, MMSelfSup defines a data structure, called SelfSupDataSample
, which is used to receive and pass data during the whole training/testing process.
SelfSupDataSample
inherits the BaseDataElement
implemented in MMEngine.
We recommend users to refer to BaseDataElement
for more in-depth introduction about the basics of BaseDataElement
. In this tutorials, we mainly discuss some customized
features in SelfSupDataSample.
Customized attributes in SelfSupDataSample
In MMSelfSup, except for images, SelfSupDataSample
wraps all information required by models, e.g. mask
requested by
mask image modeling(MIM) and pseudo_label
in pretext tasks. In addition to providing information, it can also accept
information generated by models, such as the prediction score. To fulfill these functionalities described above, SelfSupDataSample
defines five
customized attributes:
- gt_label (LabelData), containing the groud-truth label for image.
- sample_idx (InstanceData), containing the index of current image in data list, initialized by dataset in the beginning.
- mask (BaseDataElement), containing the mask in MIM, e.g. SimMIM, CAE.
- pred_label (LabelData), containing the label, predicted by model.
- pseudo_label (BaseDataElement), containing the pseudo label used in pretext tasks, such as the location in Relative Location.
To help users capture the basic idea of SelfSupDataSample, we give a toy example, about how to create a SelfSupDataSample
instance and set these attributes in it.
import torch
from mmselfsup.core import SelfSupDataSample
from mmengine.data import LabelData, InstanceData, BaseDataElement
selfsup_data_sample = SelfSupDataSample()
# set the gt_label in selfsup_data_sample
# gt_label should be the type of LabelData
selfsup_data_sample.gt_label = LabelData(value=torch.tensor([1]))
# setting gt_label to a type, which is not LabelData, will raise an error
selfsup_data_sample.gt_label = torch.tensor([1])
# AssertionError: tensor([1]) should be a <class 'mmengine.data.label_data.LabelData'> but got <class 'torch.Tensor'>
# set the sample_idx in selfsup_data_sample
# also, the assigned value of sample_idx should the type of InstanceData
selfsup_data_sample.sample_idx = InstanceData(value=torch.tensor([1]))
# setting the mask in selfsup_data_sample
selfsup_data_sample.mask = BaseDataElement(value=torch.ones((3, 3)))
# setting the pseudo_label in selfsup_data_sample
selfsup_data_sample.pseudo_label = InstanceData(location=torch.tensor([1, 2, 3]))
# After creating these attributes, you can easily fetch values in these attributes
print(selfsup_data_sample.gt_label.value)
# tensor([1])
print(selfsup_data_sample.mask.value.shape)
# torch.Size([3, 3])
Pack data to SelfSupDataSample in MMSelfSup
Before feeding data into model, MMSelfSup packs data into SelfSupDataSample
in data pipeline.
If you are not familiar with data pipeline, you can consult data transform. To pack data, we implement a data transform, called PackSelfSupInputs
class PackSelfSupInputs(BaseTransform):
"""Pack data into the format compatible with the inputs of algorithm.
Required Keys:
- img
Added Keys:
- data_sample
- inputs
Args:
key (str): The key of image inputted into the model. Defaults to 'img'.
algorithm_keys (List[str]): Keys of elements related
to algorithms, e.g. mask. Defaults to [].
pseudo_label_keys (List[str]): Keys set to be the attributes of
pseudo_label. Defaults to [].
meta_keys (List[str]): The keys of meta info of an image.
Defaults to [].
"""
def __init__(self,
key: Optional[str] = 'img',
algorithm_keys: Optional[List[str]] = [],
pseudo_label_keys: Optional[List[str]] = [],
meta_keys: Optional[List[str]] = []) -> None:
assert isinstance(key, str), f'key should be the type of str, instead \
of {type(key)}.'
self.key = key
self.algorithm_keys = algorithm_keys
self.pseudo_label_keys = pseudo_label_keys
self.meta_keys = meta_keys
def transform(self,
results: Dict) -> Dict[torch.Tensor, SelfSupDataSample]:
"""Method to pack the data.
Args:
results (Dict): Result dict from the data pipeline.
Returns:
Dict:
- 'inputs' (List[torch.Tensor]): The forward data of models.
- 'data_sample' (SelfSupDataSample): The annotation info of the
the forward data.
"""
packed_results = dict()
if self.key in results:
img = results[self.key]
# if img is not a list, convert it to a list
if not isinstance(img, List):
img = [img]
for i, img_ in enumerate(img):
if len(img_.shape) < 3:
img_ = np.expand_dims(img_, -1)
img_ = np.ascontiguousarray(img_.transpose(2, 0, 1))
img[i] = to_tensor(img_)
packed_results['inputs'] = img
data_sample = SelfSupDataSample()
if len(self.pseudo_label_keys) > 0:
pseudo_label = InstanceData()
data_sample.pseudo_label = pseudo_label
# gt_label, sample_idx, mask, pred_label will be set here
for key in self.algorithm_keys:
self.set_algorithm_keys(data_sample, key, results)
# keys, except for gt_label, sample_idx, mask, pred_label, will be
# set as the attributes of pseudo_label
for key in self.pseudo_label_keys:
# convert data to torch.Tensor
value = to_tensor(results[key])
setattr(data_sample.pseudo_label, key, value)
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_sample'] = data_sample
return packed_results
@classmethod
def set_algorithm_keys(self, data_sample: SelfSupDataSample, key: str,
results: Dict) -> None:
"""Set the algorithm keys of SelfSupDataSample."""
value = to_tensor(results[key])
if key == 'sample_idx':
sample_idx = InstanceData(value=value)
setattr(data_sample, 'sample_idx', sample_idx)
elif key == 'mask':
mask = InstanceData(value=value)
setattr(data_sample, 'mask', mask)
elif key == 'gt_label':
gt_label = LabelData(value=value)
setattr(data_sample, 'gt_label', gt_label)
elif key == 'pred_label':
pred_label = LabelData(value=value)
setattr(data_sample, 'pred_label', pred_label)
else:
raise AttributeError(f'{key} is not a attribute of \
SelfSupDataSample')
algorithm_keys
are these attributes, except for pseudo_label
, in SelfSupDataSample and
pseudo_label_keys
are these sub-keys in pseudo_label of SelfSupDataSample. Thank you for reading
the whole tutorial. If you have any problems, you can raise an issue in GitHub, and we will reach you
as soon as possible.