[Processing]remove segocr and split processing

This commit is contained in:
liukuikun 2022-07-13 06:01:57 +00:00 committed by gaotongxiao
parent a844b497db
commit d50d2a46eb
47 changed files with 2052 additions and 3039 deletions

View File

@ -11,7 +11,7 @@ train_pipeline_r18 = [
dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(
type='ImgAug',
type='ImgAugWrapper',
args=[['Fliplr', 0.5],
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
dict(type='EastRandomCrop', target_size=(640, 640)),
@ -57,7 +57,7 @@ train_pipeline_r50dcnv2 = [
dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
dict(type='Normalize', **img_norm_cfg_r50dcnv2),
dict(
type='ImgAug',
type='ImgAugWrapper',
args=[['Fliplr', 0.5],
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
dict(type='EastRandomCrop', target_size=(640, 640)),

View File

@ -20,7 +20,7 @@ train_pipeline_r18 = [
brightness=32.0 / 255,
saturation=0.5),
dict(
type='ImgAug',
type='ImgAugWrapper',
args=[['Fliplr', 0.5],
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
dict(type='RandomCrop', min_side_ratio=0.1),

View File

@ -22,7 +22,7 @@ train_pipeline_r50dcnv2 = [
brightness=32.0 / 255,
saturation=0.5),
dict(
type='ImgAug',
type='ImgAugWrapper',
args=[['Fliplr', 0.5],
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
dict(type='RandomCrop', min_side_ratio=0.1),

View File

@ -1,13 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .icdar_dataset import IcdarDataset
from .ocr_dataset import OCRDataset
from .ocr_seg_dataset import OCRSegDataset
from .pipelines import * # NOQA
from .recog_lmdb_dataset import RecogLMDBDataset
from .recog_text_dataset import RecogTextDataset
from .transforms import * # NOQA
from .wildreceipt_dataset import WildReceiptDataset
__all__ = [
'IcdarDataset', 'OCRDataset', 'OCRSegDataset', 'RecogLMDBDataset',
'RecogTextDataset', 'WildReceiptDataset'
'IcdarDataset', 'OCRDataset', 'RecogLMDBDataset', 'RecogTextDataset',
'WildReceiptDataset'
]

View File

@ -1,90 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmocr.utils as utils
from mmocr.datasets.ocr_dataset import OCRDataset
from mmocr.registry import DATASETS
@DATASETS.register_module()
class OCRSegDataset(OCRDataset):
def pre_pipeline(self, results):
results['img_prefix'] = self.img_prefix
def _parse_anno_info(self, annotations):
"""Parse char boxes annotations.
Args:
annotations (list[dict]): Annotations of one image, where
each dict is for one character.
Returns:
dict: A dict containing the following keys:
- chars (list[str]): List of character strings.
- char_rects (list[list[float]]): List of char box, with each
in style of rectangle: [x_min, y_min, x_max, y_max].
- char_quads (list[list[float]]): List of char box, with each
in style of quadrangle: [x1, y1, x2, y2, x3, y3, x4, y4].
"""
assert utils.is_type_list(annotations, dict)
assert 'char_box' in annotations[0]
assert 'char_text' in annotations[0]
assert len(annotations[0]['char_box']) in [4, 8]
chars, char_rects, char_quads = [], [], []
for ann in annotations:
char_box = ann['char_box']
if len(char_box) == 4:
char_box_type = ann.get('char_box_type', 'xyxy')
if char_box_type == 'xyxy':
char_rects.append(char_box)
char_quads.append([
char_box[0], char_box[1], char_box[2], char_box[1],
char_box[2], char_box[3], char_box[0], char_box[3]
])
elif char_box_type == 'xywh':
x1, y1, w, h = char_box
x2 = x1 + w
y2 = y1 + h
char_rects.append([x1, y1, x2, y2])
char_quads.append([x1, y1, x2, y1, x2, y2, x1, y2])
else:
raise ValueError(f'invalid char_box_type {char_box_type}')
elif len(char_box) == 8:
x_list, y_list = [], []
for i in range(4):
x_list.append(char_box[2 * i])
y_list.append(char_box[2 * i + 1])
x_max, x_min = max(x_list), min(x_list)
y_max, y_min = max(y_list), min(y_list)
char_rects.append([x_min, y_min, x_max, y_max])
char_quads.append(char_box)
else:
raise Exception(
f'invalid num in char box: {len(char_box)} not in (4, 8)')
chars.append(ann['char_text'])
ann = dict(chars=chars, char_rects=char_rects, char_quads=char_quads)
return ann
def prepare_train_img(self, index):
"""Get training data and annotations from pipeline.
Args:
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_ann_info = self.data_infos[index]
img_info = {
'filename': img_ann_info['file_name'],
}
ann_info = self._parse_anno_info(img_ann_info['annotations'])
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return self.pipeline(results)

View File

@ -1,25 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
from .loading import LoadKIEAnnotations, LoadOCRAnnotations
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
OpencvToPil, PilToOpencv, RandomPaddingOCR,
RandomRotateImageBox, ResizeOCR, ToTensorOCR)
from .processing import (BoundedScaleAspectJitter, FixInvalidPolygon,
PadToWidth, PyramidRescale, RandomCrop, RandomFlip,
RandomRotate, RescaleToHeight, Resize,
ShortScaleAspectJitter, SourceImagePad,
TextDetRandomCrop, TextDetRandomCropFlip)
from .test_time_aug import MultiRotateAugOCR
from .wrappers import ImgAug, TorchVisionWrapper
__all__ = [
'LoadOCRAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR',
'ToTensorOCR', 'RandomRotate', 'MultiRotateAugOCR', 'FancyPCA',
'RandomPaddingOCR', 'ImgAug', 'RandomRotateImageBox', 'OpencvToPil',
'PilToOpencv', 'SourceImagePad', 'TextDetRandomCropFlip', 'PyramidRescale',
'TorchVisionWrapper', 'Resize', 'RandomCrop', 'TextDetRandomCrop',
'RandomCrop', 'PackTextDetInputs', 'PackTextRecogInputs',
'RescaleToHeight', 'PadToWidth', 'ShortScaleAspectJitter', 'RandomFlip',
'BoundedScaleAspectJitter', 'PackKIEInputs', 'LoadKIEAnnotations',
'FixInvalidPolygon'
]

View File

@ -1,454 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import mmcv
import numpy as np
import torch
import torchvision.transforms.functional as TF
from mmcv.runner.dist_utils import get_dist_info
from PIL import Image
from shapely.geometry import Polygon
from shapely.geometry import box as shapely_box
import mmocr.utils as utils
from mmocr.datasets.pipelines.crop import warp_img
from mmocr.registry import TRANSFORMS
@TRANSFORMS.register_module()
class ResizeOCR:
"""Image resizing and padding for OCR.
Args:
height (int | tuple(int)): Image height after resizing.
min_width (none | int | tuple(int)): Image minimum width
after resizing.
max_width (none | int | tuple(int)): Image maximum width
after resizing.
keep_aspect_ratio (bool): Keep image aspect ratio if True
during resizing, Otherwise resize to the size height *
max_width.
img_pad_value (int): Scalar to fill padding area.
width_downsample_ratio (float): Downsample ratio in horizontal
direction from input image to output feature.
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used. Default: None.
"""
def __init__(self,
height,
min_width=None,
max_width=None,
keep_aspect_ratio=True,
img_pad_value=0,
width_downsample_ratio=1.0 / 16,
backend=None):
assert isinstance(height, (int, tuple))
assert utils.is_none_or_type(min_width, (int, tuple))
assert utils.is_none_or_type(max_width, (int, tuple))
if not keep_aspect_ratio:
assert max_width is not None, ('"max_width" must assigned '
'if "keep_aspect_ratio" is False')
assert isinstance(img_pad_value, int)
if isinstance(height, tuple):
assert isinstance(min_width, tuple)
assert isinstance(max_width, tuple)
assert len(height) == len(min_width) == len(max_width)
self.height = height
self.min_width = min_width
self.max_width = max_width
self.keep_aspect_ratio = keep_aspect_ratio
self.img_pad_value = img_pad_value
self.width_downsample_ratio = width_downsample_ratio
self.backend = backend
def __call__(self, results):
rank, _ = get_dist_info()
if isinstance(self.height, int):
dst_height = self.height
dst_min_width = self.min_width
dst_max_width = self.max_width
else:
# Multi-scale resize used in distributed training.
# Choose one (height, width) pair for one rank id.
idx = rank % len(self.height)
dst_height = self.height[idx]
dst_min_width = self.min_width[idx]
dst_max_width = self.max_width[idx]
img_shape = results['img_shape']
ori_height, ori_width = img_shape[:2]
valid_ratio = 1.0
resize_shape = list(img_shape)
pad_shape = list(img_shape)
if self.keep_aspect_ratio:
new_width = math.ceil(float(dst_height) / ori_height * ori_width)
width_divisor = int(1 / self.width_downsample_ratio)
# make sure new_width is an integral multiple of width_divisor.
if new_width % width_divisor != 0:
new_width = round(new_width / width_divisor) * width_divisor
if dst_min_width is not None:
new_width = max(dst_min_width, new_width)
if dst_max_width is not None:
valid_ratio = min(1.0, 1.0 * new_width / dst_max_width)
resize_width = min(dst_max_width, new_width)
img_resize = mmcv.imresize(
results['img'], (resize_width, dst_height),
backend=self.backend)
resize_shape = img_resize.shape
pad_shape = img_resize.shape
if new_width < dst_max_width:
img_resize = mmcv.impad(
img_resize,
shape=(dst_height, dst_max_width),
pad_val=self.img_pad_value)
pad_shape = img_resize.shape
else:
img_resize = mmcv.imresize(
results['img'], (new_width, dst_height),
backend=self.backend)
resize_shape = img_resize.shape
pad_shape = img_resize.shape
else:
img_resize = mmcv.imresize(
results['img'], (dst_max_width, dst_height),
backend=self.backend)
resize_shape = img_resize.shape
pad_shape = img_resize.shape
results['img'] = img_resize
results['img_shape'] = resize_shape
results['resize_shape'] = resize_shape
results['pad_shape'] = pad_shape
results['valid_ratio'] = valid_ratio
return results
@TRANSFORMS.register_module()
class ToTensorOCR:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor."""
def __init__(self):
pass
def __call__(self, results):
results['img'] = TF.to_tensor(results['img'].copy())
return results
@TRANSFORMS.register_module()
class NormalizeOCR:
"""Normalize a tensor image with mean and standard deviation."""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, results):
results['img'] = TF.normalize(results['img'], self.mean, self.std)
results['img_norm_cfg'] = dict(mean=self.mean, std=self.std)
return results
@TRANSFORMS.register_module()
class OnlineCropOCR:
"""Crop text areas from whole image with bounding box jitter. If no bbox is
given, return directly.
Args:
box_keys (list[str]): Keys in results which correspond to RoI bbox.
jitter_prob (float): The probability of box jitter.
max_jitter_ratio_x (float): Maximum horizontal jitter ratio
relative to height.
max_jitter_ratio_y (float): Maximum vertical jitter ratio
relative to height.
"""
def __init__(self,
box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
jitter_prob=0.5,
max_jitter_ratio_x=0.05,
max_jitter_ratio_y=0.02):
assert utils.is_type_list(box_keys, str)
assert 0 <= jitter_prob <= 1
assert 0 <= max_jitter_ratio_x <= 1
assert 0 <= max_jitter_ratio_y <= 1
self.box_keys = box_keys
self.jitter_prob = jitter_prob
self.max_jitter_ratio_x = max_jitter_ratio_x
self.max_jitter_ratio_y = max_jitter_ratio_y
def __call__(self, results):
if 'img_info' not in results:
return results
crop_flag = True
box = []
for key in self.box_keys:
if key not in results['img_info']:
crop_flag = False
break
box.append(float(results['img_info'][key]))
if not crop_flag:
return results
jitter_flag = np.random.random() > self.jitter_prob
kwargs = dict(
jitter_flag=jitter_flag,
jitter_ratio_x=self.max_jitter_ratio_x,
jitter_ratio_y=self.max_jitter_ratio_y)
crop_img = warp_img(results['img'], box, **kwargs)
results['img'] = crop_img
results['img_shape'] = crop_img.shape
return results
@TRANSFORMS.register_module()
class FancyPCA:
"""Implementation of PCA based image augmentation, proposed in the paper
``Imagenet Classification With Deep Convolutional Neural Networks``.
It alters the intensities of RGB values along the principal components of
ImageNet dataset.
"""
def __init__(self, eig_vec=None, eig_val=None):
if eig_vec is None:
eig_vec = torch.Tensor([
[-0.5675, +0.7192, +0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, +0.4203],
]).t()
if eig_val is None:
eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]])
self.eig_val = eig_val # 1*3
self.eig_vec = eig_vec # 3*3
def pca(self, tensor):
assert tensor.size(0) == 3
alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
reconst = torch.mm(self.eig_val * alpha, self.eig_vec)
tensor = tensor + reconst.view(3, 1, 1)
return tensor
def __call__(self, results):
img = results['img']
tensor = self.pca(img)
results['img'] = tensor
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
@TRANSFORMS.register_module()
class RandomPaddingOCR:
"""Pad the given image on all sides, as well as modify the coordinates of
character bounding box in image.
Args:
max_ratio (list[int]): [left, top, right, bottom].
box_type (None|str): Character box type. If not none,
should be either 'char_rects' or 'char_quads', with
'char_rects' for rectangle with ``xyxy`` style and
'char_quads' for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
"""
def __init__(self, max_ratio=None, box_type=None):
if max_ratio is None:
max_ratio = [0.1, 0.2, 0.1, 0.2]
else:
assert utils.is_type_list(max_ratio, float)
assert len(max_ratio) == 4
assert box_type is None or box_type in ('char_rects', 'char_quads')
self.max_ratio = max_ratio
self.box_type = box_type
def __call__(self, results):
img_shape = results['img_shape']
ori_height, ori_width = img_shape[:2]
random_padding_left = round(
np.random.uniform(0, self.max_ratio[0]) * ori_width)
random_padding_top = round(
np.random.uniform(0, self.max_ratio[1]) * ori_height)
random_padding_right = round(
np.random.uniform(0, self.max_ratio[2]) * ori_width)
random_padding_bottom = round(
np.random.uniform(0, self.max_ratio[3]) * ori_height)
padding = (random_padding_left, random_padding_top,
random_padding_right, random_padding_bottom)
img = mmcv.impad(results['img'], padding=padding, padding_mode='edge')
results['img'] = img
results['img_shape'] = img.shape
if self.box_type is not None:
num_points = 2 if self.box_type == 'char_rects' else 4
char_num = len(results['ann_info'][self.box_type])
for i in range(char_num):
for j in range(num_points):
results['ann_info'][self.box_type][i][
j * 2] += random_padding_left
results['ann_info'][self.box_type][i][
j * 2 + 1] += random_padding_top
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
@TRANSFORMS.register_module()
class RandomRotateImageBox:
"""Rotate augmentation for segmentation based text recognition.
Args:
min_angle (int): Minimum rotation angle for image and box.
max_angle (int): Maximum rotation angle for image and box.
box_type (str): Character box type, should be either
'char_rects' or 'char_quads', with 'char_rects'
for rectangle with ``xyxy`` style and 'char_quads'
for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
"""
def __init__(self, min_angle=-10, max_angle=10, box_type='char_quads'):
assert box_type in ('char_rects', 'char_quads')
self.min_angle = min_angle
self.max_angle = max_angle
self.box_type = box_type
def __call__(self, results):
in_img = results['img']
in_chars = results['ann_info']['chars']
in_boxes = results['ann_info'][self.box_type]
img_width, img_height = in_img.size
rotate_center = [img_width / 2., img_height / 2.]
tan_temp_max_angle = rotate_center[1] / rotate_center[0]
temp_max_angle = np.arctan(tan_temp_max_angle) * 180. / np.pi
random_angle = np.random.uniform(
max(self.min_angle, -temp_max_angle),
min(self.max_angle, temp_max_angle))
random_angle_radian = random_angle * np.pi / 180.
img_box = shapely_box(0, 0, img_width, img_height)
out_img = TF.rotate(
in_img,
random_angle,
resample=False,
expand=False,
center=rotate_center)
out_boxes, out_chars = self.rotate_bbox(in_boxes, in_chars,
random_angle_radian,
rotate_center, img_box)
results['img'] = out_img
results['ann_info']['chars'] = out_chars
results['ann_info'][self.box_type] = out_boxes
return results
@staticmethod
def rotate_bbox(boxes, chars, angle, center, img_box):
out_boxes = []
out_chars = []
for idx, bbox in enumerate(boxes):
temp_bbox = []
for i in range(len(bbox) // 2):
point = [bbox[2 * i], bbox[2 * i + 1]]
temp_bbox.append(
RandomRotateImageBox.rotate_point(point, angle, center))
poly_temp_bbox = Polygon(temp_bbox).buffer(0)
if poly_temp_bbox.is_valid:
if img_box.intersects(poly_temp_bbox) and (
not img_box.touches(poly_temp_bbox)):
temp_bbox_area = poly_temp_bbox.area
intersect_area = img_box.intersection(poly_temp_bbox).area
intersect_ratio = intersect_area / temp_bbox_area
if intersect_ratio >= 0.7:
out_box = []
for p in temp_bbox:
out_box.extend(p)
out_boxes.append(out_box)
out_chars.append(chars[idx])
return out_boxes, out_chars
@staticmethod
def rotate_point(point, angle, center):
cos_theta = math.cos(-angle)
sin_theta = math.sin(-angle)
c_x = center[0]
c_y = center[1]
new_x = (point[0] - c_x) * cos_theta - (point[1] -
c_y) * sin_theta + c_x
new_y = (point[0] - c_x) * sin_theta + (point[1] -
c_y) * cos_theta + c_y
return [new_x, new_y]
@TRANSFORMS.register_module()
class OpencvToPil:
"""Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb)."""
def __init__(self, **kwargs):
pass
def __call__(self, results):
img = results['img'][..., ::-1]
img = Image.fromarray(img)
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
@TRANSFORMS.register_module()
class PilToOpencv:
"""Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr)."""
def __init__(self, **kwargs):
pass
def __call__(self, results):
img = np.asarray(results['img'])
img = img[..., ::-1]
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str

View File

@ -1,109 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmcv.transforms import Compose
from mmocr.registry import TRANSFORMS
@TRANSFORMS.register_module()
class MultiRotateAugOCR:
"""Test-time augmentation with multiple rotations in the case that
img_height > img_width.
An example configuration is as follows:
.. code-block::
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=32,
min_width=32,
max_width=160,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
]),
]
After MultiRotateAugOCR with above configuration, the results are wrapped
into lists of the same length as follows:
.. code-block::
dict(
img=[...],
img_shape=[...]
...
)
Args:
transforms (list[dict]): Transformation applied for each augmentation.
rotate_degrees (list[int] | None): Degrees of anti-clockwise rotation.
force_rotate (bool): If True, rotate image by 'rotate_degrees'
while ignore image aspect ratio.
"""
def __init__(self, transforms, rotate_degrees=None, force_rotate=False):
self.transforms = Compose(transforms)
self.force_rotate = force_rotate
if rotate_degrees is not None:
self.rotate_degrees = rotate_degrees if isinstance(
rotate_degrees, list) else [rotate_degrees]
assert mmcv.is_list_of(self.rotate_degrees, int)
for degree in self.rotate_degrees:
assert 0 <= degree < 360
assert degree % 90 == 0
if 0 not in self.rotate_degrees:
self.rotate_degrees.append(0)
else:
self.rotate_degrees = [0]
def __call__(self, results):
"""Call function to apply test time augment transformation to results.
Args:
results (dict): Result dict contains the data to be transformed.
Returns:
dict[str: list]: The augmented data, where each value is wrapped
into a list.
"""
img_shape = results['img_shape']
ori_height, ori_width = img_shape[:2]
if not self.force_rotate and ori_height <= ori_width:
rotate_degrees = [0]
else:
rotate_degrees = self.rotate_degrees
aug_data = []
for degree in set(rotate_degrees):
_results = results.copy()
if degree == 0:
pass
elif degree == 90:
_results['img'] = np.rot90(_results['img'], 1)
elif degree == 180:
_results['img'] = np.rot90(_results['img'], 2)
elif degree == 270:
_results['img'] = np.rot90(_results['img'], 3)
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'rotate_degrees={self.rotate_degrees})'
return repr_str

View File

@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
from .loading import LoadKIEAnnotations, LoadOCRAnnotations
from .ocr_transforms import RandomCrop, RandomRotate, Resize
from .textdet_transforms import (BoundedScaleAspectJitter, FixInvalidPolygon,
RandomFlip, ShortScaleAspectJitter,
SourceImagePad, TextDetRandomCrop,
TextDetRandomCropFlip)
from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight
from .wrappers import ImgAugWrapper, TorchVisionWrapper
__all__ = [
'LoadOCRAnnotations', 'RandomRotate', 'ImgAugWrapper', 'SourceImagePad',
'TextDetRandomCropFlip', 'PyramidRescale', 'TorchVisionWrapper', 'Resize',
'RandomCrop', 'TextDetRandomCrop', 'RandomCrop', 'PackTextDetInputs',
'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth',
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon'
]

View File

@ -0,0 +1,615 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Tuple
import cv2
import mmcv
import numpy as np
from mmcv.transforms import Resize as MMCV_Resize
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import avoid_cache_randomness, cache_randomness
from mmocr.registry import TRANSFORMS
from mmocr.utils import (bbox2poly, crop_polygon, is_poly_inside_rect,
poly2bbox, rescale_polygon)
from .wrappers import ImgAugWrapper
@TRANSFORMS.register_module()
@avoid_cache_randomness
class RandomCrop(BaseTransform):
"""Randomly crop images and make sure to contain at least one intact
instance.
Required Keys:
- img
- gt_polygons
- gt_bboxes
- gt_bboxes_labels
- gt_ignored
- gt_texts (optional)
Modified Keys:
- img
- img_shape
- gt_polygons
- gt_bboxes
- gt_bboxes_labels
- gt_ignored
- gt_texts (optional)
Args:
min_side_ratio (float): The ratio of the shortest edge of the cropped
image to the original image size.
"""
def __init__(self, min_side_ratio: float = 0.4) -> None:
if not 0. <= min_side_ratio <= 1.:
raise ValueError('`min_side_ratio` should be in range [0, 1],')
self.min_side_ratio = min_side_ratio
def _sample_valid_start_end(self, valid_array: np.ndarray, min_len: int,
max_start_idx: int,
min_end_idx: int) -> Tuple[int, int]:
"""Sample a start and end idx on a given axis that contains at least
one polygon. There should be at least one intact polygon bounded by
max_start_idx and min_end_idx.
Args:
valid_array (ndarray): A 0-1 mask 1D array indicating valid regions
on the axis. 0 indicates text regions which are not allowed to
be sampled from.
min_len (int): Minimum distance between two start and end points.
max_start_idx (int): The maximum start index.
min_end_idx (int): The minimum end index.
Returns:
tuple(int, int): Start and end index on a given axis, where
0 <= start < max_start_idx and
min_end_idx <= end < len(valid_array).
"""
assert isinstance(min_len, int)
assert len(valid_array) > min_len
start_array = valid_array.copy()
max_start_idx = min(len(start_array) - min_len, max_start_idx)
start_array[max_start_idx:] = 0
start_array[0] = 1
diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
region_starts = np.where(diff_array < 0)[0]
region_ends = np.where(diff_array > 0)[0]
region_ind = np.random.randint(0, len(region_starts))
start = np.random.randint(region_starts[region_ind],
region_ends[region_ind])
end_array = valid_array.copy()
min_end_idx = max(start + min_len, min_end_idx)
end_array[:min_end_idx] = 0
end_array[-1] = 1
diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
region_starts = np.where(diff_array < 0)[0]
region_ends = np.where(diff_array > 0)[0]
region_ind = np.random.randint(0, len(region_starts))
# Note that end index will never be region_ends[region_ind]
# and therefore end index is always in range [0, w+1]
end = np.random.randint(region_starts[region_ind],
region_ends[region_ind])
return start, end
def _sample_crop_box(self, img_size: Tuple[int, int],
results: Dict) -> np.ndarray:
"""Generate crop box which only contains intact polygon instances with
the number >= 1.
Args:
img_size (tuple(int, int)): The image size (h, w).
results (dict): The results dict.
Returns:
ndarray: Crop area in shape (4, ).
"""
assert isinstance(img_size, tuple)
h, w = img_size[:2]
# Crop box can be represented by any integer numbers in
# range [0, w] and [0, h]
x_valid_array = np.ones(w + 1, dtype=np.int32)
y_valid_array = np.ones(h + 1, dtype=np.int32)
polygons = results['gt_polygons']
# Randomly select a polygon that must be inside
# the cropped region
kept_poly_idx = np.random.randint(0, len(polygons))
for i, polygon in enumerate(polygons):
polygon = polygon.reshape((-1, 2))
clip_x = np.clip(polygon[:, 0], 0, w)
clip_y = np.clip(polygon[:, 1], 0, h)
min_x = np.floor(np.min(clip_x)).astype(np.int32)
min_y = np.floor(np.min(clip_y)).astype(np.int32)
max_x = np.ceil(np.max(clip_x)).astype(np.int32)
max_y = np.ceil(np.max(clip_y)).astype(np.int32)
x_valid_array[min_x:max_x] = 0
y_valid_array[min_y:max_y] = 0
if i == kept_poly_idx:
max_x_start = min_x
min_x_end = max_x
max_y_start = min_y
min_y_end = max_y
min_w = int(w * self.min_side_ratio)
min_h = int(h * self.min_side_ratio)
x1, x2 = self._sample_valid_start_end(x_valid_array, min_w,
max_x_start, min_x_end)
y1, y2 = self._sample_valid_start_end(y_valid_array, min_h,
max_y_start, min_y_end)
return np.array([x1, y1, x2, y2])
def _crop_img(self, img: np.ndarray, bbox: np.ndarray) -> np.ndarray:
"""Crop image given a bbox region.
Args:
img (ndarray): Image.
bbox (ndarray): Cropping region in shape (4, )
Returns:
ndarray: Cropped image.
"""
assert img.ndim == 3
h, w, _ = img.shape
assert 0 <= bbox[1] < bbox[3] <= h
assert 0 <= bbox[0] < bbox[2] <= w
return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
def transform(self, results: Dict) -> Dict:
"""Applying random crop on results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict: The transformed data.
"""
if len(results['gt_polygons']) < 1:
return results
crop_box = self._sample_crop_box(results['img'].shape, results)
img = self._crop_img(results['img'], crop_box)
results['img'] = img
results['img_shape'] = img.shape[:2]
crop_x = crop_box[0]
crop_y = crop_box[1]
crop_w = crop_box[2] - crop_box[0]
crop_h = crop_box[3] - crop_box[1]
labels = results['gt_bboxes_labels']
valid_labels = []
ignored = results['gt_ignored']
valid_ignored = []
if 'gt_texts' in results:
valid_texts = []
texts = results['gt_texts']
polys = results['gt_polygons']
valid_polys = []
for idx, poly in enumerate(polys):
poly = poly.reshape(-1, 2)
poly = (poly - (crop_x, crop_y)).flatten()
if is_poly_inside_rect(poly, [0, 0, crop_w, crop_h]):
valid_polys.append(poly)
valid_labels.append(labels[idx])
valid_ignored.append(ignored[idx])
if 'gt_texts' in results:
valid_texts.append(texts[idx])
results['gt_polygons'] = valid_polys
results['gt_bboxes_labels'] = np.array(valid_labels, dtype=np.int64)
results['gt_ignored'] = np.array(valid_ignored, dtype=bool)
if 'gt_texts' in results:
results['gt_texts'] = valid_texts
valid_bboxes = [poly2bbox(poly) for poly in results['gt_polygons']]
results['gt_bboxes'] = np.array(valid_bboxes).astype(
np.float32).reshape(-1, 4)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(min_side_ratio = {self.min_side_ratio})'
return repr_str
@TRANSFORMS.register_module()
class RandomRotate(BaseTransform):
"""Randomly rotate the image, boxes, and polygons. For recognition task,
only the image will be rotated. If set ``use_canvas`` as True, the shape of
rotated image might be modified based on the rotated angle size, otherwise,
the image will keep the shape before rotation.
Required Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_polygons (optional)
Modified Keys:
- img
- img_shape (optional)
- gt_bboxes (optional)
- gt_polygons (optional)
Added Keys:
- rotated_angle
Args:
max_angle (int): The maximum rotation angle (can be bigger than 180 or
a negative). Defaults to 10.
pad_with_fixed_color (bool): The flag for whether to pad rotated
image with fixed value. Defaults to False.
pad_value (tuple[int, int, int]): The color value for padding rotated
image. Defaults to (0, 0, 0).
use_canvas (bool): Whether to create a canvas for rotated image.
Defaults to False. If set true, the image shape may be modified.
"""
def __init__(
self,
max_angle: int = 10,
pad_with_fixed_color: bool = False,
pad_value: Tuple[int, int, int] = (0, 0, 0),
use_canvas: bool = False,
) -> None:
if not isinstance(max_angle, int):
raise TypeError('`max_angle` should be an integer'
f', but got {type(max_angle)} instead')
if not isinstance(pad_with_fixed_color, bool):
raise TypeError('`pad_with_fixed_color` should be a bool, '
f'but got {type(pad_with_fixed_color)} instead')
if not isinstance(pad_value, (list, tuple)):
raise TypeError('`pad_value` should be a list or tuple, '
f'but got {type(pad_value)} instead')
if len(pad_value) != 3:
raise ValueError('`pad_value` should contain three integers')
if not isinstance(pad_value[0], int) or not isinstance(
pad_value[1], int) or not isinstance(pad_value[2], int):
raise ValueError('`pad_value` should contain three integers')
self.max_angle = max_angle
self.pad_with_fixed_color = pad_with_fixed_color
self.pad_value = pad_value
self.use_canvas = use_canvas
@cache_randomness
def _sample_angle(self, max_angle: int) -> float:
"""Sampling a random angle for rotation.
Args:
max_angle (int): Maximum rotation angle
Returns:
float: The random angle used for rotation
"""
angle = np.random.random_sample() * 2 * max_angle - max_angle
return angle
@staticmethod
def _cal_canvas_size(ori_size: Tuple[int, int],
degree: int) -> Tuple[int, int]:
"""Calculate the canvas size.
Args:
ori_size (Tuple[int, int]): The original image size (height, width)
degree (int): The rotation angle
Returns:
Tuple[int, int]: The size of the canvas
"""
assert isinstance(ori_size, tuple)
angle = degree * math.pi / 180.0
h, w = ori_size[:2]
cos = math.cos(angle)
sin = math.sin(angle)
canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
canvas_size = (canvas_h, canvas_w)
return canvas_size
@staticmethod
def _rotate_points(center: Tuple[float, float],
points: np.array,
theta: float,
center_shift: Tuple[int, int] = (0, 0)) -> np.array:
"""Rotating a set of points according to the given theta.
Args:
center (Tuple[float, float]): The coordinate of the canvas center
points (np.array): A set of points needed to be rotated
theta (float): Rotation angle
center_shift (Tuple[int, int]): The shifting offset of the center
coordinate
Returns:
np.array: The rotated coordinates of the input points
"""
(center_x, center_y) = center
center_y = -center_y
x, y = points[::2], points[1::2]
y = -y
theta = theta / 180 * math.pi
cos = math.cos(theta)
sin = math.sin(theta)
x = (x - center_x)
y = (y - center_y)
_x = center_x + x * cos - y * sin + center_shift[0]
_y = -(center_y + x * sin + y * cos) + center_shift[1]
points[::2], points[1::2] = _x, _y
return points
def _rotate_img(self, results: Dict) -> Tuple[int, int]:
"""Rotating the input image based on the given angle.
Args:
results (dict): Result dict containing the data to transform.
Returns:
Tuple[int, int]: The shifting offset of the center point.
"""
if results.get('img', None) is not None:
h = results['img'].shape[0]
w = results['img'].shape[1]
rotation_matrix = cv2.getRotationMatrix2D(
(w / 2, h / 2), results['rotated_angle'], 1)
canvas_size = self._cal_canvas_size((h, w),
results['rotated_angle'])
center_shift = (int(
(canvas_size[1] - w) / 2), int((canvas_size[0] - h) / 2))
rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
if self.pad_with_fixed_color:
rotated_img = cv2.warpAffine(
results['img'],
rotation_matrix, (canvas_size[1], canvas_size[0]),
flags=cv2.INTER_NEAREST,
borderValue=self.pad_value)
else:
mask = np.zeros_like(results['img'])
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
np.random.randint(0, w * 7 // 8))
img_cut = results['img'][h_ind:(h_ind + h // 9),
w_ind:(w_ind + w // 9)]
img_cut = mmcv.imresize(img_cut,
(canvas_size[1], canvas_size[0]))
mask = cv2.warpAffine(
mask,
rotation_matrix, (canvas_size[1], canvas_size[0]),
borderValue=[1, 1, 1])
rotated_img = cv2.warpAffine(
results['img'],
rotation_matrix, (canvas_size[1], canvas_size[0]),
borderValue=[0, 0, 0])
rotated_img = rotated_img + img_cut * mask
results['img'] = rotated_img
else:
raise ValueError('`img` is not found in results')
return center_shift
def _rotate_bboxes(self, results: Dict, center_shift: Tuple[int,
int]) -> None:
"""Rotating the bounding boxes based on the given angle.
Args:
results (dict): Result dict containing the data to transform.
center_shift (Tuple[int, int]): The shifting offset of the
center point
"""
if results.get('gt_bboxes', None) is not None:
height, width = results['img_shape']
box_list = []
for box in results['gt_bboxes']:
rotated_box = self._rotate_points((width / 2, height / 2),
bbox2poly(box),
results['rotated_angle'],
center_shift)
rotated_box = poly2bbox(rotated_box)
box_list.append(rotated_box)
results['gt_bboxes'] = np.array(
box_list, dtype=np.float32).reshape(-1, 4)
def _rotate_polygons(self, results: Dict,
center_shift: Tuple[int, int]) -> None:
"""Rotating the polygons based on the given angle.
Args:
results (dict): Result dict containing the data to transform.
center_shift (Tuple[int, int]): The shifting offset of the
center point
"""
if results.get('gt_polygons', None) is not None:
height, width = results['img_shape']
polygon_list = []
for poly in results['gt_polygons']:
rotated_poly = self._rotate_points(
(width / 2, height / 2), poly, results['rotated_angle'],
center_shift)
polygon_list.append(rotated_poly)
results['gt_polygons'] = polygon_list
def transform(self, results: Dict) -> Dict:
"""Applying random rotate on results.
Args:
results (Dict): Result dict containing the data to transform.
center_shift (Tuple[int, int]): The shifting offset of the
center point
Returns:
dict: The transformed data
"""
# TODO rotate char_quads & char_rects for SegOCR
if self.use_canvas:
results['rotated_angle'] = self._sample_angle(self.max_angle)
# rotate image
center_shift = self._rotate_img(results)
# rotate gt_bboxes
self._rotate_bboxes(results, center_shift)
# rotate gt_polygons
self._rotate_polygons(results, center_shift)
results['img_shape'] = (results['img'].shape[0],
results['img'].shape[1])
else:
args = [
dict(
cls='Affine',
rotate=[-self.max_angle, self.max_angle],
backend='cv2',
order=0) # order=0 -> cv2.INTER_NEAREST
]
imgaug_transform = ImgAugWrapper(args)
results = imgaug_transform(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(max_angle = {self.max_angle}'
repr_str += f', pad_with_fixed_color = {self.pad_with_fixed_color}'
repr_str += f', pad_value = {self.pad_value}'
repr_str += f', use_canvas = {self.use_canvas})'
return repr_str
@TRANSFORMS.register_module()
class Resize(MMCV_Resize):
"""Resize image & bboxes & polygons.
This transform resizes the input image according to ``scale`` or
``scale_factor``. Bboxes and polygons are then resized with the same
scale factor. if ``scale`` and ``scale_factor`` are both set, it will use
``scale`` to resize.
Required Keys:
- img
- img_shape
- gt_bboxes
- gt_polygons
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_polygons
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (int or tuple): Image scales for resizing. Defaults to None.
scale_factor (float or tuple[float, float]): Scale factors for
resizing. It's either a factor applicable to both dimensions or
in the form of (scale_w, scale_h). Defaults to None.
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image. Defaults to False.
clip_object_border (bool): Whether to clip the objects outside the
border of the image. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def _resize_img(self, results: dict) -> None:
"""Resize images with ``results['scale']``.
If no image is provided, only resize ``results['img_shape']``.
"""
if results.get('img', None) is not None:
return super()._resize_img(results)
h, w = results['img_shape']
if self.keep_ratio:
new_w, new_h = mmcv.rescale_size((w, h),
results['scale'],
return_scale=False)
else:
new_w, new_h = results['scale']
w_scale = new_w / w
h_scale = new_h / h
results['img_shape'] = (new_h, new_w)
results['scale'] = (new_w, new_h)
results['scale_factor'] = (w_scale, h_scale)
results['keep_ratio'] = self.keep_ratio
def _resize_bboxes(self, results: dict) -> None:
"""Resize bounding boxes."""
super()._resize_bboxes(results)
if results.get('gt_bboxes', None) is not None:
results['gt_bboxes'] = results['gt_bboxes'].astype(np.float32)
def _resize_polygons(self, results: dict) -> None:
"""Resize polygons with ``results['scale_factor']``."""
if results.get('gt_polygons', None) is not None:
polygons = results['gt_polygons']
polygons_resize = []
for idx, polygon in enumerate(polygons):
polygon = rescale_polygon(polygon, results['scale_factor'])
if self.clip_object_border:
crop_bbox = np.array([
0, 0, results['img_shape'][1], results['img_shape'][0]
])
polygon = crop_polygon(polygon, crop_bbox)
if polygon is not None:
polygons_resize.append(polygon.astype(np.float32))
else:
polygons_resize.append(
np.zeros_like(polygons[idx], dtype=np.float32))
results['gt_polygons'] = polygons_resize
def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes and polygons.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_polygons',
'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys
are updated in result dict.
"""
results = super().transform(results)
self._resize_polygons(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(scale={self.scale}, '
repr_str += f'scale_factor={self.scale_factor}, '
repr_str += f'keep_ratio={self.keep_ratio}, '
repr_str += f'clip_object_border={self.clip_object_border}), '
repr_str += f'backend={self.backend}), '
repr_str += f'interpolation={self.interpolation})'
return repr_str

View File

@ -0,0 +1,252 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Optional, Tuple
import cv2
import mmcv
import numpy as np
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmocr.registry import TRANSFORMS
@TRANSFORMS.register_module()
class PyramidRescale(BaseTransform):
"""Resize the image to the base shape, downsample it with gaussian pyramid,
and rescale it back to original size.
Adapted from https://github.com/FangShancheng/ABINet.
Required Keys:
- img (ndarray)
Modified Keys:
- img (ndarray)
Args:
factor (int): The decay factor from base size, or the number of
downsampling operations from the base layer.
base_shape (tuple[int, int]): The shape (width, height) of the base
layer of the pyramid.
randomize_factor (bool): If True, the final factor would be a random
integer in [0, factor].
"""
def __init__(self,
factor: int = 4,
base_shape: Tuple[int, int] = (128, 512),
randomize_factor: bool = True) -> None:
if not isinstance(factor, int):
raise TypeError('`factor` should be an integer, '
f'but got {type(factor)} instead')
if not isinstance(base_shape, (list, tuple)):
raise TypeError('`base_shape` should be a list or tuple, '
f'but got {type(base_shape)} instead')
if not len(base_shape) == 2:
raise ValueError('`base_shape` should contain two integers')
if not isinstance(base_shape[0], int) or not isinstance(
base_shape[1], int):
raise ValueError('`base_shape` should contain two integers')
if not isinstance(randomize_factor, bool):
raise TypeError('`randomize_factor` should be a bool, '
f'but got {type(randomize_factor)} instead')
self.factor = factor
self.randomize_factor = randomize_factor
self.base_w, self.base_h = base_shape
@cache_randomness
def get_random_factor(self) -> float:
"""Get the randomized factor.
Returns:
float: The randomized factor.
"""
return np.random.randint(0, self.factor + 1)
def transform(self, results: Dict) -> Dict:
"""Applying pyramid rescale on results.
Args:
results (dict): Result dict containing the data to transform.
Returns:
Dict: The transformed data.
"""
assert 'img' in results, '`img` is not found in results'
if self.randomize_factor:
self.factor = self.get_random_factor()
if self.factor == 0:
return results
img = results['img']
src_h, src_w = img.shape[:2]
scale_img = mmcv.imresize(img, (self.base_w, self.base_h))
for _ in range(self.factor):
scale_img = cv2.pyrDown(scale_img)
scale_img = mmcv.imresize(scale_img, (src_w, src_h))
results['img'] = scale_img
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(factor = {self.factor}'
repr_str += f', randomize_factor = {self.randomize_factor}'
repr_str += f', base_w = {self.base_w}'
repr_str += f', base_h = {self.base_h})'
return repr_str
@TRANSFORMS.register_module()
class RescaleToHeight(BaseTransform):
"""Rescale the image to the height according to setting and keep the aspect
ratio unchanged if possible. However, if any of ``min_width``,
``max_width`` or ``width_divisor`` are specified, aspect ratio may still be
changed to ensure the width meets these constraints.
Required Keys:
- img
Modified Keys:
- img
- img_shape
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
height (int): Height of rescaled image.
min_width (int, optional): Minimum width of rescaled image. Defaults
to None.
max_width (int, optional): Maximum width of rescaled image. Defaults
to None.
width_divisor (int): The divisor of width size. Defaults to 1.
resize_cfg (dict): (dict): Config to construct the Resize transform.
Refer to ``Resize`` for detail. Defaults to
``dict(type='Resize')``.
"""
def __init__(self,
height: int,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
width_divisor: int = 1,
resize_cfg: dict = dict(type='Resize')) -> None:
super().__init__()
assert isinstance(height, int)
assert isinstance(width_divisor, int)
if min_width is not None:
assert isinstance(min_width, int)
if max_width is not None:
assert isinstance(max_width, int)
self.width_divisor = width_divisor
self.height = height
self.min_width = min_width
self.max_width = max_width
self.resize_cfg = resize_cfg
_resize_cfg = self.resize_cfg.copy()
_resize_cfg.update(dict(scale=0))
self.resize = TRANSFORMS.build(_resize_cfg)
def transform(self, results: Dict) -> Dict:
"""Transform function to resize images, bounding boxes and polygons.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results.
"""
ori_height, ori_width = results['img'].shape[:2]
new_width = math.ceil(float(self.height) / ori_height * ori_width)
if self.min_width is not None:
new_width = max(self.min_width, new_width)
if self.max_width is not None:
new_width = min(self.max_width, new_width)
if new_width % self.width_divisor != 0:
new_width = round(
new_width / self.width_divisor) * self.width_divisor
# TODO replace up code after testing precision.
# new_width = math.ceil(
# new_width / self.width_divisor) * self.width_divisor
scale = (new_width, self.height)
self.resize.scale = scale
results = self.resize(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(height={self.height}, '
repr_str += f'min_width={self.min_width}, '
repr_str += f'max_width={self.max_width}, '
repr_str += f'width_divisor={self.width_divisor}, '
repr_str += f'resize_cfg={self.resize_cfg})'
return repr_str
@TRANSFORMS.register_module()
class PadToWidth(BaseTransform):
"""Only pad the image's width.
Required Keys:
- img
Modified Keys:
- img
- img_shape
Added Keys:
- pad_shape
- pad_fixed_size
- pad_size_divisor
- valid_ratio
Args:
width (int): Target width of padded image. Defaults to None.
pad_cfg (dict): Config to construct the Resize transform. Refer to
``Pad`` for detail. Defaults to ``dict(type='Pad')``.
"""
def __init__(self, width: int, pad_cfg: dict = dict(type='Pad')) -> None:
super().__init__()
assert isinstance(width, int)
self.width = width
self.pad_cfg = pad_cfg
_pad_cfg = self.pad_cfg.copy()
_pad_cfg.update(dict(size=0))
self.pad = TRANSFORMS.build(_pad_cfg)
def transform(self, results: Dict) -> Dict:
"""Call function to pad images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
ori_height, ori_width = results['img'].shape[:2]
valid_ratio = min(1.0, 1.0 * ori_width / self.width)
size = (self.width, ori_height)
self.pad.size = size
results = self.pad(results)
results['valid_ratio'] = valid_ratio
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(width={self.width}, '
repr_str += f'pad_cfg={self.pad_cfg})'
return repr_str

View File

@ -13,7 +13,7 @@ from mmocr.utils import poly2bbox
@TRANSFORMS.register_module()
class ImgAug(BaseTransform):
class ImgAugWrapper(BaseTransform):
"""A wrapper around imgaug https://github.com/aleju/imgaug.
Find available augmenters at
@ -167,7 +167,7 @@ class ImgAug(BaseTransform):
return new_polys, removed_poly_inds
def _build_augmentation(self, args, root=True):
"""Build ImgAug augmentations.
"""Build ImgAugWrapper augmentations.
Args:
args (dict): Arguments to be passed to imgaug.

View File

@ -143,10 +143,10 @@ def eval_hmean_ic13(det_boxes,
gt_point = np.array(gt_points[gt_id])
det_point = np.array(pred_points[pred_id])
norm_dist = utils.box_center_distance(
norm_dist = utils.bbox_center_distance(
det_point, gt_point)
norm_dist /= utils.box_diag(
det_point) + utils.box_diag(gt_point)
norm_dist /= utils.bbox_diag(
det_point) + utils.bbox_diag(gt_point)
norm_dist *= 2.0
if norm_dist < center_dist_thr:

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, build_from_cfg
from .bbox_utils import (bbox2poly, bezier_to_polygon, box_center_distance,
box_diag, is_on_same_line, rescale_bboxes,
from .bbox_utils import (bbox2poly, bbox_center_distance, bbox_diag,
bezier_to_polygon, is_on_same_line, rescale_bboxes,
sort_points, sort_vertex, sort_vertex8,
stitch_boxes_into_lines)
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type,
@ -13,6 +13,7 @@ from .evaluation_utils import (compute_hmean, filter_2dlist_result,
many2one_match_ic13, one2one_match_ic13,
select_top_boundary)
from .fileio import list_from_file, list_to_file
from .img_utils import crop_img, warp_img
from .logger import get_root_logger
from .mask_utils import fill_hole
from .model import revert_sync_batchnorm
@ -36,9 +37,9 @@ __all__ = [
'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_inside_rect',
'poly2bbox', 'poly_intersection', 'poly_iou', 'poly_make_valid',
'poly_union', 'poly2shapely', 'polys2shapely', 'register_all_modules',
'offset_polygon', 'sort_vertex8', 'sort_vertex', 'box_center_distance',
'box_diag', 'compute_hmean', 'filter_2dlist_result', 'many2one_match_ic13',
'one2one_match_ic13', 'select_top_boundary', 'boundary_iou',
'point_distance', 'points_center', 'fill_hole', 'LineJsonParser',
'LineStrParser', 'shapely2poly'
'offset_polygon', 'sort_vertex8', 'sort_vertex', 'bbox_center_distance',
'bbox_diag', 'compute_hmean', 'filter_2dlist_result',
'many2one_match_ic13', 'one2one_match_ic13', 'select_top_boundary',
'boundary_iou', 'point_distance', 'points_center', 'fill_hole',
'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img', 'warp_img'
]

View File

@ -4,6 +4,7 @@ from typing import Tuple
import numpy as np
from numpy.typing import ArrayLike
from shapely.geometry import LineString, Point
from mmocr.utils.check_argument import is_2dlist, is_type_list
from mmocr.utils.point_utils import point_distance, points_center
@ -329,16 +330,47 @@ def sort_vertex8(points):
return sorted_box
def box_center_distance(b1, b2):
def bbox_center_distance(b1, b2):
# TODO typehints & docstring & test
assert isinstance(b1, np.ndarray)
assert isinstance(b2, np.ndarray)
return point_distance(points_center(b1), points_center(b2))
def box_diag(box):
def bbox_diag(box):
# TODO typehints & docstring & test
assert isinstance(box, np.ndarray)
assert box.size == 8
return point_distance(box[0:2], box[4:6])
def bbox_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1):
"""Jitter on the coordinates of bounding box.
Args:
points_x (list[float | int]): List of y for four vertices.
points_y (list[float | int]): List of x for four vertices.
jitter_ratio_x (float): Horizontal jitter ratio relative to the height.
jitter_ratio_y (float): Vertical jitter ratio relative to the height.
"""
assert len(points_x) == 4
assert len(points_y) == 4
assert isinstance(jitter_ratio_x, float)
assert isinstance(jitter_ratio_y, float)
assert 0 <= jitter_ratio_x < 1
assert 0 <= jitter_ratio_y < 1
points = [Point(points_x[i], points_y[i]) for i in range(4)]
line_list = [
LineString([points[i], points[i + 1 if i < 3 else 0]])
for i in range(4)
]
tmp_h = max(line_list[1].length, line_list[3].length)
for i in range(4):
jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h
jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h
points_x[i] += jitter_pixel_x
points_y[i] += jitter_pixel_y

View File

@ -1,64 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
from mmcv.utils import is_seq_of
from shapely.geometry import LineString, Point
import mmocr.utils as utils
def box_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1):
"""Jitter on the coordinates of bounding box.
Args:
points_x (list[float | int]): List of y for four vertices.
points_y (list[float | int]): List of x for four vertices.
jitter_ratio_x (float): Horizontal jitter ratio relative to the height.
jitter_ratio_y (float): Vertical jitter ratio relative to the height.
"""
assert len(points_x) == 4
assert len(points_y) == 4
assert isinstance(jitter_ratio_x, float)
assert isinstance(jitter_ratio_y, float)
assert 0 <= jitter_ratio_x < 1
assert 0 <= jitter_ratio_y < 1
points = [Point(points_x[i], points_y[i]) for i in range(4)]
line_list = [
LineString([points[i], points[i + 1 if i < 3 else 0]])
for i in range(4)
]
tmp_h = max(line_list[1].length, line_list[3].length)
for i in range(4):
jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h
jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h
points_x[i] += jitter_pixel_x
points_y[i] += jitter_pixel_y
from .bbox_utils import bbox_jitter, sort_vertex
def warp_img(src_img,
box,
jitter_flag=False,
jitter=False,
jitter_ratio_x=0.5,
jitter_ratio_y=0.1):
"""Crop box area from image using opencv warpPerspective w/o box jitter.
"""Crop box area from image using opencv warpPerspective.
Args:
src_img (np.array): Image before cropping.
box (list[float | int]): Coordinates of quadrangle.
jitter (bool): Whether to jitter the box.
jitter_ratio_x (float): Horizontal jitter ratio relative to the height.
jitter_ratio_y (float): Vertical jitter ratio relative to the height.
Returns:
np.array: The warped image.
"""
assert utils.is_type_list(box, (float, int))
assert is_seq_of(box, (float, int))
assert len(box) == 8
h, w = src_img.shape[:2]
points_x = [min(max(x, 0), w) for x in box[0:8:2]]
points_y = [min(max(y, 0), h) for y in box[1:9:2]]
points_x, points_y = utils.sort_vertex(points_x, points_y)
points_x, points_y = sort_vertex(points_x, points_y)
if jitter_flag:
box_jitter(
if jitter:
bbox_jitter(
points_x,
points_y,
jitter_ratio_x=jitter_ratio_x,
@ -84,17 +60,24 @@ def warp_img(src_img,
def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
"""Crop text region with their bounding box.
"""Crop text region given the bounding box which might be slightly padded.
The bounding box is assumed to be a quadrangle and tightly bound the text
region.
Args:
src_img (np.array): The original image.
box (list[float | int]): Points of quadrangle.
long_edge_pad_ratio (float): Box pad ratio for long edge
corresponding to font size.
short_edge_pad_ratio (float): Box pad ratio for short edge
corresponding to font size.
long_edge_pad_ratio (float): The ratio of padding to the long edge. The
padding will be the length of the short edge * long_edge_pad_ratio.
Defaults to 0.4.
short_edge_pad_ratio (float): The ratio of padding to the short edge.
The padding will be the length of the long edge *
short_edge_pad_ratio. Defaults to 0.2.
Returns:
np.array: The cropped image.
"""
assert utils.is_type_list(box, (float, int))
assert is_seq_of(box, (float, int))
assert len(box) == 8
assert 0. <= long_edge_pad_ratio < 1.0
assert 0. <= short_edge_pad_ratio < 1.0
@ -105,14 +88,14 @@ def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
box_width = np.max(points_x) - np.min(points_x)
box_height = np.max(points_y) - np.min(points_y)
font_size = min(box_height, box_width)
shorter_size = min(box_height, box_width)
if box_height < box_width:
horizontal_pad = long_edge_pad_ratio * font_size
vertical_pad = short_edge_pad_ratio * font_size
horizontal_pad = long_edge_pad_ratio * shorter_size
vertical_pad = short_edge_pad_ratio * shorter_size
else:
horizontal_pad = short_edge_pad_ratio * font_size
vertical_pad = long_edge_pad_ratio * font_size
horizontal_pad = short_edge_pad_ratio * shorter_size
vertical_pad = long_edge_pad_ratio * shorter_size
left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w)
top = np.clip(int(np.min(points_y) - vertical_pad), 0, h)

View File

@ -22,13 +22,13 @@ except ImportError:
from mmocr.apis import init_detector
from mmocr.apis.inference import model_inference
from mmocr.core.visualize import det_recog_show_result
from mmocr.datasets.kie_dataset import KIEDataset
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.datasets import WildReceiptDataset
from mmocr.models.textdet.detectors import TextDetectorMixin
from mmocr.models.textrecog.recognizers import BaseRecognizer
from mmocr.registry import MODELS
from mmocr.utils import is_type_list, stitch_boxes_into_lines
from mmocr.utils.fileio import list_from_file
from mmocr.utils.img_utils import crop_img
from mmocr.utils.model import revert_sync_batchnorm
@ -681,7 +681,7 @@ class MMOCR:
bboxes_list = [res['boundary_result'] for res in det_result]
if kie_model:
kie_dataset = KIEDataset(
kie_dataset = WildReceiptDataset(
dict_file=kie_model.cfg.data.test.dict_file)
# For each bounding box, the image is cropped and

View File

@ -1,128 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import math
import os.path as osp
import tempfile
import pytest
from mmocr.datasets.ocr_seg_dataset import OCRSegDataset
def _create_dummy_ann_file(ann_file):
ann_info1 = {
'file_name':
'sample1.png',
'annotations': [{
'char_text':
'F',
'char_box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]
}, {
'char_text':
'r',
'char_box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0]
}, {
'char_text':
'o',
'char_box': [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0]
}, {
'char_text':
'm',
'char_box': [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0]
}, {
'char_text':
':',
'char_box': [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0]
}],
'text':
'From:'
}
ann_info2 = {
'file_name':
'sample2.png',
'annotations': [{
'char_text': 'o',
'char_box': [0.0, 5.0, 7.0, 5.0, 9.0, 15.0, 2.0, 15.0]
}, {
'char_text':
'u',
'char_box': [7.0, 4.0, 14.0, 4.0, 18.0, 18.0, 11.0, 18.0]
}, {
'char_text':
't',
'char_box': [13.0, 1.0, 19.0, 2.0, 24.0, 18.0, 17.0, 18.0]
}],
'text':
'out'
}
with open(ann_file, 'w') as fw:
for ann_info in [ann_info1, ann_info2]:
fw.write(json.dumps(ann_info) + '\n')
return ann_info1, ann_info2
def _create_dummy_loader():
loader = dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineJsonParser', keys=['file_name', 'text', 'annotations']))
return loader
def test_ocr_seg_dataset():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
ann_info1, ann_info2 = _create_dummy_ann_file(ann_file)
# test initialization
loader = _create_dummy_loader()
dataset = OCRSegDataset(ann_file, loader, pipeline=[])
tmp_dir.cleanup()
# test pre_pipeline
img_info = dataset.data_infos[0]
results = dict(img_info=img_info)
dataset.pre_pipeline(results)
assert results['img_prefix'] == dataset.img_prefix
# test _parse_anno_info
annos = ann_info1['annotations']
with pytest.raises(AssertionError):
dataset._parse_anno_info(annos[0])
annos2 = ann_info2['annotations']
with pytest.raises(AssertionError):
dataset._parse_anno_info([{'char_text': 'i'}])
with pytest.raises(AssertionError):
dataset._parse_anno_info([{'char_box': [1, 2, 3, 4, 5, 6, 7, 8]}])
annos2[0]['char_box'] = [1, 2, 3]
with pytest.raises(AssertionError):
dataset._parse_anno_info(annos2)
return_anno = dataset._parse_anno_info(annos)
assert return_anno['chars'] == ['F', 'r', 'o', 'm', ':']
assert len(return_anno['char_rects']) == 5
# test prepare_train_img
expect_results = {
'img_info': {
'filename': 'sample1.png'
},
'img_prefix': '',
'ann_info': return_anno
}
data = dataset.prepare_train_img(0)
assert data == expect_results
# test evluation
metric = 'acc'
results = [{'text': 'From:'}, {'text': 'ou'}]
eval_res = dataset.evaluate(results, metric)
assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4)
assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4)
assert math.isclose(eval_res['char_recall'], 0.857, abs_tol=1e-4)

View File

@ -1,94 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import numpy as np
import pytest
from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets
def _create_dummy_dict_file(dict_file):
chars = list('0123456789')
with open(dict_file, 'w') as fw:
for char in chars:
fw.write(char + '\n')
def test_ocr_segm_targets():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy dict file
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
_create_dummy_dict_file(dict_file)
# dummy label convertor
label_convertor = dict(
type='SegConvertor',
dict_file=dict_file,
with_unknown=True,
lower=True)
# test init
with pytest.raises(AssertionError):
OCRSegTargets(None, 0.5, 0.5)
with pytest.raises(AssertionError):
OCRSegTargets(label_convertor, '1by2', 0.5)
with pytest.raises(AssertionError):
OCRSegTargets(label_convertor, 0.5, 2)
ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5)
# test generate kernels
img_size = (8, 8)
pad_size = (8, 10)
char_boxes = [[2, 2, 6, 6]]
char_idxs = [2]
with pytest.raises(AssertionError):
ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5,
True)
with pytest.raises(AssertionError):
ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6],
char_idxs, 0.5, True)
with pytest.raises(AssertionError):
ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5,
True)
attn_tgt = ocr_seg_tgt.generate_kernels(
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True)
expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32))
segm_tgt = ocr_seg_tgt.generate_kernels(
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False)
expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32))
# test __call__
results = {}
results['img_shape'] = (4, 4, 3)
results['resize_shape'] = (8, 8, 3)
results['pad_shape'] = (8, 10)
results['ann_info'] = {}
results['ann_info']['char_rects'] = [[1, 1, 3, 3]]
results['ann_info']['chars'] = ['1']
results = ocr_seg_tgt(results)
assert results['mask_fields'] == ['gt_kernels']
assert np.allclose(results['gt_kernels'].masks[0],
np.array(expect_attn_tgt, dtype=np.int32))
assert np.allclose(results['gt_kernels'].masks[1],
np.array(expect_segm_tgt, dtype=np.int32))
tmp_dir.cleanup()

View File

@ -1,141 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import unittest.mock as mock
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
import mmocr.datasets.pipelines.ocr_transforms as transforms
def test_resize_ocr():
input_img = np.ones((64, 256, 3), dtype=np.uint8)
rci = transforms.ResizeOCR(
32, min_width=32, max_width=160, keep_aspect_ratio=True)
results = {'img_shape': input_img.shape, 'img': input_img}
# test call
results = rci(results)
assert np.allclose([32, 160, 3], results['pad_shape'])
assert np.allclose([32, 160, 3], results['img'].shape)
assert 'valid_ratio' in results
assert math.isclose(results['valid_ratio'], 0.8)
assert math.isclose(np.sum(results['img'][:, 129:, :]), 0)
rci = transforms.ResizeOCR(
32, min_width=32, max_width=160, keep_aspect_ratio=False)
results = {'img_shape': input_img.shape, 'img': input_img}
results = rci(results)
assert math.isclose(results['valid_ratio'], 1)
def test_to_tensor():
input_img = np.ones((64, 256, 3), dtype=np.uint8)
expect_output = TF.to_tensor(input_img)
rci = transforms.ToTensorOCR()
results = {'img': input_img}
results = rci(results)
assert np.allclose(results['img'].numpy(), expect_output.numpy())
def test_normalize():
inputs = torch.zeros(3, 10, 10)
expect_output = torch.ones_like(inputs) * (-1)
rci = transforms.NormalizeOCR(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
results = {'img': inputs}
results = rci(results)
assert np.allclose(results['img'].numpy(), expect_output.numpy())
@mock.patch('%s.transforms.np.random.random' % __name__)
def test_online_crop(mock_random):
kwargs = dict(
box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
jitter_prob=0.5,
max_jitter_ratio_x=0.05,
max_jitter_ratio_y=0.02)
mock_random.side_effect = [0.1, 1, 1, 1]
src_img = np.ones((100, 100, 3), dtype=np.uint8)
results = {
'img': src_img,
'img_info': {
'x1': '20',
'y1': '20',
'x2': '40',
'y2': '20',
'x3': '40',
'y3': '40',
'x4': '20',
'y4': '40'
}
}
rci = transforms.OnlineCropOCR(**kwargs)
results = rci(results)
assert np.allclose(results['img_shape'], [20, 20, 3])
# test not crop
mock_random.side_effect = [0.1, 1, 1, 1]
results['img_info'] = {}
results['img'] = src_img
results = rci(results)
assert np.allclose(results['img'].shape, [100, 100, 3])
def test_fancy_pca():
input_tensor = torch.rand(3, 32, 100)
rci = transforms.FancyPCA()
results = {'img': input_tensor}
results = rci(results)
assert results['img'].shape == torch.Size([3, 32, 100])
@mock.patch('%s.transforms.np.random.uniform' % __name__)
def test_random_padding(mock_random):
kwargs = dict(max_ratio=[0.0, 0.0, 0.0, 0.0], box_type=None)
mock_random.side_effect = [1, 1, 1, 1]
src_img = np.ones((32, 100, 3), dtype=np.uint8)
results = {'img': src_img, 'img_shape': (32, 100, 3)}
rci = transforms.RandomPaddingOCR(**kwargs)
results = rci(results)
print(results['img'].shape)
assert np.allclose(results['img_shape'], [96, 300, 3])
def test_opencv2pil():
src_img = np.ones((32, 100, 3), dtype=np.uint8)
results = {'img': src_img}
rci = transforms.OpencvToPil()
results = rci(results)
assert np.allclose(results['img'].size, (100, 32))
def test_pil2opencv():
src_img = Image.new('RGB', (100, 32), color=(255, 255, 255))
results = {'img': src_img}
rci = transforms.PilToOpencv()
results = rci(results)
assert np.allclose(results['img'].shape, (32, 100, 3))

View File

@ -1,33 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmocr.datasets.pipelines.textdet_targets.dbnet_targets import DBNetTargets
def test_invalid_polys():
dbtarget = DBNetTargets()
poly = np.array([[256.1229216, 347.17471155], [257.63126133, 347.0069367],
[257.70317729, 347.65337423],
[256.19488113, 347.82114909]])
assert dbtarget.invalid_polygon(poly)
poly = np.array([[570.34735492,
335.00214526], [570.99778839, 335.00327318],
[569.69077318, 338.47009908],
[569.04038393, 338.46894904]])
assert dbtarget.invalid_polygon(poly)
poly = np.array([[481.18343777,
305.03190065], [479.88478587, 305.10684512],
[479.90976971, 305.53968843], [480.99197962,
305.4772347]])
assert dbtarget.invalid_polygon(poly)
poly = np.array([[0, 0], [2, 0], [2, 2], [0, 2]])
assert dbtarget.invalid_polygon(poly)
poly = np.array([[0, 0], [10, 0], [10, 10], [0, 10]])
assert not dbtarget.invalid_polygon(poly)

View File

@ -5,8 +5,8 @@ from unittest import TestCase
import numpy as np
import torch
from mmocr.datasets.pipelines import (PackKIEInputs, PackTextDetInputs,
PackTextRecogInputs)
from mmocr.datasets.transforms import (PackKIEInputs, PackTextDetInputs,
PackTextRecogInputs)
class TestPackTextDetInputs(TestCase):

View File

@ -4,7 +4,7 @@ from unittest import TestCase
import numpy as np
from mmocr.datasets.pipelines import LoadKIEAnnotations, LoadOCRAnnotations
from mmocr.datasets.transforms import LoadKIEAnnotations, LoadOCRAnnotations
class TestLoadOCRAnnotations(TestCase):

View File

@ -0,0 +1,208 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import unittest
import unittest.mock as mock
import numpy as np
from mmocr.datasets.transforms import RandomCrop, RandomRotate, Resize
class TestRandomCrop(unittest.TestCase):
def setUp(self):
img = np.zeros((30, 30, 3))
gt_polygons = [
np.array([5., 5., 25., 5., 25., 10., 5., 10.]),
np.array([5., 20., 25., 20., 25., 25., 5., 25.])
]
gt_bboxes = np.array([[5, 5, 25, 10], [5, 20, 25, 25]])
labels = np.array([0, 1])
gt_ignored = np.array([True, False], dtype=bool)
texts = ['text1', 'text2']
self.data_info = dict(
img=img,
gt_polygons=gt_polygons,
gt_bboxes=gt_bboxes,
gt_bboxes_labels=labels,
gt_ignored=gt_ignored,
gt_texts=texts)
@mock.patch('mmocr.datasets.transforms.ocr_transforms.np.random.randint')
def test_sample_crop_box(self, mock_randint):
def rand_min(low, high):
return low
trans = RandomCrop(min_side_ratio=0.3)
mock_randint.side_effect = rand_min
crop_box = trans._sample_crop_box((30, 30), self.data_info.copy())
assert np.allclose(np.array(crop_box), np.array([0, 0, 25, 10]))
def rand_max(low, high):
return high - 1
mock_randint.side_effect = rand_max
crop_box = trans._sample_crop_box((30, 30), self.data_info.copy())
assert np.allclose(np.array(crop_box), np.array([4, 19, 30, 30]))
@mock.patch('mmocr.datasets.transforms.ocr_transforms.np.random.randint')
def test_transform(self, mock_randint):
def rand_min(low, high):
return low
# mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 0, 0, 15]
mock_randint.side_effect = rand_min
trans = RandomCrop(min_side_ratio=0.3)
polygon_target = np.array([5., 5., 25., 5., 25., 10., 5., 10.])
bbox_target = np.array([[5., 5., 25., 10.]])
results = trans(self.data_info)
self.assertEqual(results['img'].shape, (10, 25, 3))
self.assertEqual(results['img_shape'], (10, 25))
self.assertTrue(np.allclose(results['gt_bboxes'], bbox_target))
self.assertEqual(results['gt_bboxes'].shape, (1, 4))
self.assertEqual(len(results['gt_polygons']), 1)
self.assertTrue(np.allclose(results['gt_polygons'][0], polygon_target))
self.assertEqual(results['gt_bboxes_labels'][0], 0)
self.assertEqual(results['gt_ignored'][0], True)
self.assertEqual(results['gt_texts'][0], 'text1')
def rand_max(low, high):
return high - 1
mock_randint.side_effect = rand_max
trans = RandomCrop(min_side_ratio=0.3)
polygon_target = np.array([1, 1, 21, 1, 21, 6, 1, 6])
bbox_target = np.array([[1, 1, 21, 6]])
results = trans(self.data_info)
self.assertEqual(results['img'].shape, (6, 21, 3))
self.assertEqual(results['img_shape'], (6, 21))
self.assertTrue(np.allclose(results['gt_bboxes'], bbox_target))
self.assertEqual(results['gt_bboxes'].shape, (1, 4))
self.assertEqual(len(results['gt_polygons']), 1)
self.assertTrue(np.allclose(results['gt_polygons'][0], polygon_target))
self.assertEqual(results['gt_bboxes_labels'][0], 0)
self.assertTrue(results['gt_ignored'][0])
self.assertEqual(results['gt_texts'][0], 'text1')
def test_repr(self):
transform = RandomCrop(min_side_ratio=0.4)
self.assertEqual(repr(transform), ('RandomCrop(min_side_ratio = 0.4)'))
class TestRandomRotate(unittest.TestCase):
def setUp(self):
img = np.random.random((5, 5))
self.data_info1 = dict(img=img.copy(), img_shape=img.shape[:2])
self.data_info2 = dict(
img=np.random.random((30, 30, 3)),
gt_bboxes=np.array([[10, 10, 20, 20], [5, 5, 10, 10]]),
img_shape=(30, 30))
self.data_info3 = dict(
img=np.random.random((30, 30, 3)),
gt_polygons=[np.array([10., 10., 20., 10., 20., 20., 10., 20.])],
img_shape=(30, 30))
def test_init(self):
# max angle is float
with self.assertRaisesRegex(TypeError,
'`max_angle` should be an integer'):
RandomRotate(max_angle=16.8)
# invalid pad value
with self.assertRaisesRegex(
ValueError, '`pad_value` should contain three integers'):
RandomRotate(pad_value=[16.8, 0.1])
def test_transform(self):
self._test_recog()
self._test_bboxes()
self._test_polygons()
def _test_recog(self):
# test random rotate for recognition (image only) input
transform = RandomRotate(max_angle=10)
results = transform(copy.deepcopy(self.data_info1))
self.assertTrue(np.allclose(results['img'], self.data_info1['img']))
@mock.patch(
'mmocr.datasets.transforms.ocr_transforms.np.random.random_sample')
def _test_bboxes(self, mock_sample):
# test random rotate for bboxes
# returns 1. for random_sample() in _sample_angle(), i.e., angle = 90
mock_sample.side_effect = [1.]
transform = RandomRotate(max_angle=90, use_canvas=True)
results = transform(copy.deepcopy(self.data_info2))
self.assertTrue(
np.allclose(results['gt_bboxes'][0], np.array([10, 10, 20, 20])))
self.assertTrue(
np.allclose(results['gt_bboxes'][1], np.array([5, 20, 10, 25])))
self.assertEqual(results['img'].shape, self.data_info2['img'].shape)
@mock.patch(
'mmocr.datasets.transforms.ocr_transforms.np.random.random_sample')
def _test_polygons(self, mock_sample):
# test random rotate for polygons
# returns 1. for random_sample() in _sample_angle(), i.e., angle = 90
mock_sample.side_effect = [1.]
transform = RandomRotate(max_angle=90, use_canvas=True)
results = transform(copy.deepcopy(self.data_info3))
self.assertTrue(
np.allclose(results['gt_polygons'][0],
np.array([10., 20., 10., 10., 20., 10., 20., 20.])))
self.assertEqual(results['img'].shape, self.data_info3['img'].shape)
def test_repr(self):
transform = RandomRotate(
max_angle=10,
pad_with_fixed_color=False,
pad_value=(0, 0, 0),
use_canvas=False)
self.assertEqual(
repr(transform),
('RandomRotate(max_angle = 10, '
'pad_with_fixed_color = False, pad_value = (0, 0, 0), '
'use_canvas = False)'))
class TestResize(unittest.TestCase):
def test_resize_wo_img(self):
# keep_ratio = True
dummy_result = dict(img_shape=(10, 20))
resize = Resize(scale=(40, 30), keep_ratio=True)
result = resize(dummy_result)
self.assertEqual(result['img_shape'], (20, 40))
self.assertEqual(result['scale'], (40, 20))
self.assertEqual(result['scale_factor'], (2., 2.))
self.assertEqual(result['keep_ratio'], True)
# keep_ratio = False
dummy_result = dict(img_shape=(10, 20))
resize = Resize(scale=(40, 30), keep_ratio=False)
result = resize(dummy_result)
self.assertEqual(result['img_shape'], (30, 40))
self.assertEqual(result['scale'], (40, 30))
self.assertEqual(result['scale_factor'], (
2.,
3.,
))
self.assertEqual(result['keep_ratio'], False)
def test_resize_bbox(self):
# keep_ratio = True
dummy_result = dict(
img_shape=(10, 20),
gt_bboxes=np.array([[0, 0, 1, 1]], dtype=np.float32))
resize = Resize(scale=(40, 30))
result = resize(dummy_result)
self.assertEqual(result['gt_bboxes'].dtype, np.float32)
if __name__ == '__main__':
t = TestRandomCrop()
t.test_sample_crop_box()
t.test_transform()

View File

@ -0,0 +1,127 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import unittest
import numpy as np
from mmocr.datasets.transforms import (PadToWidth, PyramidRescale,
RescaleToHeight)
class TestPadToWidth(unittest.TestCase):
def test_pad_to_width(self):
data_info = dict(img=np.random.random((16, 25, 3)))
# test size and size_divisor are both set
with self.assertRaises(AssertionError):
PadToWidth(width=10.5)
transform = PadToWidth(width=100)
results = transform(copy.deepcopy(data_info))
self.assertTupleEqual(results['img'].shape[:2], (16, 100))
self.assertEqual(results['valid_ratio'], 25 / 100)
def test_repr(self):
transform = PadToWidth(width=100)
self.assertEqual(
repr(transform),
("PadToWidth(width=100, pad_cfg={'type': 'Pad'})"))
class TestPyramidRescale(unittest.TestCase):
def setUp(self):
self.data_info = dict(img=np.random.random((128, 100, 3)))
def test_init(self):
# factor is int
transform = PyramidRescale(factor=4, randomize_factor=False)
self.assertEqual(transform.factor, 4)
# factor is float
with self.assertRaisesRegex(TypeError,
'`factor` should be an integer'):
PyramidRescale(factor=4.0)
# invalid base_shape
with self.assertRaisesRegex(TypeError,
'`base_shape` should be a list or tuple'):
PyramidRescale(base_shape=128)
with self.assertRaisesRegex(
ValueError, '`base_shape` should contain two integers'):
PyramidRescale(base_shape=(128, ))
with self.assertRaisesRegex(
ValueError, '`base_shape` should contain two integers'):
PyramidRescale(base_shape=(128.0, 2.0))
# invalid randomize_factor
with self.assertRaisesRegex(TypeError,
'`randomize_factor` should be a bool'):
PyramidRescale(randomize_factor=None)
def test_transform(self):
# test if the rescale keeps the original size
transform = PyramidRescale()
results = transform(copy.deepcopy(self.data_info))
self.assertEqual(results['img'].shape, (128, 100, 3))
# test factor = 0
transform = PyramidRescale(factor=0, randomize_factor=False)
results = transform(copy.deepcopy(self.data_info))
self.assertTrue(np.all(results['img'] == self.data_info['img']))
def test_repr(self):
transform = PyramidRescale(
factor=4, base_shape=(128, 512), randomize_factor=False)
self.assertEqual(
repr(transform),
('PyramidRescale(factor = 4, randomize_factor = False, '
'base_w = 128, base_h = 512)'))
class TestRescaleToHeight(unittest.TestCase):
def test_rescale_height(self):
data_info = dict(
img=np.random.random((16, 25, 3)),
gt_seg_map=np.random.random((16, 25, 3)),
gt_bboxes=np.array([[0, 0, 10, 10]]),
gt_keypoints=np.array([[[10, 10, 1]]]))
with self.assertRaises(AssertionError):
RescaleToHeight(height=20.9)
with self.assertRaises(AssertionError):
RescaleToHeight(height=20, min_width=20.9)
with self.assertRaises(AssertionError):
RescaleToHeight(height=20, max_width=20.9)
with self.assertRaises(AssertionError):
RescaleToHeight(height=20, width_divisor=0.5)
transform = RescaleToHeight(height=32)
results = transform(copy.deepcopy(data_info))
self.assertTupleEqual(results['img'].shape[:2], (32, 50))
self.assertTupleEqual(results['scale'], (50, 32))
self.assertTupleEqual(results['scale_factor'], (50 / 25, 32 / 16))
# test min_width
transform = RescaleToHeight(height=32, min_width=60)
results = transform(copy.deepcopy(data_info))
self.assertTupleEqual(results['img'].shape[:2], (32, 60))
self.assertTupleEqual(results['scale'], (60, 32))
self.assertTupleEqual(results['scale_factor'], (60 / 25, 32 / 16))
# test max_width
transform = RescaleToHeight(height=32, max_width=45)
results = transform(copy.deepcopy(data_info))
self.assertTupleEqual(results['img'].shape[:2], (32, 45))
self.assertTupleEqual(results['scale'], (45, 32))
self.assertTupleEqual(results['scale_factor'], (45 / 25, 32 / 16))
# test width_divisor
transform = RescaleToHeight(height=32, width_divisor=4)
results = transform(copy.deepcopy(data_info))
self.assertTupleEqual(results['img'].shape[:2], (32, 48))
self.assertTupleEqual(results['scale'], (48, 32))
self.assertTupleEqual(results['scale_factor'], (48 / 25, 32 / 16))
def test_repr(self):
transform = RescaleToHeight(height=32)
self.assertEqual(
repr(transform), ('RescaleToHeight(height=32, '
'min_width=None, max_width=None, '
'width_divisor=1, '
"resize_cfg={'type': 'Resize'})"))

View File

@ -6,16 +6,16 @@ from typing import Dict, List, Optional
import numpy as np
from shapely.geometry import Polygon
from mmocr.datasets.pipelines import ImgAug, TorchVisionWrapper
from mmocr.datasets.transforms import ImgAugWrapper, TorchVisionWrapper
class TestImgAug(unittest.TestCase):
def test_init(self):
with self.assertRaises(AssertionError):
ImgAug(args=[])
ImgAugWrapper(args=[])
with self.assertRaises(AssertionError):
ImgAug(args=['test'])
ImgAugWrapper(args=['test'])
def _create_dummy_data(self):
img = np.random.rand(50, 50, 3)
@ -61,7 +61,7 @@ class TestImgAug(unittest.TestCase):
def test_transform(self):
# Test empty transform
imgaug_transform = ImgAug()
imgaug_transform = ImgAugWrapper()
results = self._create_dummy_data()
origin_results = copy.deepcopy(results)
results = imgaug_transform(results)
@ -72,7 +72,7 @@ class TestImgAug(unittest.TestCase):
origin_results['gt_texts'])
args = [dict(cls='Affine', translate_px=dict(x=-10, y=-10))]
imgaug_transform = ImgAug(args)
imgaug_transform = ImgAugWrapper(args)
results = self._create_dummy_data()
results = imgaug_transform(results)
@ -99,7 +99,7 @@ class TestImgAug(unittest.TestCase):
label_target = np.array([0], dtype=np.int64)
ignored = np.array([False], dtype=bool)
texts = ['text1']
imgaug_transform = ImgAug(args)
imgaug_transform = ImgAugWrapper(args)
results = self._create_dummy_data()
results = imgaug_transform(results)
self.assert_result_equal(results, poly_target, box_target,
@ -111,7 +111,7 @@ class TestImgAug(unittest.TestCase):
# When some transforms result in empty polygons
args = [dict(cls='Affine', translate_px=dict(x=100, y=100))]
results = self._create_dummy_data()
invalid_transform = ImgAug(args)
invalid_transform = ImgAugWrapper(args)
results = invalid_transform(results)
self.assertIsNone(results)
@ -131,11 +131,12 @@ class TestImgAug(unittest.TestCase):
def test_repr(self):
args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]]
transform = ImgAug(args)
transform = ImgAugWrapper(args)
print(repr(transform))
self.assertEqual(
repr(transform),
("ImgAug(args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]])"))
("ImgAugWrapper(args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]])"
))
class TestTorchVisionWrapper(unittest.TestCase):

View File

@ -8,8 +8,8 @@ from functools import partial
import mmcv
from mmocr.datasets.pipelines.crop import crop_img, warp_img
from mmocr.utils import list_to_file
from mmocr.utils.img_utils import crop_img, warp_img
def parse_labelme_json(json_file,

View File

@ -6,8 +6,8 @@ import os.path as osp
import mmcv
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir):

View File

@ -7,8 +7,8 @@ import os.path as osp
import mmcv
import numpy as np
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir):

View File

@ -7,8 +7,8 @@ import os.path as osp
import mmcv
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir):

View File

@ -7,8 +7,8 @@ import xml.etree.ElementTree as ET
import mmcv
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir):

View File

@ -8,8 +8,8 @@ import os.path as osp
import mmcv
import numpy as np
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def parse_args():

View File

@ -8,8 +8,8 @@ import xml.etree.ElementTree as ET
import mmcv
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir, ratio):

View File

@ -9,8 +9,8 @@ import cv2
import mmcv
from PIL import Image
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir, ratio):

View File

@ -6,8 +6,8 @@ import os.path as osp
import mmcv
import numpy as np
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir, split_info):

View File

@ -7,8 +7,8 @@ import os.path as osp
import mmcv
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir, ratio):

View File

@ -7,8 +7,8 @@ import os.path as osp
import mmcv
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir, ratio):

View File

@ -7,8 +7,8 @@ import os.path as osp
import mmcv
import numpy as np
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir):

View File

@ -11,8 +11,8 @@ import scipy.io as scio
import yaml
from shapely.geometry import Polygon
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir):

View File

@ -6,8 +6,8 @@ import os.path as osp
import mmcv
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.fileio import list_to_file
from mmocr.utils.img_utils import crop_img
def collect_files(img_dir, gt_dir):