mirror of https://github.com/open-mmlab/mmocr.git
109 lines
3.6 KiB
Python
109 lines
3.6 KiB
Python
import mmcv
|
|
import numpy as np
|
|
|
|
from mmdet.datasets.builder import PIPELINES
|
|
from mmdet.datasets.pipelines.compose import Compose
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class MultiRotateAugOCR:
|
|
"""Test-time augmentation with multiple rotations in the case that
|
|
img_height > img_width.
|
|
|
|
An example configuration is as follows:
|
|
|
|
.. code-block::
|
|
|
|
rotate_degrees=[0, 90, 270],
|
|
transforms=[
|
|
dict(
|
|
type='ResizeOCR',
|
|
height=32,
|
|
min_width=32,
|
|
max_width=160,
|
|
keep_aspect_ratio=True),
|
|
dict(type='ToTensorOCR'),
|
|
dict(type='NormalizeOCR', **img_norm_cfg),
|
|
dict(
|
|
type='Collect',
|
|
keys=['img'],
|
|
meta_keys=[
|
|
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
|
]),
|
|
]
|
|
|
|
After MultiRotateAugOCR with above configuration, the results are wrapped
|
|
into lists of the same length as follows:
|
|
|
|
.. code-block::
|
|
|
|
dict(
|
|
img=[...],
|
|
img_shape=[...]
|
|
...
|
|
)
|
|
|
|
Args:
|
|
transforms (list[dict]): Transformation applied for each augmentation.
|
|
rotate_degrees (list[int] | None): Degrees of anti-clockwise rotation.
|
|
force_rotate (bool): If True, rotate image by 'rotate_degrees'
|
|
while ignore image aspect ratio.
|
|
"""
|
|
|
|
def __init__(self, transforms, rotate_degrees=None, force_rotate=False):
|
|
self.transforms = Compose(transforms)
|
|
self.force_rotate = force_rotate
|
|
if rotate_degrees is not None:
|
|
self.rotate_degrees = rotate_degrees if isinstance(
|
|
rotate_degrees, list) else [rotate_degrees]
|
|
assert mmcv.is_list_of(self.rotate_degrees, int)
|
|
for degree in self.rotate_degrees:
|
|
assert 0 <= degree < 360
|
|
assert degree % 90 == 0
|
|
if 0 not in self.rotate_degrees:
|
|
self.rotate_degrees.append(0)
|
|
else:
|
|
self.rotate_degrees = [0]
|
|
|
|
def __call__(self, results):
|
|
"""Call function to apply test time augment transformation to results.
|
|
|
|
Args:
|
|
results (dict): Result dict contains the data to be transformed.
|
|
|
|
Returns:
|
|
dict[str: list]: The augmented data, where each value is wrapped
|
|
into a list.
|
|
"""
|
|
img_shape = results['img_shape']
|
|
ori_height, ori_width = img_shape[:2]
|
|
if not self.force_rotate and ori_height <= ori_width:
|
|
rotate_degrees = [0]
|
|
else:
|
|
rotate_degrees = self.rotate_degrees
|
|
aug_data = []
|
|
for degree in set(rotate_degrees):
|
|
_results = results.copy()
|
|
if degree == 0:
|
|
pass
|
|
elif degree == 90:
|
|
_results['img'] = np.rot90(_results['img'], 1)
|
|
elif degree == 180:
|
|
_results['img'] = np.rot90(_results['img'], 2)
|
|
elif degree == 270:
|
|
_results['img'] = np.rot90(_results['img'], 3)
|
|
data = self.transforms(_results)
|
|
aug_data.append(data)
|
|
# list of dict to dict of list
|
|
aug_data_dict = {key: [] for key in aug_data[0]}
|
|
for data in aug_data:
|
|
for key, val in data.items():
|
|
aug_data_dict[key].append(val)
|
|
return aug_data_dict
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(transforms={self.transforms}, '
|
|
repr_str += f'rotate_degrees={self.rotate_degrees})'
|
|
return repr_str
|