mirror of https://github.com/open-mmlab/mmocr.git
Refactor PyramidRescale
parent
23458f8a47
commit
41c1671e7b
|
@ -9,14 +9,15 @@ from .ocr_seg_targets import OCRSegTargets
|
|||
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
|
||||
OpencvToPil, PilToOpencv, RandomPaddingOCR,
|
||||
RandomRotateImageBox, ResizeOCR, ToTensorOCR)
|
||||
from .processing import PyramidRescale
|
||||
from .test_time_aug import MultiRotateAugOCR
|
||||
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
|
||||
TextSnakeTargets)
|
||||
from .transform_wrappers import OneOfWrapper, RandomWrapper, TorchVisionWrapper
|
||||
from .transforms import (ColorJitter, PyramidRescale, RandomCropFlip,
|
||||
RandomCropInstances, RandomCropPolyInstances,
|
||||
RandomRotatePolyInstances, RandomRotateTextDet,
|
||||
RandomScaling, ScaleAspectJitter, SquareResizePad)
|
||||
from .transforms import (ColorJitter, RandomCropFlip, RandomCropInstances,
|
||||
RandomCropPolyInstances, RandomRotatePolyInstances,
|
||||
RandomRotateTextDet, RandomScaling, ScaleAspectJitter,
|
||||
SquareResizePad)
|
||||
|
||||
__all__ = [
|
||||
'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR',
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmcv.transforms.utils import cache_randomness
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class PyramidRescale(BaseTransform):
|
||||
"""Resize the image to the base shape, downsample it with gaussian pyramid,
|
||||
and rescale it back to original size.
|
||||
|
||||
Adapted from https://github.com/FangShancheng/ABINet.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img (ndarray)
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img (ndarray)
|
||||
|
||||
Args:
|
||||
factor (int): The decay factor from base size, or the number of
|
||||
downsampling operations from the base layer.
|
||||
base_shape (tuple(int)): The shape of the base layer of the pyramid.
|
||||
randomize_factor (bool): If True, the final factor would be a random
|
||||
integer in [0, factor].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
factor: int = 4,
|
||||
base_shape: Tuple[int, int] = (128, 512),
|
||||
randomize_factor: bool = True) -> None:
|
||||
if not isinstance(factor, int):
|
||||
raise TypeError('`factor` should be an integer, '
|
||||
f'but got {type(factor)} instead')
|
||||
if not isinstance(base_shape, (list, tuple)):
|
||||
raise TypeError('`base_shape` should be a list or tuple, '
|
||||
f'but got {type(base_shape)} instead')
|
||||
if not len(base_shape) == 2:
|
||||
raise ValueError('`base_shape` should contain two integers')
|
||||
if not isinstance(base_shape[0], int) or not isinstance(
|
||||
base_shape[1], int):
|
||||
raise ValueError('`base_shape` should contain two integers')
|
||||
if not isinstance(randomize_factor, bool):
|
||||
raise TypeError('`randomize_factor` should be a bool, '
|
||||
f'but got {type(randomize_factor)} instead')
|
||||
|
||||
self.factor = factor
|
||||
self.randomize_factor = randomize_factor
|
||||
self.base_w, self.base_h = base_shape
|
||||
|
||||
@cache_randomness
|
||||
def get_random_factor(self):
|
||||
return np.random.randint(0, self.factor + 1)
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Applying pyramid rescale on results.
|
||||
Args:
|
||||
results (dict): Result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: The transformed data
|
||||
"""
|
||||
|
||||
assert 'img' in results, '`img` is not found in results'
|
||||
if self.randomize_factor:
|
||||
self.factor = self.get_random_factor()
|
||||
if self.factor == 0:
|
||||
return results
|
||||
img = results['img']
|
||||
src_h, src_w = img.shape[:2]
|
||||
scale_img = mmcv.imresize(img, (self.base_w, self.base_h))
|
||||
for _ in range(self.factor):
|
||||
scale_img = cv2.pyrDown(scale_img)
|
||||
scale_img = mmcv.imresize(scale_img, (src_w, src_h))
|
||||
results['img'] = scale_img
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(factor = {self.factor}'
|
||||
repr_str += f', randomize_factor = {self.randomize_factor}'
|
||||
repr_str += f', base_w = {self.base_w}'
|
||||
repr_str += f', base_h = {self.base_h})'
|
||||
return repr_str
|
|
@ -967,54 +967,3 @@ class RandomCropFlip:
|
|||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
return h_axis, w_axis
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class PyramidRescale:
|
||||
"""Resize the image to the base shape, downsample it with gaussian pyramid,
|
||||
and rescale it back to original size.
|
||||
|
||||
Adapted from https://github.com/FangShancheng/ABINet.
|
||||
|
||||
Args:
|
||||
factor (int): The decay factor from base size, or the number of
|
||||
downsampling operations from the base layer.
|
||||
base_shape (tuple(int)): The shape of the base layer of the pyramid.
|
||||
randomize_factor (bool): If True, the final factor would be a random
|
||||
integer in [0, factor].
|
||||
|
||||
:Required Keys:
|
||||
- | ``img`` (ndarray): The input image.
|
||||
|
||||
:Affected Keys:
|
||||
:Modified:
|
||||
- | ``img`` (ndarray): The modified image.
|
||||
"""
|
||||
|
||||
def __init__(self, factor=4, base_shape=(128, 512), randomize_factor=True):
|
||||
assert isinstance(factor, int)
|
||||
assert isinstance(base_shape, list) or isinstance(base_shape, tuple)
|
||||
assert len(base_shape) == 2
|
||||
assert isinstance(randomize_factor, bool)
|
||||
self.factor = factor if not randomize_factor else np.random.randint(
|
||||
0, factor + 1)
|
||||
self.base_w, self.base_h = base_shape
|
||||
|
||||
def __call__(self, results):
|
||||
assert 'img' in results
|
||||
if self.factor == 0:
|
||||
return results
|
||||
img = results['img']
|
||||
src_h, src_w = img.shape[:2]
|
||||
scale_img = mmcv.imresize(img, (self.base_w, self.base_h))
|
||||
for _ in range(self.factor):
|
||||
scale_img = cv2.pyrDown(scale_img)
|
||||
scale_img = mmcv.imresize(scale_img, (src_w, src_h))
|
||||
results['img'] = scale_img
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(factor={self.factor}, '
|
||||
repr_str += f'basew={self.basew}, baseh={self.baseh})'
|
||||
return repr_str
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import unittest.mock as mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torchvision.transforms as TF
|
||||
from mmdet.core import BitmapMasks, PolygonMasks
|
||||
from PIL import Image
|
||||
|
@ -345,29 +343,3 @@ def test_square_resize_pad(mock_sample):
|
|||
target[1::2] *= 8. / 3
|
||||
assert np.allclose(output['gt_masks'].masks[0][0], target)
|
||||
assert output['img'].shape == (40, 40, 3)
|
||||
|
||||
|
||||
def test_pyramid_rescale():
|
||||
img = np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8)
|
||||
x = {'img': copy.deepcopy(img)}
|
||||
f = transforms.PyramidRescale()
|
||||
results = f(x)
|
||||
assert results['img'].shape == (128, 100, 3)
|
||||
|
||||
# Test invalid inputs
|
||||
with pytest.raises(AssertionError):
|
||||
transforms.PyramidRescale(base_shape=(128))
|
||||
with pytest.raises(AssertionError):
|
||||
transforms.PyramidRescale(base_shape=128)
|
||||
with pytest.raises(AssertionError):
|
||||
transforms.PyramidRescale(factor=[])
|
||||
with pytest.raises(AssertionError):
|
||||
transforms.PyramidRescale(randomize_factor=[])
|
||||
with pytest.raises(AssertionError):
|
||||
f({})
|
||||
|
||||
# Test factor = 0
|
||||
f_derandomized = transforms.PyramidRescale(
|
||||
factor=0, randomize_factor=False)
|
||||
results = f_derandomized({'img': copy.deepcopy(img)})
|
||||
assert np.all(results['img'] == img)
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.pipelines import PyramidRescale
|
||||
|
||||
|
||||
class TestPyramidRescale(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.data_info = dict(img=np.random.random((128, 100, 3)))
|
||||
|
||||
def test_init(self):
|
||||
# factor is int
|
||||
transform = PyramidRescale(factor=4, randomize_factor=False)
|
||||
self.assertEqual(transform.factor, 4)
|
||||
# factor is float
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'`factor` should be an integer'):
|
||||
PyramidRescale(factor=4.0)
|
||||
# invalid base_shape
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'`base_shape` should be a list or tuple'):
|
||||
PyramidRescale(base_shape=128)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, '`base_shape` should contain two integers'):
|
||||
PyramidRescale(base_shape=(128, ))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, '`base_shape` should contain two integers'):
|
||||
PyramidRescale(base_shape=(128.0, 2.0))
|
||||
# invalid randomize_factor
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'`randomize_factor` should be a bool'):
|
||||
PyramidRescale(randomize_factor=None)
|
||||
|
||||
def test_transform(self):
|
||||
# test if the rescale keeps the original size
|
||||
transform = PyramidRescale()
|
||||
results = transform(copy.deepcopy(self.data_info))
|
||||
self.assertEqual(results['img'].shape, (128, 100, 3))
|
||||
# test factor = 0
|
||||
transform = PyramidRescale(factor=0, randomize_factor=False)
|
||||
results = transform(copy.deepcopy(self.data_info))
|
||||
self.assertTrue(np.all(results['img'] == self.data_info['img']))
|
||||
|
||||
def test_repr(self):
|
||||
transform = PyramidRescale(
|
||||
factor=4, base_shape=(128, 512), randomize_factor=False)
|
||||
print(repr(transform))
|
||||
self.assertEqual(
|
||||
repr(transform),
|
||||
('PyramidRescale(factor = 4, randomize_factor = False, '
|
||||
'base_w = 128, base_h = 512)'))
|
Loading…
Reference in New Issue