[Feature] TextRecogSample

pull/1178/head
liukuikun 2022-05-12 22:24:42 +08:00 committed by gaotongxiao
parent c920edfb3a
commit 98bc90bd1c
4 changed files with 150 additions and 3 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import evaluation
from .data_structures import TextDetDataSample
from .data_structures import TextDetDataSample, TextRecogDataSample
from .evaluation import * # NOQA
from .mask import extract_boundary, points2boundary, seg2boundary
from .visualize import (det_recog_show_result, imshow_edge, imshow_node,
@ -12,6 +12,7 @@ __all__ = [
'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img',
'show_feature', 'show_img_boundary', 'show_pred_gt',
'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label',
'imshow_node', 'det_recog_show_result', 'imshow_edge', 'TextDetDataSample'
'imshow_node', 'det_recog_show_result', 'imshow_edge', 'TextDetDataSample',
'TextRecogDataSample'
]
__all__ += evaluation.__all__

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .textdet_data_sample import TextDetDataSample
from .textrecog_data_element import TextRecogDataSample
__all__ = ['TextDetDataSample']
__all__ = ['TextDetDataSample', 'TextRecogDataSample']

View File

@ -0,0 +1,85 @@
# Copyright (c) OpenMMLab. All rights reserved.
# TODO
from mmengine.data import BaseDataElement
from mmengine.data import BaseDataElement as LabelData
# TODO score
class TextRecogDataSample(BaseDataElement):
"""A data structure interface of MMOCR for text recognition. They are used
as interfaces between different components.
The attributes in ``TextRecogDataSample`` are divided into two parts:
- ``gt_text``(LabelData): Ground truth text.
- ``pred_text``(LabelData): predictions text.
Examples:
>>> import torch
>>> import numpy as np
>>> from mmengine.data import LabelData
>>> from mmocr.core import TextRecogDataSample
>>> # gt_text
>>> data_sample = TextRecogDataSample()
>>> img_meta = dict(img_shape=(800, 1196, 3),
... pad_shape=(800, 1216, 3))
>>> gt_text = LabelData(metainfo=img_meta)
>>> gt_text.item = 'mmocr'
>>> data_sample.gt_text = gt_text
>>> assert 'img_shape' in data_sample.gt_text.metainfo_keys()
>>> print(data_sample)
<TextRecogDataSample(
META INFORMATION
DATA FIELDS
gt_text: <LabelData(
META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
DATA FIELDS
item: 'mmocr'
) at 0x7f21fb1b9190>
) at 0x7f21fb1b9880>
>>> # pred_text
>>> pred_text = LabelData(metainfo=img_meta)
>>> pred_text.item = 'mmocr'
>>> data_sample = TextRecogDataSample(pred_text=pred_text)
>>> assert 'pred_text' in data_sample
>>> data_sample = TextRecogDataSample()
>>> gt_text_data = dict(item='mmocr')
>>> gt_text = LabelData(**gt_text_data)
>>> data_sample.gt_text = gt_text
>>> assert 'gt_text' in data_sample
>>> assert 'item' in data_sample.gt_text
"""
@property
def gt_text(self) -> LabelData:
"""LabelData: ground truth text.
"""
return self._gt_text
@gt_text.setter
def gt_text(self, value: LabelData) -> None:
"""gt_text setter."""
self.set_field(value, '_gt_text', dtype=LabelData)
@gt_text.deleter
def gt_text(self) -> None:
"""gt_text deleter."""
del self._gt_text
@property
def pred_text(self) -> LabelData:
"""LabelData: prediction text.
"""
return self._pred_text
@pred_text.setter
def pred_text(self, value: LabelData) -> None:
"""pred_text setter."""
self.set_field(value, '_pred_text', dtype=LabelData)
@pred_text.deleter
def pred_text(self) -> None:
"""pred_text deleter."""
del self._pred_text

View File

@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
# TODO
from mmengine.data import BaseDataElement as LabelData
from mmocr.core import TextRecogDataSample
class TestTextRecogDataSample(TestCase):
def test_init(self):
meta_info = dict(
img_size=[256, 256],
scale_factor=np.array([1.5, 1.5]),
img_shape=torch.rand(4))
recog_data_sample = TextRecogDataSample(metainfo=meta_info)
assert 'img_size' in recog_data_sample
self.assertListEqual(recog_data_sample.img_size, [256, 256])
self.assertListEqual(recog_data_sample.get('img_size'), [256, 256])
def test_setter(self):
recog_data_sample = TextRecogDataSample()
# test gt_text
gt_label_data = dict(item='mmocr')
gt_text = LabelData(**gt_label_data)
recog_data_sample.gt_text = gt_text
assert 'gt_text' in recog_data_sample
self.assertEqual(recog_data_sample.gt_text.item, gt_text.item)
# test pred_text
pred_label_data = dict(item='mmocr')
pred_text = LabelData(**pred_label_data)
recog_data_sample.pred_text = pred_text
assert 'pred_text' in recog_data_sample
self.assertEqual(recog_data_sample.pred_text.item, pred_text.item)
# test type error
with self.assertRaises(AssertionError):
recog_data_sample.gt_text = torch.rand(2, 4)
with self.assertRaises(AssertionError):
recog_data_sample.pred_text = torch.rand(2, 4)
def test_deleter(self):
recog_data_sample = TextRecogDataSample()
# test gt_text
gt_label_data = dict(item='mmocr')
gt_text = LabelData(**gt_label_data)
recog_data_sample.gt_text = gt_text
assert 'gt_text' in recog_data_sample
del recog_data_sample.gt_text
assert 'gt_text' not in recog_data_sample
recog_data_sample.pred_text = gt_text
assert 'pred_text' in recog_data_sample
del recog_data_sample.pred_text
assert 'pred_text' not in recog_data_sample