diff --git a/mmocr/datasets/transforms/__init__.py b/mmocr/datasets/transforms/__init__.py index d23b0fcb..a1e51950 100644 --- a/mmocr/datasets/transforms/__init__.py +++ b/mmocr/datasets/transforms/__init__.py @@ -10,7 +10,7 @@ from .textdet_transforms import (BoundedScaleAspectJitter, RandomFlip, ShortScaleAspectJitter, SourceImagePad, TextDetRandomCrop, TextDetRandomCropFlip) from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight -from .wrappers import ImgAugWrapper, TorchVisionWrapper +from .wrappers import ConditionApply, ImgAugWrapper, TorchVisionWrapper __all__ = [ 'LoadOCRAnnotations', 'RandomRotate', 'ImgAugWrapper', 'SourceImagePad', @@ -20,5 +20,5 @@ __all__ = [ 'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter', 'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR', 'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile', - 'LoadImageFromNDArray', 'RemoveIgnored' + 'LoadImageFromNDArray', 'RemoveIgnored', 'ConditionApply' ] diff --git a/mmocr/datasets/transforms/wrappers.py b/mmocr/datasets/transforms/wrappers.py index 90ca6975..086edb75 100644 --- a/mmocr/datasets/transforms/wrappers.py +++ b/mmocr/datasets/transforms/wrappers.py @@ -6,6 +6,7 @@ import imgaug import imgaug.augmenters as iaa import numpy as np import torchvision.transforms as torchvision_transforms +from mmcv.transforms import Compose from mmcv.transforms.base import BaseTransform from PIL import Image @@ -296,3 +297,47 @@ class TorchVisionWrapper(BaseTransform): repr_str += f', {k} = {v}' repr_str += ')' return repr_str + + +@TRANSFORMS.register_module() +class ConditionApply(BaseTransform): + """Apply transforms according to the condition. If the condition is met, + true_transforms will be applied, otherwise false_transforms will be + applied. + + Args: + condition (str): The string that can be evaluated to a boolean value. + true_transforms (list[dict]): Transforms to be applied if the condition + is met. Defaults to []. + false_transforms (list[dict]): Transforms to be applied if the + condition is not met. Defaults to []. + """ + + def __init__(self, + condition: str, + true_transforms: Union[Dict, List[Dict]] = [], + false_transforms: Union[Dict, List[Dict]] = []): + self.condition = condition + self.true_transforms = Compose(true_transforms) + self.false_transforms = Compose(false_transforms) + + def transform(self, results: Dict) -> Optional[Dict]: + """Transform the image. + + Args: + results (dict):Result dict containing the data to transform. + + Returns: + dict: Transformed results. + """ + if eval(self.condition): + return self.true_transforms(results) # type: ignore + else: + return self.false_transforms(results) + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(condition = {self.condition}, ' + repr_str += f'true_transforms = {self.true_transforms}, ' + repr_str += f'false_transforms = {self.false_transforms})' + return repr_str diff --git a/tests/test_datasets/test_transforms/test_wrappers.py b/tests/test_datasets/test_transforms/test_wrappers.py index 84f96264..51953640 100644 --- a/tests/test_datasets/test_transforms/test_wrappers.py +++ b/tests/test_datasets/test_transforms/test_wrappers.py @@ -6,7 +6,8 @@ from typing import Dict, List, Optional import numpy as np from shapely.geometry import Polygon -from mmocr.datasets.transforms import ImgAugWrapper, TorchVisionWrapper +from mmocr.datasets.transforms import (ConditionApply, ImgAugWrapper, + TorchVisionWrapper) class TestImgAug(unittest.TestCase): @@ -160,3 +161,36 @@ class TestTorchVisionWrapper(unittest.TestCase): self.assertEqual( repr(f), 'TorchVisionWrapper(op = Grayscale, num_output_channels = 3)') + + +class TestConditionApply(unittest.TestCase): + + def test_transform(self): + dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3))) + resize = dict(type='Resize', scale=(40, 50), keep_ratio=False) + + trans = ConditionApply( + "results['img_shape'][0] > 80", true_transforms=resize) + results = trans(dummy_result) + self.assertEqual(results['img_shape'], (50, 40)) + dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3))) + trans = ConditionApply( + "results['img_shape'][0] < 80", false_transforms=resize) + results = trans(dummy_result) + self.assertEqual(results['img_shape'], (50, 40)) + dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3))) + trans = ConditionApply("results['img_shape'][0] < 80") + results = trans(dummy_result) + self.assertEqual(results['img_shape'], (100, 200)) + + def test_repr(self): + resize = dict(type='Resize', scale=(40, 50), keep_ratio=False) + trans = ConditionApply( + "results['img_shape'][0] < 80", true_transforms=resize) + self.assertEqual( + repr(trans), + "ConditionApply(condition = results['img_shape'][0] < 80, " + 'true_transforms = Compose(\n Resize(scale=(40, 50), ' + 'scale_factor=None, keep_ratio=False, clip_object_border=True), ' + 'backend=cv2), interpolation=bilinear)\n), ' + 'false_transforms = Compose(\n))')