mmsegmentation/tests/test_data/test_seg_data_sample.py

78 lines
2.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.data import PixelData
from mmseg.structures import SegDataSample
def _equal(a, b):
if isinstance(a, (torch.Tensor, np.ndarray)):
return (a == b).all()
else:
return a == b
class TestSegDataSample(TestCase):
def test_init(self):
meta_info = dict(
img_size=[256, 256],
scale_factor=np.array([1.5, 1.5]),
img_shape=torch.rand(4))
seg_data_sample = SegDataSample(metainfo=meta_info)
assert 'img_size' in seg_data_sample
assert seg_data_sample.img_size == [256, 256]
assert seg_data_sample.get('img_size') == [256, 256]
def test_setter(self):
seg_data_sample = SegDataSample()
# test gt_sem_seg
gt_sem_seg_data = dict(sem_seg=torch.rand(5, 4, 2))
gt_sem_seg = PixelData(**gt_sem_seg_data)
seg_data_sample.gt_sem_seg = gt_sem_seg
assert 'gt_sem_seg' in seg_data_sample
assert _equal(seg_data_sample.gt_sem_seg.sem_seg,
gt_sem_seg_data['sem_seg'])
# test pred_sem_seg
pred_sem_seg_data = dict(sem_seg=torch.rand(5, 4, 2))
pred_sem_seg = PixelData(**pred_sem_seg_data)
seg_data_sample.pred_sem_seg = pred_sem_seg
assert 'pred_sem_seg' in seg_data_sample
assert _equal(seg_data_sample.pred_sem_seg.sem_seg,
pred_sem_seg_data['sem_seg'])
# test seg_logits
seg_logits_data = dict(sem_seg=torch.rand(5, 4, 2))
seg_logits = PixelData(**seg_logits_data)
seg_data_sample.seg_logits = seg_logits
assert 'seg_logits' in seg_data_sample
assert _equal(seg_data_sample.seg_logits.sem_seg,
seg_logits_data['sem_seg'])
# test type error
with pytest.raises(AssertionError):
seg_data_sample.gt_sem_seg = torch.rand(2, 4)
with pytest.raises(AssertionError):
seg_data_sample.pred_sem_seg = torch.rand(2, 4)
with pytest.raises(AssertionError):
seg_data_sample.seg_logits = torch.rand(2, 4)
def test_deleter(self):
seg_data_sample = SegDataSample()
pred_sem_seg_data = dict(sem_seg=torch.rand(5, 4, 2))
pred_sem_seg = PixelData(**pred_sem_seg_data)
seg_data_sample.pred_sem_seg = pred_sem_seg
assert 'pred_sem_seg' in seg_data_sample
del seg_data_sample.pred_sem_seg
assert 'pred_sem_seg' not in seg_data_sample