mirror of https://github.com/open-mmlab/mmocr.git
[Transforms] SVTR transforms (#1645)
* rec transforms * fix * ut * update docs * fix * new name * fixpull/1663/head
parent
d679691a02
commit
b0557c2c55
|
@ -34,7 +34,6 @@ TextDet Transforms
|
|||
:template: classtemplate.rst
|
||||
|
||||
BoundedScaleAspectJitter
|
||||
FixInvalidPolygon
|
||||
RandomFlip
|
||||
SourceImagePad
|
||||
ShortScaleAspectJitter
|
||||
|
@ -50,6 +49,10 @@ TextRecog Transforms
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
TextRecogGeneralAug
|
||||
CropHeight
|
||||
ImageContentJitter
|
||||
ReversePixels
|
||||
PyramidRescale
|
||||
PadToWidth
|
||||
RescaleToHeight
|
||||
|
@ -66,6 +69,8 @@ OCR Transforms
|
|||
RandomCrop
|
||||
RandomRotate
|
||||
Resize
|
||||
FixInvalidPolygon
|
||||
RemoveIgnored
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -34,7 +34,6 @@ TextDet Transforms
|
|||
:template: classtemplate.rst
|
||||
|
||||
BoundedScaleAspectJitter
|
||||
FixInvalidPolygon
|
||||
RandomFlip
|
||||
SourceImagePad
|
||||
ShortScaleAspectJitter
|
||||
|
@ -50,6 +49,10 @@ TextRecog Transforms
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
TextRecogGeneralAug
|
||||
CropHeight
|
||||
ImageContentJitter
|
||||
ReversePixels
|
||||
PyramidRescale
|
||||
PadToWidth
|
||||
RescaleToHeight
|
||||
|
@ -66,6 +69,8 @@ OCR Transforms
|
|||
RandomCrop
|
||||
RandomRotate
|
||||
Resize
|
||||
FixInvalidPolygon
|
||||
RemoveIgnored
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,9 @@ from .ocr_transforms import (FixInvalidPolygon, RandomCrop, RandomRotate,
|
|||
from .textdet_transforms import (BoundedScaleAspectJitter, RandomFlip,
|
||||
ShortScaleAspectJitter, SourceImagePad,
|
||||
TextDetRandomCrop, TextDetRandomCropFlip)
|
||||
from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight
|
||||
from .textrecog_transforms import (CropHeight, ImageContentJitter, PadToWidth,
|
||||
PyramidRescale, RescaleToHeight,
|
||||
ReversePixels, TextRecogGeneralAug)
|
||||
from .wrappers import ConditionApply, ImgAugWrapper, TorchVisionWrapper
|
||||
|
||||
__all__ = [
|
||||
|
@ -20,5 +22,6 @@ __all__ = [
|
|||
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
|
||||
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
|
||||
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile',
|
||||
'LoadImageFromNDArray', 'RemoveIgnored', 'ConditionApply'
|
||||
'LoadImageFromNDArray', 'CropHeight', 'TextRecogGeneralAug',
|
||||
'ImageContentJitter', 'ReversePixels', 'RemoveIgnored', 'ConditionApply'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
|
@ -251,3 +252,473 @@ class PadToWidth(BaseTransform):
|
|||
repr_str += f'(width={self.width}, '
|
||||
repr_str += f'pad_cfg={self.pad_cfg})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class TextRecogGeneralAug(BaseTransform):
|
||||
"""A general geometric augmentation tool for text images in the CVPR 2020
|
||||
paper "Learn to Augment: Joint Data Augmentation and Network Optimization
|
||||
for Text Recognition". It applies distortion, stretching, and perspective
|
||||
transforms to an image.
|
||||
|
||||
This implementation is adapted from
|
||||
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py # noqa
|
||||
|
||||
TODO: Split this transform into three transforms.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
""" # noqa
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Call function to pad images.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Updated result dict.
|
||||
"""
|
||||
h, w = results['img'].shape[:2]
|
||||
if h >= 20 and w >= 20:
|
||||
results['img'] = self.tia_distort(results['img'],
|
||||
random.randint(3, 6))
|
||||
results['img'] = self.tia_stretch(results['img'],
|
||||
random.randint(3, 6))
|
||||
h, w = results['img'].shape[:2]
|
||||
if h >= 5 and w >= 5:
|
||||
results['img'] = self.tia_perspective(results['img'])
|
||||
results['img_shape'] = results['img'].shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += '()'
|
||||
return repr_str
|
||||
|
||||
def tia_distort(self, img: np.ndarray, segment: int = 4) -> np.ndarray:
|
||||
"""Image distortion.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The image.
|
||||
segment (int): The number of segments to divide the image along
|
||||
the width. Defaults to 4.
|
||||
"""
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
cut = img_w // segment
|
||||
thresh = cut // 3
|
||||
|
||||
src_pts = list()
|
||||
dst_pts = list()
|
||||
|
||||
src_pts.append([0, 0])
|
||||
src_pts.append([img_w, 0])
|
||||
src_pts.append([img_w, img_h])
|
||||
src_pts.append([0, img_h])
|
||||
|
||||
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
|
||||
dst_pts.append(
|
||||
[img_w - np.random.randint(thresh),
|
||||
np.random.randint(thresh)])
|
||||
dst_pts.append([
|
||||
img_w - np.random.randint(thresh),
|
||||
img_h - np.random.randint(thresh)
|
||||
])
|
||||
dst_pts.append(
|
||||
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
|
||||
|
||||
half_thresh = thresh * 0.5
|
||||
|
||||
for cut_idx in np.arange(1, segment, 1):
|
||||
src_pts.append([cut * cut_idx, 0])
|
||||
src_pts.append([cut * cut_idx, img_h])
|
||||
dst_pts.append([
|
||||
cut * cut_idx + np.random.randint(thresh) - half_thresh,
|
||||
np.random.randint(thresh) - half_thresh
|
||||
])
|
||||
dst_pts.append([
|
||||
cut * cut_idx + np.random.randint(thresh) - half_thresh,
|
||||
img_h + np.random.randint(thresh) - half_thresh
|
||||
])
|
||||
|
||||
dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h)
|
||||
|
||||
return dst
|
||||
|
||||
def tia_stretch(self, img: np.ndarray, segment: int = 4) -> np.ndarray:
|
||||
"""Image stretching.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The image.
|
||||
segment (int): The number of segments to divide the image along
|
||||
the width. Defaults to 4.
|
||||
"""
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
cut = img_w // segment
|
||||
thresh = cut * 4 // 5
|
||||
|
||||
src_pts = list()
|
||||
dst_pts = list()
|
||||
|
||||
src_pts.append([0, 0])
|
||||
src_pts.append([img_w, 0])
|
||||
src_pts.append([img_w, img_h])
|
||||
src_pts.append([0, img_h])
|
||||
|
||||
dst_pts.append([0, 0])
|
||||
dst_pts.append([img_w, 0])
|
||||
dst_pts.append([img_w, img_h])
|
||||
dst_pts.append([0, img_h])
|
||||
|
||||
half_thresh = thresh * 0.5
|
||||
|
||||
for cut_idx in np.arange(1, segment, 1):
|
||||
move = np.random.randint(thresh) - half_thresh
|
||||
src_pts.append([cut * cut_idx, 0])
|
||||
src_pts.append([cut * cut_idx, img_h])
|
||||
dst_pts.append([cut * cut_idx + move, 0])
|
||||
dst_pts.append([cut * cut_idx + move, img_h])
|
||||
|
||||
dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h)
|
||||
|
||||
return dst
|
||||
|
||||
def tia_perspective(self, img: np.ndarray) -> np.ndarray:
|
||||
"""Image perspective transformation.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The image.
|
||||
segment (int): The number of segments to divide the image along
|
||||
the width. Defaults to 4.
|
||||
"""
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
thresh = img_h // 2
|
||||
|
||||
src_pts = list()
|
||||
dst_pts = list()
|
||||
|
||||
src_pts.append([0, 0])
|
||||
src_pts.append([img_w, 0])
|
||||
src_pts.append([img_w, img_h])
|
||||
src_pts.append([0, img_h])
|
||||
|
||||
dst_pts.append([0, np.random.randint(thresh)])
|
||||
dst_pts.append([img_w, np.random.randint(thresh)])
|
||||
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
|
||||
dst_pts.append([0, img_h - np.random.randint(thresh)])
|
||||
|
||||
dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h)
|
||||
|
||||
return dst
|
||||
|
||||
def warp_mls(self,
|
||||
src: np.ndarray,
|
||||
src_pts: List[int],
|
||||
dst_pts: List[int],
|
||||
dst_w: int,
|
||||
dst_h: int,
|
||||
trans_ratio: float = 1.) -> np.ndarray:
|
||||
"""Warp the image."""
|
||||
rdx, rdy = self._calc_delta(dst_w, dst_h, src_pts, dst_pts, 100)
|
||||
return self._gen_img(src, rdx, rdy, dst_w, dst_h, 100, trans_ratio)
|
||||
|
||||
def _calc_delta(self, dst_w: int, dst_h: int, src_pts: List[int],
|
||||
dst_pts: List[int],
|
||||
grid_size: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Compute delta."""
|
||||
|
||||
pt_count = len(dst_pts)
|
||||
rdx = np.zeros((dst_h, dst_w))
|
||||
rdy = np.zeros((dst_h, dst_w))
|
||||
w = np.zeros(pt_count, dtype=np.float32)
|
||||
|
||||
if pt_count < 2:
|
||||
return
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
if dst_w <= i < dst_w + grid_size - 1:
|
||||
i = dst_w - 1
|
||||
elif i >= dst_w:
|
||||
break
|
||||
|
||||
j = 0
|
||||
while True:
|
||||
if dst_h <= j < dst_h + grid_size - 1:
|
||||
j = dst_h - 1
|
||||
elif j >= dst_h:
|
||||
break
|
||||
|
||||
sw = 0
|
||||
swp = np.zeros(2, dtype=np.float32)
|
||||
swq = np.zeros(2, dtype=np.float32)
|
||||
new_pt = np.zeros(2, dtype=np.float32)
|
||||
cur_pt = np.array([i, j], dtype=np.float32)
|
||||
|
||||
k = 0
|
||||
for k in range(pt_count):
|
||||
if i == dst_pts[k][0] and j == dst_pts[k][1]:
|
||||
break
|
||||
|
||||
w[k] = 1. / ((i - dst_pts[k][0]) * (i - dst_pts[k][0]) +
|
||||
(j - dst_pts[k][1]) * (j - dst_pts[k][1]))
|
||||
|
||||
sw += w[k]
|
||||
swp = swp + w[k] * np.array(dst_pts[k])
|
||||
swq = swq + w[k] * np.array(src_pts[k])
|
||||
|
||||
if k == pt_count - 1:
|
||||
pstar = 1 / sw * swp
|
||||
qstar = 1 / sw * swq
|
||||
|
||||
miu_s = 0
|
||||
for k in range(pt_count):
|
||||
if i == dst_pts[k][0] and j == dst_pts[k][1]:
|
||||
continue
|
||||
pt_i = dst_pts[k] - pstar
|
||||
miu_s += w[k] * np.sum(pt_i * pt_i)
|
||||
|
||||
cur_pt -= pstar
|
||||
cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
|
||||
|
||||
for k in range(pt_count):
|
||||
if i == dst_pts[k][0] and j == dst_pts[k][1]:
|
||||
continue
|
||||
|
||||
pt_i = dst_pts[k] - pstar
|
||||
pt_j = np.array([-pt_i[1], pt_i[0]])
|
||||
|
||||
tmp_pt = np.zeros(2, dtype=np.float32)
|
||||
tmp_pt[0] = (
|
||||
np.sum(pt_i * cur_pt) * src_pts[k][0] -
|
||||
np.sum(pt_j * cur_pt) * src_pts[k][1])
|
||||
tmp_pt[1] = (-np.sum(pt_i * cur_pt_j) * src_pts[k][0] +
|
||||
np.sum(pt_j * cur_pt_j) * src_pts[k][1])
|
||||
tmp_pt *= (w[k] / miu_s)
|
||||
new_pt += tmp_pt
|
||||
|
||||
new_pt += qstar
|
||||
else:
|
||||
new_pt = src_pts[k]
|
||||
|
||||
rdx[j, i] = new_pt[0] - i
|
||||
rdy[j, i] = new_pt[1] - j
|
||||
|
||||
j += grid_size
|
||||
i += grid_size
|
||||
return rdx, rdy
|
||||
|
||||
def _gen_img(self, src: np.ndarray, rdx: np.ndarray, rdy: np.ndarray,
|
||||
dst_w: int, dst_h: int, grid_size: int,
|
||||
trans_ratio: float) -> np.ndarray:
|
||||
"""Generate the image based on delta."""
|
||||
|
||||
src_h, src_w = src.shape[:2]
|
||||
dst = np.zeros_like(src, dtype=np.float32)
|
||||
|
||||
for i in np.arange(0, dst_h, grid_size):
|
||||
for j in np.arange(0, dst_w, grid_size):
|
||||
ni = i + grid_size
|
||||
nj = j + grid_size
|
||||
w = h = grid_size
|
||||
if ni >= dst_h:
|
||||
ni = dst_h - 1
|
||||
h = ni - i + 1
|
||||
if nj >= dst_w:
|
||||
nj = dst_w - 1
|
||||
w = nj - j + 1
|
||||
|
||||
di = np.reshape(np.arange(h), (-1, 1))
|
||||
dj = np.reshape(np.arange(w), (1, -1))
|
||||
delta_x = self._bilinear_interp(di / h, dj / w, rdx[i, j],
|
||||
rdx[i, nj], rdx[ni, j],
|
||||
rdx[ni, nj])
|
||||
delta_y = self._bilinear_interp(di / h, dj / w, rdy[i, j],
|
||||
rdy[i, nj], rdy[ni, j],
|
||||
rdy[ni, nj])
|
||||
nx = j + dj + delta_x * trans_ratio
|
||||
ny = i + di + delta_y * trans_ratio
|
||||
nx = np.clip(nx, 0, src_w - 1)
|
||||
ny = np.clip(ny, 0, src_h - 1)
|
||||
nxi = np.array(np.floor(nx), dtype=np.int32)
|
||||
nyi = np.array(np.floor(ny), dtype=np.int32)
|
||||
nxi1 = np.array(np.ceil(nx), dtype=np.int32)
|
||||
nyi1 = np.array(np.ceil(ny), dtype=np.int32)
|
||||
|
||||
if len(src.shape) == 3:
|
||||
x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
|
||||
y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
|
||||
else:
|
||||
x = ny - nyi
|
||||
y = nx - nxi
|
||||
dst[i:i + h,
|
||||
j:j + w] = self._bilinear_interp(x, y, src[nyi, nxi],
|
||||
src[nyi, nxi1],
|
||||
src[nyi1, nxi], src[nyi1,
|
||||
nxi1])
|
||||
|
||||
dst = np.clip(dst, 0, 255)
|
||||
dst = np.array(dst, dtype=np.uint8)
|
||||
|
||||
return dst
|
||||
|
||||
@staticmethod
|
||||
def _bilinear_interp(x, y, v11, v12, v21, v22):
|
||||
"""Bilinear interpolation.
|
||||
|
||||
TODO: Docs for args and put it into utils.
|
||||
"""
|
||||
return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
|
||||
(1 - y) + v22 * y) * x
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class CropHeight(BaseTransform):
|
||||
"""Randomly crop the image's height, either from top or bottom.
|
||||
|
||||
Adapted from
|
||||
https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
||||
Args:
|
||||
crop_min (int): Minimum pixel(s) to crop. Defaults to 1.
|
||||
crop_max (int): Maximum pixel(s) to crop. Defaults to 8.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_pixels: int = 1,
|
||||
max_pixels: int = 8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert max_pixels >= min_pixels
|
||||
self.min_pixels = min_pixels
|
||||
self.max_pixels = max_pixels
|
||||
|
||||
@cache_randomness
|
||||
def get_random_vars(self):
|
||||
"""Get all the random values used in this transform."""
|
||||
crop_pixels = int(random.randint(self.min_pixels, self.max_pixels))
|
||||
crop_top = random.randint(0, 1)
|
||||
return crop_pixels, crop_top
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Transform function to crop images.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Cropped results.
|
||||
"""
|
||||
h = results['img'].shape[0]
|
||||
crop_pixels, crop_top = self.get_random_vars()
|
||||
crop_pixels = min(crop_pixels, h - 1)
|
||||
img = results['img'].copy()
|
||||
if crop_top:
|
||||
img = img[crop_pixels:h, :, :]
|
||||
else:
|
||||
img = img[0:h - crop_pixels, :, :]
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['img'] = img
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(min_pixels = {self.min_pixels}, '
|
||||
repr_str += f'max_pixels = {self.max_pixels})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ImageContentJitter(BaseTransform):
|
||||
"""Jitter the image contents.
|
||||
|
||||
Adapted from
|
||||
https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
"""
|
||||
|
||||
def transform(self, results: Dict, jitter_ratio: float = 0.01) -> Dict:
|
||||
"""Transform function to jitter images.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
jitter_ratio (float): Controls the strength of jittering.
|
||||
Defaults to 0.01.
|
||||
|
||||
Returns:
|
||||
dict: Jittered results.
|
||||
"""
|
||||
h, w = results['img'].shape[:2]
|
||||
img = results['img'].copy()
|
||||
if h > 10 and w > 10:
|
||||
thres = min(h, w)
|
||||
jitter_range = int(random.random() * thres * 0.01)
|
||||
for i in range(jitter_range):
|
||||
img[i:, i:, :] = img[:h - i, :w - i, :]
|
||||
results['img'] = img
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += '()'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ReversePixels(BaseTransform):
|
||||
"""Reverse image pixels.
|
||||
|
||||
Adapted from
|
||||
https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
"""
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Transform function to reverse image pixels.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Reversed results.
|
||||
"""
|
||||
results['img'] = 255. - results['img'].copy()
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += '()'
|
||||
return repr_str
|
||||
|
|
|
@ -3,9 +3,12 @@ import copy
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
|
||||
from mmocr.datasets.transforms import (PadToWidth, PyramidRescale,
|
||||
RescaleToHeight)
|
||||
from mmocr.datasets.transforms import (CropHeight, ImageContentJitter,
|
||||
PadToWidth, PyramidRescale,
|
||||
RescaleToHeight, ReversePixels,
|
||||
TextRecogGeneralAug)
|
||||
|
||||
|
||||
class TestPadToWidth(unittest.TestCase):
|
||||
|
@ -125,3 +128,85 @@ class TestRescaleToHeight(unittest.TestCase):
|
|||
'min_width=None, max_width=None, '
|
||||
'width_divisor=1, '
|
||||
"resize_cfg={'type': 'Resize', 'scale': 0})"))
|
||||
|
||||
|
||||
class TestTextRecogGeneralAug(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.transform = TextRecogGeneralAug()
|
||||
|
||||
@parameterized.expand([(np.random.random((3, 3, 3)), ),
|
||||
(np.random.random((10, 10, 3)), ),
|
||||
(np.random.random((30, 30, 3)), )])
|
||||
def test_transform(self, img):
|
||||
data_info = dict(img=img)
|
||||
results = self.transform(copy.deepcopy(data_info))
|
||||
self.assertEqual(results['img'].shape[:2], results['img_shape'])
|
||||
|
||||
def test_repr(self):
|
||||
repr_str = self.transform.__repr__()
|
||||
self.assertEqual(repr_str, 'TextRecogGeneralAug()')
|
||||
|
||||
|
||||
class TestCropHeight(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.data_info = dict(img=np.random.random((20, 20, 3)))
|
||||
|
||||
@parameterized.expand([
|
||||
(3, 3),
|
||||
(5, 10),
|
||||
])
|
||||
def test_transform(self, min_pixels, max_pixels):
|
||||
self.transform = CropHeight(
|
||||
min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
results = self.transform(copy.deepcopy(self.data_info))
|
||||
self.assertEqual(results['img'].shape[:2], results['img_shape'])
|
||||
h_diff = self.data_info['img'].shape[0] - results['img_shape'][0]
|
||||
self.assertGreaterEqual(h_diff, min_pixels)
|
||||
self.assertLessEqual(h_diff, max_pixels)
|
||||
|
||||
def test_invalid(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
self.transform = CropHeight(min_pixels=10, max_pixels=9)
|
||||
|
||||
def test_repr(self):
|
||||
transform = CropHeight(min_pixels=2, max_pixels=10)
|
||||
repr_str = transform.__repr__()
|
||||
self.assertEqual(repr_str, 'CropHeight(min_pixels = 2, '
|
||||
'max_pixels = 10)')
|
||||
|
||||
|
||||
class TestImageContentJitter(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.transform = ImageContentJitter()
|
||||
|
||||
@parameterized.expand([(np.random.random((3, 3, 3)), ),
|
||||
(np.random.random((10, 10, 3)), ),
|
||||
(np.random.random((30, 30, 3)), )])
|
||||
def test_transform(self, img):
|
||||
data_info = dict(img=img)
|
||||
self.transform(copy.deepcopy(data_info))
|
||||
|
||||
def test_repr(self):
|
||||
repr_str = self.transform.__repr__()
|
||||
self.assertEqual(repr_str, 'ImageContentJitter()')
|
||||
|
||||
|
||||
class TestReversePixels(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.transform = ReversePixels()
|
||||
|
||||
@parameterized.expand([(np.random.random((3, 3, 3)), ),
|
||||
(np.random.random((10, 10, 3)), ),
|
||||
(np.random.random((30, 30, 3)), )])
|
||||
def test_transform(self, img):
|
||||
data_info = dict(img=img)
|
||||
results = self.transform(copy.deepcopy(data_info))
|
||||
self.assertTrue(np.array_equal(results['img'], 255. - img))
|
||||
|
||||
def test_repr(self):
|
||||
repr_str = self.transform.__repr__()
|
||||
self.assertEqual(repr_str, 'ReversePixels()')
|
||||
|
|
Loading…
Reference in New Issue