Add MMDet2MMOCR MMOCR2MMdet

This commit is contained in:
jiangqing.vendor 2022-07-13 11:17:45 +00:00 committed by gaotongxiao
parent de616ffa02
commit dae4c9ca8c
3 changed files with 264 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .adapters import MMDet2MMOCR, MMOCR2MMDet
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
from .loading import LoadKIEAnnotations, LoadOCRAnnotations
from .ocr_transforms import RandomCrop, RandomRotate, Resize
@ -15,5 +16,6 @@ __all__ = [
'RandomCrop', 'TextDetRandomCrop', 'RandomCrop', 'PackTextDetInputs',
'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth',
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon'
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
'MMOCR2MMDet'
]

View File

@ -0,0 +1,122 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
from mmcv.transforms.base import BaseTransform
from mmdet.core import PolygonMasks
from mmdet.core.mask.structures import bitmap_to_polygon
from mmocr.registry import TRANSFORMS
@TRANSFORMS.register_module()
class MMDet2MMOCR(BaseTransform):
"""Convert transforms's data format from MMDet to MMOCR.
Required Keys:
- gt_masks (PolygonMasks | BitmapMasks) (optional)
- gt_ignore_flags (np.bool) (optional)
Added Keys:
- gt_polygons (list[np.ndarray])
- gt_ignored (np.ndarray)
"""
def transform(self, results: Dict) -> Dict:
"""Convert MMDet's data format to MMOCR's data format.
Args:
results (Dict): Result dict containing the data to transform.
Returns:
(Dict): The transformed data.
"""
# gt_masks -> gt_polygons
if 'gt_masks' in results.keys():
gt_polygons = []
gt_masks = results.pop('gt_masks')
if len(gt_masks) > 0:
# PolygonMasks
if isinstance(gt_masks[0], PolygonMasks):
gt_polygons = [mask[0] for mask in gt_masks.masks]
# BitmapMasks
else:
polygons = []
for mask in gt_masks.masks:
contours, _ = bitmap_to_polygon(mask)
polygons += [
contour.reshape(-1) for contour in contours
]
# filter invalid polygons
gt_polygons = []
for polygon in polygons:
if len(polygon) < 6:
continue
gt_polygons.append(polygon)
results['gt_polygons'] = gt_polygons
# gt_ignore_flags -> gt_ignored
if 'gt_ignore_flags' in results.keys():
gt_ignored = results.pop('gt_ignore_flags')
results['gt_ignored'] = gt_ignored
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str
@TRANSFORMS.register_module()
class MMOCR2MMDet(BaseTransform):
"""Convert transforms's data format from MMOCR to MMDet.
Required Keys:
- img_shape
- gt_polygons (List[ndarray]) (optional)
- gt_ignored (np.bool) (optional)
Added Keys:
- gt_masks (PolygonMasks | BitmapMasks) (optional)
- gt_ignore_flags (np.bool) (optional)
Args:
poly2mask (bool): Whether to convert mask to bitmap. Default: True.
"""
def __init__(self, poly2mask: bool = False) -> None:
self.poly2mask = poly2mask
def transform(self, results: Dict) -> Dict:
"""Convert MMOCR's data format to MMDet's data format.
Args:
results (Dict): Result dict containing the data to transform.
Returns:
(Dict): The transformed data.
"""
# gt_polygons -> gt_masks
if 'gt_polygons' in results.keys():
gt_polygons = results.pop('gt_polygons')
gt_polygons = [[gt_polygon] for gt_polygon in gt_polygons]
gt_masks = PolygonMasks(gt_polygons, *results['img_shape'])
if self.poly2mask:
gt_masks = gt_masks.to_bitmap()
results['gt_masks'] = gt_masks
# gt_ignore_flags -> gt_ignored
if 'gt_ignored' in results.keys():
gt_ignored = results.pop('gt_ignored')
results['gt_ignore_flags'] = gt_ignored
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(poly2mask = {self.poly2mask})'
return repr_str

View File

@ -0,0 +1,139 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import numpy as np
from mmdet.core import PolygonMasks
from mmdet.core.mask.structures import bitmap_to_polygon
from mmocr.datasets import MMDet2MMOCR, MMOCR2MMDet, Resize
from mmocr.utils import poly2shapely
class TestMMDet2MMOCR(unittest.TestCase):
def setUp(self):
img = np.zeros((15, 30, 3))
img_shape = (15, 30)
polygons = [
np.array([10., 5., 20., 5., 20., 10., 10., 10.]),
np.array([10., 5., 20., 5., 20., 10., 10., 10., 8., 7.])
]
ignores = np.array([True, False])
bboxes = np.array([[10., 5., 20., 10.], [0., 0., 10., 10.]])
self.data_info_ocr = dict(
img=img,
gt_polygons=polygons,
gt_bboxes=bboxes,
img_shape=img_shape,
gt_ignored=ignores)
_polygons = [[polygon] for polygon in polygons]
masks = PolygonMasks(_polygons, *img_shape)
self.data_info_det_polygon = dict(
img=img,
gt_masks=masks,
gt_bboxes=bboxes,
gt_ignore_flags=ignores,
img_shape=img_shape)
masks = masks.to_bitmap()
self.data_info_det_mask = dict(
img=img,
gt_masks=masks,
gt_bboxes=bboxes,
gt_ignore_flags=ignores,
img_shape=img_shape)
def test_ocr2det_polygonmasks(self):
transform = MMOCR2MMDet()
results = transform(self.data_info_ocr.copy())
self.assertEqual(results['img'].shape, (15, 30, 3))
self.assertEqual(results['img_shape'], (15, 30))
self.assertTrue(
np.allclose(results['gt_masks'].masks[0][0],
self.data_info_det_polygon['gt_masks'].masks[0][0]))
self.assertTrue(
np.allclose(results['gt_masks'].masks[0][0],
self.data_info_det_polygon['gt_masks'].masks[0][0]))
self.assertTrue(
np.allclose(results['gt_bboxes'],
self.data_info_det_polygon['gt_bboxes']))
self.assertTrue(
np.allclose(results['gt_ignore_flags'],
self.data_info_det_polygon['gt_ignore_flags']))
def test_ocr2det_bitmapmasks(self):
transform = MMOCR2MMDet(poly2mask=True)
results = transform(self.data_info_ocr.copy())
self.assertEqual(results['img'].shape, (15, 30, 3))
self.assertEqual(results['img_shape'], (15, 30))
self.assertTrue(
poly2shapely(
bitmap_to_polygon(
results['gt_masks'].masks[0])[0][0].flatten()).equals(
poly2shapely(
bitmap_to_polygon(
self.data_info_det_mask['gt_masks'].masks[0])
[0][0].flatten())))
self.assertTrue(
np.allclose(results['gt_bboxes'],
self.data_info_det_mask['gt_bboxes']))
self.assertTrue(
np.allclose(results['gt_ignore_flags'],
self.data_info_det_mask['gt_ignore_flags']))
def test_det2ocr_polygonmasks(self):
transform = MMDet2MMOCR()
results = transform(self.data_info_det_polygon.copy())
self.assertEqual(results['img'].shape, (15, 30, 3))
self.assertEqual(results['img_shape'], (15, 30))
self.assertTrue(
np.allclose(results['gt_polygons'][0],
self.data_info_ocr['gt_polygons'][0]))
self.assertTrue(
np.allclose(results['gt_polygons'][1],
self.data_info_ocr['gt_polygons'][1]))
self.assertTrue(
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
self.assertTrue(
np.allclose(results['gt_ignored'],
self.data_info_ocr['gt_ignored']))
def test_det2ocr_bitmapmasks(self):
transform = MMDet2MMOCR()
results = transform(self.data_info_det_mask.copy())
self.assertEqual(results['img'].shape, (15, 30, 3))
self.assertEqual(results['img_shape'], (15, 30))
self.assertTrue(
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
self.assertTrue(
np.allclose(results['gt_ignored'],
self.data_info_ocr['gt_ignored']))
def test_ocr2det2ocr(self):
from mmdet.datasets.pipelines import Resize as MMDet_Resize
t1 = MMOCR2MMDet()
t2 = MMDet_Resize(scale=(60, 60))
t3 = MMDet2MMOCR()
t4 = Resize(scale=(30, 15))
results = t4(t3(t2(t1(self.data_info_ocr.copy()))))
self.assertEqual(results['img'].shape, (15, 30, 3))
self.assertTrue(
np.allclose(results['gt_polygons'][0],
self.data_info_ocr['gt_polygons'][0]))
self.assertTrue(
np.allclose(results['gt_polygons'][1],
self.data_info_ocr['gt_polygons'][1]))
self.assertTrue(
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
self.assertEqual(results['gt_ignored'].all(),
self.data_info_ocr['gt_ignored'].all())
def test_repr_det2ocr(self):
transform = MMDet2MMOCR()
self.assertEqual(repr(transform), ('MMDet2MMOCR'))
def test_repr_ocr2det(self):
transform = MMOCR2MMDet(poly2mask=True)
self.assertEqual(repr(transform), ('MMOCR2MMDet(poly2mask = True)'))