From 40b58076f7e6dd7f9e28838f82e9b33dc1f1d703 Mon Sep 17 00:00:00 2001 From: "liuyuan1.vendor" Date: Sat, 7 May 2022 09:20:56 +0000 Subject: [PATCH] [Feature]: Add self data sample --- mmselfsup/core/__init__.py | 1 + mmselfsup/core/data_structures/__init__.py | 4 + .../data_structures/selfsup_data_sample.py | 138 ++++++++++++++++++ .../test_selfsup_data_sample.py | 92 ++++++++++++ 4 files changed, 235 insertions(+) create mode 100644 mmselfsup/core/data_structures/__init__.py create mode 100644 mmselfsup/core/data_structures/selfsup_data_sample.py create mode 100644 tests/test_core/test_data_structures/test_selfsup_data_sample.py diff --git a/mmselfsup/core/__init__.py b/mmselfsup/core/__init__.py index 85a65cb8..8b057d3d 100644 --- a/mmselfsup/core/__init__.py +++ b/mmselfsup/core/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .data_structures import * # noqa: F401, F403 from .hooks import * # noqa: F401,F403 from .optimizer import * # noqa: F401, F403 diff --git a/mmselfsup/core/data_structures/__init__.py b/mmselfsup/core/data_structures/__init__.py new file mode 100644 index 00000000..2c0c5b07 --- /dev/null +++ b/mmselfsup/core/data_structures/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .selfsup_data_sample import SelfSupDataSample + +__all__ = ['SelfSupDataSample'] diff --git a/mmselfsup/core/data_structures/selfsup_data_sample.py b/mmselfsup/core/data_structures/selfsup_data_sample.py new file mode 100644 index 00000000..1a9ea2aa --- /dev/null +++ b/mmselfsup/core/data_structures/selfsup_data_sample.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# TODO: will use real PixelData once it is added in mmengine +from mmengine.data import BaseDataElement +from mmengine.data import BaseDataElement as PixelData +from mmengine.data import InstanceData + + +class SelfSupDataSample(BaseDataElement): + """A data structure interface of MMSelfSup. They are used as interfaces + between different components. + + The attributes in ``SelfSupDataSample`` are divided into several parts: + + - ``gt_label``(InstanceData): The ground truth label of an image. + - ``idx``(InstanceData): The idx of an image in the dataset. + - ``mask``(PixelData): Mask used in masks image modeling. + - ``pred_label``(InstanceData): Label used in pretext task, + e.g. Relative Location. + + Examples: + >>> import torch + >>> import numpy as np + >>> from mmengine.data import InstanceData + >>> from mmselfsup.core import SelfSupDataSample + + >>> data_sample = SelfSupDataSample() + >>> gt_label = InstanceData() + >>> gt_label.value = [1] + >>> data_sample.gt_label = gt_label + >>> len(data_sample.gt_label) + 1 + >>> print(data_sample) + + _gt_label: + ) at 0x7f15c077ef10> + >>> idx = InstanceData() + >>> idx.value = [0] + >>> data_sample = SelfSupDataSample(idx=idx) + >>> assert 'idx' in data_sample + + >>> data_sample = SelfSupDataSample() + >>> mask = dict(value=np.random.rand(48, 48)) + >>> mask = PixelData(**mask) + >>> data_sample.mask = mask + >>> assert 'mask' in data_sample + >>> assert 'value' in data_sample.mask + + >>> data_sample = SelfSupDataSample() + >>> pred_label = dict(pred_label=[3]) + >>> pred_label = InstanceData(**pred_label) + >>> data_sample.pred_label = pred_label + >>> print(data_sample) + + pred_label: + ) at 0x7f15c07b8bd0> + """ + + @property + def gt_label(self) -> InstanceData: + return self._gt_label + + @gt_label.setter + def gt_label(self, value: InstanceData): + self.set_field(value, '_gt_label', dtype=InstanceData) + + @gt_label.deleter + def gt_label(self): + del self._gt_label + + @property + def idx(self) -> InstanceData: + return self._idx + + @idx.setter + def idx(self, value: InstanceData): + self.set_field(value, '_idx', dtype=InstanceData) + + @idx.deleter + def idx(self): + del self._idx + + @property + def mask(self) -> PixelData: + return self._mask + + @mask.setter + def mask(self, value: PixelData): + self.set_field(value, '_mask', dtype=PixelData) + + @mask.deleter + def mask(self): + del self._mask + + @property + def pred_label(self) -> InstanceData: + return self._pred_label + + @pred_label.setter + def pred_label(self, value: InstanceData): + self.set_field(value, '_pred_label', dtype=InstanceData) + + @pred_label.deleter + def pred_label(self): + del self._pred_label diff --git a/tests/test_core/test_data_structures/test_selfsup_data_sample.py b/tests/test_core/test_data_structures/test_selfsup_data_sample.py new file mode 100644 index 00000000..f532140a --- /dev/null +++ b/tests/test_core/test_data_structures/test_selfsup_data_sample.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +import torch +# TODO: will use real PixelData once it is added in mmengine +from mmengine.data import InstanceData + +from mmselfsup.core import SelfSupDataSample + + +def _equal(a, b): + if isinstance(a, (torch.Tensor, np.ndarray)): + return (a == b).all() + else: + return a == b + + +class TestSelfSupDataSample(TestCase): + + def test_init(self): + meta_info = dict(img_size=[256, 256]) + + det_data_sample = SelfSupDataSample(metainfo=meta_info) + assert 'img_size' in det_data_sample + assert det_data_sample.img_size == [256, 256] + + def test_setter(self): + selfsup_data_sample = SelfSupDataSample() + # test gt_label + gt_label_data = dict(value=[1]) + gt_label = InstanceData(**gt_label_data) + selfsup_data_sample.gt_label = gt_label + assert 'gt_label' in selfsup_data_sample + assert _equal(selfsup_data_sample.gt_label.value, + gt_label_data['value']) + + # test idx + idx_data = dict(value=[1]) + idx_instances = InstanceData(**idx_data) + selfsup_data_sample.idx = idx_instances + assert 'idx' in selfsup_data_sample + assert _equal(selfsup_data_sample.idx.value, idx_data['value']) + + # test mask + mask_data = dict(value=torch.rand(4, 4)) + mask = InstanceData(**mask_data) + selfsup_data_sample.mask = mask + assert 'mask' in selfsup_data_sample + assert _equal(selfsup_data_sample.mask.value, mask_data['value']) + + # test pred_label + pred_label_data = dict(value=[1]) + pred_label_instances = InstanceData(**pred_label_data) + selfsup_data_sample.pred_label = pred_label_instances + assert 'pred_label' in selfsup_data_sample + assert _equal(selfsup_data_sample.pred_label.value, + pred_label_data['value']) + + def test_deleter(self): + + gt_label_data = dict(value=[1]) + selfsup_data_sample = SelfSupDataSample() + gt_label = InstanceData(value=gt_label_data) + selfsup_data_sample.gt_label = gt_label + assert 'gt_label' in selfsup_data_sample + del selfsup_data_sample.gt_label + assert 'gt_label' not in selfsup_data_sample + + idx_data = dict(value=[1]) + selfsup_data_sample = SelfSupDataSample() + idx = InstanceData(value=idx_data) + selfsup_data_sample.idx = idx + assert 'idx' in selfsup_data_sample + del selfsup_data_sample.idx + assert 'idx' not in selfsup_data_sample + + mask_data = dict(value=torch.rand(4, 4)) + selfsup_data_sample = SelfSupDataSample() + mask = InstanceData(value=mask_data) + selfsup_data_sample.mask = mask + assert 'mask' in selfsup_data_sample + del selfsup_data_sample.mask + assert 'mask' not in selfsup_data_sample + + pred_label_data = dict(value=[1]) + selfsup_data_sample = SelfSupDataSample() + pred_label = InstanceData(value=pred_label_data) + selfsup_data_sample.pred_label = pred_label + assert 'pred_label' in selfsup_data_sample + del selfsup_data_sample.pred_label + assert 'pred_label' not in selfsup_data_sample