mmpretrain/tests/test_structures/test_datasample.py

125 lines
4.5 KiB
Python
Raw Normal View History

2022-05-07 18:01:08 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.structures import LabelData
2022-05-07 18:01:08 +08:00
from mmcls.structures import ClsDataSample
2022-05-07 18:01:08 +08:00
class TestClsDataSample(TestCase):
def _test_set_label(self, key):
data_sample = ClsDataSample()
method = getattr(data_sample, 'set_' + key)
# Test number
method(1)
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
2022-05-13 16:55:36 +08:00
self.assertIsInstance(label.label, torch.LongTensor)
2022-05-07 18:01:08 +08:00
# Test tensor with single number
method(torch.tensor(2))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
2022-05-13 16:55:36 +08:00
self.assertIsInstance(label.label, torch.LongTensor)
2022-05-07 18:01:08 +08:00
# Test array with single number
method(np.array(3))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
2022-05-13 16:55:36 +08:00
self.assertIsInstance(label.label, torch.LongTensor)
2022-05-07 18:01:08 +08:00
# Test tensor
method(torch.tensor([1, 2, 3]))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
2022-05-13 16:55:36 +08:00
self.assertIsInstance(label.label, torch.Tensor)
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
2022-05-07 18:01:08 +08:00
# Test array
method(np.array([1, 2, 3]))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
2022-05-13 16:55:36 +08:00
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
2022-05-07 18:01:08 +08:00
# Test Sequence
method([1, 2, 3])
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
2022-05-13 16:55:36 +08:00
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
2022-05-07 18:01:08 +08:00
# Test unavailable type
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
method('hi')
def test_set_gt_label(self):
self._test_set_label('gt_label')
def test_set_pred_label(self):
self._test_set_label('pred_label')
2022-05-10 21:08:29 +08:00
def test_del_gt_label(self):
data_sample = ClsDataSample()
self.assertNotIn('gt_label', data_sample)
data_sample.set_gt_label(1)
self.assertIn('gt_label', data_sample)
del data_sample.gt_label
self.assertNotIn('gt_label', data_sample)
def test_del_pred_label(self):
data_sample = ClsDataSample()
self.assertNotIn('pred_label', data_sample)
data_sample.set_pred_label(1)
self.assertIn('pred_label', data_sample)
del data_sample.pred_label
self.assertNotIn('pred_label', data_sample)
2022-05-13 16:55:36 +08:00
def test_set_gt_score(self):
data_sample = ClsDataSample()
2022-05-13 16:55:36 +08:00
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.gt_label)
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.1, 0.1, 0.6, 0.1, 0.1])
# Test set again
data_sample.set_gt_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2])
2022-05-13 16:55:36 +08:00
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
def test_set_pred_score(self):
data_sample = ClsDataSample()
2022-05-13 16:55:36 +08:00
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.pred_label)
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.1, 0.1, 0.6, 0.1, 0.1])
# Test set again
data_sample.set_pred_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2])
2022-05-13 16:55:36 +08:00
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_pred_score(
torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))