mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
Add MMDet2MMOCR MMOCR2MMdet
This commit is contained in:
parent
de616ffa02
commit
dae4c9ca8c
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .adapters import MMDet2MMOCR, MMOCR2MMDet
|
||||||
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
|
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
|
||||||
from .loading import LoadKIEAnnotations, LoadOCRAnnotations
|
from .loading import LoadKIEAnnotations, LoadOCRAnnotations
|
||||||
from .ocr_transforms import RandomCrop, RandomRotate, Resize
|
from .ocr_transforms import RandomCrop, RandomRotate, Resize
|
||||||
@ -15,5 +16,6 @@ __all__ = [
|
|||||||
'RandomCrop', 'TextDetRandomCrop', 'RandomCrop', 'PackTextDetInputs',
|
'RandomCrop', 'TextDetRandomCrop', 'RandomCrop', 'PackTextDetInputs',
|
||||||
'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth',
|
'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth',
|
||||||
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
|
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
|
||||||
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon'
|
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
|
||||||
|
'MMOCR2MMDet'
|
||||||
]
|
]
|
||||||
|
122
mmocr/datasets/transforms/adapters.py
Normal file
122
mmocr/datasets/transforms/adapters.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from mmcv.transforms.base import BaseTransform
|
||||||
|
from mmdet.core import PolygonMasks
|
||||||
|
from mmdet.core.mask.structures import bitmap_to_polygon
|
||||||
|
|
||||||
|
from mmocr.registry import TRANSFORMS
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class MMDet2MMOCR(BaseTransform):
|
||||||
|
"""Convert transforms's data format from MMDet to MMOCR.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- gt_masks (PolygonMasks | BitmapMasks) (optional)
|
||||||
|
- gt_ignore_flags (np.bool) (optional)
|
||||||
|
|
||||||
|
Added Keys:
|
||||||
|
|
||||||
|
- gt_polygons (list[np.ndarray])
|
||||||
|
- gt_ignored (np.ndarray)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform(self, results: Dict) -> Dict:
|
||||||
|
"""Convert MMDet's data format to MMOCR's data format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (Dict): Result dict containing the data to transform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Dict): The transformed data.
|
||||||
|
"""
|
||||||
|
# gt_masks -> gt_polygons
|
||||||
|
if 'gt_masks' in results.keys():
|
||||||
|
gt_polygons = []
|
||||||
|
gt_masks = results.pop('gt_masks')
|
||||||
|
if len(gt_masks) > 0:
|
||||||
|
# PolygonMasks
|
||||||
|
if isinstance(gt_masks[0], PolygonMasks):
|
||||||
|
gt_polygons = [mask[0] for mask in gt_masks.masks]
|
||||||
|
# BitmapMasks
|
||||||
|
else:
|
||||||
|
polygons = []
|
||||||
|
for mask in gt_masks.masks:
|
||||||
|
contours, _ = bitmap_to_polygon(mask)
|
||||||
|
polygons += [
|
||||||
|
contour.reshape(-1) for contour in contours
|
||||||
|
]
|
||||||
|
# filter invalid polygons
|
||||||
|
gt_polygons = []
|
||||||
|
for polygon in polygons:
|
||||||
|
if len(polygon) < 6:
|
||||||
|
continue
|
||||||
|
gt_polygons.append(polygon)
|
||||||
|
|
||||||
|
results['gt_polygons'] = gt_polygons
|
||||||
|
# gt_ignore_flags -> gt_ignored
|
||||||
|
if 'gt_ignore_flags' in results.keys():
|
||||||
|
gt_ignored = results.pop('gt_ignore_flags')
|
||||||
|
results['gt_ignored'] = gt_ignored
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class MMOCR2MMDet(BaseTransform):
|
||||||
|
"""Convert transforms's data format from MMOCR to MMDet.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img_shape
|
||||||
|
- gt_polygons (List[ndarray]) (optional)
|
||||||
|
- gt_ignored (np.bool) (optional)
|
||||||
|
|
||||||
|
Added Keys:
|
||||||
|
|
||||||
|
- gt_masks (PolygonMasks | BitmapMasks) (optional)
|
||||||
|
- gt_ignore_flags (np.bool) (optional)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
poly2mask (bool): Whether to convert mask to bitmap. Default: True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, poly2mask: bool = False) -> None:
|
||||||
|
self.poly2mask = poly2mask
|
||||||
|
|
||||||
|
def transform(self, results: Dict) -> Dict:
|
||||||
|
"""Convert MMOCR's data format to MMDet's data format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (Dict): Result dict containing the data to transform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Dict): The transformed data.
|
||||||
|
"""
|
||||||
|
# gt_polygons -> gt_masks
|
||||||
|
if 'gt_polygons' in results.keys():
|
||||||
|
gt_polygons = results.pop('gt_polygons')
|
||||||
|
gt_polygons = [[gt_polygon] for gt_polygon in gt_polygons]
|
||||||
|
gt_masks = PolygonMasks(gt_polygons, *results['img_shape'])
|
||||||
|
|
||||||
|
if self.poly2mask:
|
||||||
|
gt_masks = gt_masks.to_bitmap()
|
||||||
|
|
||||||
|
results['gt_masks'] = gt_masks
|
||||||
|
# gt_ignore_flags -> gt_ignored
|
||||||
|
if 'gt_ignored' in results.keys():
|
||||||
|
gt_ignored = results.pop('gt_ignored')
|
||||||
|
results['gt_ignore_flags'] = gt_ignored
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(poly2mask = {self.poly2mask})'
|
||||||
|
return repr_str
|
139
tests/test_datasets/test_transforms/test_adapters.py
Normal file
139
tests/test_datasets/test_transforms/test_adapters.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mmdet.core import PolygonMasks
|
||||||
|
from mmdet.core.mask.structures import bitmap_to_polygon
|
||||||
|
|
||||||
|
from mmocr.datasets import MMDet2MMOCR, MMOCR2MMDet, Resize
|
||||||
|
from mmocr.utils import poly2shapely
|
||||||
|
|
||||||
|
|
||||||
|
class TestMMDet2MMOCR(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
img = np.zeros((15, 30, 3))
|
||||||
|
img_shape = (15, 30)
|
||||||
|
polygons = [
|
||||||
|
np.array([10., 5., 20., 5., 20., 10., 10., 10.]),
|
||||||
|
np.array([10., 5., 20., 5., 20., 10., 10., 10., 8., 7.])
|
||||||
|
]
|
||||||
|
ignores = np.array([True, False])
|
||||||
|
bboxes = np.array([[10., 5., 20., 10.], [0., 0., 10., 10.]])
|
||||||
|
self.data_info_ocr = dict(
|
||||||
|
img=img,
|
||||||
|
gt_polygons=polygons,
|
||||||
|
gt_bboxes=bboxes,
|
||||||
|
img_shape=img_shape,
|
||||||
|
gt_ignored=ignores)
|
||||||
|
|
||||||
|
_polygons = [[polygon] for polygon in polygons]
|
||||||
|
masks = PolygonMasks(_polygons, *img_shape)
|
||||||
|
self.data_info_det_polygon = dict(
|
||||||
|
img=img,
|
||||||
|
gt_masks=masks,
|
||||||
|
gt_bboxes=bboxes,
|
||||||
|
gt_ignore_flags=ignores,
|
||||||
|
img_shape=img_shape)
|
||||||
|
|
||||||
|
masks = masks.to_bitmap()
|
||||||
|
self.data_info_det_mask = dict(
|
||||||
|
img=img,
|
||||||
|
gt_masks=masks,
|
||||||
|
gt_bboxes=bboxes,
|
||||||
|
gt_ignore_flags=ignores,
|
||||||
|
img_shape=img_shape)
|
||||||
|
|
||||||
|
def test_ocr2det_polygonmasks(self):
|
||||||
|
transform = MMOCR2MMDet()
|
||||||
|
results = transform(self.data_info_ocr.copy())
|
||||||
|
self.assertEqual(results['img'].shape, (15, 30, 3))
|
||||||
|
self.assertEqual(results['img_shape'], (15, 30))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_masks'].masks[0][0],
|
||||||
|
self.data_info_det_polygon['gt_masks'].masks[0][0]))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_masks'].masks[0][0],
|
||||||
|
self.data_info_det_polygon['gt_masks'].masks[0][0]))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_bboxes'],
|
||||||
|
self.data_info_det_polygon['gt_bboxes']))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_ignore_flags'],
|
||||||
|
self.data_info_det_polygon['gt_ignore_flags']))
|
||||||
|
|
||||||
|
def test_ocr2det_bitmapmasks(self):
|
||||||
|
transform = MMOCR2MMDet(poly2mask=True)
|
||||||
|
results = transform(self.data_info_ocr.copy())
|
||||||
|
self.assertEqual(results['img'].shape, (15, 30, 3))
|
||||||
|
self.assertEqual(results['img_shape'], (15, 30))
|
||||||
|
self.assertTrue(
|
||||||
|
poly2shapely(
|
||||||
|
bitmap_to_polygon(
|
||||||
|
results['gt_masks'].masks[0])[0][0].flatten()).equals(
|
||||||
|
poly2shapely(
|
||||||
|
bitmap_to_polygon(
|
||||||
|
self.data_info_det_mask['gt_masks'].masks[0])
|
||||||
|
[0][0].flatten())))
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_bboxes'],
|
||||||
|
self.data_info_det_mask['gt_bboxes']))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_ignore_flags'],
|
||||||
|
self.data_info_det_mask['gt_ignore_flags']))
|
||||||
|
|
||||||
|
def test_det2ocr_polygonmasks(self):
|
||||||
|
transform = MMDet2MMOCR()
|
||||||
|
results = transform(self.data_info_det_polygon.copy())
|
||||||
|
self.assertEqual(results['img'].shape, (15, 30, 3))
|
||||||
|
self.assertEqual(results['img_shape'], (15, 30))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_polygons'][0],
|
||||||
|
self.data_info_ocr['gt_polygons'][0]))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_polygons'][1],
|
||||||
|
self.data_info_ocr['gt_polygons'][1]))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_ignored'],
|
||||||
|
self.data_info_ocr['gt_ignored']))
|
||||||
|
|
||||||
|
def test_det2ocr_bitmapmasks(self):
|
||||||
|
transform = MMDet2MMOCR()
|
||||||
|
results = transform(self.data_info_det_mask.copy())
|
||||||
|
self.assertEqual(results['img'].shape, (15, 30, 3))
|
||||||
|
self.assertEqual(results['img_shape'], (15, 30))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_ignored'],
|
||||||
|
self.data_info_ocr['gt_ignored']))
|
||||||
|
|
||||||
|
def test_ocr2det2ocr(self):
|
||||||
|
from mmdet.datasets.pipelines import Resize as MMDet_Resize
|
||||||
|
t1 = MMOCR2MMDet()
|
||||||
|
t2 = MMDet_Resize(scale=(60, 60))
|
||||||
|
t3 = MMDet2MMOCR()
|
||||||
|
t4 = Resize(scale=(30, 15))
|
||||||
|
results = t4(t3(t2(t1(self.data_info_ocr.copy()))))
|
||||||
|
self.assertEqual(results['img'].shape, (15, 30, 3))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_polygons'][0],
|
||||||
|
self.data_info_ocr['gt_polygons'][0]))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_polygons'][1],
|
||||||
|
self.data_info_ocr['gt_polygons'][1]))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
|
||||||
|
self.assertEqual(results['gt_ignored'].all(),
|
||||||
|
self.data_info_ocr['gt_ignored'].all())
|
||||||
|
|
||||||
|
def test_repr_det2ocr(self):
|
||||||
|
transform = MMDet2MMOCR()
|
||||||
|
self.assertEqual(repr(transform), ('MMDet2MMOCR'))
|
||||||
|
|
||||||
|
def test_repr_ocr2det(self):
|
||||||
|
transform = MMOCR2MMDet(poly2mask=True)
|
||||||
|
self.assertEqual(repr(transform), ('MMOCR2MMDet(poly2mask = True)'))
|
Loading…
x
Reference in New Issue
Block a user