168 lines
5.9 KiB
Python
168 lines
5.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from multiprocessing.reduction import ForkingPickler
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from mmengine.structures import BaseDataElement
|
|
|
|
from .utils import LABEL_TYPE, SCORE_TYPE, format_label, format_score
|
|
|
|
|
|
class DataSample(BaseDataElement):
|
|
"""A general data structure interface.
|
|
|
|
It's used as the interface between different components.
|
|
|
|
The following fields are convention names in MMPretrain, and we will set or
|
|
get these fields in data transforms, models, and metrics if needed. You can
|
|
also set any new fields for your need.
|
|
|
|
Meta fields:
|
|
img_shape (Tuple): The shape of the corresponding input image.
|
|
ori_shape (Tuple): The original shape of the corresponding image.
|
|
sample_idx (int): The index of the sample in the dataset.
|
|
num_classes (int): The number of all categories.
|
|
|
|
Data fields:
|
|
gt_label (tensor): The ground truth label.
|
|
gt_score (tensor): The ground truth score.
|
|
pred_label (tensor): The predicted label.
|
|
pred_score (tensor): The predicted score.
|
|
mask (tensor): The mask used in masked image modeling.
|
|
|
|
Examples:
|
|
>>> import torch
|
|
>>> from mmpretrain.structures import DataSample
|
|
>>>
|
|
>>> img_meta = dict(img_shape=(960, 720), num_classes=5)
|
|
>>> data_sample = DataSample(metainfo=img_meta)
|
|
>>> data_sample.set_gt_label(3)
|
|
>>> print(data_sample)
|
|
<DataSample(
|
|
META INFORMATION
|
|
num_classes: 5
|
|
img_shape: (960, 720)
|
|
DATA FIELDS
|
|
gt_label: tensor([3])
|
|
) at 0x7ff64c1c1d30>
|
|
>>>
|
|
>>> # For multi-label data
|
|
>>> data_sample = DataSample().set_gt_label([0, 1, 4])
|
|
>>> print(data_sample)
|
|
<DataSample(
|
|
DATA FIELDS
|
|
gt_label: tensor([0, 1, 4])
|
|
) at 0x7ff5b490e100>
|
|
>>>
|
|
>>> # Set one-hot format score
|
|
>>> data_sample = DataSample().set_pred_score([0.1, 0.1, 0.6, 0.1])
|
|
>>> print(data_sample)
|
|
<DataSample(
|
|
META INFORMATION
|
|
num_classes: 4
|
|
DATA FIELDS
|
|
pred_score: tensor([0.1000, 0.1000, 0.6000, 0.1000])
|
|
) at 0x7ff5b48ef6a0>
|
|
>>>
|
|
>>> # Set custom field
|
|
>>> data_sample = DataSample()
|
|
>>> data_sample.my_field = [1, 2, 3]
|
|
>>> print(data_sample)
|
|
<DataSample(
|
|
DATA FIELDS
|
|
my_field: [1, 2, 3]
|
|
) at 0x7f8e9603d3a0>
|
|
>>> print(data_sample.my_field)
|
|
[1, 2, 3]
|
|
"""
|
|
|
|
def set_gt_label(self, value: LABEL_TYPE) -> 'DataSample':
|
|
"""Set ``gt_label``."""
|
|
self.set_field(format_label(value), 'gt_label', dtype=torch.Tensor)
|
|
return self
|
|
|
|
def set_gt_score(self, value: SCORE_TYPE) -> 'DataSample':
|
|
"""Set ``gt_score``."""
|
|
score = format_score(value)
|
|
self.set_field(score, 'gt_score', dtype=torch.Tensor)
|
|
if hasattr(self, 'num_classes'):
|
|
assert len(score) == self.num_classes, \
|
|
f'The length of score {len(score)} should be '\
|
|
f'equal to the num_classes {self.num_classes}.'
|
|
else:
|
|
self.set_field(
|
|
name='num_classes', value=len(score), field_type='metainfo')
|
|
return self
|
|
|
|
def set_pred_label(self, value: LABEL_TYPE) -> 'DataSample':
|
|
"""Set ``pred_label``."""
|
|
self.set_field(format_label(value), 'pred_label', dtype=torch.Tensor)
|
|
return self
|
|
|
|
def set_pred_score(self, value: SCORE_TYPE):
|
|
"""Set ``pred_label``."""
|
|
score = format_score(value)
|
|
self.set_field(score, 'pred_score', dtype=torch.Tensor)
|
|
if hasattr(self, 'num_classes'):
|
|
assert len(score) == self.num_classes, \
|
|
f'The length of score {len(score)} should be '\
|
|
f'equal to the num_classes {self.num_classes}.'
|
|
else:
|
|
self.set_field(
|
|
name='num_classes', value=len(score), field_type='metainfo')
|
|
return self
|
|
|
|
def set_mask(self, value: Union[torch.Tensor, np.ndarray]):
|
|
if isinstance(value, np.ndarray):
|
|
value = torch.from_numpy(value)
|
|
elif not isinstance(value, torch.Tensor):
|
|
raise TypeError(f'Invalid mask type {type(value)}')
|
|
self.set_field(value, 'mask', dtype=torch.Tensor)
|
|
return self
|
|
|
|
def __repr__(self) -> str:
|
|
"""Represent the object."""
|
|
|
|
def dump_items(items, prefix=''):
|
|
return '\n'.join(f'{prefix}{k}: {v}' for k, v in items)
|
|
|
|
repr_ = ''
|
|
if len(self._metainfo_fields) > 0:
|
|
repr_ += '\n\nMETA INFORMATION\n'
|
|
repr_ += dump_items(self.metainfo_items(), prefix=' ' * 4)
|
|
if len(self._data_fields) > 0:
|
|
repr_ += '\n\nDATA FIELDS\n'
|
|
repr_ += dump_items(self.items(), prefix=' ' * 4)
|
|
|
|
repr_ = f'<{self.__class__.__name__}({repr_}\n\n) at {hex(id(self))}>'
|
|
return repr_
|
|
|
|
|
|
def _reduce_datasample(data_sample):
|
|
"""reduce DataSample."""
|
|
attr_dict = data_sample.__dict__
|
|
convert_keys = []
|
|
for k, v in attr_dict.items():
|
|
if isinstance(v, torch.Tensor):
|
|
attr_dict[k] = v.numpy()
|
|
convert_keys.append(k)
|
|
return _rebuild_datasample, (attr_dict, convert_keys)
|
|
|
|
|
|
def _rebuild_datasample(attr_dict, convert_keys):
|
|
"""rebuild DataSample."""
|
|
data_sample = DataSample()
|
|
for k in convert_keys:
|
|
attr_dict[k] = torch.from_numpy(attr_dict[k])
|
|
data_sample.__dict__ = attr_dict
|
|
return data_sample
|
|
|
|
|
|
# Due to the multi-processing strategy of PyTorch, DataSample may consume many
|
|
# file descriptors because it contains multiple tensors. Here we overwrite the
|
|
# reduce function of DataSample in ForkingPickler and convert these tensors to
|
|
# np.ndarray during pickling. It may slightly influence the performance of
|
|
# dataloader.
|
|
ForkingPickler.register(DataSample, _reduce_datasample)
|