mirror of https://github.com/open-mmlab/mmocr.git
[Feature] TextRecogSample
parent
c920edfb3a
commit
98bc90bd1c
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import evaluation
|
||||
from .data_structures import TextDetDataSample
|
||||
from .data_structures import TextDetDataSample, TextRecogDataSample
|
||||
from .evaluation import * # NOQA
|
||||
from .mask import extract_boundary, points2boundary, seg2boundary
|
||||
from .visualize import (det_recog_show_result, imshow_edge, imshow_node,
|
||||
|
@ -12,6 +12,7 @@ __all__ = [
|
|||
'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img',
|
||||
'show_feature', 'show_img_boundary', 'show_pred_gt',
|
||||
'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label',
|
||||
'imshow_node', 'det_recog_show_result', 'imshow_edge', 'TextDetDataSample'
|
||||
'imshow_node', 'det_recog_show_result', 'imshow_edge', 'TextDetDataSample',
|
||||
'TextRecogDataSample'
|
||||
]
|
||||
__all__ += evaluation.__all__
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .textdet_data_sample import TextDetDataSample
|
||||
from .textrecog_data_element import TextRecogDataSample
|
||||
|
||||
__all__ = ['TextDetDataSample']
|
||||
__all__ = ['TextDetDataSample', 'TextRecogDataSample']
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# TODO
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.data import BaseDataElement as LabelData
|
||||
|
||||
|
||||
# TODO score
|
||||
class TextRecogDataSample(BaseDataElement):
|
||||
"""A data structure interface of MMOCR for text recognition. They are used
|
||||
as interfaces between different components.
|
||||
|
||||
The attributes in ``TextRecogDataSample`` are divided into two parts:
|
||||
|
||||
- ``gt_text``(LabelData): Ground truth text.
|
||||
- ``pred_text``(LabelData): predictions text.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from mmengine.data import LabelData
|
||||
>>> from mmocr.core import TextRecogDataSample
|
||||
>>> # gt_text
|
||||
>>> data_sample = TextRecogDataSample()
|
||||
>>> img_meta = dict(img_shape=(800, 1196, 3),
|
||||
... pad_shape=(800, 1216, 3))
|
||||
>>> gt_text = LabelData(metainfo=img_meta)
|
||||
>>> gt_text.item = 'mmocr'
|
||||
>>> data_sample.gt_text = gt_text
|
||||
>>> assert 'img_shape' in data_sample.gt_text.metainfo_keys()
|
||||
>>> print(data_sample)
|
||||
<TextRecogDataSample(
|
||||
META INFORMATION
|
||||
DATA FIELDS
|
||||
gt_text: <LabelData(
|
||||
META INFORMATION
|
||||
pad_shape: (800, 1216, 3)
|
||||
img_shape: (800, 1196, 3)
|
||||
DATA FIELDS
|
||||
item: 'mmocr'
|
||||
) at 0x7f21fb1b9190>
|
||||
) at 0x7f21fb1b9880>
|
||||
>>> # pred_text
|
||||
>>> pred_text = LabelData(metainfo=img_meta)
|
||||
>>> pred_text.item = 'mmocr'
|
||||
>>> data_sample = TextRecogDataSample(pred_text=pred_text)
|
||||
>>> assert 'pred_text' in data_sample
|
||||
>>> data_sample = TextRecogDataSample()
|
||||
>>> gt_text_data = dict(item='mmocr')
|
||||
>>> gt_text = LabelData(**gt_text_data)
|
||||
>>> data_sample.gt_text = gt_text
|
||||
>>> assert 'gt_text' in data_sample
|
||||
>>> assert 'item' in data_sample.gt_text
|
||||
"""
|
||||
|
||||
@property
|
||||
def gt_text(self) -> LabelData:
|
||||
"""LabelData: ground truth text.
|
||||
"""
|
||||
return self._gt_text
|
||||
|
||||
@gt_text.setter
|
||||
def gt_text(self, value: LabelData) -> None:
|
||||
"""gt_text setter."""
|
||||
self.set_field(value, '_gt_text', dtype=LabelData)
|
||||
|
||||
@gt_text.deleter
|
||||
def gt_text(self) -> None:
|
||||
"""gt_text deleter."""
|
||||
del self._gt_text
|
||||
|
||||
@property
|
||||
def pred_text(self) -> LabelData:
|
||||
"""LabelData: prediction text.
|
||||
"""
|
||||
return self._pred_text
|
||||
|
||||
@pred_text.setter
|
||||
def pred_text(self, value: LabelData) -> None:
|
||||
"""pred_text setter."""
|
||||
self.set_field(value, '_pred_text', dtype=LabelData)
|
||||
|
||||
@pred_text.deleter
|
||||
def pred_text(self) -> None:
|
||||
"""pred_text deleter."""
|
||||
del self._pred_text
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
# TODO
|
||||
from mmengine.data import BaseDataElement as LabelData
|
||||
|
||||
from mmocr.core import TextRecogDataSample
|
||||
|
||||
|
||||
class TestTextRecogDataSample(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
meta_info = dict(
|
||||
img_size=[256, 256],
|
||||
scale_factor=np.array([1.5, 1.5]),
|
||||
img_shape=torch.rand(4))
|
||||
|
||||
recog_data_sample = TextRecogDataSample(metainfo=meta_info)
|
||||
assert 'img_size' in recog_data_sample
|
||||
|
||||
self.assertListEqual(recog_data_sample.img_size, [256, 256])
|
||||
self.assertListEqual(recog_data_sample.get('img_size'), [256, 256])
|
||||
|
||||
def test_setter(self):
|
||||
recog_data_sample = TextRecogDataSample()
|
||||
# test gt_text
|
||||
gt_label_data = dict(item='mmocr')
|
||||
gt_text = LabelData(**gt_label_data)
|
||||
recog_data_sample.gt_text = gt_text
|
||||
assert 'gt_text' in recog_data_sample
|
||||
self.assertEqual(recog_data_sample.gt_text.item, gt_text.item)
|
||||
|
||||
# test pred_text
|
||||
pred_label_data = dict(item='mmocr')
|
||||
pred_text = LabelData(**pred_label_data)
|
||||
recog_data_sample.pred_text = pred_text
|
||||
assert 'pred_text' in recog_data_sample
|
||||
self.assertEqual(recog_data_sample.pred_text.item, pred_text.item)
|
||||
# test type error
|
||||
with self.assertRaises(AssertionError):
|
||||
recog_data_sample.gt_text = torch.rand(2, 4)
|
||||
with self.assertRaises(AssertionError):
|
||||
recog_data_sample.pred_text = torch.rand(2, 4)
|
||||
|
||||
def test_deleter(self):
|
||||
recog_data_sample = TextRecogDataSample()
|
||||
# test gt_text
|
||||
gt_label_data = dict(item='mmocr')
|
||||
gt_text = LabelData(**gt_label_data)
|
||||
recog_data_sample.gt_text = gt_text
|
||||
assert 'gt_text' in recog_data_sample
|
||||
del recog_data_sample.gt_text
|
||||
assert 'gt_text' not in recog_data_sample
|
||||
|
||||
recog_data_sample.pred_text = gt_text
|
||||
assert 'pred_text' in recog_data_sample
|
||||
del recog_data_sample.pred_text
|
||||
assert 'pred_text' not in recog_data_sample
|
Loading…
Reference in New Issue