mirror of https://github.com/open-mmlab/mmocr.git
[Feature] ConditionApply (#1646)
parent
89606a1cf1
commit
9baf440d7a
mmocr/datasets/transforms
tests/test_datasets/test_transforms
|
@ -10,7 +10,7 @@ from .textdet_transforms import (BoundedScaleAspectJitter, RandomFlip,
|
|||
ShortScaleAspectJitter, SourceImagePad,
|
||||
TextDetRandomCrop, TextDetRandomCropFlip)
|
||||
from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight
|
||||
from .wrappers import ImgAugWrapper, TorchVisionWrapper
|
||||
from .wrappers import ConditionApply, ImgAugWrapper, TorchVisionWrapper
|
||||
|
||||
__all__ = [
|
||||
'LoadOCRAnnotations', 'RandomRotate', 'ImgAugWrapper', 'SourceImagePad',
|
||||
|
@ -20,5 +20,5 @@ __all__ = [
|
|||
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
|
||||
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
|
||||
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile',
|
||||
'LoadImageFromNDArray', 'RemoveIgnored'
|
||||
'LoadImageFromNDArray', 'RemoveIgnored', 'ConditionApply'
|
||||
]
|
||||
|
|
|
@ -6,6 +6,7 @@ import imgaug
|
|||
import imgaug.augmenters as iaa
|
||||
import numpy as np
|
||||
import torchvision.transforms as torchvision_transforms
|
||||
from mmcv.transforms import Compose
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from PIL import Image
|
||||
|
||||
|
@ -296,3 +297,47 @@ class TorchVisionWrapper(BaseTransform):
|
|||
repr_str += f', {k} = {v}'
|
||||
repr_str += ')'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ConditionApply(BaseTransform):
|
||||
"""Apply transforms according to the condition. If the condition is met,
|
||||
true_transforms will be applied, otherwise false_transforms will be
|
||||
applied.
|
||||
|
||||
Args:
|
||||
condition (str): The string that can be evaluated to a boolean value.
|
||||
true_transforms (list[dict]): Transforms to be applied if the condition
|
||||
is met. Defaults to [].
|
||||
false_transforms (list[dict]): Transforms to be applied if the
|
||||
condition is not met. Defaults to [].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
condition: str,
|
||||
true_transforms: Union[Dict, List[Dict]] = [],
|
||||
false_transforms: Union[Dict, List[Dict]] = []):
|
||||
self.condition = condition
|
||||
self.true_transforms = Compose(true_transforms)
|
||||
self.false_transforms = Compose(false_transforms)
|
||||
|
||||
def transform(self, results: Dict) -> Optional[Dict]:
|
||||
"""Transform the image.
|
||||
|
||||
Args:
|
||||
results (dict):Result dict containing the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: Transformed results.
|
||||
"""
|
||||
if eval(self.condition):
|
||||
return self.true_transforms(results) # type: ignore
|
||||
else:
|
||||
return self.false_transforms(results)
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(condition = {self.condition}, '
|
||||
repr_str += f'true_transforms = {self.true_transforms}, '
|
||||
repr_str += f'false_transforms = {self.false_transforms})'
|
||||
return repr_str
|
||||
|
|
|
@ -6,7 +6,8 @@ from typing import Dict, List, Optional
|
|||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from mmocr.datasets.transforms import ImgAugWrapper, TorchVisionWrapper
|
||||
from mmocr.datasets.transforms import (ConditionApply, ImgAugWrapper,
|
||||
TorchVisionWrapper)
|
||||
|
||||
|
||||
class TestImgAug(unittest.TestCase):
|
||||
|
@ -160,3 +161,36 @@ class TestTorchVisionWrapper(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
repr(f),
|
||||
'TorchVisionWrapper(op = Grayscale, num_output_channels = 3)')
|
||||
|
||||
|
||||
class TestConditionApply(unittest.TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3)))
|
||||
resize = dict(type='Resize', scale=(40, 50), keep_ratio=False)
|
||||
|
||||
trans = ConditionApply(
|
||||
"results['img_shape'][0] > 80", true_transforms=resize)
|
||||
results = trans(dummy_result)
|
||||
self.assertEqual(results['img_shape'], (50, 40))
|
||||
dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3)))
|
||||
trans = ConditionApply(
|
||||
"results['img_shape'][0] < 80", false_transforms=resize)
|
||||
results = trans(dummy_result)
|
||||
self.assertEqual(results['img_shape'], (50, 40))
|
||||
dummy_result = dict(img_shape=(100, 200), img=np.zeros((100, 200, 3)))
|
||||
trans = ConditionApply("results['img_shape'][0] < 80")
|
||||
results = trans(dummy_result)
|
||||
self.assertEqual(results['img_shape'], (100, 200))
|
||||
|
||||
def test_repr(self):
|
||||
resize = dict(type='Resize', scale=(40, 50), keep_ratio=False)
|
||||
trans = ConditionApply(
|
||||
"results['img_shape'][0] < 80", true_transforms=resize)
|
||||
self.assertEqual(
|
||||
repr(trans),
|
||||
"ConditionApply(condition = results['img_shape'][0] < 80, "
|
||||
'true_transforms = Compose(\n Resize(scale=(40, 50), '
|
||||
'scale_factor=None, keep_ratio=False, clip_object_border=True), '
|
||||
'backend=cv2), interpolation=bilinear)\n), '
|
||||
'false_transforms = Compose(\n))')
|
||||
|
|
Loading…
Reference in New Issue