[Feature] ConditionApply (#1646)

pull/1647/head
liukuikun 2022-12-28 11:53:32 +08:00 committed by GitHub
parent 89606a1cf1
commit 9baf440d7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 3 deletions

View File

@ -10,7 +10,7 @@ from .textdet_transforms import (BoundedScaleAspectJitter, RandomFlip,
ShortScaleAspectJitter, SourceImagePad, ShortScaleAspectJitter, SourceImagePad,
TextDetRandomCrop, TextDetRandomCropFlip) TextDetRandomCrop, TextDetRandomCropFlip)
from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight
from .wrappers import ImgAugWrapper, TorchVisionWrapper from .wrappers import ConditionApply, ImgAugWrapper, TorchVisionWrapper
__all__ = [ __all__ = [
'LoadOCRAnnotations', 'RandomRotate', 'ImgAugWrapper', 'SourceImagePad', 'LoadOCRAnnotations', 'RandomRotate', 'ImgAugWrapper', 'SourceImagePad',
@ -20,5 +20,5 @@ __all__ = [
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter', 'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR', 'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile', 'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile',
'LoadImageFromNDArray', 'RemoveIgnored' 'LoadImageFromNDArray', 'RemoveIgnored', 'ConditionApply'
] ]

View File

@ -6,6 +6,7 @@ import imgaug
import imgaug.augmenters as iaa import imgaug.augmenters as iaa
import numpy as np import numpy as np
import torchvision.transforms as torchvision_transforms import torchvision.transforms as torchvision_transforms
from mmcv.transforms import Compose
from mmcv.transforms.base import BaseTransform from mmcv.transforms.base import BaseTransform
from PIL import Image from PIL import Image
@ -296,3 +297,47 @@ class TorchVisionWrapper(BaseTransform):
repr_str += f', {k} = {v}' repr_str += f', {k} = {v}'
repr_str += ')' repr_str += ')'
return repr_str return repr_str
@TRANSFORMS.register_module()
class ConditionApply(BaseTransform):
"""Apply transforms according to the condition. If the condition is met,
true_transforms will be applied, otherwise false_transforms will be
applied.
Args:
condition (str): The string that can be evaluated to a boolean value.
true_transforms (list[dict]): Transforms to be applied if the condition
is met. Defaults to [].
false_transforms (list[dict]): Transforms to be applied if the
condition is not met. Defaults to [].
"""
def __init__(self,
condition: str,
true_transforms: Union[Dict, List[Dict]] = [],
false_transforms: Union[Dict, List[Dict]] = []):
self.condition = condition
self.true_transforms = Compose(true_transforms)
self.false_transforms = Compose(false_transforms)
def transform(self, results: Dict) -> Optional[Dict]:
"""Transform the image.
Args:
results (dict):Result dict containing the data to transform.
Returns:
dict: Transformed results.
"""
if eval(self.condition):
return self.true_transforms(results) # type: ignore
else:
return self.false_transforms(results)
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(condition = {self.condition}, '
repr_str += f'true_transforms = {self.true_transforms}, '
repr_str += f'false_transforms = {self.false_transforms})'
return repr_str

View File

@ -6,7 +6,8 @@ from typing import Dict, List, Optional
import numpy as np import numpy as np
from shapely.geometry import Polygon from shapely.geometry import Polygon
from mmocr.datasets.transforms import ImgAugWrapper, TorchVisionWrapper from mmocr.datasets.transforms import (ConditionApply, ImgAugWrapper,
TorchVisionWrapper)
class TestImgAug(unittest.TestCase): class TestImgAug(unittest.TestCase):
@ -160,3 +161,36 @@ class TestTorchVisionWrapper(unittest.TestCase):
self.assertEqual( self.assertEqual(
repr(f), repr(f),
'TorchVisionWrapper(op = Grayscale, num_output_channels = 3)') 'TorchVisionWrapper(op = Grayscale, num_output_channels = 3)')
class TestConditionApply(unittest.TestCase):
def test_transform(self):
dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3)))
resize = dict(type='Resize', scale=(40, 50), keep_ratio=False)
trans = ConditionApply(
"results['img_shape'][0] > 80", true_transforms=resize)
results = trans(dummy_result)
self.assertEqual(results['img_shape'], (50, 40))
dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3)))
trans = ConditionApply(
"results['img_shape'][0] < 80", false_transforms=resize)
results = trans(dummy_result)
self.assertEqual(results['img_shape'], (50, 40))
dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3)))
trans = ConditionApply("results['img_shape'][0] < 80")
results = trans(dummy_result)
self.assertEqual(results['img_shape'], (100, 200))
def test_repr(self):
resize = dict(type='Resize', scale=(40, 50), keep_ratio=False)
trans = ConditionApply(
"results['img_shape'][0] < 80", true_transforms=resize)
self.assertEqual(
repr(trans),
"ConditionApply(condition = results['img_shape'][0] < 80, "
'true_transforms = Compose(\n Resize(scale=(40, 50), '
'scale_factor=None, keep_ratio=False, clip_object_border=True), '
'backend=cv2), interpolation=bilinear)\n), '
'false_transforms = Compose(\n))')