[Transform] ScaleAspectJitter

pull/1178/head
liukuikun 2022-06-13 06:17:32 +00:00 committed by gaotongxiao
parent da175b44a4
commit dfe93dc7d2
3 changed files with 140 additions and 5 deletions

View File

@ -9,8 +9,9 @@ from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
OpencvToPil, PilToOpencv, RandomPaddingOCR,
RandomRotateImageBox, ResizeOCR, ToTensorOCR)
from .processing import (PadToWidth, PyramidRescale, RandomCrop, RandomRotate,
RescaleToHeight, Resize, SourceImagePad,
TextDetRandomCrop, TextDetRandomCropFlip)
RescaleToHeight, Resize, ShortScaleAspectJitter,
SourceImagePad, TextDetRandomCrop,
TextDetRandomCropFlip)
from .test_time_aug import MultiRotateAugOCR
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
TextSnakeTargets)
@ -26,5 +27,6 @@ __all__ = [
'sort_vertex8', 'FCENetTargets', 'TextDetRandomCropFlip', 'NerTransform',
'ToTensorNER', 'ResizeNoImg', 'PyramidRescale', 'TorchVisionWrapper',
'Resize', 'RandomCrop', 'TextDetRandomCrop', 'RandomCrop',
'PackTextDetInputs', 'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth'
'PackTextDetInputs', 'PackTextRecogInputs', 'RescaleToHeight',
'PadToWidth', 'ShortScaleAspectJitter'
]

View File

@ -1357,3 +1357,100 @@ class SourceImagePad(BaseTransform):
repr_str += f'(target_scale = {self.target_scale}, '
repr_str += f'crop_ratio = {self.crop_ratio})'
return repr_str
@TRANSFORMS.register_module()
@avoid_cache_randomness
class ShortScaleAspectJitter(BaseTransform):
"""First rescale the image for its shorter side to reach the short_size and
then jitter its aspect ratio, final rescale the shape guaranteed to be
divided by scale_divisor.
Required Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_polygons (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_polygons (optional)
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
short_size (int): Target shorter size before jittering the aspect
ratio. Defaults to 736.
short_size_jitter_range (tuple(float, float)): Range of the ratio used
to jitter the target shorter size. Defaults to (0.7, 1.3).
aspect_ratio_jitter_range (tuple(float, float)): Range of the ratio
used to jitter its aspect ratio. Defaults to (0.9, 1.1).
scale_divisor (int): The scale divisor. Defaults to 1.
resize_cfg (dict): (dict): Config to construct the Resize transform.
Refer to ``Resize`` for detail. Defaults to
``dict(type='Resize')``.
"""
def __init__(self,
short_size: int = 736,
ratio_range: Tuple[float, float] = (0.7, 1.3),
aspect_ratio_range: Tuple[float, float] = (0.9, 1.1),
scale_divisor: int = 1,
resize_cfg: Dict = dict(type='Resize')) -> None:
super().__init__()
self.short_size = short_size
self.ratio_range = ratio_range
self.aspect_ratio_range = aspect_ratio_range
self.resize_cfg = resize_cfg
# create a empty Reisize object
resize_cfg.update(dict(scale=0))
self.resize = TRANSFORMS.build(resize_cfg)
self.scale_divisor = scale_divisor
def _sample_from_range(self, range: Tuple[float, float]) -> float:
"""A ratio will be randomly sampled from the range specified by
``range``.
Args:
ratio_range (tuple[float]): The minimum and maximum ratio.
Returns:
float: A ratio randomly sampled from the range.
"""
min_value, max_value = min(range), max(range)
value = np.random.random_sample() * (max_value - min_value) + min_value
return value
def transform(self, results: Dict) -> Dict:
h, w = results['img'].shape[:2]
ratio = self._sample_from_range(self.ratio_range)
scale = (ratio * self.short_size) / min(h, w)
aspect = self._sample_from_range(self.aspect_ratio_range)
h_scale = scale * math.sqrt(aspect)
w_scale = scale / math.sqrt(aspect)
new_h = round(h * h_scale)
new_w = round(w * w_scale)
new_h = math.ceil(new_h / self.scale_divisor) * self.scale_divisor
new_w = math.ceil(new_w / self.scale_divisor) * self.scale_divisor
self.resize.scale = (new_w, new_h)
return self.resize(results)
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(short_size = {self.short_size}, '
repr_str += f'ratio_range = {self.ratio_range}, '
repr_str += f'aspect_ratio_range = {self.aspect_ratio_range}, '
repr_str += f'scale_divisor = {self.scale_divisor}, '
repr_str += f'resize_cfg = {self.resize_cfg})'
return repr_str

View File

@ -8,8 +8,8 @@ from mmcv.transforms import Pad, RandomResize
from mmocr.datasets.pipelines import (PadToWidth, PyramidRescale, RandomCrop,
RandomRotate, RescaleToHeight, Resize,
SourceImagePad, TextDetRandomCrop,
TextDetRandomCropFlip)
ShortScaleAspectJitter, SourceImagePad,
TextDetRandomCrop, TextDetRandomCropFlip)
from mmocr.utils import bbox2poly, poly2shapely
@ -631,3 +631,39 @@ class TestSourceImagePad(unittest.TestCase):
repr(transform),
('SourceImagePad(target_scale = (30, 30), crop_ratio = (0.1, 0.1))'
))
class TestShortScaleAspectJitter(unittest.TestCase):
@mock.patch('mmocr.datasets.pipelines.processing.np.random.random_sample')
def test_transform(self, mock_random):
ratio_range = (0.5, 1.5)
aspect_ratio_range = (0.9, 1.1)
mock_random.side_effect = [0.5, 0.5]
img = np.zeros((15, 20, 3))
polygon = [np.array([10., 5., 20., 5., 20., 10., 10., 10.])]
bbox = np.array([[10., 5., 20., 10.]])
data_info = dict(img=img, gt_polygons=polygon, gt_bboxes=bbox)
t = ShortScaleAspectJitter(
short_size=40,
ratio_range=ratio_range,
aspect_ratio_range=aspect_ratio_range,
scale_divisor=4)
results = t(data_info)
self.assertEqual(results['img'].shape, (40, 56, 3))
self.assertEqual(results['img_shape'], (40, 56))
def test_repr(self):
transform = ShortScaleAspectJitter(
short_size=40,
ratio_range=(0.5, 1.5),
aspect_ratio_range=(0.9, 1.1),
scale_divisor=4,
resize_cfg=dict(type='Resize'))
self.assertEqual(
repr(transform), ('ShortScaleAspectJitter('
'short_size = 40, '
'ratio_range = (0.5, 1.5), '
'aspect_ratio_range = (0.9, 1.1), '
'scale_divisor = 4, '
"resize_cfg = {'type': 'Resize', 'scale': 0})"))