[Transforms] SVTR transforms (#1645)

* rec transforms

* fix

* ut

* update docs

* fix

* new name

* fix
pull/1663/head
Tong Gao 2023-01-06 16:04:20 +08:00 committed by GitHub
parent d679691a02
commit b0557c2c55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 576 additions and 7 deletions

View File

@ -34,7 +34,6 @@ TextDet Transforms
:template: classtemplate.rst :template: classtemplate.rst
BoundedScaleAspectJitter BoundedScaleAspectJitter
FixInvalidPolygon
RandomFlip RandomFlip
SourceImagePad SourceImagePad
ShortScaleAspectJitter ShortScaleAspectJitter
@ -50,6 +49,10 @@ TextRecog Transforms
:nosignatures: :nosignatures:
:template: classtemplate.rst :template: classtemplate.rst
TextRecogGeneralAug
CropHeight
ImageContentJitter
ReversePixels
PyramidRescale PyramidRescale
PadToWidth PadToWidth
RescaleToHeight RescaleToHeight
@ -66,6 +69,8 @@ OCR Transforms
RandomCrop RandomCrop
RandomRotate RandomRotate
Resize Resize
FixInvalidPolygon
RemoveIgnored

View File

@ -34,7 +34,6 @@ TextDet Transforms
:template: classtemplate.rst :template: classtemplate.rst
BoundedScaleAspectJitter BoundedScaleAspectJitter
FixInvalidPolygon
RandomFlip RandomFlip
SourceImagePad SourceImagePad
ShortScaleAspectJitter ShortScaleAspectJitter
@ -50,6 +49,10 @@ TextRecog Transforms
:nosignatures: :nosignatures:
:template: classtemplate.rst :template: classtemplate.rst
TextRecogGeneralAug
CropHeight
ImageContentJitter
ReversePixels
PyramidRescale PyramidRescale
PadToWidth PadToWidth
RescaleToHeight RescaleToHeight
@ -66,6 +69,8 @@ OCR Transforms
RandomCrop RandomCrop
RandomRotate RandomRotate
Resize Resize
FixInvalidPolygon
RemoveIgnored

View File

@ -9,7 +9,9 @@ from .ocr_transforms import (FixInvalidPolygon, RandomCrop, RandomRotate,
from .textdet_transforms import (BoundedScaleAspectJitter, RandomFlip, from .textdet_transforms import (BoundedScaleAspectJitter, RandomFlip,
ShortScaleAspectJitter, SourceImagePad, ShortScaleAspectJitter, SourceImagePad,
TextDetRandomCrop, TextDetRandomCropFlip) TextDetRandomCrop, TextDetRandomCropFlip)
from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight from .textrecog_transforms import (CropHeight, ImageContentJitter, PadToWidth,
PyramidRescale, RescaleToHeight,
ReversePixels, TextRecogGeneralAug)
from .wrappers import ConditionApply, ImgAugWrapper, TorchVisionWrapper from .wrappers import ConditionApply, ImgAugWrapper, TorchVisionWrapper
__all__ = [ __all__ = [
@ -20,5 +22,6 @@ __all__ = [
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter', 'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR', 'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile', 'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile',
'LoadImageFromNDArray', 'RemoveIgnored', 'ConditionApply' 'LoadImageFromNDArray', 'CropHeight', 'TextRecogGeneralAug',
'ImageContentJitter', 'ReversePixels', 'RemoveIgnored', 'ConditionApply'
] ]

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
from typing import Dict, Optional, Tuple import random
from typing import Dict, List, Optional, Tuple
import cv2 import cv2
import mmcv import mmcv
@ -251,3 +252,473 @@ class PadToWidth(BaseTransform):
repr_str += f'(width={self.width}, ' repr_str += f'(width={self.width}, '
repr_str += f'pad_cfg={self.pad_cfg})' repr_str += f'pad_cfg={self.pad_cfg})'
return repr_str return repr_str
@TRANSFORMS.register_module()
class TextRecogGeneralAug(BaseTransform):
"""A general geometric augmentation tool for text images in the CVPR 2020
paper "Learn to Augment: Joint Data Augmentation and Network Optimization
for Text Recognition". It applies distortion, stretching, and perspective
transforms to an image.
This implementation is adapted from
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py # noqa
TODO: Split this transform into three transforms.
Required Keys:
- img
Modified Keys:
- img
- img_shape
""" # noqa
def transform(self, results: Dict) -> Dict:
"""Call function to pad images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
h, w = results['img'].shape[:2]
if h >= 20 and w >= 20:
results['img'] = self.tia_distort(results['img'],
random.randint(3, 6))
results['img'] = self.tia_stretch(results['img'],
random.randint(3, 6))
h, w = results['img'].shape[:2]
if h >= 5 and w >= 5:
results['img'] = self.tia_perspective(results['img'])
results['img_shape'] = results['img'].shape[:2]
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += '()'
return repr_str
def tia_distort(self, img: np.ndarray, segment: int = 4) -> np.ndarray:
"""Image distortion.
Args:
img (np.ndarray): The image.
segment (int): The number of segments to divide the image along
the width. Defaults to 4.
"""
img_h, img_w = img.shape[:2]
cut = img_w // segment
thresh = cut // 3
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh),
np.random.randint(thresh)])
dst_pts.append([
img_w - np.random.randint(thresh),
img_h - np.random.randint(thresh)
])
dst_pts.append(
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
np.random.randint(thresh) - half_thresh
])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
img_h + np.random.randint(thresh) - half_thresh
])
dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h)
return dst
def tia_stretch(self, img: np.ndarray, segment: int = 4) -> np.ndarray:
"""Image stretching.
Args:
img (np.ndarray): The image.
segment (int): The number of segments to divide the image along
the width. Defaults to 4.
"""
img_h, img_w = img.shape[:2]
cut = img_w // segment
thresh = cut * 4 // 5
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, 0])
dst_pts.append([img_w, 0])
dst_pts.append([img_w, img_h])
dst_pts.append([0, img_h])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
move = np.random.randint(thresh) - half_thresh
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([cut * cut_idx + move, 0])
dst_pts.append([cut * cut_idx + move, img_h])
dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h)
return dst
def tia_perspective(self, img: np.ndarray) -> np.ndarray:
"""Image perspective transformation.
Args:
img (np.ndarray): The image.
segment (int): The number of segments to divide the image along
the width. Defaults to 4.
"""
img_h, img_w = img.shape[:2]
thresh = img_h // 2
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, np.random.randint(thresh)])
dst_pts.append([img_w, np.random.randint(thresh)])
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
dst_pts.append([0, img_h - np.random.randint(thresh)])
dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h)
return dst
def warp_mls(self,
src: np.ndarray,
src_pts: List[int],
dst_pts: List[int],
dst_w: int,
dst_h: int,
trans_ratio: float = 1.) -> np.ndarray:
"""Warp the image."""
rdx, rdy = self._calc_delta(dst_w, dst_h, src_pts, dst_pts, 100)
return self._gen_img(src, rdx, rdy, dst_w, dst_h, 100, trans_ratio)
def _calc_delta(self, dst_w: int, dst_h: int, src_pts: List[int],
dst_pts: List[int],
grid_size: int) -> Tuple[np.ndarray, np.ndarray]:
"""Compute delta."""
pt_count = len(dst_pts)
rdx = np.zeros((dst_h, dst_w))
rdy = np.zeros((dst_h, dst_w))
w = np.zeros(pt_count, dtype=np.float32)
if pt_count < 2:
return
i = 0
while True:
if dst_w <= i < dst_w + grid_size - 1:
i = dst_w - 1
elif i >= dst_w:
break
j = 0
while True:
if dst_h <= j < dst_h + grid_size - 1:
j = dst_h - 1
elif j >= dst_h:
break
sw = 0
swp = np.zeros(2, dtype=np.float32)
swq = np.zeros(2, dtype=np.float32)
new_pt = np.zeros(2, dtype=np.float32)
cur_pt = np.array([i, j], dtype=np.float32)
k = 0
for k in range(pt_count):
if i == dst_pts[k][0] and j == dst_pts[k][1]:
break
w[k] = 1. / ((i - dst_pts[k][0]) * (i - dst_pts[k][0]) +
(j - dst_pts[k][1]) * (j - dst_pts[k][1]))
sw += w[k]
swp = swp + w[k] * np.array(dst_pts[k])
swq = swq + w[k] * np.array(src_pts[k])
if k == pt_count - 1:
pstar = 1 / sw * swp
qstar = 1 / sw * swq
miu_s = 0
for k in range(pt_count):
if i == dst_pts[k][0] and j == dst_pts[k][1]:
continue
pt_i = dst_pts[k] - pstar
miu_s += w[k] * np.sum(pt_i * pt_i)
cur_pt -= pstar
cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
for k in range(pt_count):
if i == dst_pts[k][0] and j == dst_pts[k][1]:
continue
pt_i = dst_pts[k] - pstar
pt_j = np.array([-pt_i[1], pt_i[0]])
tmp_pt = np.zeros(2, dtype=np.float32)
tmp_pt[0] = (
np.sum(pt_i * cur_pt) * src_pts[k][0] -
np.sum(pt_j * cur_pt) * src_pts[k][1])
tmp_pt[1] = (-np.sum(pt_i * cur_pt_j) * src_pts[k][0] +
np.sum(pt_j * cur_pt_j) * src_pts[k][1])
tmp_pt *= (w[k] / miu_s)
new_pt += tmp_pt
new_pt += qstar
else:
new_pt = src_pts[k]
rdx[j, i] = new_pt[0] - i
rdy[j, i] = new_pt[1] - j
j += grid_size
i += grid_size
return rdx, rdy
def _gen_img(self, src: np.ndarray, rdx: np.ndarray, rdy: np.ndarray,
dst_w: int, dst_h: int, grid_size: int,
trans_ratio: float) -> np.ndarray:
"""Generate the image based on delta."""
src_h, src_w = src.shape[:2]
dst = np.zeros_like(src, dtype=np.float32)
for i in np.arange(0, dst_h, grid_size):
for j in np.arange(0, dst_w, grid_size):
ni = i + grid_size
nj = j + grid_size
w = h = grid_size
if ni >= dst_h:
ni = dst_h - 1
h = ni - i + 1
if nj >= dst_w:
nj = dst_w - 1
w = nj - j + 1
di = np.reshape(np.arange(h), (-1, 1))
dj = np.reshape(np.arange(w), (1, -1))
delta_x = self._bilinear_interp(di / h, dj / w, rdx[i, j],
rdx[i, nj], rdx[ni, j],
rdx[ni, nj])
delta_y = self._bilinear_interp(di / h, dj / w, rdy[i, j],
rdy[i, nj], rdy[ni, j],
rdy[ni, nj])
nx = j + dj + delta_x * trans_ratio
ny = i + di + delta_y * trans_ratio
nx = np.clip(nx, 0, src_w - 1)
ny = np.clip(ny, 0, src_h - 1)
nxi = np.array(np.floor(nx), dtype=np.int32)
nyi = np.array(np.floor(ny), dtype=np.int32)
nxi1 = np.array(np.ceil(nx), dtype=np.int32)
nyi1 = np.array(np.ceil(ny), dtype=np.int32)
if len(src.shape) == 3:
x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
else:
x = ny - nyi
y = nx - nxi
dst[i:i + h,
j:j + w] = self._bilinear_interp(x, y, src[nyi, nxi],
src[nyi, nxi1],
src[nyi1, nxi], src[nyi1,
nxi1])
dst = np.clip(dst, 0, 255)
dst = np.array(dst, dtype=np.uint8)
return dst
@staticmethod
def _bilinear_interp(x, y, v11, v12, v21, v22):
"""Bilinear interpolation.
TODO: Docs for args and put it into utils.
"""
return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
(1 - y) + v22 * y) * x
@TRANSFORMS.register_module()
class CropHeight(BaseTransform):
"""Randomly crop the image's height, either from top or bottom.
Adapted from
https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa
Required Keys:
- img
Modified Keys:
- img
- img_shape
Args:
crop_min (int): Minimum pixel(s) to crop. Defaults to 1.
crop_max (int): Maximum pixel(s) to crop. Defaults to 8.
"""
def __init__(
self,
min_pixels: int = 1,
max_pixels: int = 8,
) -> None:
super().__init__()
assert max_pixels >= min_pixels
self.min_pixels = min_pixels
self.max_pixels = max_pixels
@cache_randomness
def get_random_vars(self):
"""Get all the random values used in this transform."""
crop_pixels = int(random.randint(self.min_pixels, self.max_pixels))
crop_top = random.randint(0, 1)
return crop_pixels, crop_top
def transform(self, results: Dict) -> Dict:
"""Transform function to crop images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Cropped results.
"""
h = results['img'].shape[0]
crop_pixels, crop_top = self.get_random_vars()
crop_pixels = min(crop_pixels, h - 1)
img = results['img'].copy()
if crop_top:
img = img[crop_pixels:h, :, :]
else:
img = img[0:h - crop_pixels, :, :]
results['img_shape'] = img.shape[:2]
results['img'] = img
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(min_pixels = {self.min_pixels}, '
repr_str += f'max_pixels = {self.max_pixels})'
return repr_str
@TRANSFORMS.register_module()
class ImageContentJitter(BaseTransform):
"""Jitter the image contents.
Adapted from
https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa
Required Keys:
- img
Modified Keys:
- img
"""
def transform(self, results: Dict, jitter_ratio: float = 0.01) -> Dict:
"""Transform function to jitter images.
Args:
results (dict): Result dict from loading pipeline.
jitter_ratio (float): Controls the strength of jittering.
Defaults to 0.01.
Returns:
dict: Jittered results.
"""
h, w = results['img'].shape[:2]
img = results['img'].copy()
if h > 10 and w > 10:
thres = min(h, w)
jitter_range = int(random.random() * thres * 0.01)
for i in range(jitter_range):
img[i:, i:, :] = img[:h - i, :w - i, :]
results['img'] = img
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += '()'
return repr_str
@TRANSFORMS.register_module()
class ReversePixels(BaseTransform):
"""Reverse image pixels.
Adapted from
https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa
Required Keys:
- img
Modified Keys:
- img
"""
def transform(self, results: Dict) -> Dict:
"""Transform function to reverse image pixels.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Reversed results.
"""
results['img'] = 255. - results['img'].copy()
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += '()'
return repr_str

View File

@ -3,9 +3,12 @@ import copy
import unittest import unittest
import numpy as np import numpy as np
from parameterized import parameterized
from mmocr.datasets.transforms import (PadToWidth, PyramidRescale, from mmocr.datasets.transforms import (CropHeight, ImageContentJitter,
RescaleToHeight) PadToWidth, PyramidRescale,
RescaleToHeight, ReversePixels,
TextRecogGeneralAug)
class TestPadToWidth(unittest.TestCase): class TestPadToWidth(unittest.TestCase):
@ -125,3 +128,85 @@ class TestRescaleToHeight(unittest.TestCase):
'min_width=None, max_width=None, ' 'min_width=None, max_width=None, '
'width_divisor=1, ' 'width_divisor=1, '
"resize_cfg={'type': 'Resize', 'scale': 0})")) "resize_cfg={'type': 'Resize', 'scale': 0})"))
class TestTextRecogGeneralAug(unittest.TestCase):
def setUp(self) -> None:
self.transform = TextRecogGeneralAug()
@parameterized.expand([(np.random.random((3, 3, 3)), ),
(np.random.random((10, 10, 3)), ),
(np.random.random((30, 30, 3)), )])
def test_transform(self, img):
data_info = dict(img=img)
results = self.transform(copy.deepcopy(data_info))
self.assertEqual(results['img'].shape[:2], results['img_shape'])
def test_repr(self):
repr_str = self.transform.__repr__()
self.assertEqual(repr_str, 'TextRecogGeneralAug()')
class TestCropHeight(unittest.TestCase):
def setUp(self) -> None:
self.data_info = dict(img=np.random.random((20, 20, 3)))
@parameterized.expand([
(3, 3),
(5, 10),
])
def test_transform(self, min_pixels, max_pixels):
self.transform = CropHeight(
min_pixels=min_pixels, max_pixels=max_pixels)
results = self.transform(copy.deepcopy(self.data_info))
self.assertEqual(results['img'].shape[:2], results['img_shape'])
h_diff = self.data_info['img'].shape[0] - results['img_shape'][0]
self.assertGreaterEqual(h_diff, min_pixels)
self.assertLessEqual(h_diff, max_pixels)
def test_invalid(self):
with self.assertRaises(AssertionError):
self.transform = CropHeight(min_pixels=10, max_pixels=9)
def test_repr(self):
transform = CropHeight(min_pixels=2, max_pixels=10)
repr_str = transform.__repr__()
self.assertEqual(repr_str, 'CropHeight(min_pixels = 2, '
'max_pixels = 10)')
class TestImageContentJitter(unittest.TestCase):
def setUp(self) -> None:
self.transform = ImageContentJitter()
@parameterized.expand([(np.random.random((3, 3, 3)), ),
(np.random.random((10, 10, 3)), ),
(np.random.random((30, 30, 3)), )])
def test_transform(self, img):
data_info = dict(img=img)
self.transform(copy.deepcopy(data_info))
def test_repr(self):
repr_str = self.transform.__repr__()
self.assertEqual(repr_str, 'ImageContentJitter()')
class TestReversePixels(unittest.TestCase):
def setUp(self) -> None:
self.transform = ReversePixels()
@parameterized.expand([(np.random.random((3, 3, 3)), ),
(np.random.random((10, 10, 3)), ),
(np.random.random((30, 30, 3)), )])
def test_transform(self, img):
data_info = dict(img=img)
results = self.transform(copy.deepcopy(data_info))
self.assertTrue(np.array_equal(results['img'], 255. - img))
def test_repr(self):
repr_str = self.transform.__repr__()
self.assertEqual(repr_str, 'ReversePixels()')