2022-05-07 09:20:56 +00:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
from unittest import TestCase
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
# TODO: will use real PixelData once it is added in mmengine
|
2022-08-30 11:34:04 +08:00
|
|
|
from mmengine.structures import InstanceData, LabelData
|
2022-05-07 09:20:56 +00:00
|
|
|
|
2022-07-30 16:36:48 +08:00
|
|
|
from mmselfsup.structures import SelfSupDataSample
|
2022-05-07 09:20:56 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _equal(a, b):
|
|
|
|
if isinstance(a, (torch.Tensor, np.ndarray)):
|
|
|
|
return (a == b).all()
|
|
|
|
else:
|
|
|
|
return a == b
|
|
|
|
|
|
|
|
|
|
|
|
class TestSelfSupDataSample(TestCase):
|
|
|
|
|
|
|
|
def test_init(self):
|
|
|
|
meta_info = dict(img_size=[256, 256])
|
|
|
|
|
|
|
|
det_data_sample = SelfSupDataSample(metainfo=meta_info)
|
|
|
|
assert 'img_size' in det_data_sample
|
|
|
|
assert det_data_sample.img_size == [256, 256]
|
|
|
|
|
|
|
|
def test_setter(self):
|
|
|
|
selfsup_data_sample = SelfSupDataSample()
|
|
|
|
# test gt_label
|
|
|
|
gt_label_data = dict(value=[1])
|
2022-05-16 11:11:25 +08:00
|
|
|
gt_label = LabelData(**gt_label_data)
|
2022-05-07 09:20:56 +00:00
|
|
|
selfsup_data_sample.gt_label = gt_label
|
|
|
|
assert 'gt_label' in selfsup_data_sample
|
|
|
|
assert _equal(selfsup_data_sample.gt_label.value,
|
|
|
|
gt_label_data['value'])
|
|
|
|
|
|
|
|
# test idx
|
|
|
|
idx_data = dict(value=[1])
|
|
|
|
idx_instances = InstanceData(**idx_data)
|
2022-05-27 13:02:25 +00:00
|
|
|
selfsup_data_sample.sample_idx = idx_instances
|
|
|
|
assert 'sample_idx' in selfsup_data_sample
|
|
|
|
assert _equal(selfsup_data_sample.sample_idx.value, idx_data['value'])
|
2022-05-07 09:20:56 +00:00
|
|
|
|
|
|
|
# test mask
|
|
|
|
mask_data = dict(value=torch.rand(4, 4))
|
|
|
|
mask = InstanceData(**mask_data)
|
|
|
|
selfsup_data_sample.mask = mask
|
|
|
|
assert 'mask' in selfsup_data_sample
|
|
|
|
assert _equal(selfsup_data_sample.mask.value, mask_data['value'])
|
|
|
|
|
|
|
|
# test pred_label
|
|
|
|
pred_label_data = dict(value=[1])
|
2022-05-16 11:11:25 +08:00
|
|
|
pred_label_instances = LabelData(**pred_label_data)
|
2022-05-07 09:20:56 +00:00
|
|
|
selfsup_data_sample.pred_label = pred_label_instances
|
|
|
|
assert 'pred_label' in selfsup_data_sample
|
|
|
|
assert _equal(selfsup_data_sample.pred_label.value,
|
|
|
|
pred_label_data['value'])
|
|
|
|
|
2022-05-27 13:02:25 +00:00
|
|
|
# test pseudo_label
|
|
|
|
pseudo_label_data = dict(value=[1])
|
|
|
|
pseudo_label_instances = InstanceData(**pseudo_label_data)
|
|
|
|
selfsup_data_sample.pseudo_label = pseudo_label_instances
|
|
|
|
assert 'pseudo_label' in selfsup_data_sample
|
|
|
|
assert _equal(selfsup_data_sample.pseudo_label.value,
|
|
|
|
pseudo_label_data['value'])
|
|
|
|
|
2022-05-07 09:20:56 +00:00
|
|
|
def test_deleter(self):
|
|
|
|
|
|
|
|
gt_label_data = dict(value=[1])
|
|
|
|
selfsup_data_sample = SelfSupDataSample()
|
2022-05-16 11:11:25 +08:00
|
|
|
gt_label = LabelData(value=gt_label_data)
|
2022-05-07 09:20:56 +00:00
|
|
|
selfsup_data_sample.gt_label = gt_label
|
|
|
|
assert 'gt_label' in selfsup_data_sample
|
|
|
|
del selfsup_data_sample.gt_label
|
|
|
|
assert 'gt_label' not in selfsup_data_sample
|
|
|
|
|
|
|
|
idx_data = dict(value=[1])
|
|
|
|
selfsup_data_sample = SelfSupDataSample()
|
2022-05-27 13:02:25 +00:00
|
|
|
sample_idx = InstanceData(value=idx_data)
|
|
|
|
selfsup_data_sample.sample_idx = sample_idx
|
|
|
|
assert 'sample_idx' in selfsup_data_sample
|
|
|
|
del selfsup_data_sample.sample_idx
|
|
|
|
assert 'sample_idx' not in selfsup_data_sample
|
2022-05-07 09:20:56 +00:00
|
|
|
|
|
|
|
mask_data = dict(value=torch.rand(4, 4))
|
|
|
|
selfsup_data_sample = SelfSupDataSample()
|
2022-05-27 13:02:25 +00:00
|
|
|
mask = InstanceData(**mask_data)
|
2022-05-07 09:20:56 +00:00
|
|
|
selfsup_data_sample.mask = mask
|
|
|
|
assert 'mask' in selfsup_data_sample
|
|
|
|
del selfsup_data_sample.mask
|
|
|
|
assert 'mask' not in selfsup_data_sample
|
|
|
|
|
|
|
|
pred_label_data = dict(value=[1])
|
|
|
|
selfsup_data_sample = SelfSupDataSample()
|
2022-05-16 11:11:25 +08:00
|
|
|
pred_label = LabelData(value=pred_label_data)
|
2022-05-07 09:20:56 +00:00
|
|
|
selfsup_data_sample.pred_label = pred_label
|
|
|
|
assert 'pred_label' in selfsup_data_sample
|
|
|
|
del selfsup_data_sample.pred_label
|
|
|
|
assert 'pred_label' not in selfsup_data_sample
|
2022-05-27 13:02:25 +00:00
|
|
|
|
|
|
|
pseudo_label_data = dict(value=[1])
|
|
|
|
selfsup_data_sample = SelfSupDataSample()
|
|
|
|
pseudo_label = InstanceData(value=pseudo_label_data)
|
|
|
|
selfsup_data_sample.pseudo_label = pseudo_label
|
|
|
|
assert 'pseudo_label' in selfsup_data_sample
|
|
|
|
del selfsup_data_sample.pseudo_label
|
|
|
|
assert 'pseudo_label' not in selfsup_data_sample
|