diff --git a/mmocr/datasets/transforms/__init__.py b/mmocr/datasets/transforms/__init__.py index 39389bb7..646ad325 100644 --- a/mmocr/datasets/transforms/__init__.py +++ b/mmocr/datasets/transforms/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .adapters import MMDet2MMOCR, MMOCR2MMDet from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs from .loading import LoadKIEAnnotations, LoadOCRAnnotations from .ocr_transforms import RandomCrop, RandomRotate, Resize @@ -15,5 +16,6 @@ __all__ = [ 'RandomCrop', 'TextDetRandomCrop', 'RandomCrop', 'PackTextDetInputs', 'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth', 'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter', - 'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon' + 'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR', + 'MMOCR2MMDet' ] diff --git a/mmocr/datasets/transforms/adapters.py b/mmocr/datasets/transforms/adapters.py new file mode 100644 index 00000000..6c1201d6 --- /dev/null +++ b/mmocr/datasets/transforms/adapters.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict + +from mmcv.transforms.base import BaseTransform +from mmdet.core import PolygonMasks +from mmdet.core.mask.structures import bitmap_to_polygon + +from mmocr.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class MMDet2MMOCR(BaseTransform): + """Convert transforms's data format from MMDet to MMOCR. + + Required Keys: + + - gt_masks (PolygonMasks | BitmapMasks) (optional) + - gt_ignore_flags (np.bool) (optional) + + Added Keys: + + - gt_polygons (list[np.ndarray]) + - gt_ignored (np.ndarray) + """ + + def transform(self, results: Dict) -> Dict: + """Convert MMDet's data format to MMOCR's data format. + + Args: + results (Dict): Result dict containing the data to transform. + + Returns: + (Dict): The transformed data. + """ + # gt_masks -> gt_polygons + if 'gt_masks' in results.keys(): + gt_polygons = [] + gt_masks = results.pop('gt_masks') + if len(gt_masks) > 0: + # PolygonMasks + if isinstance(gt_masks[0], PolygonMasks): + gt_polygons = [mask[0] for mask in gt_masks.masks] + # BitmapMasks + else: + polygons = [] + for mask in gt_masks.masks: + contours, _ = bitmap_to_polygon(mask) + polygons += [ + contour.reshape(-1) for contour in contours + ] + # filter invalid polygons + gt_polygons = [] + for polygon in polygons: + if len(polygon) < 6: + continue + gt_polygons.append(polygon) + + results['gt_polygons'] = gt_polygons + # gt_ignore_flags -> gt_ignored + if 'gt_ignore_flags' in results.keys(): + gt_ignored = results.pop('gt_ignore_flags') + results['gt_ignored'] = gt_ignored + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + return repr_str + + +@TRANSFORMS.register_module() +class MMOCR2MMDet(BaseTransform): + """Convert transforms's data format from MMOCR to MMDet. + + Required Keys: + + - img_shape + - gt_polygons (List[ndarray]) (optional) + - gt_ignored (np.bool) (optional) + + Added Keys: + + - gt_masks (PolygonMasks | BitmapMasks) (optional) + - gt_ignore_flags (np.bool) (optional) + + Args: + poly2mask (bool): Whether to convert mask to bitmap. Default: True. + """ + + def __init__(self, poly2mask: bool = False) -> None: + self.poly2mask = poly2mask + + def transform(self, results: Dict) -> Dict: + """Convert MMOCR's data format to MMDet's data format. + + Args: + results (Dict): Result dict containing the data to transform. + + Returns: + (Dict): The transformed data. + """ + # gt_polygons -> gt_masks + if 'gt_polygons' in results.keys(): + gt_polygons = results.pop('gt_polygons') + gt_polygons = [[gt_polygon] for gt_polygon in gt_polygons] + gt_masks = PolygonMasks(gt_polygons, *results['img_shape']) + + if self.poly2mask: + gt_masks = gt_masks.to_bitmap() + + results['gt_masks'] = gt_masks + # gt_ignore_flags -> gt_ignored + if 'gt_ignored' in results.keys(): + gt_ignored = results.pop('gt_ignored') + results['gt_ignore_flags'] = gt_ignored + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(poly2mask = {self.poly2mask})' + return repr_str diff --git a/tests/test_datasets/test_transforms/test_adapters.py b/tests/test_datasets/test_transforms/test_adapters.py new file mode 100644 index 00000000..e0e5f0ac --- /dev/null +++ b/tests/test_datasets/test_transforms/test_adapters.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import numpy as np +from mmdet.core import PolygonMasks +from mmdet.core.mask.structures import bitmap_to_polygon + +from mmocr.datasets import MMDet2MMOCR, MMOCR2MMDet, Resize +from mmocr.utils import poly2shapely + + +class TestMMDet2MMOCR(unittest.TestCase): + + def setUp(self): + img = np.zeros((15, 30, 3)) + img_shape = (15, 30) + polygons = [ + np.array([10., 5., 20., 5., 20., 10., 10., 10.]), + np.array([10., 5., 20., 5., 20., 10., 10., 10., 8., 7.]) + ] + ignores = np.array([True, False]) + bboxes = np.array([[10., 5., 20., 10.], [0., 0., 10., 10.]]) + self.data_info_ocr = dict( + img=img, + gt_polygons=polygons, + gt_bboxes=bboxes, + img_shape=img_shape, + gt_ignored=ignores) + + _polygons = [[polygon] for polygon in polygons] + masks = PolygonMasks(_polygons, *img_shape) + self.data_info_det_polygon = dict( + img=img, + gt_masks=masks, + gt_bboxes=bboxes, + gt_ignore_flags=ignores, + img_shape=img_shape) + + masks = masks.to_bitmap() + self.data_info_det_mask = dict( + img=img, + gt_masks=masks, + gt_bboxes=bboxes, + gt_ignore_flags=ignores, + img_shape=img_shape) + + def test_ocr2det_polygonmasks(self): + transform = MMOCR2MMDet() + results = transform(self.data_info_ocr.copy()) + self.assertEqual(results['img'].shape, (15, 30, 3)) + self.assertEqual(results['img_shape'], (15, 30)) + self.assertTrue( + np.allclose(results['gt_masks'].masks[0][0], + self.data_info_det_polygon['gt_masks'].masks[0][0])) + self.assertTrue( + np.allclose(results['gt_masks'].masks[0][0], + self.data_info_det_polygon['gt_masks'].masks[0][0])) + self.assertTrue( + np.allclose(results['gt_bboxes'], + self.data_info_det_polygon['gt_bboxes'])) + self.assertTrue( + np.allclose(results['gt_ignore_flags'], + self.data_info_det_polygon['gt_ignore_flags'])) + + def test_ocr2det_bitmapmasks(self): + transform = MMOCR2MMDet(poly2mask=True) + results = transform(self.data_info_ocr.copy()) + self.assertEqual(results['img'].shape, (15, 30, 3)) + self.assertEqual(results['img_shape'], (15, 30)) + self.assertTrue( + poly2shapely( + bitmap_to_polygon( + results['gt_masks'].masks[0])[0][0].flatten()).equals( + poly2shapely( + bitmap_to_polygon( + self.data_info_det_mask['gt_masks'].masks[0]) + [0][0].flatten()))) + + self.assertTrue( + np.allclose(results['gt_bboxes'], + self.data_info_det_mask['gt_bboxes'])) + self.assertTrue( + np.allclose(results['gt_ignore_flags'], + self.data_info_det_mask['gt_ignore_flags'])) + + def test_det2ocr_polygonmasks(self): + transform = MMDet2MMOCR() + results = transform(self.data_info_det_polygon.copy()) + self.assertEqual(results['img'].shape, (15, 30, 3)) + self.assertEqual(results['img_shape'], (15, 30)) + self.assertTrue( + np.allclose(results['gt_polygons'][0], + self.data_info_ocr['gt_polygons'][0])) + self.assertTrue( + np.allclose(results['gt_polygons'][1], + self.data_info_ocr['gt_polygons'][1])) + self.assertTrue( + np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes'])) + self.assertTrue( + np.allclose(results['gt_ignored'], + self.data_info_ocr['gt_ignored'])) + + def test_det2ocr_bitmapmasks(self): + transform = MMDet2MMOCR() + results = transform(self.data_info_det_mask.copy()) + self.assertEqual(results['img'].shape, (15, 30, 3)) + self.assertEqual(results['img_shape'], (15, 30)) + self.assertTrue( + np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes'])) + self.assertTrue( + np.allclose(results['gt_ignored'], + self.data_info_ocr['gt_ignored'])) + + def test_ocr2det2ocr(self): + from mmdet.datasets.pipelines import Resize as MMDet_Resize + t1 = MMOCR2MMDet() + t2 = MMDet_Resize(scale=(60, 60)) + t3 = MMDet2MMOCR() + t4 = Resize(scale=(30, 15)) + results = t4(t3(t2(t1(self.data_info_ocr.copy())))) + self.assertEqual(results['img'].shape, (15, 30, 3)) + self.assertTrue( + np.allclose(results['gt_polygons'][0], + self.data_info_ocr['gt_polygons'][0])) + self.assertTrue( + np.allclose(results['gt_polygons'][1], + self.data_info_ocr['gt_polygons'][1])) + self.assertTrue( + np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes'])) + self.assertEqual(results['gt_ignored'].all(), + self.data_info_ocr['gt_ignored'].all()) + + def test_repr_det2ocr(self): + transform = MMDet2MMOCR() + self.assertEqual(repr(transform), ('MMDet2MMOCR')) + + def test_repr_ocr2det(self): + transform = MMOCR2MMDet(poly2mask=True) + self.assertEqual(repr(transform), ('MMOCR2MMDet(poly2mask = True)'))