From c47c5711c1943eb0d4112bd781fbea2d34d6cedf Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Thu, 12 May 2022 11:55:43 +0000 Subject: [PATCH] [Feature] TextDetSample --- mmocr/core/__init__.py | 3 +- mmocr/core/data_structures/__init__.py | 4 + .../data_structures/textdet_data_sample.py | 93 +++++++++++++++++++ .../test_textdet_data_sample.py | 85 +++++++++++++++++ 4 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 mmocr/core/data_structures/__init__.py create mode 100644 mmocr/core/data_structures/textdet_data_sample.py create mode 100644 tests/test_core/test_data_structures/test_textdet_data_sample.py diff --git a/mmocr/core/__init__.py b/mmocr/core/__init__.py index beae1ba4..ab3b67e4 100644 --- a/mmocr/core/__init__.py +++ b/mmocr/core/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from . import evaluation +from .data_structures import TextDetDataSample from .evaluation import * # NOQA from .mask import extract_boundary, points2boundary, seg2boundary from .visualize import (det_recog_show_result, imshow_edge, imshow_node, @@ -11,6 +12,6 @@ __all__ = [ 'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img', 'show_feature', 'show_img_boundary', 'show_pred_gt', 'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label', - 'imshow_node', 'det_recog_show_result', 'imshow_edge' + 'imshow_node', 'det_recog_show_result', 'imshow_edge', 'TextDetDataSample' ] __all__ += evaluation.__all__ diff --git a/mmocr/core/data_structures/__init__.py b/mmocr/core/data_structures/__init__.py new file mode 100644 index 00000000..019154ae --- /dev/null +++ b/mmocr/core/data_structures/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .textdet_data_sample import TextDetDataSample + +__all__ = ['TextDetDataSample'] diff --git a/mmocr/core/data_structures/textdet_data_sample.py b/mmocr/core/data_structures/textdet_data_sample.py new file mode 100644 index 00000000..b99a8330 --- /dev/null +++ b/mmocr/core/data_structures/textdet_data_sample.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.data import BaseDataElement, InstanceData + + +class TextDetDataSample(BaseDataElement): + """A data structure interface of MMOCR. They are used as interfaces between + different components. + + The attributes in ``TextDetDataSample`` are divided into two parts: + + - ``gt_instances``(InstanceData): Ground truth of instance annotations. + - ``pred_instances``(InstanceData): Instances of model predictions. + + Examples: + >>> import torch + >>> import numpy as np + >>> from mmengine.data import InstanceData + >>> from mmocr.core import TextDetDataSample + >>> # gt_instances + >>> data_sample = TextDetDataSample() + >>> img_meta = dict(img_shape=(800, 1196, 3), + ... pad_shape=(800, 1216, 3)) + >>> gt_instances = InstanceData(metainfo=img_meta) + >>> gt_instances.bboxes = torch.rand((5, 4)) + >>> gt_instances.labels = torch.rand((5,)) + >>> data_sample.gt_instances = gt_instances + >>> assert 'img_shape' in data_sample.gt_instances.metainfo_keys() + >>> len(data_sample.gt_instances) + 5 + >>> print(data_sample) + + ) at 0x7f21fb1b9880> + >>> # pred_instances + >>> pred_instances = InstanceData(metainfo=img_meta) + >>> pred_instances.bboxes = torch.rand((5, 4)) + >>> pred_instances.scores = torch.rand((5,)) + >>> data_sample = TextDetDataSample(pred_instances=pred_instances) + >>> assert 'pred_instances' in data_sample + >>> data_sample = TextDetDataSample() + >>> gt_instances_data = dict( + ... bboxes=torch.rand(2, 4), + ... labels=torch.rand(2), + ... masks=np.random.rand(2, 2, 2)) + >>> gt_instances = InstanceData(**gt_instances_data) + >>> data_sample.gt_instances = gt_instances + >>> assert 'gt_instances' in data_sample + >>> assert 'masks' in data_sample.gt_instances + """ + + @property + def gt_instances(self) -> InstanceData: + """InstanceData: groundtruth instances.""" + return self._gt_instances + + @gt_instances.setter + def gt_instances(self, value: InstanceData): + """gt_instances setter.""" + self.set_field(value, '_gt_instances', dtype=InstanceData) + + @gt_instances.deleter + def gt_instances(self): + """gt_instances deleter.""" + del self._gt_instances + + @property + def pred_instances(self) -> InstanceData: + """InstanceData: prediction instances.""" + return self._pred_instances + + @pred_instances.setter + def pred_instances(self, value: InstanceData): + """pred_instances setter.""" + self.set_field(value, '_pred_instances', dtype=InstanceData) + + @pred_instances.deleter + def pred_instances(self): + """pred_instances deleter.""" + del self._pred_instances diff --git a/tests/test_core/test_data_structures/test_textdet_data_sample.py b/tests/test_core/test_data_structures/test_textdet_data_sample.py new file mode 100644 index 00000000..22fe0393 --- /dev/null +++ b/tests/test_core/test_data_structures/test_textdet_data_sample.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +import torch +from mmengine.data import InstanceData + +from mmocr.core import TextDetDataSample + + +class TestTextDetDataSample(TestCase): + + def _equal(self, a, b): + if isinstance(a, (torch.Tensor, np.ndarray)): + return (a == b).all() + else: + return a == b + + def test_init(self): + meta_info = dict( + img_size=[256, 256], + scale_factor=np.array([1.5, 1.5]), + img_shape=torch.rand(4)) + + det_data_sample = TextDetDataSample(metainfo=meta_info) + assert 'img_size' in det_data_sample + + self.assertListEqual(det_data_sample.img_size, [256, 256]) + self.assertListEqual(det_data_sample.get('img_size'), [256, 256]) + + def test_setter(self): + det_data_sample = TextDetDataSample() + # test gt_instances + gt_instances_data = dict( + bboxes=torch.rand(4, 4), + labels=torch.rand(4), + masks=np.random.rand(4, 2, 2)) + gt_instances = InstanceData(**gt_instances_data) + det_data_sample.gt_instances = gt_instances + assert 'gt_instances' in det_data_sample + assert self._equal(det_data_sample.gt_instances.bboxes, + gt_instances_data['bboxes']) + assert self._equal(det_data_sample.gt_instances.labels, + gt_instances_data['labels']) + assert self._equal(det_data_sample.gt_instances.masks, + gt_instances_data['masks']) + + # test pred_instances + pred_instances_data = dict( + bboxes=torch.rand(2, 4), + labels=torch.rand(2), + masks=np.random.rand(2, 2, 2)) + pred_instances = InstanceData(**pred_instances_data) + det_data_sample.pred_instances = pred_instances + assert 'pred_instances' in det_data_sample + assert self._equal(det_data_sample.pred_instances.bboxes, + pred_instances_data['bboxes']) + assert self._equal(det_data_sample.pred_instances.labels, + pred_instances_data['labels']) + assert self._equal(det_data_sample.pred_instances.masks, + pred_instances_data['masks']) + + # test type error + with self.assertRaises(AssertionError): + det_data_sample.gt_instances = torch.rand(2, 4) + with self.assertRaises(AssertionError): + det_data_sample.pred_instances = torch.rand(2, 4) + + def test_deleter(self): + gt_instances_data = dict( + bboxes=torch.rand(4, 4), + labels=torch.rand(4), + masks=np.random.rand(4, 2, 2)) + + det_data_sample = TextDetDataSample() + gt_instances = InstanceData(data=gt_instances_data) + det_data_sample.gt_instances = gt_instances + assert 'gt_instances' in det_data_sample + del det_data_sample.gt_instances + assert 'gt_instances' not in det_data_sample + + det_data_sample.pred_instances = gt_instances + assert 'pred_instances' in det_data_sample + del det_data_sample.pred_instances + assert 'pred_instances' not in det_data_sample