mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] RandomCropFlip
parent
79186b61ec
commit
ac4eb34843
.dev_scripts
mmocr/datasets/pipelines
old_tests/test_dataset
tests/test_datasets/test_pipelines
|
@ -7,6 +7,7 @@ mmocr/models/textrecog/recognizer/base.py
|
|||
.*/__init__.py
|
||||
# It will be removed after all transforms have been refactored into processing.py
|
||||
mmocr/datasets/pipelines/transforms.py
|
||||
mmocr/datasets/pipelines/dbnet_transforms.py
|
||||
|
||||
# will be deleted
|
||||
mmocr/models/textdet/dense_heads/head_mixin.py
|
||||
|
@ -20,6 +21,3 @@ mmocr/models/textdet/detectors/text_detector_mixin.py
|
|||
|
||||
# It will be covered by tests of any det model implemented in future
|
||||
mmocr/models/textdet/detectors/single_stage_text_detector.py
|
||||
|
||||
mmocr/datasets/pipelines/transforms.py
|
||||
mmocr/datasets/pipelines/dbnet_transforms.py
|
||||
|
|
|
@ -9,12 +9,13 @@ from .ocr_seg_targets import OCRSegTargets
|
|||
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
|
||||
OpencvToPil, PilToOpencv, RandomPaddingOCR,
|
||||
RandomRotateImageBox, ResizeOCR, ToTensorOCR)
|
||||
from .processing import PyramidRescale, RandomRotate, Resize
|
||||
from .processing import (PyramidRescale, RandomRotate, Resize,
|
||||
TextDetRandomCropFlip)
|
||||
from .test_time_aug import MultiRotateAugOCR
|
||||
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
|
||||
TextSnakeTargets)
|
||||
from .transform_wrappers import OneOfWrapper, RandomWrapper, TorchVisionWrapper
|
||||
from .transforms import (ColorJitter, RandomCropFlip, RandomCropInstances,
|
||||
from .transforms import (ColorJitter, RandomCropInstances,
|
||||
RandomCropPolyInstances, RandomScaling,
|
||||
ScaleAspectJitter, SquareResizePad)
|
||||
from .wrappers import ImgAug
|
||||
|
@ -27,7 +28,7 @@ __all__ = [
|
|||
'RandomCropPolyInstances', 'RandomPaddingOCR', 'ImgAug', 'EastRandomCrop',
|
||||
'RandomRotateImageBox', 'OpencvToPil', 'PilToOpencv', 'SquareResizePad',
|
||||
'TextSnakeTargets', 'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8',
|
||||
'FCENetTargets', 'RandomScaling', 'RandomCropFlip', 'NerTransform',
|
||||
'FCENetTargets', 'RandomScaling', 'TextDetRandomCropFlip', 'NerTransform',
|
||||
'ToTensorNER', 'ResizeNoImg', 'PyramidRescale', 'OneOfWrapper',
|
||||
'RandomWrapper', 'TorchVisionWrapper', 'LoadImageFromLMDB', 'Resize'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import math
|
||||
from typing import Dict, Tuple
|
||||
import random
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
|
@ -9,7 +11,9 @@ from mmcv.image.geometric import _scale_size
|
|||
from mmcv.transforms import Resize as MMCV_Resize
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmcv.transforms.utils import cache_randomness
|
||||
from shapely.geometry import Polygon as plg
|
||||
|
||||
import mmocr.core.evaluation.utils as eval_utils
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from mmocr.utils import (bbox2poly, crop_polygon, poly2bbox, rescale_bboxes,
|
||||
rescale_polygon)
|
||||
|
@ -484,3 +488,251 @@ class RandomRotate(BaseTransform):
|
|||
repr_str += f', pad_value = {self.pad_value}'
|
||||
repr_str += f', use_canvas = {self.use_canvas})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class TextDetRandomCropFlip(BaseTransform):
|
||||
# TODO Rename this transformer; Refactor the redundant code.
|
||||
"""Random crop and flip a patch in the image. Only used in text detection
|
||||
task.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_bboxes
|
||||
- gt_polygons
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- gt_bboxes
|
||||
- gt_polygons
|
||||
|
||||
Args:
|
||||
pad_ratio (float): The ratio of padding. Defaults to 0.1.
|
||||
crop_ratio (float): The ratio of cropping. Defaults to 0.5.
|
||||
iter_num (int): Number of operations. Defaults to 1.
|
||||
min_area_ratio (float): Minimal area ratio between cropped patch
|
||||
and original image. Defaults to 0.2.
|
||||
epsilon (float): The threshold of polygon IoU between cropped area
|
||||
and polygon, which is used to avoid cropping text instances.
|
||||
Defaults to 0.01.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pad_ratio: float = 0.1,
|
||||
crop_ratio: float = 0.5,
|
||||
iter_num: int = 1,
|
||||
min_area_ratio: float = 0.2,
|
||||
epsilon: float = 1e-2) -> None:
|
||||
if not isinstance(pad_ratio, float):
|
||||
raise TypeError('`pad_ratio` should be an float, '
|
||||
f'but got {type(pad_ratio)} instead')
|
||||
if not isinstance(crop_ratio, float):
|
||||
raise TypeError('`crop_ratio` should be a float, '
|
||||
f'but got {type(crop_ratio)} instead')
|
||||
if not isinstance(iter_num, int):
|
||||
raise TypeError('`iter_num` should be an integer, '
|
||||
f'but got {type(iter_num)} instead')
|
||||
if not isinstance(min_area_ratio, float):
|
||||
raise TypeError('`min_area_ratio` should be a float, '
|
||||
f'but got {type(min_area_ratio)} instead')
|
||||
if not isinstance(epsilon, float):
|
||||
raise TypeError('`epsilon` should be a float, '
|
||||
f'but got {type(epsilon)} instead')
|
||||
|
||||
self.pad_ratio = pad_ratio
|
||||
self.epsilon = epsilon
|
||||
self.crop_ratio = crop_ratio
|
||||
self.iter_num = iter_num
|
||||
self.min_area_ratio = min_area_ratio
|
||||
|
||||
@cache_randomness
|
||||
def _random_prob(self) -> float:
|
||||
"""Get the random prob to decide whether apply the transform.
|
||||
|
||||
Returns:
|
||||
float: The probability
|
||||
"""
|
||||
return random.random()
|
||||
|
||||
@cache_randomness
|
||||
def _random_flip_type(self) -> int:
|
||||
"""Get the random flip type.
|
||||
|
||||
Returns:
|
||||
int: The flip type index. (0: horizontal; 1: vertical; 2: both)
|
||||
"""
|
||||
return np.random.randint(3)
|
||||
|
||||
@cache_randomness
|
||||
def _random_choice(self, axis: np.ndarray) -> np.ndarray:
|
||||
"""Randomly select two coordinates from the axis.
|
||||
|
||||
Args:
|
||||
axis (np.ndarray): Result dict containing the data to transform
|
||||
|
||||
Returns:
|
||||
np.ndarray: The selected coordinates
|
||||
"""
|
||||
return np.random.choice(axis, size=2)
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Applying random crop flip on results.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict containing the data to transform
|
||||
|
||||
Returns:
|
||||
dict: The transformed data
|
||||
"""
|
||||
assert 'img' in results, '`img` is not found in results'
|
||||
for _ in range(self.iter_num):
|
||||
results = self._random_crop_flip_polygons(results)
|
||||
# TODO Add random_crop_flip_bboxes (will be added after the poly2box
|
||||
# and box2poly have been merged)
|
||||
return results
|
||||
|
||||
def _random_crop_flip_polygons(self, results: Dict) -> Dict:
|
||||
"""Applying random crop flip on polygons.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict containing the data to transform
|
||||
|
||||
Returns:
|
||||
dict: The transformed data
|
||||
"""
|
||||
if results.get('gt_polygons', None) is None:
|
||||
return results
|
||||
|
||||
image = results['img']
|
||||
polygons = results['gt_polygons']
|
||||
if len(polygons) == 0 or self._random_prob() > self.crop_ratio:
|
||||
return results
|
||||
|
||||
h, w = results['img_shape']
|
||||
area = h * w
|
||||
pad_h = int(h * self.pad_ratio)
|
||||
pad_w = int(w * self.pad_ratio)
|
||||
h_axis, w_axis = self._generate_crop_target(image, polygons, pad_h,
|
||||
pad_w)
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return results
|
||||
|
||||
# At most 10 attempts
|
||||
for _ in range(10):
|
||||
polys_keep = []
|
||||
polys_new = []
|
||||
xx = self._random_choice(w_axis)
|
||||
yy = self._random_choice(h_axis)
|
||||
xmin = np.clip(np.min(xx) - pad_w, 0, w - 1)
|
||||
xmax = np.clip(np.max(xx) - pad_w, 0, w - 1)
|
||||
ymin = np.clip(np.min(yy) - pad_h, 0, h - 1)
|
||||
ymax = np.clip(np.max(yy) - pad_h, 0, h - 1)
|
||||
if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
|
||||
# Skip when cropped area is too small
|
||||
continue
|
||||
|
||||
pts = np.stack([[xmin, xmax, xmax, xmin],
|
||||
[ymin, ymin, ymax, ymax]]).T.astype(np.int32)
|
||||
pp = plg(pts)
|
||||
success_flag = True
|
||||
for polygon in polygons:
|
||||
ppi = plg(polygon.reshape(-1, 2))
|
||||
# TODO Move this eval_utils to point_utils?
|
||||
ppiou = eval_utils.poly_intersection(ppi, pp)
|
||||
if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
|
||||
np.abs(ppiou) > self.epsilon:
|
||||
success_flag = False
|
||||
break
|
||||
if np.abs(ppiou - float(ppi.area)) < self.epsilon:
|
||||
polys_new.append(polygon)
|
||||
else:
|
||||
polys_keep.append(polygon)
|
||||
|
||||
if success_flag:
|
||||
break
|
||||
|
||||
cropped = image[ymin:ymax, xmin:xmax, :]
|
||||
select_type = self._random_flip_type()
|
||||
print(select_type)
|
||||
if select_type == 0:
|
||||
img = np.ascontiguousarray(cropped[:, ::-1])
|
||||
elif select_type == 1:
|
||||
img = np.ascontiguousarray(cropped[::-1, :])
|
||||
else:
|
||||
img = np.ascontiguousarray(cropped[::-1, ::-1])
|
||||
image[ymin:ymax, xmin:xmax, :] = img
|
||||
results['img'] = image
|
||||
|
||||
if len(polys_new) != 0:
|
||||
height, width, _ = cropped.shape
|
||||
if select_type == 0:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon.reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
polys_new[idx] = poly.reshape(-1, )
|
||||
elif select_type == 1:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon.reshape(-1, 2)
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
polys_new[idx] = poly.reshape(-1, )
|
||||
else:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon.reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
polys_new[idx] = poly.reshape(-1, )
|
||||
polygons = polys_keep + polys_new
|
||||
results['gt_polygons'] = polygons
|
||||
|
||||
return results
|
||||
|
||||
def _generate_crop_target(self, image: np.ndarray,
|
||||
all_polys: List[np.ndarray], pad_h: int,
|
||||
pad_w: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Generate cropping target and make sure not to crop the polygon
|
||||
instances.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The image waited to be crop.
|
||||
all_polys (list[np.ndarray]): Ground-truth polygons.
|
||||
pad_h (int): Padding length of height.
|
||||
pad_w (int): Padding length of width.
|
||||
|
||||
Returns:
|
||||
(np.ndarray, np.ndarray): Returns a tuple ``(h_axis, w_axis)``,
|
||||
where ``h_axis`` is the vertical cropping range and ``w_axis``
|
||||
is the horizontal cropping range.
|
||||
"""
|
||||
h, w, _ = image.shape
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
|
||||
text_polys = []
|
||||
for polygon in all_polys:
|
||||
rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2))
|
||||
box = cv2.boxPoints(rect)
|
||||
box = np.int0(box)
|
||||
text_polys.append([box[0], box[1], box[2], box[3]])
|
||||
|
||||
polys = np.array(text_polys, dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx, maxx = np.min(poly[:, 0]), np.max(poly[:, 0])
|
||||
miny, maxy = np.min(poly[:, 1]), np.max(poly[:, 1])
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
return h_axis, w_axis
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(pad_ratio = {self.pad_ratio}'
|
||||
repr_str += f', crop_ratio = {self.crop_ratio}'
|
||||
repr_str += f', iter_num = {self.iter_num}'
|
||||
repr_str += f', min_area_ratio = {self.min_area_ratio}'
|
||||
repr_str += f', epsilon = {self.epsilon})'
|
||||
return repr_str
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torchvision.transforms as transforms
|
||||
from mmdet.core import BitmapMasks, PolygonMasks
|
||||
from mmdet.datasets.pipelines.transforms import Resize
|
||||
from PIL import Image
|
||||
from shapely.geometry import Polygon as plg
|
||||
|
||||
import mmocr.core.evaluation.utils as eval_utils
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
@ -599,198 +597,3 @@ class RandomScaling:
|
|||
results[key] = results[key].resize(out_size)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomCropFlip:
|
||||
|
||||
def __init__(self,
|
||||
pad_ratio=0.1,
|
||||
crop_ratio=0.5,
|
||||
iter_num=1,
|
||||
min_area_ratio=0.2):
|
||||
"""Random crop and flip a patch of the image.
|
||||
|
||||
Args:
|
||||
crop_ratio (float): The ratio of cropping.
|
||||
iter_num (int): Number of operations.
|
||||
min_area_ratio (float): Minimal area ratio between cropped patch
|
||||
and original image.
|
||||
"""
|
||||
assert isinstance(crop_ratio, float)
|
||||
assert isinstance(iter_num, int)
|
||||
assert isinstance(min_area_ratio, float)
|
||||
|
||||
self.pad_ratio = pad_ratio
|
||||
self.epsilon = 1e-2
|
||||
self.crop_ratio = crop_ratio
|
||||
self.iter_num = iter_num
|
||||
self.min_area_ratio = min_area_ratio
|
||||
|
||||
def __call__(self, results):
|
||||
for i in range(self.iter_num):
|
||||
results = self.random_crop_flip(results)
|
||||
return results
|
||||
|
||||
def random_crop_flip(self, results):
|
||||
image = results['img']
|
||||
polygons = results['gt_masks'].masks
|
||||
ignore_polygons = results['gt_masks_ignore'].masks
|
||||
all_polygons = polygons + ignore_polygons
|
||||
if len(polygons) == 0:
|
||||
return results
|
||||
|
||||
if np.random.random() >= self.crop_ratio:
|
||||
return results
|
||||
|
||||
h, w, _ = results['img_shape']
|
||||
area = h * w
|
||||
pad_h = int(h * self.pad_ratio)
|
||||
pad_w = int(w * self.pad_ratio)
|
||||
h_axis, w_axis = self.generate_crop_target(image, all_polygons, pad_h,
|
||||
pad_w)
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return results
|
||||
|
||||
attempt = 0
|
||||
while attempt < 10:
|
||||
attempt += 1
|
||||
polys_keep = []
|
||||
polys_new = []
|
||||
ign_polys_keep = []
|
||||
ign_polys_new = []
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
|
||||
# area too small
|
||||
continue
|
||||
|
||||
pts = np.stack([[xmin, xmax, xmax, xmin],
|
||||
[ymin, ymin, ymax, ymax]]).T.astype(np.int32)
|
||||
pp = plg(pts)
|
||||
fail_flag = False
|
||||
for polygon in polygons:
|
||||
ppi = plg(polygon[0].reshape(-1, 2))
|
||||
ppiou = eval_utils.poly_intersection(ppi, pp)
|
||||
if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
|
||||
np.abs(ppiou) > self.epsilon:
|
||||
fail_flag = True
|
||||
break
|
||||
elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
|
||||
polys_new.append(polygon)
|
||||
else:
|
||||
polys_keep.append(polygon)
|
||||
|
||||
for polygon in ignore_polygons:
|
||||
ppi = plg(polygon[0].reshape(-1, 2))
|
||||
ppiou = eval_utils.poly_intersection(ppi, pp)
|
||||
if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
|
||||
np.abs(ppiou) > self.epsilon:
|
||||
fail_flag = True
|
||||
break
|
||||
elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
|
||||
ign_polys_new.append(polygon)
|
||||
else:
|
||||
ign_polys_keep.append(polygon)
|
||||
|
||||
if fail_flag:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
cropped = image[ymin:ymax, xmin:xmax, :]
|
||||
select_type = np.random.randint(3)
|
||||
if select_type == 0:
|
||||
img = np.ascontiguousarray(cropped[:, ::-1])
|
||||
elif select_type == 1:
|
||||
img = np.ascontiguousarray(cropped[::-1, :])
|
||||
else:
|
||||
img = np.ascontiguousarray(cropped[::-1, ::-1])
|
||||
image[ymin:ymax, xmin:xmax, :] = img
|
||||
results['img'] = image
|
||||
|
||||
if len(polys_new) + len(ign_polys_new) != 0:
|
||||
height, width, _ = cropped.shape
|
||||
if select_type == 0:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
polys_new[idx] = [poly.reshape(-1, )]
|
||||
for idx, polygon in enumerate(ign_polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
ign_polys_new[idx] = [poly.reshape(-1, )]
|
||||
elif select_type == 1:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
polys_new[idx] = [poly.reshape(-1, )]
|
||||
for idx, polygon in enumerate(ign_polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
ign_polys_new[idx] = [poly.reshape(-1, )]
|
||||
else:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
polys_new[idx] = [poly.reshape(-1, )]
|
||||
for idx, polygon in enumerate(ign_polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
ign_polys_new[idx] = [poly.reshape(-1, )]
|
||||
polygons = polys_keep + polys_new
|
||||
ignore_polygons = ign_polys_keep + ign_polys_new
|
||||
results['gt_masks'] = PolygonMasks(polygons, *(image.shape[:2]))
|
||||
results['gt_masks_ignore'] = PolygonMasks(ignore_polygons,
|
||||
*(image.shape[:2]))
|
||||
|
||||
return results
|
||||
|
||||
def generate_crop_target(self, image, all_polys, pad_h, pad_w):
|
||||
"""Generate crop target and make sure not to crop the polygon
|
||||
instances.
|
||||
|
||||
Args:
|
||||
image (ndarray): The image waited to be crop.
|
||||
all_polys (list[list[ndarray]]): All polygons including ground
|
||||
truth polygons and ground truth ignored polygons.
|
||||
pad_h (int): Padding length of height.
|
||||
pad_w (int): Padding length of width.
|
||||
Returns:
|
||||
h_axis (ndarray): Vertical cropping range.
|
||||
w_axis (ndarray): Horizontal cropping range.
|
||||
"""
|
||||
h, w, _ = image.shape
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
|
||||
text_polys = []
|
||||
for polygon in all_polys:
|
||||
rect = cv2.minAreaRect(polygon[0].astype(np.int32).reshape(-1, 2))
|
||||
box = cv2.boxPoints(rect)
|
||||
box = np.int0(box)
|
||||
text_polys.append([box[0], box[1], box[2], box[3]])
|
||||
|
||||
polys = np.array(text_polys, dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
return h_axis, w_axis
|
||||
|
|
|
@ -187,52 +187,6 @@ def test_random_scale():
|
|||
assert np.allclose(out_poly, gt_poly)
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.randint' % __name__)
|
||||
def test_random_crop_flip(mock_randint):
|
||||
img = np.ones((10, 10, 3), dtype=np.uint8)
|
||||
img[0, 0, :] = 0
|
||||
results = {'img': img, 'img_shape': img.shape}
|
||||
|
||||
polygon = np.array([0., 0., 0., 10., 10., 10., 10., 0.])
|
||||
|
||||
results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2]))
|
||||
results['gt_masks_ignore'] = PolygonMasks([], *(img.shape[:2]))
|
||||
results['mask_fields'] = ['gt_masks', 'gt_masks_ignore']
|
||||
|
||||
crop_ratio = 1.1
|
||||
iter_num = 3
|
||||
random_crop_fliper = transforms.RandomCropFlip(
|
||||
crop_ratio=crop_ratio, iter_num=iter_num)
|
||||
|
||||
# test crop_target
|
||||
pad_ratio = 0.1
|
||||
h, w = img.shape[:2]
|
||||
pad_h = int(h * pad_ratio)
|
||||
pad_w = int(w * pad_ratio)
|
||||
all_polys = results['gt_masks'].masks
|
||||
h_axis, w_axis = random_crop_fliper.generate_crop_target(
|
||||
img, all_polys, pad_h, pad_w)
|
||||
|
||||
assert np.allclose(h_axis, (0, 11))
|
||||
assert np.allclose(w_axis, (0, 11))
|
||||
|
||||
# test __call__
|
||||
polygon = np.array([1., 1., 1., 9., 9., 9., 9., 1.])
|
||||
results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2]))
|
||||
results['gt_masks_ignore'] = PolygonMasks([[polygon]], *(img.shape[:2]))
|
||||
|
||||
mock_randint.side_effect = [0, 1, 2]
|
||||
results = random_crop_fliper(results)
|
||||
|
||||
out_img = results['img']
|
||||
out_poly = results['gt_masks'].masks[0][0]
|
||||
gt_img = img
|
||||
gt_poly = polygon
|
||||
|
||||
assert np.allclose(out_img, gt_img)
|
||||
assert np.allclose(out_poly, gt_poly)
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.random_sample' % __name__)
|
||||
@mock.patch('%s.transforms.np.random.randint' % __name__)
|
||||
def test_random_crop_poly_instances(mock_randint, mock_sample):
|
||||
|
|
|
@ -5,7 +5,8 @@ import unittest.mock as mock
|
|||
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.pipelines import PyramidRescale, RandomRotate, Resize
|
||||
from mmocr.datasets.pipelines import (PyramidRescale, RandomRotate, Resize,
|
||||
TextDetRandomCropFlip)
|
||||
|
||||
|
||||
class TestPyramidRescale(unittest.TestCase):
|
||||
|
@ -55,6 +56,64 @@ class TestPyramidRescale(unittest.TestCase):
|
|||
'base_w = 128, base_h = 512)'))
|
||||
|
||||
|
||||
class TestTextDetRandomCropFlip(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
img = np.ones((10, 10, 3))
|
||||
img[0, 0, :] = 0
|
||||
self.data_info1 = dict(
|
||||
img=copy.deepcopy(img),
|
||||
gt_polygons=[np.array([0., 0., 0., 10., 10., 10., 10., 0.])],
|
||||
img_shape=[10, 10])
|
||||
self.data_info2 = dict(
|
||||
img=copy.deepcopy(img),
|
||||
gt_polygons=[np.array([1., 1., 1., 9., 9., 9., 9., 1.])],
|
||||
img_shape=[10, 10])
|
||||
|
||||
def test_init(self):
|
||||
# iter_num is int
|
||||
transform = TextDetRandomCropFlip(iter_num=1)
|
||||
self.assertEqual(transform.iter_num, 1)
|
||||
# iter_num is float
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'`iter_num` should be an integer'):
|
||||
transform = TextDetRandomCropFlip(iter_num=1.5)
|
||||
|
||||
@mock.patch('mmocr.datasets.pipelines.processing.np.random.randint')
|
||||
def test_transforms(self, mock_sample):
|
||||
mock_sample.side_effect = [0, 1, 2]
|
||||
transform = TextDetRandomCropFlip(crop_ratio=1.0, iter_num=3)
|
||||
results = transform(self.data_info2)
|
||||
self.assertTrue(np.allclose(results['img'], self.data_info2['img']))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_polygons'],
|
||||
self.data_info2['gt_polygons']))
|
||||
|
||||
def test_generate_crop_target(self):
|
||||
transform = TextDetRandomCropFlip(
|
||||
crop_ratio=1.0, iter_num=3, pad_ratio=0.1)
|
||||
h, w = self.data_info1['img_shape']
|
||||
pad_h = int(h * transform.pad_ratio)
|
||||
pad_w = int(w * transform.pad_ratio)
|
||||
h_axis, w_axis = transform._generate_crop_target(
|
||||
self.data_info1['img'], self.data_info1['gt_polygons'], pad_h,
|
||||
pad_w)
|
||||
self.assertTrue(np.allclose(h_axis, (0, 11)))
|
||||
self.assertTrue(np.allclose(w_axis, (0, 11)))
|
||||
|
||||
def test_repr(self):
|
||||
transform = TextDetRandomCropFlip(
|
||||
pad_ratio=0.1,
|
||||
crop_ratio=0.5,
|
||||
iter_num=1,
|
||||
min_area_ratio=0.2,
|
||||
epsilon=1e-2)
|
||||
self.assertEqual(
|
||||
repr(transform),
|
||||
('TextDetRandomCropFlip(pad_ratio = 0.1, crop_ratio = 0.5, '
|
||||
'iter_num = 1, min_area_ratio = 0.2, epsilon = 0.01)'))
|
||||
|
||||
|
||||
class TestRandomRotate(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
|
Loading…
Reference in New Issue