[Feature] ConditionApply ()

pull/1647/head
liukuikun 2022-12-28 11:53:32 +08:00 committed by GitHub
parent 89606a1cf1
commit 9baf440d7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 3 deletions
mmocr/datasets/transforms
tests/test_datasets/test_transforms

View File

@ -10,7 +10,7 @@ from .textdet_transforms import (BoundedScaleAspectJitter, RandomFlip,
ShortScaleAspectJitter, SourceImagePad,
TextDetRandomCrop, TextDetRandomCropFlip)
from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight
from .wrappers import ImgAugWrapper, TorchVisionWrapper
from .wrappers import ConditionApply, ImgAugWrapper, TorchVisionWrapper
__all__ = [
'LoadOCRAnnotations', 'RandomRotate', 'ImgAugWrapper', 'SourceImagePad',
@ -20,5 +20,5 @@ __all__ = [
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile',
'LoadImageFromNDArray', 'RemoveIgnored'
'LoadImageFromNDArray', 'RemoveIgnored', 'ConditionApply'
]

View File

@ -6,6 +6,7 @@ import imgaug
import imgaug.augmenters as iaa
import numpy as np
import torchvision.transforms as torchvision_transforms
from mmcv.transforms import Compose
from mmcv.transforms.base import BaseTransform
from PIL import Image
@ -296,3 +297,47 @@ class TorchVisionWrapper(BaseTransform):
repr_str += f', {k} = {v}'
repr_str += ')'
return repr_str
@TRANSFORMS.register_module()
class ConditionApply(BaseTransform):
"""Apply transforms according to the condition. If the condition is met,
true_transforms will be applied, otherwise false_transforms will be
applied.
Args:
condition (str): The string that can be evaluated to a boolean value.
true_transforms (list[dict]): Transforms to be applied if the condition
is met. Defaults to [].
false_transforms (list[dict]): Transforms to be applied if the
condition is not met. Defaults to [].
"""
def __init__(self,
condition: str,
true_transforms: Union[Dict, List[Dict]] = [],
false_transforms: Union[Dict, List[Dict]] = []):
self.condition = condition
self.true_transforms = Compose(true_transforms)
self.false_transforms = Compose(false_transforms)
def transform(self, results: Dict) -> Optional[Dict]:
"""Transform the image.
Args:
results (dict):Result dict containing the data to transform.
Returns:
dict: Transformed results.
"""
if eval(self.condition):
return self.true_transforms(results) # type: ignore
else:
return self.false_transforms(results)
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(condition = {self.condition}, '
repr_str += f'true_transforms = {self.true_transforms}, '
repr_str += f'false_transforms = {self.false_transforms})'
return repr_str

View File

@ -6,7 +6,8 @@ from typing import Dict, List, Optional
import numpy as np
from shapely.geometry import Polygon
from mmocr.datasets.transforms import ImgAugWrapper, TorchVisionWrapper
from mmocr.datasets.transforms import (ConditionApply, ImgAugWrapper,
TorchVisionWrapper)
class TestImgAug(unittest.TestCase):
@ -160,3 +161,36 @@ class TestTorchVisionWrapper(unittest.TestCase):
self.assertEqual(
repr(f),
'TorchVisionWrapper(op = Grayscale, num_output_channels = 3)')
class TestConditionApply(unittest.TestCase):
def test_transform(self):
dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3)))
resize = dict(type='Resize', scale=(40, 50), keep_ratio=False)
trans = ConditionApply(
"results['img_shape'][0] > 80", true_transforms=resize)
results = trans(dummy_result)
self.assertEqual(results['img_shape'], (50, 40))
dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3)))
trans = ConditionApply(
"results['img_shape'][0] < 80", false_transforms=resize)
results = trans(dummy_result)
self.assertEqual(results['img_shape'], (50, 40))
dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3)))
trans = ConditionApply("results['img_shape'][0] < 80")
results = trans(dummy_result)
self.assertEqual(results['img_shape'], (100, 200))
def test_repr(self):
resize = dict(type='Resize', scale=(40, 50), keep_ratio=False)
trans = ConditionApply(
"results['img_shape'][0] < 80", true_transforms=resize)
self.assertEqual(
repr(trans),
"ConditionApply(condition = results['img_shape'][0] < 80, "
'true_transforms = Compose(\n Resize(scale=(40, 50), '
'scale_factor=None, keep_ratio=False, clip_object_border=True), '
'backend=cv2), interpolation=bilinear)\n), '
'false_transforms = Compose(\n))')