Add MMDet2MMOCR MMOCR2MMdet

This commit is contained in:
jiangqing.vendor 2022-07-13 11:17:45 +00:00 committed by gaotongxiao
parent de616ffa02
commit dae4c9ca8c
3 changed files with 264 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .adapters import MMDet2MMOCR, MMOCR2MMDet
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
from .loading import LoadKIEAnnotations, LoadOCRAnnotations from .loading import LoadKIEAnnotations, LoadOCRAnnotations
from .ocr_transforms import RandomCrop, RandomRotate, Resize from .ocr_transforms import RandomCrop, RandomRotate, Resize
@ -15,5 +16,6 @@ __all__ = [
'RandomCrop', 'TextDetRandomCrop', 'RandomCrop', 'PackTextDetInputs', 'RandomCrop', 'TextDetRandomCrop', 'RandomCrop', 'PackTextDetInputs',
'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth', 'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth',
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter', 'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon' 'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
'MMOCR2MMDet'
] ]

View File

@ -0,0 +1,122 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
from mmcv.transforms.base import BaseTransform
from mmdet.core import PolygonMasks
from mmdet.core.mask.structures import bitmap_to_polygon
from mmocr.registry import TRANSFORMS
@TRANSFORMS.register_module()
class MMDet2MMOCR(BaseTransform):
"""Convert transforms's data format from MMDet to MMOCR.
Required Keys:
- gt_masks (PolygonMasks | BitmapMasks) (optional)
- gt_ignore_flags (np.bool) (optional)
Added Keys:
- gt_polygons (list[np.ndarray])
- gt_ignored (np.ndarray)
"""
def transform(self, results: Dict) -> Dict:
"""Convert MMDet's data format to MMOCR's data format.
Args:
results (Dict): Result dict containing the data to transform.
Returns:
(Dict): The transformed data.
"""
# gt_masks -> gt_polygons
if 'gt_masks' in results.keys():
gt_polygons = []
gt_masks = results.pop('gt_masks')
if len(gt_masks) > 0:
# PolygonMasks
if isinstance(gt_masks[0], PolygonMasks):
gt_polygons = [mask[0] for mask in gt_masks.masks]
# BitmapMasks
else:
polygons = []
for mask in gt_masks.masks:
contours, _ = bitmap_to_polygon(mask)
polygons += [
contour.reshape(-1) for contour in contours
]
# filter invalid polygons
gt_polygons = []
for polygon in polygons:
if len(polygon) < 6:
continue
gt_polygons.append(polygon)
results['gt_polygons'] = gt_polygons
# gt_ignore_flags -> gt_ignored
if 'gt_ignore_flags' in results.keys():
gt_ignored = results.pop('gt_ignore_flags')
results['gt_ignored'] = gt_ignored
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str
@TRANSFORMS.register_module()
class MMOCR2MMDet(BaseTransform):
"""Convert transforms's data format from MMOCR to MMDet.
Required Keys:
- img_shape
- gt_polygons (List[ndarray]) (optional)
- gt_ignored (np.bool) (optional)
Added Keys:
- gt_masks (PolygonMasks | BitmapMasks) (optional)
- gt_ignore_flags (np.bool) (optional)
Args:
poly2mask (bool): Whether to convert mask to bitmap. Default: True.
"""
def __init__(self, poly2mask: bool = False) -> None:
self.poly2mask = poly2mask
def transform(self, results: Dict) -> Dict:
"""Convert MMOCR's data format to MMDet's data format.
Args:
results (Dict): Result dict containing the data to transform.
Returns:
(Dict): The transformed data.
"""
# gt_polygons -> gt_masks
if 'gt_polygons' in results.keys():
gt_polygons = results.pop('gt_polygons')
gt_polygons = [[gt_polygon] for gt_polygon in gt_polygons]
gt_masks = PolygonMasks(gt_polygons, *results['img_shape'])
if self.poly2mask:
gt_masks = gt_masks.to_bitmap()
results['gt_masks'] = gt_masks
# gt_ignore_flags -> gt_ignored
if 'gt_ignored' in results.keys():
gt_ignored = results.pop('gt_ignored')
results['gt_ignore_flags'] = gt_ignored
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(poly2mask = {self.poly2mask})'
return repr_str

View File

@ -0,0 +1,139 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import numpy as np
from mmdet.core import PolygonMasks
from mmdet.core.mask.structures import bitmap_to_polygon
from mmocr.datasets import MMDet2MMOCR, MMOCR2MMDet, Resize
from mmocr.utils import poly2shapely
class TestMMDet2MMOCR(unittest.TestCase):
def setUp(self):
img = np.zeros((15, 30, 3))
img_shape = (15, 30)
polygons = [
np.array([10., 5., 20., 5., 20., 10., 10., 10.]),
np.array([10., 5., 20., 5., 20., 10., 10., 10., 8., 7.])
]
ignores = np.array([True, False])
bboxes = np.array([[10., 5., 20., 10.], [0., 0., 10., 10.]])
self.data_info_ocr = dict(
img=img,
gt_polygons=polygons,
gt_bboxes=bboxes,
img_shape=img_shape,
gt_ignored=ignores)
_polygons = [[polygon] for polygon in polygons]
masks = PolygonMasks(_polygons, *img_shape)
self.data_info_det_polygon = dict(
img=img,
gt_masks=masks,
gt_bboxes=bboxes,
gt_ignore_flags=ignores,
img_shape=img_shape)
masks = masks.to_bitmap()
self.data_info_det_mask = dict(
img=img,
gt_masks=masks,
gt_bboxes=bboxes,
gt_ignore_flags=ignores,
img_shape=img_shape)
def test_ocr2det_polygonmasks(self):
transform = MMOCR2MMDet()
results = transform(self.data_info_ocr.copy())
self.assertEqual(results['img'].shape, (15, 30, 3))
self.assertEqual(results['img_shape'], (15, 30))
self.assertTrue(
np.allclose(results['gt_masks'].masks[0][0],
self.data_info_det_polygon['gt_masks'].masks[0][0]))
self.assertTrue(
np.allclose(results['gt_masks'].masks[0][0],
self.data_info_det_polygon['gt_masks'].masks[0][0]))
self.assertTrue(
np.allclose(results['gt_bboxes'],
self.data_info_det_polygon['gt_bboxes']))
self.assertTrue(
np.allclose(results['gt_ignore_flags'],
self.data_info_det_polygon['gt_ignore_flags']))
def test_ocr2det_bitmapmasks(self):
transform = MMOCR2MMDet(poly2mask=True)
results = transform(self.data_info_ocr.copy())
self.assertEqual(results['img'].shape, (15, 30, 3))
self.assertEqual(results['img_shape'], (15, 30))
self.assertTrue(
poly2shapely(
bitmap_to_polygon(
results['gt_masks'].masks[0])[0][0].flatten()).equals(
poly2shapely(
bitmap_to_polygon(
self.data_info_det_mask['gt_masks'].masks[0])
[0][0].flatten())))
self.assertTrue(
np.allclose(results['gt_bboxes'],
self.data_info_det_mask['gt_bboxes']))
self.assertTrue(
np.allclose(results['gt_ignore_flags'],
self.data_info_det_mask['gt_ignore_flags']))
def test_det2ocr_polygonmasks(self):
transform = MMDet2MMOCR()
results = transform(self.data_info_det_polygon.copy())
self.assertEqual(results['img'].shape, (15, 30, 3))
self.assertEqual(results['img_shape'], (15, 30))
self.assertTrue(
np.allclose(results['gt_polygons'][0],
self.data_info_ocr['gt_polygons'][0]))
self.assertTrue(
np.allclose(results['gt_polygons'][1],
self.data_info_ocr['gt_polygons'][1]))
self.assertTrue(
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
self.assertTrue(
np.allclose(results['gt_ignored'],
self.data_info_ocr['gt_ignored']))
def test_det2ocr_bitmapmasks(self):
transform = MMDet2MMOCR()
results = transform(self.data_info_det_mask.copy())
self.assertEqual(results['img'].shape, (15, 30, 3))
self.assertEqual(results['img_shape'], (15, 30))
self.assertTrue(
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
self.assertTrue(
np.allclose(results['gt_ignored'],
self.data_info_ocr['gt_ignored']))
def test_ocr2det2ocr(self):
from mmdet.datasets.pipelines import Resize as MMDet_Resize
t1 = MMOCR2MMDet()
t2 = MMDet_Resize(scale=(60, 60))
t3 = MMDet2MMOCR()
t4 = Resize(scale=(30, 15))
results = t4(t3(t2(t1(self.data_info_ocr.copy()))))
self.assertEqual(results['img'].shape, (15, 30, 3))
self.assertTrue(
np.allclose(results['gt_polygons'][0],
self.data_info_ocr['gt_polygons'][0]))
self.assertTrue(
np.allclose(results['gt_polygons'][1],
self.data_info_ocr['gt_polygons'][1]))
self.assertTrue(
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
self.assertEqual(results['gt_ignored'].all(),
self.data_info_ocr['gt_ignored'].all())
def test_repr_det2ocr(self):
transform = MMDet2MMOCR()
self.assertEqual(repr(transform), ('MMDet2MMOCR'))
def test_repr_ocr2det(self):
transform = MMOCR2MMDet(poly2mask=True)
self.assertEqual(repr(transform), ('MMOCR2MMDet(poly2mask = True)'))