mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Feature] TextRecogSample
This commit is contained in:
parent
c920edfb3a
commit
98bc90bd1c
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from . import evaluation
|
from . import evaluation
|
||||||
from .data_structures import TextDetDataSample
|
from .data_structures import TextDetDataSample, TextRecogDataSample
|
||||||
from .evaluation import * # NOQA
|
from .evaluation import * # NOQA
|
||||||
from .mask import extract_boundary, points2boundary, seg2boundary
|
from .mask import extract_boundary, points2boundary, seg2boundary
|
||||||
from .visualize import (det_recog_show_result, imshow_edge, imshow_node,
|
from .visualize import (det_recog_show_result, imshow_edge, imshow_node,
|
||||||
@ -12,6 +12,7 @@ __all__ = [
|
|||||||
'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img',
|
'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img',
|
||||||
'show_feature', 'show_img_boundary', 'show_pred_gt',
|
'show_feature', 'show_img_boundary', 'show_pred_gt',
|
||||||
'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label',
|
'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label',
|
||||||
'imshow_node', 'det_recog_show_result', 'imshow_edge', 'TextDetDataSample'
|
'imshow_node', 'det_recog_show_result', 'imshow_edge', 'TextDetDataSample',
|
||||||
|
'TextRecogDataSample'
|
||||||
]
|
]
|
||||||
__all__ += evaluation.__all__
|
__all__ += evaluation.__all__
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .textdet_data_sample import TextDetDataSample
|
from .textdet_data_sample import TextDetDataSample
|
||||||
|
from .textrecog_data_element import TextRecogDataSample
|
||||||
|
|
||||||
__all__ = ['TextDetDataSample']
|
__all__ = ['TextDetDataSample', 'TextRecogDataSample']
|
||||||
|
85
mmocr/core/data_structures/textrecog_data_element.py
Normal file
85
mmocr/core/data_structures/textrecog_data_element.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
# TODO
|
||||||
|
from mmengine.data import BaseDataElement
|
||||||
|
from mmengine.data import BaseDataElement as LabelData
|
||||||
|
|
||||||
|
|
||||||
|
# TODO score
|
||||||
|
class TextRecogDataSample(BaseDataElement):
|
||||||
|
"""A data structure interface of MMOCR for text recognition. They are used
|
||||||
|
as interfaces between different components.
|
||||||
|
|
||||||
|
The attributes in ``TextRecogDataSample`` are divided into two parts:
|
||||||
|
|
||||||
|
- ``gt_text``(LabelData): Ground truth text.
|
||||||
|
- ``pred_text``(LabelData): predictions text.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import torch
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> from mmengine.data import LabelData
|
||||||
|
>>> from mmocr.core import TextRecogDataSample
|
||||||
|
>>> # gt_text
|
||||||
|
>>> data_sample = TextRecogDataSample()
|
||||||
|
>>> img_meta = dict(img_shape=(800, 1196, 3),
|
||||||
|
... pad_shape=(800, 1216, 3))
|
||||||
|
>>> gt_text = LabelData(metainfo=img_meta)
|
||||||
|
>>> gt_text.item = 'mmocr'
|
||||||
|
>>> data_sample.gt_text = gt_text
|
||||||
|
>>> assert 'img_shape' in data_sample.gt_text.metainfo_keys()
|
||||||
|
>>> print(data_sample)
|
||||||
|
<TextRecogDataSample(
|
||||||
|
META INFORMATION
|
||||||
|
DATA FIELDS
|
||||||
|
gt_text: <LabelData(
|
||||||
|
META INFORMATION
|
||||||
|
pad_shape: (800, 1216, 3)
|
||||||
|
img_shape: (800, 1196, 3)
|
||||||
|
DATA FIELDS
|
||||||
|
item: 'mmocr'
|
||||||
|
) at 0x7f21fb1b9190>
|
||||||
|
) at 0x7f21fb1b9880>
|
||||||
|
>>> # pred_text
|
||||||
|
>>> pred_text = LabelData(metainfo=img_meta)
|
||||||
|
>>> pred_text.item = 'mmocr'
|
||||||
|
>>> data_sample = TextRecogDataSample(pred_text=pred_text)
|
||||||
|
>>> assert 'pred_text' in data_sample
|
||||||
|
>>> data_sample = TextRecogDataSample()
|
||||||
|
>>> gt_text_data = dict(item='mmocr')
|
||||||
|
>>> gt_text = LabelData(**gt_text_data)
|
||||||
|
>>> data_sample.gt_text = gt_text
|
||||||
|
>>> assert 'gt_text' in data_sample
|
||||||
|
>>> assert 'item' in data_sample.gt_text
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gt_text(self) -> LabelData:
|
||||||
|
"""LabelData: ground truth text.
|
||||||
|
"""
|
||||||
|
return self._gt_text
|
||||||
|
|
||||||
|
@gt_text.setter
|
||||||
|
def gt_text(self, value: LabelData) -> None:
|
||||||
|
"""gt_text setter."""
|
||||||
|
self.set_field(value, '_gt_text', dtype=LabelData)
|
||||||
|
|
||||||
|
@gt_text.deleter
|
||||||
|
def gt_text(self) -> None:
|
||||||
|
"""gt_text deleter."""
|
||||||
|
del self._gt_text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pred_text(self) -> LabelData:
|
||||||
|
"""LabelData: prediction text.
|
||||||
|
"""
|
||||||
|
return self._pred_text
|
||||||
|
|
||||||
|
@pred_text.setter
|
||||||
|
def pred_text(self, value: LabelData) -> None:
|
||||||
|
"""pred_text setter."""
|
||||||
|
self.set_field(value, '_pred_text', dtype=LabelData)
|
||||||
|
|
||||||
|
@pred_text.deleter
|
||||||
|
def pred_text(self) -> None:
|
||||||
|
"""pred_text deleter."""
|
||||||
|
del self._pred_text
|
@ -0,0 +1,60 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
# TODO
|
||||||
|
from mmengine.data import BaseDataElement as LabelData
|
||||||
|
|
||||||
|
from mmocr.core import TextRecogDataSample
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextRecogDataSample(TestCase):
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
meta_info = dict(
|
||||||
|
img_size=[256, 256],
|
||||||
|
scale_factor=np.array([1.5, 1.5]),
|
||||||
|
img_shape=torch.rand(4))
|
||||||
|
|
||||||
|
recog_data_sample = TextRecogDataSample(metainfo=meta_info)
|
||||||
|
assert 'img_size' in recog_data_sample
|
||||||
|
|
||||||
|
self.assertListEqual(recog_data_sample.img_size, [256, 256])
|
||||||
|
self.assertListEqual(recog_data_sample.get('img_size'), [256, 256])
|
||||||
|
|
||||||
|
def test_setter(self):
|
||||||
|
recog_data_sample = TextRecogDataSample()
|
||||||
|
# test gt_text
|
||||||
|
gt_label_data = dict(item='mmocr')
|
||||||
|
gt_text = LabelData(**gt_label_data)
|
||||||
|
recog_data_sample.gt_text = gt_text
|
||||||
|
assert 'gt_text' in recog_data_sample
|
||||||
|
self.assertEqual(recog_data_sample.gt_text.item, gt_text.item)
|
||||||
|
|
||||||
|
# test pred_text
|
||||||
|
pred_label_data = dict(item='mmocr')
|
||||||
|
pred_text = LabelData(**pred_label_data)
|
||||||
|
recog_data_sample.pred_text = pred_text
|
||||||
|
assert 'pred_text' in recog_data_sample
|
||||||
|
self.assertEqual(recog_data_sample.pred_text.item, pred_text.item)
|
||||||
|
# test type error
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
recog_data_sample.gt_text = torch.rand(2, 4)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
recog_data_sample.pred_text = torch.rand(2, 4)
|
||||||
|
|
||||||
|
def test_deleter(self):
|
||||||
|
recog_data_sample = TextRecogDataSample()
|
||||||
|
# test gt_text
|
||||||
|
gt_label_data = dict(item='mmocr')
|
||||||
|
gt_text = LabelData(**gt_label_data)
|
||||||
|
recog_data_sample.gt_text = gt_text
|
||||||
|
assert 'gt_text' in recog_data_sample
|
||||||
|
del recog_data_sample.gt_text
|
||||||
|
assert 'gt_text' not in recog_data_sample
|
||||||
|
|
||||||
|
recog_data_sample.pred_text = gt_text
|
||||||
|
assert 'pred_text' in recog_data_sample
|
||||||
|
del recog_data_sample.pred_text
|
||||||
|
assert 'pred_text' not in recog_data_sample
|
Loading…
x
Reference in New Issue
Block a user