mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
Add MMDet2MMOCR MMOCR2MMdet
This commit is contained in:
parent
de616ffa02
commit
dae4c9ca8c
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .adapters import MMDet2MMOCR, MMOCR2MMDet
|
||||
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
|
||||
from .loading import LoadKIEAnnotations, LoadOCRAnnotations
|
||||
from .ocr_transforms import RandomCrop, RandomRotate, Resize
|
||||
@ -15,5 +16,6 @@ __all__ = [
|
||||
'RandomCrop', 'TextDetRandomCrop', 'RandomCrop', 'PackTextDetInputs',
|
||||
'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth',
|
||||
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
|
||||
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon'
|
||||
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
|
||||
'MMOCR2MMDet'
|
||||
]
|
||||
|
122
mmocr/datasets/transforms/adapters.py
Normal file
122
mmocr/datasets/transforms/adapters.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict
|
||||
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmdet.core import PolygonMasks
|
||||
from mmdet.core.mask.structures import bitmap_to_polygon
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class MMDet2MMOCR(BaseTransform):
|
||||
"""Convert transforms's data format from MMDet to MMOCR.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- gt_masks (PolygonMasks | BitmapMasks) (optional)
|
||||
- gt_ignore_flags (np.bool) (optional)
|
||||
|
||||
Added Keys:
|
||||
|
||||
- gt_polygons (list[np.ndarray])
|
||||
- gt_ignored (np.ndarray)
|
||||
"""
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Convert MMDet's data format to MMOCR's data format.
|
||||
|
||||
Args:
|
||||
results (Dict): Result dict containing the data to transform.
|
||||
|
||||
Returns:
|
||||
(Dict): The transformed data.
|
||||
"""
|
||||
# gt_masks -> gt_polygons
|
||||
if 'gt_masks' in results.keys():
|
||||
gt_polygons = []
|
||||
gt_masks = results.pop('gt_masks')
|
||||
if len(gt_masks) > 0:
|
||||
# PolygonMasks
|
||||
if isinstance(gt_masks[0], PolygonMasks):
|
||||
gt_polygons = [mask[0] for mask in gt_masks.masks]
|
||||
# BitmapMasks
|
||||
else:
|
||||
polygons = []
|
||||
for mask in gt_masks.masks:
|
||||
contours, _ = bitmap_to_polygon(mask)
|
||||
polygons += [
|
||||
contour.reshape(-1) for contour in contours
|
||||
]
|
||||
# filter invalid polygons
|
||||
gt_polygons = []
|
||||
for polygon in polygons:
|
||||
if len(polygon) < 6:
|
||||
continue
|
||||
gt_polygons.append(polygon)
|
||||
|
||||
results['gt_polygons'] = gt_polygons
|
||||
# gt_ignore_flags -> gt_ignored
|
||||
if 'gt_ignore_flags' in results.keys():
|
||||
gt_ignored = results.pop('gt_ignore_flags')
|
||||
results['gt_ignored'] = gt_ignored
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class MMOCR2MMDet(BaseTransform):
|
||||
"""Convert transforms's data format from MMOCR to MMDet.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_shape
|
||||
- gt_polygons (List[ndarray]) (optional)
|
||||
- gt_ignored (np.bool) (optional)
|
||||
|
||||
Added Keys:
|
||||
|
||||
- gt_masks (PolygonMasks | BitmapMasks) (optional)
|
||||
- gt_ignore_flags (np.bool) (optional)
|
||||
|
||||
Args:
|
||||
poly2mask (bool): Whether to convert mask to bitmap. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, poly2mask: bool = False) -> None:
|
||||
self.poly2mask = poly2mask
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Convert MMOCR's data format to MMDet's data format.
|
||||
|
||||
Args:
|
||||
results (Dict): Result dict containing the data to transform.
|
||||
|
||||
Returns:
|
||||
(Dict): The transformed data.
|
||||
"""
|
||||
# gt_polygons -> gt_masks
|
||||
if 'gt_polygons' in results.keys():
|
||||
gt_polygons = results.pop('gt_polygons')
|
||||
gt_polygons = [[gt_polygon] for gt_polygon in gt_polygons]
|
||||
gt_masks = PolygonMasks(gt_polygons, *results['img_shape'])
|
||||
|
||||
if self.poly2mask:
|
||||
gt_masks = gt_masks.to_bitmap()
|
||||
|
||||
results['gt_masks'] = gt_masks
|
||||
# gt_ignore_flags -> gt_ignored
|
||||
if 'gt_ignored' in results.keys():
|
||||
gt_ignored = results.pop('gt_ignored')
|
||||
results['gt_ignore_flags'] = gt_ignored
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(poly2mask = {self.poly2mask})'
|
||||
return repr_str
|
139
tests/test_datasets/test_transforms/test_adapters.py
Normal file
139
tests/test_datasets/test_transforms/test_adapters.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from mmdet.core import PolygonMasks
|
||||
from mmdet.core.mask.structures import bitmap_to_polygon
|
||||
|
||||
from mmocr.datasets import MMDet2MMOCR, MMOCR2MMDet, Resize
|
||||
from mmocr.utils import poly2shapely
|
||||
|
||||
|
||||
class TestMMDet2MMOCR(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
img = np.zeros((15, 30, 3))
|
||||
img_shape = (15, 30)
|
||||
polygons = [
|
||||
np.array([10., 5., 20., 5., 20., 10., 10., 10.]),
|
||||
np.array([10., 5., 20., 5., 20., 10., 10., 10., 8., 7.])
|
||||
]
|
||||
ignores = np.array([True, False])
|
||||
bboxes = np.array([[10., 5., 20., 10.], [0., 0., 10., 10.]])
|
||||
self.data_info_ocr = dict(
|
||||
img=img,
|
||||
gt_polygons=polygons,
|
||||
gt_bboxes=bboxes,
|
||||
img_shape=img_shape,
|
||||
gt_ignored=ignores)
|
||||
|
||||
_polygons = [[polygon] for polygon in polygons]
|
||||
masks = PolygonMasks(_polygons, *img_shape)
|
||||
self.data_info_det_polygon = dict(
|
||||
img=img,
|
||||
gt_masks=masks,
|
||||
gt_bboxes=bboxes,
|
||||
gt_ignore_flags=ignores,
|
||||
img_shape=img_shape)
|
||||
|
||||
masks = masks.to_bitmap()
|
||||
self.data_info_det_mask = dict(
|
||||
img=img,
|
||||
gt_masks=masks,
|
||||
gt_bboxes=bboxes,
|
||||
gt_ignore_flags=ignores,
|
||||
img_shape=img_shape)
|
||||
|
||||
def test_ocr2det_polygonmasks(self):
|
||||
transform = MMOCR2MMDet()
|
||||
results = transform(self.data_info_ocr.copy())
|
||||
self.assertEqual(results['img'].shape, (15, 30, 3))
|
||||
self.assertEqual(results['img_shape'], (15, 30))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_masks'].masks[0][0],
|
||||
self.data_info_det_polygon['gt_masks'].masks[0][0]))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_masks'].masks[0][0],
|
||||
self.data_info_det_polygon['gt_masks'].masks[0][0]))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_bboxes'],
|
||||
self.data_info_det_polygon['gt_bboxes']))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_ignore_flags'],
|
||||
self.data_info_det_polygon['gt_ignore_flags']))
|
||||
|
||||
def test_ocr2det_bitmapmasks(self):
|
||||
transform = MMOCR2MMDet(poly2mask=True)
|
||||
results = transform(self.data_info_ocr.copy())
|
||||
self.assertEqual(results['img'].shape, (15, 30, 3))
|
||||
self.assertEqual(results['img_shape'], (15, 30))
|
||||
self.assertTrue(
|
||||
poly2shapely(
|
||||
bitmap_to_polygon(
|
||||
results['gt_masks'].masks[0])[0][0].flatten()).equals(
|
||||
poly2shapely(
|
||||
bitmap_to_polygon(
|
||||
self.data_info_det_mask['gt_masks'].masks[0])
|
||||
[0][0].flatten())))
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_bboxes'],
|
||||
self.data_info_det_mask['gt_bboxes']))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_ignore_flags'],
|
||||
self.data_info_det_mask['gt_ignore_flags']))
|
||||
|
||||
def test_det2ocr_polygonmasks(self):
|
||||
transform = MMDet2MMOCR()
|
||||
results = transform(self.data_info_det_polygon.copy())
|
||||
self.assertEqual(results['img'].shape, (15, 30, 3))
|
||||
self.assertEqual(results['img_shape'], (15, 30))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_polygons'][0],
|
||||
self.data_info_ocr['gt_polygons'][0]))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_polygons'][1],
|
||||
self.data_info_ocr['gt_polygons'][1]))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_ignored'],
|
||||
self.data_info_ocr['gt_ignored']))
|
||||
|
||||
def test_det2ocr_bitmapmasks(self):
|
||||
transform = MMDet2MMOCR()
|
||||
results = transform(self.data_info_det_mask.copy())
|
||||
self.assertEqual(results['img'].shape, (15, 30, 3))
|
||||
self.assertEqual(results['img_shape'], (15, 30))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_ignored'],
|
||||
self.data_info_ocr['gt_ignored']))
|
||||
|
||||
def test_ocr2det2ocr(self):
|
||||
from mmdet.datasets.pipelines import Resize as MMDet_Resize
|
||||
t1 = MMOCR2MMDet()
|
||||
t2 = MMDet_Resize(scale=(60, 60))
|
||||
t3 = MMDet2MMOCR()
|
||||
t4 = Resize(scale=(30, 15))
|
||||
results = t4(t3(t2(t1(self.data_info_ocr.copy()))))
|
||||
self.assertEqual(results['img'].shape, (15, 30, 3))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_polygons'][0],
|
||||
self.data_info_ocr['gt_polygons'][0]))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_polygons'][1],
|
||||
self.data_info_ocr['gt_polygons'][1]))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
|
||||
self.assertEqual(results['gt_ignored'].all(),
|
||||
self.data_info_ocr['gt_ignored'].all())
|
||||
|
||||
def test_repr_det2ocr(self):
|
||||
transform = MMDet2MMOCR()
|
||||
self.assertEqual(repr(transform), ('MMDet2MMOCR'))
|
||||
|
||||
def test_repr_ocr2det(self):
|
||||
transform = MMOCR2MMDet(poly2mask=True)
|
||||
self.assertEqual(repr(transform), ('MMOCR2MMDet(poly2mask = True)'))
|
Loading…
x
Reference in New Issue
Block a user