From 41c1671e7b9811c8d9d8776693dd103f1e7ad488 Mon Sep 17 00:00:00 2001 From: "wangxinyu.vendor" Date: Thu, 12 May 2022 03:13:09 +0000 Subject: [PATCH] Refactor PyramidRescale --- mmocr/datasets/pipelines/__init__.py | 9 +- mmocr/datasets/pipelines/processing.py | 92 +++++++++++++++++++ mmocr/datasets/pipelines/transforms.py | 51 ---------- old_tests/test_dataset/test_transforms.py | 28 ------ .../test_pipelines/test_processing.py | 55 +++++++++++ 5 files changed, 152 insertions(+), 83 deletions(-) create mode 100644 mmocr/datasets/pipelines/processing.py create mode 100644 tests/test_datasets/test_pipelines/test_processing.py diff --git a/mmocr/datasets/pipelines/__init__.py b/mmocr/datasets/pipelines/__init__.py index 54c70262..9ac0d828 100644 --- a/mmocr/datasets/pipelines/__init__.py +++ b/mmocr/datasets/pipelines/__init__.py @@ -9,14 +9,15 @@ from .ocr_seg_targets import OCRSegTargets from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR, OpencvToPil, PilToOpencv, RandomPaddingOCR, RandomRotateImageBox, ResizeOCR, ToTensorOCR) +from .processing import PyramidRescale from .test_time_aug import MultiRotateAugOCR from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets, TextSnakeTargets) from .transform_wrappers import OneOfWrapper, RandomWrapper, TorchVisionWrapper -from .transforms import (ColorJitter, PyramidRescale, RandomCropFlip, - RandomCropInstances, RandomCropPolyInstances, - RandomRotatePolyInstances, RandomRotateTextDet, - RandomScaling, ScaleAspectJitter, SquareResizePad) +from .transforms import (ColorJitter, RandomCropFlip, RandomCropInstances, + RandomCropPolyInstances, RandomRotatePolyInstances, + RandomRotateTextDet, RandomScaling, ScaleAspectJitter, + SquareResizePad) __all__ = [ 'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR', diff --git a/mmocr/datasets/pipelines/processing.py b/mmocr/datasets/pipelines/processing.py new file mode 100644 index 00000000..0d74a0a6 --- /dev/null +++ b/mmocr/datasets/pipelines/processing.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Tuple + +import cv2 +import mmcv +import numpy as np +from mmcv.transforms.base import BaseTransform +from mmcv.transforms.utils import cache_randomness + +from mmocr.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class PyramidRescale(BaseTransform): + """Resize the image to the base shape, downsample it with gaussian pyramid, + and rescale it back to original size. + + Adapted from https://github.com/FangShancheng/ABINet. + + Required Keys: + + - img (ndarray) + + Modified Keys: + + - img (ndarray) + + Args: + factor (int): The decay factor from base size, or the number of + downsampling operations from the base layer. + base_shape (tuple(int)): The shape of the base layer of the pyramid. + randomize_factor (bool): If True, the final factor would be a random + integer in [0, factor]. + """ + + def __init__(self, + factor: int = 4, + base_shape: Tuple[int, int] = (128, 512), + randomize_factor: bool = True) -> None: + if not isinstance(factor, int): + raise TypeError('`factor` should be an integer, ' + f'but got {type(factor)} instead') + if not isinstance(base_shape, (list, tuple)): + raise TypeError('`base_shape` should be a list or tuple, ' + f'but got {type(base_shape)} instead') + if not len(base_shape) == 2: + raise ValueError('`base_shape` should contain two integers') + if not isinstance(base_shape[0], int) or not isinstance( + base_shape[1], int): + raise ValueError('`base_shape` should contain two integers') + if not isinstance(randomize_factor, bool): + raise TypeError('`randomize_factor` should be a bool, ' + f'but got {type(randomize_factor)} instead') + + self.factor = factor + self.randomize_factor = randomize_factor + self.base_w, self.base_h = base_shape + + @cache_randomness + def get_random_factor(self): + return np.random.randint(0, self.factor + 1) + + def transform(self, results: Dict) -> Dict: + """Applying pyramid rescale on results. + Args: + results (dict): Result dict contains the data to transform. + + Returns: + dict: The transformed data + """ + + assert 'img' in results, '`img` is not found in results' + if self.randomize_factor: + self.factor = self.get_random_factor() + if self.factor == 0: + return results + img = results['img'] + src_h, src_w = img.shape[:2] + scale_img = mmcv.imresize(img, (self.base_w, self.base_h)) + for _ in range(self.factor): + scale_img = cv2.pyrDown(scale_img) + scale_img = mmcv.imresize(scale_img, (src_w, src_h)) + results['img'] = scale_img + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(factor = {self.factor}' + repr_str += f', randomize_factor = {self.randomize_factor}' + repr_str += f', base_w = {self.base_w}' + repr_str += f', base_h = {self.base_h})' + return repr_str diff --git a/mmocr/datasets/pipelines/transforms.py b/mmocr/datasets/pipelines/transforms.py index 0d7805bc..f9ec9664 100644 --- a/mmocr/datasets/pipelines/transforms.py +++ b/mmocr/datasets/pipelines/transforms.py @@ -967,54 +967,3 @@ class RandomCropFlip: h_axis = np.where(h_array == 0)[0] w_axis = np.where(w_array == 0)[0] return h_axis, w_axis - - -@TRANSFORMS.register_module() -class PyramidRescale: - """Resize the image to the base shape, downsample it with gaussian pyramid, - and rescale it back to original size. - - Adapted from https://github.com/FangShancheng/ABINet. - - Args: - factor (int): The decay factor from base size, or the number of - downsampling operations from the base layer. - base_shape (tuple(int)): The shape of the base layer of the pyramid. - randomize_factor (bool): If True, the final factor would be a random - integer in [0, factor]. - - :Required Keys: - - | ``img`` (ndarray): The input image. - - :Affected Keys: - :Modified: - - | ``img`` (ndarray): The modified image. - """ - - def __init__(self, factor=4, base_shape=(128, 512), randomize_factor=True): - assert isinstance(factor, int) - assert isinstance(base_shape, list) or isinstance(base_shape, tuple) - assert len(base_shape) == 2 - assert isinstance(randomize_factor, bool) - self.factor = factor if not randomize_factor else np.random.randint( - 0, factor + 1) - self.base_w, self.base_h = base_shape - - def __call__(self, results): - assert 'img' in results - if self.factor == 0: - return results - img = results['img'] - src_h, src_w = img.shape[:2] - scale_img = mmcv.imresize(img, (self.base_w, self.base_h)) - for _ in range(self.factor): - scale_img = cv2.pyrDown(scale_img) - scale_img = mmcv.imresize(scale_img, (src_w, src_h)) - results['img'] = scale_img - return results - - def __repr__(self): - repr_str = self.__class__.__name__ - repr_str += f'(factor={self.factor}, ' - repr_str += f'basew={self.basew}, baseh={self.baseh})' - return repr_str diff --git a/old_tests/test_dataset/test_transforms.py b/old_tests/test_dataset/test_transforms.py index fc51f3d7..1d9a9c0d 100644 --- a/old_tests/test_dataset/test_transforms.py +++ b/old_tests/test_dataset/test_transforms.py @@ -1,9 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy import unittest.mock as mock import numpy as np -import pytest import torchvision.transforms as TF from mmdet.core import BitmapMasks, PolygonMasks from PIL import Image @@ -345,29 +343,3 @@ def test_square_resize_pad(mock_sample): target[1::2] *= 8. / 3 assert np.allclose(output['gt_masks'].masks[0][0], target) assert output['img'].shape == (40, 40, 3) - - -def test_pyramid_rescale(): - img = np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8) - x = {'img': copy.deepcopy(img)} - f = transforms.PyramidRescale() - results = f(x) - assert results['img'].shape == (128, 100, 3) - - # Test invalid inputs - with pytest.raises(AssertionError): - transforms.PyramidRescale(base_shape=(128)) - with pytest.raises(AssertionError): - transforms.PyramidRescale(base_shape=128) - with pytest.raises(AssertionError): - transforms.PyramidRescale(factor=[]) - with pytest.raises(AssertionError): - transforms.PyramidRescale(randomize_factor=[]) - with pytest.raises(AssertionError): - f({}) - - # Test factor = 0 - f_derandomized = transforms.PyramidRescale( - factor=0, randomize_factor=False) - results = f_derandomized({'img': copy.deepcopy(img)}) - assert np.all(results['img'] == img) diff --git a/tests/test_datasets/test_pipelines/test_processing.py b/tests/test_datasets/test_pipelines/test_processing.py new file mode 100644 index 00000000..f39dad8b --- /dev/null +++ b/tests/test_datasets/test_pipelines/test_processing.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import unittest + +import numpy as np + +from mmocr.datasets.pipelines import PyramidRescale + + +class TestPyramidRescale(unittest.TestCase): + + def setUp(self): + self.data_info = dict(img=np.random.random((128, 100, 3))) + + def test_init(self): + # factor is int + transform = PyramidRescale(factor=4, randomize_factor=False) + self.assertEqual(transform.factor, 4) + # factor is float + with self.assertRaisesRegex(TypeError, + '`factor` should be an integer'): + PyramidRescale(factor=4.0) + # invalid base_shape + with self.assertRaisesRegex(TypeError, + '`base_shape` should be a list or tuple'): + PyramidRescale(base_shape=128) + with self.assertRaisesRegex( + ValueError, '`base_shape` should contain two integers'): + PyramidRescale(base_shape=(128, )) + with self.assertRaisesRegex( + ValueError, '`base_shape` should contain two integers'): + PyramidRescale(base_shape=(128.0, 2.0)) + # invalid randomize_factor + with self.assertRaisesRegex(TypeError, + '`randomize_factor` should be a bool'): + PyramidRescale(randomize_factor=None) + + def test_transform(self): + # test if the rescale keeps the original size + transform = PyramidRescale() + results = transform(copy.deepcopy(self.data_info)) + self.assertEqual(results['img'].shape, (128, 100, 3)) + # test factor = 0 + transform = PyramidRescale(factor=0, randomize_factor=False) + results = transform(copy.deepcopy(self.data_info)) + self.assertTrue(np.all(results['img'] == self.data_info['img'])) + + def test_repr(self): + transform = PyramidRescale( + factor=4, base_shape=(128, 512), randomize_factor=False) + print(repr(transform)) + self.assertEqual( + repr(transform), + ('PyramidRescale(factor = 4, randomize_factor = False, ' + 'base_w = 128, base_h = 512)'))