From 98bc90bd1c0a8d38d5d564f851a106648fba693a Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Thu, 12 May 2022 22:24:42 +0800 Subject: [PATCH] [Feature] TextRecogSample --- mmocr/core/__init__.py | 5 +- mmocr/core/data_structures/__init__.py | 3 +- .../data_structures/textrecog_data_element.py | 85 +++++++++++++++++++ .../test_textrecog_data_sample.py | 60 +++++++++++++ 4 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 mmocr/core/data_structures/textrecog_data_element.py create mode 100644 tests/test_core/test_data_structures/test_textrecog_data_sample.py diff --git a/mmocr/core/__init__.py b/mmocr/core/__init__.py index ab3b67e4..aa66ea99 100644 --- a/mmocr/core/__init__.py +++ b/mmocr/core/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from . import evaluation -from .data_structures import TextDetDataSample +from .data_structures import TextDetDataSample, TextRecogDataSample from .evaluation import * # NOQA from .mask import extract_boundary, points2boundary, seg2boundary from .visualize import (det_recog_show_result, imshow_edge, imshow_node, @@ -12,6 +12,7 @@ __all__ = [ 'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img', 'show_feature', 'show_img_boundary', 'show_pred_gt', 'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label', - 'imshow_node', 'det_recog_show_result', 'imshow_edge', 'TextDetDataSample' + 'imshow_node', 'det_recog_show_result', 'imshow_edge', 'TextDetDataSample', + 'TextRecogDataSample' ] __all__ += evaluation.__all__ diff --git a/mmocr/core/data_structures/__init__.py b/mmocr/core/data_structures/__init__.py index 019154ae..aa6c5006 100644 --- a/mmocr/core/data_structures/__init__.py +++ b/mmocr/core/data_structures/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .textdet_data_sample import TextDetDataSample +from .textrecog_data_element import TextRecogDataSample -__all__ = ['TextDetDataSample'] +__all__ = ['TextDetDataSample', 'TextRecogDataSample'] diff --git a/mmocr/core/data_structures/textrecog_data_element.py b/mmocr/core/data_structures/textrecog_data_element.py new file mode 100644 index 00000000..6af8efaf --- /dev/null +++ b/mmocr/core/data_structures/textrecog_data_element.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# TODO +from mmengine.data import BaseDataElement +from mmengine.data import BaseDataElement as LabelData + + +# TODO score +class TextRecogDataSample(BaseDataElement): + """A data structure interface of MMOCR for text recognition. They are used + as interfaces between different components. + + The attributes in ``TextRecogDataSample`` are divided into two parts: + + - ``gt_text``(LabelData): Ground truth text. + - ``pred_text``(LabelData): predictions text. + + Examples: + >>> import torch + >>> import numpy as np + >>> from mmengine.data import LabelData + >>> from mmocr.core import TextRecogDataSample + >>> # gt_text + >>> data_sample = TextRecogDataSample() + >>> img_meta = dict(img_shape=(800, 1196, 3), + ... pad_shape=(800, 1216, 3)) + >>> gt_text = LabelData(metainfo=img_meta) + >>> gt_text.item = 'mmocr' + >>> data_sample.gt_text = gt_text + >>> assert 'img_shape' in data_sample.gt_text.metainfo_keys() + >>> print(data_sample) + + ) at 0x7f21fb1b9880> + >>> # pred_text + >>> pred_text = LabelData(metainfo=img_meta) + >>> pred_text.item = 'mmocr' + >>> data_sample = TextRecogDataSample(pred_text=pred_text) + >>> assert 'pred_text' in data_sample + >>> data_sample = TextRecogDataSample() + >>> gt_text_data = dict(item='mmocr') + >>> gt_text = LabelData(**gt_text_data) + >>> data_sample.gt_text = gt_text + >>> assert 'gt_text' in data_sample + >>> assert 'item' in data_sample.gt_text + """ + + @property + def gt_text(self) -> LabelData: + """LabelData: ground truth text. + """ + return self._gt_text + + @gt_text.setter + def gt_text(self, value: LabelData) -> None: + """gt_text setter.""" + self.set_field(value, '_gt_text', dtype=LabelData) + + @gt_text.deleter + def gt_text(self) -> None: + """gt_text deleter.""" + del self._gt_text + + @property + def pred_text(self) -> LabelData: + """LabelData: prediction text. + """ + return self._pred_text + + @pred_text.setter + def pred_text(self, value: LabelData) -> None: + """pred_text setter.""" + self.set_field(value, '_pred_text', dtype=LabelData) + + @pred_text.deleter + def pred_text(self) -> None: + """pred_text deleter.""" + del self._pred_text diff --git a/tests/test_core/test_data_structures/test_textrecog_data_sample.py b/tests/test_core/test_data_structures/test_textrecog_data_sample.py new file mode 100644 index 00000000..756aa217 --- /dev/null +++ b/tests/test_core/test_data_structures/test_textrecog_data_sample.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +import torch +# TODO +from mmengine.data import BaseDataElement as LabelData + +from mmocr.core import TextRecogDataSample + + +class TestTextRecogDataSample(TestCase): + + def test_init(self): + meta_info = dict( + img_size=[256, 256], + scale_factor=np.array([1.5, 1.5]), + img_shape=torch.rand(4)) + + recog_data_sample = TextRecogDataSample(metainfo=meta_info) + assert 'img_size' in recog_data_sample + + self.assertListEqual(recog_data_sample.img_size, [256, 256]) + self.assertListEqual(recog_data_sample.get('img_size'), [256, 256]) + + def test_setter(self): + recog_data_sample = TextRecogDataSample() + # test gt_text + gt_label_data = dict(item='mmocr') + gt_text = LabelData(**gt_label_data) + recog_data_sample.gt_text = gt_text + assert 'gt_text' in recog_data_sample + self.assertEqual(recog_data_sample.gt_text.item, gt_text.item) + + # test pred_text + pred_label_data = dict(item='mmocr') + pred_text = LabelData(**pred_label_data) + recog_data_sample.pred_text = pred_text + assert 'pred_text' in recog_data_sample + self.assertEqual(recog_data_sample.pred_text.item, pred_text.item) + # test type error + with self.assertRaises(AssertionError): + recog_data_sample.gt_text = torch.rand(2, 4) + with self.assertRaises(AssertionError): + recog_data_sample.pred_text = torch.rand(2, 4) + + def test_deleter(self): + recog_data_sample = TextRecogDataSample() + # test gt_text + gt_label_data = dict(item='mmocr') + gt_text = LabelData(**gt_label_data) + recog_data_sample.gt_text = gt_text + assert 'gt_text' in recog_data_sample + del recog_data_sample.gt_text + assert 'gt_text' not in recog_data_sample + + recog_data_sample.pred_text = gt_text + assert 'pred_text' in recog_data_sample + del recog_data_sample.pred_text + assert 'pred_text' not in recog_data_sample