[Feature] TextDetSample

pull/1178/head
liukuikun 2022-05-12 11:55:43 +00:00 committed by gaotongxiao
parent f7cea9d40f
commit c47c5711c1
4 changed files with 184 additions and 1 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import evaluation
from .data_structures import TextDetDataSample
from .evaluation import * # NOQA
from .mask import extract_boundary, points2boundary, seg2boundary
from .visualize import (det_recog_show_result, imshow_edge, imshow_node,
@ -11,6 +12,6 @@ __all__ = [
'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img',
'show_feature', 'show_img_boundary', 'show_pred_gt',
'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label',
'imshow_node', 'det_recog_show_result', 'imshow_edge'
'imshow_node', 'det_recog_show_result', 'imshow_edge', 'TextDetDataSample'
]
__all__ += evaluation.__all__

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .textdet_data_sample import TextDetDataSample
__all__ = ['TextDetDataSample']

View File

@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.data import BaseDataElement, InstanceData
class TextDetDataSample(BaseDataElement):
"""A data structure interface of MMOCR. They are used as interfaces between
different components.
The attributes in ``TextDetDataSample`` are divided into two parts:
- ``gt_instances``(InstanceData): Ground truth of instance annotations.
- ``pred_instances``(InstanceData): Instances of model predictions.
Examples:
>>> import torch
>>> import numpy as np
>>> from mmengine.data import InstanceData
>>> from mmocr.core import TextDetDataSample
>>> # gt_instances
>>> data_sample = TextDetDataSample()
>>> img_meta = dict(img_shape=(800, 1196, 3),
... pad_shape=(800, 1216, 3))
>>> gt_instances = InstanceData(metainfo=img_meta)
>>> gt_instances.bboxes = torch.rand((5, 4))
>>> gt_instances.labels = torch.rand((5,))
>>> data_sample.gt_instances = gt_instances
>>> assert 'img_shape' in data_sample.gt_instances.metainfo_keys()
>>> len(data_sample.gt_instances)
5
>>> print(data_sample)
<TextDetDataSample(
META INFORMATION
DATA FIELDS
gt_instances: <InstanceData(
META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
DATA FIELDS
labels: tensor([0.8533, 0.1550, 0.5433, 0.7294, 0.5098])
bboxes:
tensor([[9.7725e-01, 5.8417e-01, 1.7269e-01, 6.5694e-01],
[1.7894e-01, 5.1780e-01, 7.0590e-01, 4.8589e-01],
[7.0392e-01, 6.6770e-01, 1.7520e-01, 1.4267e-01],
[2.2411e-01, 5.1962e-01, 9.6953e-01, 6.6994e-01],
[4.1338e-01, 2.1165e-01, 2.7239e-04, 6.8477e-01]])
) at 0x7f21fb1b9190>
) at 0x7f21fb1b9880>
>>> # pred_instances
>>> pred_instances = InstanceData(metainfo=img_meta)
>>> pred_instances.bboxes = torch.rand((5, 4))
>>> pred_instances.scores = torch.rand((5,))
>>> data_sample = TextDetDataSample(pred_instances=pred_instances)
>>> assert 'pred_instances' in data_sample
>>> data_sample = TextDetDataSample()
>>> gt_instances_data = dict(
... bboxes=torch.rand(2, 4),
... labels=torch.rand(2),
... masks=np.random.rand(2, 2, 2))
>>> gt_instances = InstanceData(**gt_instances_data)
>>> data_sample.gt_instances = gt_instances
>>> assert 'gt_instances' in data_sample
>>> assert 'masks' in data_sample.gt_instances
"""
@property
def gt_instances(self) -> InstanceData:
"""InstanceData: groundtruth instances."""
return self._gt_instances
@gt_instances.setter
def gt_instances(self, value: InstanceData):
"""gt_instances setter."""
self.set_field(value, '_gt_instances', dtype=InstanceData)
@gt_instances.deleter
def gt_instances(self):
"""gt_instances deleter."""
del self._gt_instances
@property
def pred_instances(self) -> InstanceData:
"""InstanceData: prediction instances."""
return self._pred_instances
@pred_instances.setter
def pred_instances(self, value: InstanceData):
"""pred_instances setter."""
self.set_field(value, '_pred_instances', dtype=InstanceData)
@pred_instances.deleter
def pred_instances(self):
"""pred_instances deleter."""
del self._pred_instances

View File

@ -0,0 +1,85 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.data import InstanceData
from mmocr.core import TextDetDataSample
class TestTextDetDataSample(TestCase):
def _equal(self, a, b):
if isinstance(a, (torch.Tensor, np.ndarray)):
return (a == b).all()
else:
return a == b
def test_init(self):
meta_info = dict(
img_size=[256, 256],
scale_factor=np.array([1.5, 1.5]),
img_shape=torch.rand(4))
det_data_sample = TextDetDataSample(metainfo=meta_info)
assert 'img_size' in det_data_sample
self.assertListEqual(det_data_sample.img_size, [256, 256])
self.assertListEqual(det_data_sample.get('img_size'), [256, 256])
def test_setter(self):
det_data_sample = TextDetDataSample()
# test gt_instances
gt_instances_data = dict(
bboxes=torch.rand(4, 4),
labels=torch.rand(4),
masks=np.random.rand(4, 2, 2))
gt_instances = InstanceData(**gt_instances_data)
det_data_sample.gt_instances = gt_instances
assert 'gt_instances' in det_data_sample
assert self._equal(det_data_sample.gt_instances.bboxes,
gt_instances_data['bboxes'])
assert self._equal(det_data_sample.gt_instances.labels,
gt_instances_data['labels'])
assert self._equal(det_data_sample.gt_instances.masks,
gt_instances_data['masks'])
# test pred_instances
pred_instances_data = dict(
bboxes=torch.rand(2, 4),
labels=torch.rand(2),
masks=np.random.rand(2, 2, 2))
pred_instances = InstanceData(**pred_instances_data)
det_data_sample.pred_instances = pred_instances
assert 'pred_instances' in det_data_sample
assert self._equal(det_data_sample.pred_instances.bboxes,
pred_instances_data['bboxes'])
assert self._equal(det_data_sample.pred_instances.labels,
pred_instances_data['labels'])
assert self._equal(det_data_sample.pred_instances.masks,
pred_instances_data['masks'])
# test type error
with self.assertRaises(AssertionError):
det_data_sample.gt_instances = torch.rand(2, 4)
with self.assertRaises(AssertionError):
det_data_sample.pred_instances = torch.rand(2, 4)
def test_deleter(self):
gt_instances_data = dict(
bboxes=torch.rand(4, 4),
labels=torch.rand(4),
masks=np.random.rand(4, 2, 2))
det_data_sample = TextDetDataSample()
gt_instances = InstanceData(data=gt_instances_data)
det_data_sample.gt_instances = gt_instances
assert 'gt_instances' in det_data_sample
del det_data_sample.gt_instances
assert 'gt_instances' not in det_data_sample
det_data_sample.pred_instances = gt_instances
assert 'pred_instances' in det_data_sample
del det_data_sample.pred_instances
assert 'pred_instances' not in det_data_sample