mirror of https://github.com/open-mmlab/mmocr.git
60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
import torch
|
|
from mmengine.structures import LabelData
|
|
|
|
from mmocr.structures import TextRecogDataSample
|
|
|
|
|
|
class TestTextRecogDataSample(TestCase):
|
|
|
|
def test_init(self):
|
|
meta_info = dict(
|
|
img_size=[256, 256],
|
|
scale_factor=np.array([1.5, 1.5]),
|
|
img_shape=torch.rand(4))
|
|
|
|
recog_data_sample = TextRecogDataSample(metainfo=meta_info)
|
|
assert 'img_size' in recog_data_sample
|
|
|
|
self.assertListEqual(recog_data_sample.img_size, [256, 256])
|
|
self.assertListEqual(recog_data_sample.get('img_size'), [256, 256])
|
|
|
|
def test_setter(self):
|
|
recog_data_sample = TextRecogDataSample()
|
|
# test gt_text
|
|
gt_label_data = dict(item='mmocr')
|
|
gt_text = LabelData(**gt_label_data)
|
|
recog_data_sample.gt_text = gt_text
|
|
assert 'gt_text' in recog_data_sample
|
|
self.assertEqual(recog_data_sample.gt_text.item, gt_text.item)
|
|
|
|
# test pred_text
|
|
pred_label_data = dict(item='mmocr')
|
|
pred_text = LabelData(**pred_label_data)
|
|
recog_data_sample.pred_text = pred_text
|
|
assert 'pred_text' in recog_data_sample
|
|
self.assertEqual(recog_data_sample.pred_text.item, pred_text.item)
|
|
# test type error
|
|
with self.assertRaises(AssertionError):
|
|
recog_data_sample.gt_text = torch.rand(2, 4)
|
|
with self.assertRaises(AssertionError):
|
|
recog_data_sample.pred_text = torch.rand(2, 4)
|
|
|
|
def test_deleter(self):
|
|
recog_data_sample = TextRecogDataSample()
|
|
# test gt_text
|
|
gt_label_data = dict(item='mmocr')
|
|
gt_text = LabelData(**gt_label_data)
|
|
recog_data_sample.gt_text = gt_text
|
|
assert 'gt_text' in recog_data_sample
|
|
del recog_data_sample.gt_text
|
|
assert 'gt_text' not in recog_data_sample
|
|
|
|
recog_data_sample.pred_text = gt_text
|
|
assert 'pred_text' in recog_data_sample
|
|
del recog_data_sample.pred_text
|
|
assert 'pred_text' not in recog_data_sample
|