[Transforms] SVTR transforms (#1645)

* rec transforms

* fix

* ut

* update docs

* fix

* new name

* fix
pull/1663/head
Tong Gao 2023-01-06 16:04:20 +08:00 committed by GitHub
parent d679691a02
commit b0557c2c55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 576 additions and 7 deletions

View File

@ -34,7 +34,6 @@ TextDet Transforms
:template: classtemplate.rst
BoundedScaleAspectJitter
FixInvalidPolygon
RandomFlip
SourceImagePad
ShortScaleAspectJitter
@ -50,6 +49,10 @@ TextRecog Transforms
:nosignatures:
:template: classtemplate.rst
TextRecogGeneralAug
CropHeight
ImageContentJitter
ReversePixels
PyramidRescale
PadToWidth
RescaleToHeight
@ -66,6 +69,8 @@ OCR Transforms
RandomCrop
RandomRotate
Resize
FixInvalidPolygon
RemoveIgnored

View File

@ -34,7 +34,6 @@ TextDet Transforms
:template: classtemplate.rst
BoundedScaleAspectJitter
FixInvalidPolygon
RandomFlip
SourceImagePad
ShortScaleAspectJitter
@ -50,6 +49,10 @@ TextRecog Transforms
:nosignatures:
:template: classtemplate.rst
TextRecogGeneralAug
CropHeight
ImageContentJitter
ReversePixels
PyramidRescale
PadToWidth
RescaleToHeight
@ -66,6 +69,8 @@ OCR Transforms
RandomCrop
RandomRotate
Resize
FixInvalidPolygon
RemoveIgnored

View File

@ -9,7 +9,9 @@ from .ocr_transforms import (FixInvalidPolygon, RandomCrop, RandomRotate,
from .textdet_transforms import (BoundedScaleAspectJitter, RandomFlip,
ShortScaleAspectJitter, SourceImagePad,
TextDetRandomCrop, TextDetRandomCropFlip)
from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight
from .textrecog_transforms import (CropHeight, ImageContentJitter, PadToWidth,
PyramidRescale, RescaleToHeight,
ReversePixels, TextRecogGeneralAug)
from .wrappers import ConditionApply, ImgAugWrapper, TorchVisionWrapper
__all__ = [
@ -20,5 +22,6 @@ __all__ = [
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile',
'LoadImageFromNDArray', 'RemoveIgnored', 'ConditionApply'
'LoadImageFromNDArray', 'CropHeight', 'TextRecogGeneralAug',
'ImageContentJitter', 'ReversePixels', 'RemoveIgnored', 'ConditionApply'
]

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Optional, Tuple
import random
from typing import Dict, List, Optional, Tuple
import cv2
import mmcv
@ -251,3 +252,473 @@ class PadToWidth(BaseTransform):
repr_str += f'(width={self.width}, '
repr_str += f'pad_cfg={self.pad_cfg})'
return repr_str
@TRANSFORMS.register_module()
class TextRecogGeneralAug(BaseTransform):
"""A general geometric augmentation tool for text images in the CVPR 2020
paper "Learn to Augment: Joint Data Augmentation and Network Optimization
for Text Recognition". It applies distortion, stretching, and perspective
transforms to an image.
This implementation is adapted from
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py # noqa
TODO: Split this transform into three transforms.
Required Keys:
- img
Modified Keys:
- img
- img_shape
""" # noqa
def transform(self, results: Dict) -> Dict:
"""Call function to pad images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
h, w = results['img'].shape[:2]
if h >= 20 and w >= 20:
results['img'] = self.tia_distort(results['img'],
random.randint(3, 6))
results['img'] = self.tia_stretch(results['img'],
random.randint(3, 6))
h, w = results['img'].shape[:2]
if h >= 5 and w >= 5:
results['img'] = self.tia_perspective(results['img'])
results['img_shape'] = results['img'].shape[:2]
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += '()'
return repr_str
def tia_distort(self, img: np.ndarray, segment: int = 4) -> np.ndarray:
"""Image distortion.
Args:
img (np.ndarray): The image.
segment (int): The number of segments to divide the image along
the width. Defaults to 4.
"""
img_h, img_w = img.shape[:2]
cut = img_w // segment
thresh = cut // 3
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh),
np.random.randint(thresh)])
dst_pts.append([
img_w - np.random.randint(thresh),
img_h - np.random.randint(thresh)
])
dst_pts.append(
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
np.random.randint(thresh) - half_thresh
])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
img_h + np.random.randint(thresh) - half_thresh
])
dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h)
return dst
def tia_stretch(self, img: np.ndarray, segment: int = 4) -> np.ndarray:
"""Image stretching.
Args:
img (np.ndarray): The image.
segment (int): The number of segments to divide the image along
the width. Defaults to 4.
"""
img_h, img_w = img.shape[:2]
cut = img_w // segment
thresh = cut * 4 // 5
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, 0])
dst_pts.append([img_w, 0])
dst_pts.append([img_w, img_h])
dst_pts.append([0, img_h])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
move = np.random.randint(thresh) - half_thresh
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([cut * cut_idx + move, 0])
dst_pts.append([cut * cut_idx + move, img_h])
dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h)
return dst
def tia_perspective(self, img: np.ndarray) -> np.ndarray:
"""Image perspective transformation.
Args:
img (np.ndarray): The image.
segment (int): The number of segments to divide the image along
the width. Defaults to 4.
"""
img_h, img_w = img.shape[:2]
thresh = img_h // 2
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, np.random.randint(thresh)])
dst_pts.append([img_w, np.random.randint(thresh)])
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
dst_pts.append([0, img_h - np.random.randint(thresh)])
dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h)
return dst
def warp_mls(self,
src: np.ndarray,
src_pts: List[int],
dst_pts: List[int],
dst_w: int,
dst_h: int,
trans_ratio: float = 1.) -> np.ndarray:
"""Warp the image."""
rdx, rdy = self._calc_delta(dst_w, dst_h, src_pts, dst_pts, 100)
return self._gen_img(src, rdx, rdy, dst_w, dst_h, 100, trans_ratio)
def _calc_delta(self, dst_w: int, dst_h: int, src_pts: List[int],
dst_pts: List[int],
grid_size: int) -> Tuple[np.ndarray, np.ndarray]:
"""Compute delta."""
pt_count = len(dst_pts)
rdx = np.zeros((dst_h, dst_w))
rdy = np.zeros((dst_h, dst_w))
w = np.zeros(pt_count, dtype=np.float32)
if pt_count < 2:
return
i = 0
while True:
if dst_w <= i < dst_w + grid_size - 1:
i = dst_w - 1
elif i >= dst_w:
break
j = 0
while True:
if dst_h <= j < dst_h + grid_size - 1:
j = dst_h - 1
elif j >= dst_h:
break
sw = 0
swp = np.zeros(2, dtype=np.float32)
swq = np.zeros(2, dtype=np.float32)
new_pt = np.zeros(2, dtype=np.float32)
cur_pt = np.array([i, j], dtype=np.float32)
k = 0
for k in range(pt_count):
if i == dst_pts[k][0] and j == dst_pts[k][1]:
break
w[k] = 1. / ((i - dst_pts[k][0]) * (i - dst_pts[k][0]) +
(j - dst_pts[k][1]) * (j - dst_pts[k][1]))
sw += w[k]
swp = swp + w[k] * np.array(dst_pts[k])
swq = swq + w[k] * np.array(src_pts[k])
if k == pt_count - 1:
pstar = 1 / sw * swp
qstar = 1 / sw * swq
miu_s = 0
for k in range(pt_count):
if i == dst_pts[k][0] and j == dst_pts[k][1]:
continue
pt_i = dst_pts[k] - pstar
miu_s += w[k] * np.sum(pt_i * pt_i)
cur_pt -= pstar
cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
for k in range(pt_count):
if i == dst_pts[k][0] and j == dst_pts[k][1]:
continue
pt_i = dst_pts[k] - pstar
pt_j = np.array([-pt_i[1], pt_i[0]])
tmp_pt = np.zeros(2, dtype=np.float32)
tmp_pt[0] = (
np.sum(pt_i * cur_pt) * src_pts[k][0] -
np.sum(pt_j * cur_pt) * src_pts[k][1])
tmp_pt[1] = (-np.sum(pt_i * cur_pt_j) * src_pts[k][0] +
np.sum(pt_j * cur_pt_j) * src_pts[k][1])
tmp_pt *= (w[k] / miu_s)
new_pt += tmp_pt
new_pt += qstar
else:
new_pt = src_pts[k]
rdx[j, i] = new_pt[0] - i
rdy[j, i] = new_pt[1] - j
j += grid_size
i += grid_size
return rdx, rdy
def _gen_img(self, src: np.ndarray, rdx: np.ndarray, rdy: np.ndarray,
dst_w: int, dst_h: int, grid_size: int,
trans_ratio: float) -> np.ndarray:
"""Generate the image based on delta."""
src_h, src_w = src.shape[:2]
dst = np.zeros_like(src, dtype=np.float32)
for i in np.arange(0, dst_h, grid_size):
for j in np.arange(0, dst_w, grid_size):
ni = i + grid_size
nj = j + grid_size
w = h = grid_size
if ni >= dst_h:
ni = dst_h - 1
h = ni - i + 1
if nj >= dst_w:
nj = dst_w - 1
w = nj - j + 1
di = np.reshape(np.arange(h), (-1, 1))
dj = np.reshape(np.arange(w), (1, -1))
delta_x = self._bilinear_interp(di / h, dj / w, rdx[i, j],
rdx[i, nj], rdx[ni, j],
rdx[ni, nj])
delta_y = self._bilinear_interp(di / h, dj / w, rdy[i, j],
rdy[i, nj], rdy[ni, j],
rdy[ni, nj])
nx = j + dj + delta_x * trans_ratio
ny = i + di + delta_y * trans_ratio
nx = np.clip(nx, 0, src_w - 1)
ny = np.clip(ny, 0, src_h - 1)
nxi = np.array(np.floor(nx), dtype=np.int32)
nyi = np.array(np.floor(ny), dtype=np.int32)
nxi1 = np.array(np.ceil(nx), dtype=np.int32)
nyi1 = np.array(np.ceil(ny), dtype=np.int32)
if len(src.shape) == 3:
x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
else:
x = ny - nyi
y = nx - nxi
dst[i:i + h,
j:j + w] = self._bilinear_interp(x, y, src[nyi, nxi],
src[nyi, nxi1],
src[nyi1, nxi], src[nyi1,
nxi1])
dst = np.clip(dst, 0, 255)
dst = np.array(dst, dtype=np.uint8)
return dst
@staticmethod
def _bilinear_interp(x, y, v11, v12, v21, v22):
"""Bilinear interpolation.
TODO: Docs for args and put it into utils.
"""
return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
(1 - y) + v22 * y) * x
@TRANSFORMS.register_module()
class CropHeight(BaseTransform):
"""Randomly crop the image's height, either from top or bottom.
Adapted from
https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa
Required Keys:
- img
Modified Keys:
- img
- img_shape
Args:
crop_min (int): Minimum pixel(s) to crop. Defaults to 1.
crop_max (int): Maximum pixel(s) to crop. Defaults to 8.
"""
def __init__(
self,
min_pixels: int = 1,
max_pixels: int = 8,
) -> None:
super().__init__()
assert max_pixels >= min_pixels
self.min_pixels = min_pixels
self.max_pixels = max_pixels
@cache_randomness
def get_random_vars(self):
"""Get all the random values used in this transform."""
crop_pixels = int(random.randint(self.min_pixels, self.max_pixels))
crop_top = random.randint(0, 1)
return crop_pixels, crop_top
def transform(self, results: Dict) -> Dict:
"""Transform function to crop images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Cropped results.
"""
h = results['img'].shape[0]
crop_pixels, crop_top = self.get_random_vars()
crop_pixels = min(crop_pixels, h - 1)
img = results['img'].copy()
if crop_top:
img = img[crop_pixels:h, :, :]
else:
img = img[0:h - crop_pixels, :, :]
results['img_shape'] = img.shape[:2]
results['img'] = img
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(min_pixels = {self.min_pixels}, '
repr_str += f'max_pixels = {self.max_pixels})'
return repr_str
@TRANSFORMS.register_module()
class ImageContentJitter(BaseTransform):
"""Jitter the image contents.
Adapted from
https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa
Required Keys:
- img
Modified Keys:
- img
"""
def transform(self, results: Dict, jitter_ratio: float = 0.01) -> Dict:
"""Transform function to jitter images.
Args:
results (dict): Result dict from loading pipeline.
jitter_ratio (float): Controls the strength of jittering.
Defaults to 0.01.
Returns:
dict: Jittered results.
"""
h, w = results['img'].shape[:2]
img = results['img'].copy()
if h > 10 and w > 10:
thres = min(h, w)
jitter_range = int(random.random() * thres * 0.01)
for i in range(jitter_range):
img[i:, i:, :] = img[:h - i, :w - i, :]
results['img'] = img
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += '()'
return repr_str
@TRANSFORMS.register_module()
class ReversePixels(BaseTransform):
"""Reverse image pixels.
Adapted from
https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa
Required Keys:
- img
Modified Keys:
- img
"""
def transform(self, results: Dict) -> Dict:
"""Transform function to reverse image pixels.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Reversed results.
"""
results['img'] = 255. - results['img'].copy()
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += '()'
return repr_str

View File

@ -3,9 +3,12 @@ import copy
import unittest
import numpy as np
from parameterized import parameterized
from mmocr.datasets.transforms import (PadToWidth, PyramidRescale,
RescaleToHeight)
from mmocr.datasets.transforms import (CropHeight, ImageContentJitter,
PadToWidth, PyramidRescale,
RescaleToHeight, ReversePixels,
TextRecogGeneralAug)
class TestPadToWidth(unittest.TestCase):
@ -125,3 +128,85 @@ class TestRescaleToHeight(unittest.TestCase):
'min_width=None, max_width=None, '
'width_divisor=1, '
"resize_cfg={'type': 'Resize', 'scale': 0})"))
class TestTextRecogGeneralAug(unittest.TestCase):
def setUp(self) -> None:
self.transform = TextRecogGeneralAug()
@parameterized.expand([(np.random.random((3, 3, 3)), ),
(np.random.random((10, 10, 3)), ),
(np.random.random((30, 30, 3)), )])
def test_transform(self, img):
data_info = dict(img=img)
results = self.transform(copy.deepcopy(data_info))
self.assertEqual(results['img'].shape[:2], results['img_shape'])
def test_repr(self):
repr_str = self.transform.__repr__()
self.assertEqual(repr_str, 'TextRecogGeneralAug()')
class TestCropHeight(unittest.TestCase):
def setUp(self) -> None:
self.data_info = dict(img=np.random.random((20, 20, 3)))
@parameterized.expand([
(3, 3),
(5, 10),
])
def test_transform(self, min_pixels, max_pixels):
self.transform = CropHeight(
min_pixels=min_pixels, max_pixels=max_pixels)
results = self.transform(copy.deepcopy(self.data_info))
self.assertEqual(results['img'].shape[:2], results['img_shape'])
h_diff = self.data_info['img'].shape[0] - results['img_shape'][0]
self.assertGreaterEqual(h_diff, min_pixels)
self.assertLessEqual(h_diff, max_pixels)
def test_invalid(self):
with self.assertRaises(AssertionError):
self.transform = CropHeight(min_pixels=10, max_pixels=9)
def test_repr(self):
transform = CropHeight(min_pixels=2, max_pixels=10)
repr_str = transform.__repr__()
self.assertEqual(repr_str, 'CropHeight(min_pixels = 2, '
'max_pixels = 10)')
class TestImageContentJitter(unittest.TestCase):
def setUp(self) -> None:
self.transform = ImageContentJitter()
@parameterized.expand([(np.random.random((3, 3, 3)), ),
(np.random.random((10, 10, 3)), ),
(np.random.random((30, 30, 3)), )])
def test_transform(self, img):
data_info = dict(img=img)
self.transform(copy.deepcopy(data_info))
def test_repr(self):
repr_str = self.transform.__repr__()
self.assertEqual(repr_str, 'ImageContentJitter()')
class TestReversePixels(unittest.TestCase):
def setUp(self) -> None:
self.transform = ReversePixels()
@parameterized.expand([(np.random.random((3, 3, 3)), ),
(np.random.random((10, 10, 3)), ),
(np.random.random((30, 30, 3)), )])
def test_transform(self, img):
data_info = dict(img=img)
results = self.transform(copy.deepcopy(data_info))
self.assertTrue(np.array_equal(results['img'], 255. - img))
def test_repr(self):
repr_str = self.transform.__repr__()
self.assertEqual(repr_str, 'ReversePixels()')