mmselfsup/tests/test_structures/test_selfsup_data_sample.py

109 lines
4.0 KiB
Python
Raw Normal View History

2022-05-07 09:20:56 +00:00
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
# TODO: will use real PixelData once it is added in mmengine
from mmengine.structures import InstanceData, LabelData
2022-05-07 09:20:56 +00:00
from mmselfsup.structures import SelfSupDataSample
2022-05-07 09:20:56 +00:00
def _equal(a, b):
if isinstance(a, (torch.Tensor, np.ndarray)):
return (a == b).all()
else:
return a == b
class TestSelfSupDataSample(TestCase):
def test_init(self):
meta_info = dict(img_size=[256, 256])
det_data_sample = SelfSupDataSample(metainfo=meta_info)
assert 'img_size' in det_data_sample
assert det_data_sample.img_size == [256, 256]
def test_setter(self):
selfsup_data_sample = SelfSupDataSample()
# test gt_label
gt_label_data = dict(value=[1])
2022-05-16 11:11:25 +08:00
gt_label = LabelData(**gt_label_data)
2022-05-07 09:20:56 +00:00
selfsup_data_sample.gt_label = gt_label
assert 'gt_label' in selfsup_data_sample
assert _equal(selfsup_data_sample.gt_label.value,
gt_label_data['value'])
# test idx
idx_data = dict(value=[1])
idx_instances = InstanceData(**idx_data)
selfsup_data_sample.sample_idx = idx_instances
assert 'sample_idx' in selfsup_data_sample
assert _equal(selfsup_data_sample.sample_idx.value, idx_data['value'])
2022-05-07 09:20:56 +00:00
# test mask
mask_data = dict(value=torch.rand(4, 4))
mask = InstanceData(**mask_data)
selfsup_data_sample.mask = mask
assert 'mask' in selfsup_data_sample
assert _equal(selfsup_data_sample.mask.value, mask_data['value'])
# test pred_label
pred_label_data = dict(value=[1])
2022-05-16 11:11:25 +08:00
pred_label_instances = LabelData(**pred_label_data)
2022-05-07 09:20:56 +00:00
selfsup_data_sample.pred_label = pred_label_instances
assert 'pred_label' in selfsup_data_sample
assert _equal(selfsup_data_sample.pred_label.value,
pred_label_data['value'])
# test pseudo_label
pseudo_label_data = dict(value=[1])
pseudo_label_instances = InstanceData(**pseudo_label_data)
selfsup_data_sample.pseudo_label = pseudo_label_instances
assert 'pseudo_label' in selfsup_data_sample
assert _equal(selfsup_data_sample.pseudo_label.value,
pseudo_label_data['value'])
2022-05-07 09:20:56 +00:00
def test_deleter(self):
gt_label_data = dict(value=[1])
selfsup_data_sample = SelfSupDataSample()
2022-05-16 11:11:25 +08:00
gt_label = LabelData(value=gt_label_data)
2022-05-07 09:20:56 +00:00
selfsup_data_sample.gt_label = gt_label
assert 'gt_label' in selfsup_data_sample
del selfsup_data_sample.gt_label
assert 'gt_label' not in selfsup_data_sample
idx_data = dict(value=[1])
selfsup_data_sample = SelfSupDataSample()
sample_idx = InstanceData(value=idx_data)
selfsup_data_sample.sample_idx = sample_idx
assert 'sample_idx' in selfsup_data_sample
del selfsup_data_sample.sample_idx
assert 'sample_idx' not in selfsup_data_sample
2022-05-07 09:20:56 +00:00
mask_data = dict(value=torch.rand(4, 4))
selfsup_data_sample = SelfSupDataSample()
mask = InstanceData(**mask_data)
2022-05-07 09:20:56 +00:00
selfsup_data_sample.mask = mask
assert 'mask' in selfsup_data_sample
del selfsup_data_sample.mask
assert 'mask' not in selfsup_data_sample
pred_label_data = dict(value=[1])
selfsup_data_sample = SelfSupDataSample()
2022-05-16 11:11:25 +08:00
pred_label = LabelData(value=pred_label_data)
2022-05-07 09:20:56 +00:00
selfsup_data_sample.pred_label = pred_label
assert 'pred_label' in selfsup_data_sample
del selfsup_data_sample.pred_label
assert 'pred_label' not in selfsup_data_sample
pseudo_label_data = dict(value=[1])
selfsup_data_sample = SelfSupDataSample()
pseudo_label = InstanceData(value=pseudo_label_data)
selfsup_data_sample.pseudo_label = pseudo_label
assert 'pseudo_label' in selfsup_data_sample
del selfsup_data_sample.pseudo_label
assert 'pseudo_label' not in selfsup_data_sample