diff --git a/mmocr/datasets/pipelines/__init__.py b/mmocr/datasets/pipelines/__init__.py index 5ae6e0ca..bbee4300 100644 --- a/mmocr/datasets/pipelines/__init__.py +++ b/mmocr/datasets/pipelines/__init__.py @@ -9,8 +9,9 @@ from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR, OpencvToPil, PilToOpencv, RandomPaddingOCR, RandomRotateImageBox, ResizeOCR, ToTensorOCR) from .processing import (PadToWidth, PyramidRescale, RandomCrop, RandomRotate, - RescaleToHeight, Resize, SourceImagePad, - TextDetRandomCrop, TextDetRandomCropFlip) + RescaleToHeight, Resize, ShortScaleAspectJitter, + SourceImagePad, TextDetRandomCrop, + TextDetRandomCropFlip) from .test_time_aug import MultiRotateAugOCR from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets, TextSnakeTargets) @@ -26,5 +27,6 @@ __all__ = [ 'sort_vertex8', 'FCENetTargets', 'TextDetRandomCropFlip', 'NerTransform', 'ToTensorNER', 'ResizeNoImg', 'PyramidRescale', 'TorchVisionWrapper', 'Resize', 'RandomCrop', 'TextDetRandomCrop', 'RandomCrop', - 'PackTextDetInputs', 'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth' + 'PackTextDetInputs', 'PackTextRecogInputs', 'RescaleToHeight', + 'PadToWidth', 'ShortScaleAspectJitter' ] diff --git a/mmocr/datasets/pipelines/processing.py b/mmocr/datasets/pipelines/processing.py index c7cdb6d6..dbf44100 100644 --- a/mmocr/datasets/pipelines/processing.py +++ b/mmocr/datasets/pipelines/processing.py @@ -1357,3 +1357,100 @@ class SourceImagePad(BaseTransform): repr_str += f'(target_scale = {self.target_scale}, ' repr_str += f'crop_ratio = {self.crop_ratio})' return repr_str + + +@TRANSFORMS.register_module() +@avoid_cache_randomness +class ShortScaleAspectJitter(BaseTransform): + """First rescale the image for its shorter side to reach the short_size and + then jitter its aspect ratio, final rescale the shape guaranteed to be + divided by scale_divisor. + + Required Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_polygons (optional) + + + Modified Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_polygons (optional) + + Added Keys: + + - scale + - scale_factor + - keep_ratio + + Args: + short_size (int): Target shorter size before jittering the aspect + ratio. Defaults to 736. + short_size_jitter_range (tuple(float, float)): Range of the ratio used + to jitter the target shorter size. Defaults to (0.7, 1.3). + aspect_ratio_jitter_range (tuple(float, float)): Range of the ratio + used to jitter its aspect ratio. Defaults to (0.9, 1.1). + scale_divisor (int): The scale divisor. Defaults to 1. + resize_cfg (dict): (dict): Config to construct the Resize transform. + Refer to ``Resize`` for detail. Defaults to + ``dict(type='Resize')``. + """ + + def __init__(self, + short_size: int = 736, + ratio_range: Tuple[float, float] = (0.7, 1.3), + aspect_ratio_range: Tuple[float, float] = (0.9, 1.1), + scale_divisor: int = 1, + resize_cfg: Dict = dict(type='Resize')) -> None: + super().__init__() + self.short_size = short_size + self.ratio_range = ratio_range + self.aspect_ratio_range = aspect_ratio_range + self.resize_cfg = resize_cfg + # create a empty Reisize object + resize_cfg.update(dict(scale=0)) + self.resize = TRANSFORMS.build(resize_cfg) + self.scale_divisor = scale_divisor + + def _sample_from_range(self, range: Tuple[float, float]) -> float: + """A ratio will be randomly sampled from the range specified by + ``range``. + + Args: + ratio_range (tuple[float]): The minimum and maximum ratio. + + Returns: + float: A ratio randomly sampled from the range. + """ + min_value, max_value = min(range), max(range) + value = np.random.random_sample() * (max_value - min_value) + min_value + return value + + def transform(self, results: Dict) -> Dict: + h, w = results['img'].shape[:2] + ratio = self._sample_from_range(self.ratio_range) + scale = (ratio * self.short_size) / min(h, w) + + aspect = self._sample_from_range(self.aspect_ratio_range) + h_scale = scale * math.sqrt(aspect) + w_scale = scale / math.sqrt(aspect) + new_h = round(h * h_scale) + new_w = round(w * w_scale) + + new_h = math.ceil(new_h / self.scale_divisor) * self.scale_divisor + new_w = math.ceil(new_w / self.scale_divisor) * self.scale_divisor + self.resize.scale = (new_w, new_h) + return self.resize(results) + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(short_size = {self.short_size}, ' + repr_str += f'ratio_range = {self.ratio_range}, ' + repr_str += f'aspect_ratio_range = {self.aspect_ratio_range}, ' + repr_str += f'scale_divisor = {self.scale_divisor}, ' + repr_str += f'resize_cfg = {self.resize_cfg})' + return repr_str diff --git a/tests/test_datasets/test_pipelines/test_processing.py b/tests/test_datasets/test_pipelines/test_processing.py index 953e8231..faf10164 100644 --- a/tests/test_datasets/test_pipelines/test_processing.py +++ b/tests/test_datasets/test_pipelines/test_processing.py @@ -8,8 +8,8 @@ from mmcv.transforms import Pad, RandomResize from mmocr.datasets.pipelines import (PadToWidth, PyramidRescale, RandomCrop, RandomRotate, RescaleToHeight, Resize, - SourceImagePad, TextDetRandomCrop, - TextDetRandomCropFlip) + ShortScaleAspectJitter, SourceImagePad, + TextDetRandomCrop, TextDetRandomCropFlip) from mmocr.utils import bbox2poly, poly2shapely @@ -631,3 +631,39 @@ class TestSourceImagePad(unittest.TestCase): repr(transform), ('SourceImagePad(target_scale = (30, 30), crop_ratio = (0.1, 0.1))' )) + + +class TestShortScaleAspectJitter(unittest.TestCase): + + @mock.patch('mmocr.datasets.pipelines.processing.np.random.random_sample') + def test_transform(self, mock_random): + ratio_range = (0.5, 1.5) + aspect_ratio_range = (0.9, 1.1) + mock_random.side_effect = [0.5, 0.5] + img = np.zeros((15, 20, 3)) + polygon = [np.array([10., 5., 20., 5., 20., 10., 10., 10.])] + bbox = np.array([[10., 5., 20., 10.]]) + data_info = dict(img=img, gt_polygons=polygon, gt_bboxes=bbox) + t = ShortScaleAspectJitter( + short_size=40, + ratio_range=ratio_range, + aspect_ratio_range=aspect_ratio_range, + scale_divisor=4) + results = t(data_info) + self.assertEqual(results['img'].shape, (40, 56, 3)) + self.assertEqual(results['img_shape'], (40, 56)) + + def test_repr(self): + transform = ShortScaleAspectJitter( + short_size=40, + ratio_range=(0.5, 1.5), + aspect_ratio_range=(0.9, 1.1), + scale_divisor=4, + resize_cfg=dict(type='Resize')) + self.assertEqual( + repr(transform), ('ShortScaleAspectJitter(' + 'short_size = 40, ' + 'ratio_range = (0.5, 1.5), ' + 'aspect_ratio_range = (0.9, 1.1), ' + 'scale_divisor = 4, ' + "resize_cfg = {'type': 'Resize', 'scale': 0})"))