Refactor PyramidRescale

pull/1178/head
wangxinyu.vendor 2022-05-12 03:13:09 +00:00 committed by gaotongxiao
parent 23458f8a47
commit 41c1671e7b
5 changed files with 152 additions and 83 deletions

View File

@ -9,14 +9,15 @@ from .ocr_seg_targets import OCRSegTargets
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
OpencvToPil, PilToOpencv, RandomPaddingOCR,
RandomRotateImageBox, ResizeOCR, ToTensorOCR)
from .processing import PyramidRescale
from .test_time_aug import MultiRotateAugOCR
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
TextSnakeTargets)
from .transform_wrappers import OneOfWrapper, RandomWrapper, TorchVisionWrapper
from .transforms import (ColorJitter, PyramidRescale, RandomCropFlip,
RandomCropInstances, RandomCropPolyInstances,
RandomRotatePolyInstances, RandomRotateTextDet,
RandomScaling, ScaleAspectJitter, SquareResizePad)
from .transforms import (ColorJitter, RandomCropFlip, RandomCropInstances,
RandomCropPolyInstances, RandomRotatePolyInstances,
RandomRotateTextDet, RandomScaling, ScaleAspectJitter,
SquareResizePad)
__all__ = [
'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR',

View File

@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Tuple
import cv2
import mmcv
import numpy as np
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmocr.registry import TRANSFORMS
@TRANSFORMS.register_module()
class PyramidRescale(BaseTransform):
"""Resize the image to the base shape, downsample it with gaussian pyramid,
and rescale it back to original size.
Adapted from https://github.com/FangShancheng/ABINet.
Required Keys:
- img (ndarray)
Modified Keys:
- img (ndarray)
Args:
factor (int): The decay factor from base size, or the number of
downsampling operations from the base layer.
base_shape (tuple(int)): The shape of the base layer of the pyramid.
randomize_factor (bool): If True, the final factor would be a random
integer in [0, factor].
"""
def __init__(self,
factor: int = 4,
base_shape: Tuple[int, int] = (128, 512),
randomize_factor: bool = True) -> None:
if not isinstance(factor, int):
raise TypeError('`factor` should be an integer, '
f'but got {type(factor)} instead')
if not isinstance(base_shape, (list, tuple)):
raise TypeError('`base_shape` should be a list or tuple, '
f'but got {type(base_shape)} instead')
if not len(base_shape) == 2:
raise ValueError('`base_shape` should contain two integers')
if not isinstance(base_shape[0], int) or not isinstance(
base_shape[1], int):
raise ValueError('`base_shape` should contain two integers')
if not isinstance(randomize_factor, bool):
raise TypeError('`randomize_factor` should be a bool, '
f'but got {type(randomize_factor)} instead')
self.factor = factor
self.randomize_factor = randomize_factor
self.base_w, self.base_h = base_shape
@cache_randomness
def get_random_factor(self):
return np.random.randint(0, self.factor + 1)
def transform(self, results: Dict) -> Dict:
"""Applying pyramid rescale on results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict: The transformed data
"""
assert 'img' in results, '`img` is not found in results'
if self.randomize_factor:
self.factor = self.get_random_factor()
if self.factor == 0:
return results
img = results['img']
src_h, src_w = img.shape[:2]
scale_img = mmcv.imresize(img, (self.base_w, self.base_h))
for _ in range(self.factor):
scale_img = cv2.pyrDown(scale_img)
scale_img = mmcv.imresize(scale_img, (src_w, src_h))
results['img'] = scale_img
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(factor = {self.factor}'
repr_str += f', randomize_factor = {self.randomize_factor}'
repr_str += f', base_w = {self.base_w}'
repr_str += f', base_h = {self.base_h})'
return repr_str

View File

@ -967,54 +967,3 @@ class RandomCropFlip:
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
return h_axis, w_axis
@TRANSFORMS.register_module()
class PyramidRescale:
"""Resize the image to the base shape, downsample it with gaussian pyramid,
and rescale it back to original size.
Adapted from https://github.com/FangShancheng/ABINet.
Args:
factor (int): The decay factor from base size, or the number of
downsampling operations from the base layer.
base_shape (tuple(int)): The shape of the base layer of the pyramid.
randomize_factor (bool): If True, the final factor would be a random
integer in [0, factor].
:Required Keys:
- | ``img`` (ndarray): The input image.
:Affected Keys:
:Modified:
- | ``img`` (ndarray): The modified image.
"""
def __init__(self, factor=4, base_shape=(128, 512), randomize_factor=True):
assert isinstance(factor, int)
assert isinstance(base_shape, list) or isinstance(base_shape, tuple)
assert len(base_shape) == 2
assert isinstance(randomize_factor, bool)
self.factor = factor if not randomize_factor else np.random.randint(
0, factor + 1)
self.base_w, self.base_h = base_shape
def __call__(self, results):
assert 'img' in results
if self.factor == 0:
return results
img = results['img']
src_h, src_w = img.shape[:2]
scale_img = mmcv.imresize(img, (self.base_w, self.base_h))
for _ in range(self.factor):
scale_img = cv2.pyrDown(scale_img)
scale_img = mmcv.imresize(scale_img, (src_w, src_h))
results['img'] = scale_img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(factor={self.factor}, '
repr_str += f'basew={self.basew}, baseh={self.baseh})'
return repr_str

View File

@ -1,9 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import unittest.mock as mock
import numpy as np
import pytest
import torchvision.transforms as TF
from mmdet.core import BitmapMasks, PolygonMasks
from PIL import Image
@ -345,29 +343,3 @@ def test_square_resize_pad(mock_sample):
target[1::2] *= 8. / 3
assert np.allclose(output['gt_masks'].masks[0][0], target)
assert output['img'].shape == (40, 40, 3)
def test_pyramid_rescale():
img = np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8)
x = {'img': copy.deepcopy(img)}
f = transforms.PyramidRescale()
results = f(x)
assert results['img'].shape == (128, 100, 3)
# Test invalid inputs
with pytest.raises(AssertionError):
transforms.PyramidRescale(base_shape=(128))
with pytest.raises(AssertionError):
transforms.PyramidRescale(base_shape=128)
with pytest.raises(AssertionError):
transforms.PyramidRescale(factor=[])
with pytest.raises(AssertionError):
transforms.PyramidRescale(randomize_factor=[])
with pytest.raises(AssertionError):
f({})
# Test factor = 0
f_derandomized = transforms.PyramidRescale(
factor=0, randomize_factor=False)
results = f_derandomized({'img': copy.deepcopy(img)})
assert np.all(results['img'] == img)

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import unittest
import numpy as np
from mmocr.datasets.pipelines import PyramidRescale
class TestPyramidRescale(unittest.TestCase):
def setUp(self):
self.data_info = dict(img=np.random.random((128, 100, 3)))
def test_init(self):
# factor is int
transform = PyramidRescale(factor=4, randomize_factor=False)
self.assertEqual(transform.factor, 4)
# factor is float
with self.assertRaisesRegex(TypeError,
'`factor` should be an integer'):
PyramidRescale(factor=4.0)
# invalid base_shape
with self.assertRaisesRegex(TypeError,
'`base_shape` should be a list or tuple'):
PyramidRescale(base_shape=128)
with self.assertRaisesRegex(
ValueError, '`base_shape` should contain two integers'):
PyramidRescale(base_shape=(128, ))
with self.assertRaisesRegex(
ValueError, '`base_shape` should contain two integers'):
PyramidRescale(base_shape=(128.0, 2.0))
# invalid randomize_factor
with self.assertRaisesRegex(TypeError,
'`randomize_factor` should be a bool'):
PyramidRescale(randomize_factor=None)
def test_transform(self):
# test if the rescale keeps the original size
transform = PyramidRescale()
results = transform(copy.deepcopy(self.data_info))
self.assertEqual(results['img'].shape, (128, 100, 3))
# test factor = 0
transform = PyramidRescale(factor=0, randomize_factor=False)
results = transform(copy.deepcopy(self.data_info))
self.assertTrue(np.all(results['img'] == self.data_info['img']))
def test_repr(self):
transform = PyramidRescale(
factor=4, base_shape=(128, 512), randomize_factor=False)
print(repr(transform))
self.assertEqual(
repr(transform),
('PyramidRescale(factor = 4, randomize_factor = False, '
'base_w = 128, base_h = 512)'))