mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Processing]remove segocr and split processing
This commit is contained in:
parent
a844b497db
commit
d50d2a46eb
@ -11,7 +11,7 @@ train_pipeline_r18 = [
|
||||
dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(
|
||||
type='ImgAug',
|
||||
type='ImgAugWrapper',
|
||||
args=[['Fliplr', 0.5],
|
||||
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
|
||||
dict(type='EastRandomCrop', target_size=(640, 640)),
|
||||
@ -57,7 +57,7 @@ train_pipeline_r50dcnv2 = [
|
||||
dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
|
||||
dict(type='Normalize', **img_norm_cfg_r50dcnv2),
|
||||
dict(
|
||||
type='ImgAug',
|
||||
type='ImgAugWrapper',
|
||||
args=[['Fliplr', 0.5],
|
||||
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
|
||||
dict(type='EastRandomCrop', target_size=(640, 640)),
|
||||
|
@ -20,7 +20,7 @@ train_pipeline_r18 = [
|
||||
brightness=32.0 / 255,
|
||||
saturation=0.5),
|
||||
dict(
|
||||
type='ImgAug',
|
||||
type='ImgAugWrapper',
|
||||
args=[['Fliplr', 0.5],
|
||||
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
|
||||
dict(type='RandomCrop', min_side_ratio=0.1),
|
||||
|
@ -22,7 +22,7 @@ train_pipeline_r50dcnv2 = [
|
||||
brightness=32.0 / 255,
|
||||
saturation=0.5),
|
||||
dict(
|
||||
type='ImgAug',
|
||||
type='ImgAugWrapper',
|
||||
args=[['Fliplr', 0.5],
|
||||
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
|
||||
dict(type='RandomCrop', min_side_ratio=0.1),
|
||||
|
@ -1,13 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .icdar_dataset import IcdarDataset
|
||||
from .ocr_dataset import OCRDataset
|
||||
from .ocr_seg_dataset import OCRSegDataset
|
||||
from .pipelines import * # NOQA
|
||||
from .recog_lmdb_dataset import RecogLMDBDataset
|
||||
from .recog_text_dataset import RecogTextDataset
|
||||
from .transforms import * # NOQA
|
||||
from .wildreceipt_dataset import WildReceiptDataset
|
||||
|
||||
__all__ = [
|
||||
'IcdarDataset', 'OCRDataset', 'OCRSegDataset', 'RecogLMDBDataset',
|
||||
'RecogTextDataset', 'WildReceiptDataset'
|
||||
'IcdarDataset', 'OCRDataset', 'RecogLMDBDataset', 'RecogTextDataset',
|
||||
'WildReceiptDataset'
|
||||
]
|
||||
|
@ -1,90 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmocr.utils as utils
|
||||
from mmocr.datasets.ocr_dataset import OCRDataset
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class OCRSegDataset(OCRDataset):
|
||||
|
||||
def pre_pipeline(self, results):
|
||||
results['img_prefix'] = self.img_prefix
|
||||
|
||||
def _parse_anno_info(self, annotations):
|
||||
"""Parse char boxes annotations.
|
||||
Args:
|
||||
annotations (list[dict]): Annotations of one image, where
|
||||
each dict is for one character.
|
||||
|
||||
Returns:
|
||||
dict: A dict containing the following keys:
|
||||
|
||||
- chars (list[str]): List of character strings.
|
||||
- char_rects (list[list[float]]): List of char box, with each
|
||||
in style of rectangle: [x_min, y_min, x_max, y_max].
|
||||
- char_quads (list[list[float]]): List of char box, with each
|
||||
in style of quadrangle: [x1, y1, x2, y2, x3, y3, x4, y4].
|
||||
"""
|
||||
|
||||
assert utils.is_type_list(annotations, dict)
|
||||
assert 'char_box' in annotations[0]
|
||||
assert 'char_text' in annotations[0]
|
||||
assert len(annotations[0]['char_box']) in [4, 8]
|
||||
|
||||
chars, char_rects, char_quads = [], [], []
|
||||
for ann in annotations:
|
||||
char_box = ann['char_box']
|
||||
if len(char_box) == 4:
|
||||
char_box_type = ann.get('char_box_type', 'xyxy')
|
||||
if char_box_type == 'xyxy':
|
||||
char_rects.append(char_box)
|
||||
char_quads.append([
|
||||
char_box[0], char_box[1], char_box[2], char_box[1],
|
||||
char_box[2], char_box[3], char_box[0], char_box[3]
|
||||
])
|
||||
elif char_box_type == 'xywh':
|
||||
x1, y1, w, h = char_box
|
||||
x2 = x1 + w
|
||||
y2 = y1 + h
|
||||
char_rects.append([x1, y1, x2, y2])
|
||||
char_quads.append([x1, y1, x2, y1, x2, y2, x1, y2])
|
||||
else:
|
||||
raise ValueError(f'invalid char_box_type {char_box_type}')
|
||||
elif len(char_box) == 8:
|
||||
x_list, y_list = [], []
|
||||
for i in range(4):
|
||||
x_list.append(char_box[2 * i])
|
||||
y_list.append(char_box[2 * i + 1])
|
||||
x_max, x_min = max(x_list), min(x_list)
|
||||
y_max, y_min = max(y_list), min(y_list)
|
||||
char_rects.append([x_min, y_min, x_max, y_max])
|
||||
char_quads.append(char_box)
|
||||
else:
|
||||
raise Exception(
|
||||
f'invalid num in char box: {len(char_box)} not in (4, 8)')
|
||||
chars.append(ann['char_text'])
|
||||
|
||||
ann = dict(chars=chars, char_rects=char_rects, char_quads=char_quads)
|
||||
|
||||
return ann
|
||||
|
||||
def prepare_train_img(self, index):
|
||||
"""Get training data and annotations from pipeline.
|
||||
|
||||
Args:
|
||||
index (int): Index of data.
|
||||
|
||||
Returns:
|
||||
dict: Training data and annotation after pipeline with new keys
|
||||
introduced by pipeline.
|
||||
"""
|
||||
img_ann_info = self.data_infos[index]
|
||||
img_info = {
|
||||
'filename': img_ann_info['file_name'],
|
||||
}
|
||||
ann_info = self._parse_anno_info(img_ann_info['annotations'])
|
||||
results = dict(img_info=img_info, ann_info=ann_info)
|
||||
|
||||
self.pre_pipeline(results)
|
||||
|
||||
return self.pipeline(results)
|
@ -1,25 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
|
||||
from .loading import LoadKIEAnnotations, LoadOCRAnnotations
|
||||
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
|
||||
OpencvToPil, PilToOpencv, RandomPaddingOCR,
|
||||
RandomRotateImageBox, ResizeOCR, ToTensorOCR)
|
||||
from .processing import (BoundedScaleAspectJitter, FixInvalidPolygon,
|
||||
PadToWidth, PyramidRescale, RandomCrop, RandomFlip,
|
||||
RandomRotate, RescaleToHeight, Resize,
|
||||
ShortScaleAspectJitter, SourceImagePad,
|
||||
TextDetRandomCrop, TextDetRandomCropFlip)
|
||||
from .test_time_aug import MultiRotateAugOCR
|
||||
from .wrappers import ImgAug, TorchVisionWrapper
|
||||
|
||||
__all__ = [
|
||||
'LoadOCRAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR',
|
||||
'ToTensorOCR', 'RandomRotate', 'MultiRotateAugOCR', 'FancyPCA',
|
||||
'RandomPaddingOCR', 'ImgAug', 'RandomRotateImageBox', 'OpencvToPil',
|
||||
'PilToOpencv', 'SourceImagePad', 'TextDetRandomCropFlip', 'PyramidRescale',
|
||||
'TorchVisionWrapper', 'Resize', 'RandomCrop', 'TextDetRandomCrop',
|
||||
'RandomCrop', 'PackTextDetInputs', 'PackTextRecogInputs',
|
||||
'RescaleToHeight', 'PadToWidth', 'ShortScaleAspectJitter', 'RandomFlip',
|
||||
'BoundedScaleAspectJitter', 'PackKIEInputs', 'LoadKIEAnnotations',
|
||||
'FixInvalidPolygon'
|
||||
]
|
@ -1,454 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from mmcv.runner.dist_utils import get_dist_info
|
||||
from PIL import Image
|
||||
from shapely.geometry import Polygon
|
||||
from shapely.geometry import box as shapely_box
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.datasets.pipelines.crop import warp_img
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ResizeOCR:
|
||||
"""Image resizing and padding for OCR.
|
||||
|
||||
Args:
|
||||
height (int | tuple(int)): Image height after resizing.
|
||||
min_width (none | int | tuple(int)): Image minimum width
|
||||
after resizing.
|
||||
max_width (none | int | tuple(int)): Image maximum width
|
||||
after resizing.
|
||||
keep_aspect_ratio (bool): Keep image aspect ratio if True
|
||||
during resizing, Otherwise resize to the size height *
|
||||
max_width.
|
||||
img_pad_value (int): Scalar to fill padding area.
|
||||
width_downsample_ratio (float): Downsample ratio in horizontal
|
||||
direction from input image to output feature.
|
||||
backend (str | None): The image resize backend type. Options are `cv2`,
|
||||
`pillow`, `None`. If backend is None, the global imread_backend
|
||||
specified by ``mmcv.use_backend()`` will be used. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
height,
|
||||
min_width=None,
|
||||
max_width=None,
|
||||
keep_aspect_ratio=True,
|
||||
img_pad_value=0,
|
||||
width_downsample_ratio=1.0 / 16,
|
||||
backend=None):
|
||||
assert isinstance(height, (int, tuple))
|
||||
assert utils.is_none_or_type(min_width, (int, tuple))
|
||||
assert utils.is_none_or_type(max_width, (int, tuple))
|
||||
if not keep_aspect_ratio:
|
||||
assert max_width is not None, ('"max_width" must assigned '
|
||||
'if "keep_aspect_ratio" is False')
|
||||
assert isinstance(img_pad_value, int)
|
||||
if isinstance(height, tuple):
|
||||
assert isinstance(min_width, tuple)
|
||||
assert isinstance(max_width, tuple)
|
||||
assert len(height) == len(min_width) == len(max_width)
|
||||
|
||||
self.height = height
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.keep_aspect_ratio = keep_aspect_ratio
|
||||
self.img_pad_value = img_pad_value
|
||||
self.width_downsample_ratio = width_downsample_ratio
|
||||
self.backend = backend
|
||||
|
||||
def __call__(self, results):
|
||||
rank, _ = get_dist_info()
|
||||
if isinstance(self.height, int):
|
||||
dst_height = self.height
|
||||
dst_min_width = self.min_width
|
||||
dst_max_width = self.max_width
|
||||
else:
|
||||
# Multi-scale resize used in distributed training.
|
||||
# Choose one (height, width) pair for one rank id.
|
||||
|
||||
idx = rank % len(self.height)
|
||||
dst_height = self.height[idx]
|
||||
dst_min_width = self.min_width[idx]
|
||||
dst_max_width = self.max_width[idx]
|
||||
|
||||
img_shape = results['img_shape']
|
||||
ori_height, ori_width = img_shape[:2]
|
||||
valid_ratio = 1.0
|
||||
resize_shape = list(img_shape)
|
||||
pad_shape = list(img_shape)
|
||||
|
||||
if self.keep_aspect_ratio:
|
||||
new_width = math.ceil(float(dst_height) / ori_height * ori_width)
|
||||
width_divisor = int(1 / self.width_downsample_ratio)
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
if new_width % width_divisor != 0:
|
||||
new_width = round(new_width / width_divisor) * width_divisor
|
||||
if dst_min_width is not None:
|
||||
new_width = max(dst_min_width, new_width)
|
||||
if dst_max_width is not None:
|
||||
valid_ratio = min(1.0, 1.0 * new_width / dst_max_width)
|
||||
resize_width = min(dst_max_width, new_width)
|
||||
img_resize = mmcv.imresize(
|
||||
results['img'], (resize_width, dst_height),
|
||||
backend=self.backend)
|
||||
resize_shape = img_resize.shape
|
||||
pad_shape = img_resize.shape
|
||||
if new_width < dst_max_width:
|
||||
img_resize = mmcv.impad(
|
||||
img_resize,
|
||||
shape=(dst_height, dst_max_width),
|
||||
pad_val=self.img_pad_value)
|
||||
pad_shape = img_resize.shape
|
||||
else:
|
||||
img_resize = mmcv.imresize(
|
||||
results['img'], (new_width, dst_height),
|
||||
backend=self.backend)
|
||||
resize_shape = img_resize.shape
|
||||
pad_shape = img_resize.shape
|
||||
else:
|
||||
img_resize = mmcv.imresize(
|
||||
results['img'], (dst_max_width, dst_height),
|
||||
backend=self.backend)
|
||||
resize_shape = img_resize.shape
|
||||
pad_shape = img_resize.shape
|
||||
|
||||
results['img'] = img_resize
|
||||
results['img_shape'] = resize_shape
|
||||
results['resize_shape'] = resize_shape
|
||||
results['pad_shape'] = pad_shape
|
||||
results['valid_ratio'] = valid_ratio
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ToTensorOCR:
|
||||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = TF.to_tensor(results['img'].copy())
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class NormalizeOCR:
|
||||
"""Normalize a tensor image with mean and standard deviation."""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = TF.normalize(results['img'], self.mean, self.std)
|
||||
results['img_norm_cfg'] = dict(mean=self.mean, std=self.std)
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class OnlineCropOCR:
|
||||
"""Crop text areas from whole image with bounding box jitter. If no bbox is
|
||||
given, return directly.
|
||||
|
||||
Args:
|
||||
box_keys (list[str]): Keys in results which correspond to RoI bbox.
|
||||
jitter_prob (float): The probability of box jitter.
|
||||
max_jitter_ratio_x (float): Maximum horizontal jitter ratio
|
||||
relative to height.
|
||||
max_jitter_ratio_y (float): Maximum vertical jitter ratio
|
||||
relative to height.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
|
||||
jitter_prob=0.5,
|
||||
max_jitter_ratio_x=0.05,
|
||||
max_jitter_ratio_y=0.02):
|
||||
assert utils.is_type_list(box_keys, str)
|
||||
assert 0 <= jitter_prob <= 1
|
||||
assert 0 <= max_jitter_ratio_x <= 1
|
||||
assert 0 <= max_jitter_ratio_y <= 1
|
||||
|
||||
self.box_keys = box_keys
|
||||
self.jitter_prob = jitter_prob
|
||||
self.max_jitter_ratio_x = max_jitter_ratio_x
|
||||
self.max_jitter_ratio_y = max_jitter_ratio_y
|
||||
|
||||
def __call__(self, results):
|
||||
|
||||
if 'img_info' not in results:
|
||||
return results
|
||||
|
||||
crop_flag = True
|
||||
box = []
|
||||
for key in self.box_keys:
|
||||
if key not in results['img_info']:
|
||||
crop_flag = False
|
||||
break
|
||||
|
||||
box.append(float(results['img_info'][key]))
|
||||
|
||||
if not crop_flag:
|
||||
return results
|
||||
|
||||
jitter_flag = np.random.random() > self.jitter_prob
|
||||
|
||||
kwargs = dict(
|
||||
jitter_flag=jitter_flag,
|
||||
jitter_ratio_x=self.max_jitter_ratio_x,
|
||||
jitter_ratio_y=self.max_jitter_ratio_y)
|
||||
crop_img = warp_img(results['img'], box, **kwargs)
|
||||
|
||||
results['img'] = crop_img
|
||||
results['img_shape'] = crop_img.shape
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class FancyPCA:
|
||||
"""Implementation of PCA based image augmentation, proposed in the paper
|
||||
``Imagenet Classification With Deep Convolutional Neural Networks``.
|
||||
|
||||
It alters the intensities of RGB values along the principal components of
|
||||
ImageNet dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, eig_vec=None, eig_val=None):
|
||||
if eig_vec is None:
|
||||
eig_vec = torch.Tensor([
|
||||
[-0.5675, +0.7192, +0.4009],
|
||||
[-0.5808, -0.0045, -0.8140],
|
||||
[-0.5836, -0.6948, +0.4203],
|
||||
]).t()
|
||||
if eig_val is None:
|
||||
eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]])
|
||||
self.eig_val = eig_val # 1*3
|
||||
self.eig_vec = eig_vec # 3*3
|
||||
|
||||
def pca(self, tensor):
|
||||
assert tensor.size(0) == 3
|
||||
alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
|
||||
reconst = torch.mm(self.eig_val * alpha, self.eig_vec)
|
||||
tensor = tensor + reconst.view(3, 1, 1)
|
||||
|
||||
return tensor
|
||||
|
||||
def __call__(self, results):
|
||||
img = results['img']
|
||||
tensor = self.pca(img)
|
||||
results['img'] = tensor
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomPaddingOCR:
|
||||
"""Pad the given image on all sides, as well as modify the coordinates of
|
||||
character bounding box in image.
|
||||
|
||||
Args:
|
||||
max_ratio (list[int]): [left, top, right, bottom].
|
||||
box_type (None|str): Character box type. If not none,
|
||||
should be either 'char_rects' or 'char_quads', with
|
||||
'char_rects' for rectangle with ``xyxy`` style and
|
||||
'char_quads' for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
|
||||
"""
|
||||
|
||||
def __init__(self, max_ratio=None, box_type=None):
|
||||
if max_ratio is None:
|
||||
max_ratio = [0.1, 0.2, 0.1, 0.2]
|
||||
else:
|
||||
assert utils.is_type_list(max_ratio, float)
|
||||
assert len(max_ratio) == 4
|
||||
assert box_type is None or box_type in ('char_rects', 'char_quads')
|
||||
|
||||
self.max_ratio = max_ratio
|
||||
self.box_type = box_type
|
||||
|
||||
def __call__(self, results):
|
||||
|
||||
img_shape = results['img_shape']
|
||||
ori_height, ori_width = img_shape[:2]
|
||||
|
||||
random_padding_left = round(
|
||||
np.random.uniform(0, self.max_ratio[0]) * ori_width)
|
||||
random_padding_top = round(
|
||||
np.random.uniform(0, self.max_ratio[1]) * ori_height)
|
||||
random_padding_right = round(
|
||||
np.random.uniform(0, self.max_ratio[2]) * ori_width)
|
||||
random_padding_bottom = round(
|
||||
np.random.uniform(0, self.max_ratio[3]) * ori_height)
|
||||
|
||||
padding = (random_padding_left, random_padding_top,
|
||||
random_padding_right, random_padding_bottom)
|
||||
img = mmcv.impad(results['img'], padding=padding, padding_mode='edge')
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
|
||||
if self.box_type is not None:
|
||||
num_points = 2 if self.box_type == 'char_rects' else 4
|
||||
char_num = len(results['ann_info'][self.box_type])
|
||||
for i in range(char_num):
|
||||
for j in range(num_points):
|
||||
results['ann_info'][self.box_type][i][
|
||||
j * 2] += random_padding_left
|
||||
results['ann_info'][self.box_type][i][
|
||||
j * 2 + 1] += random_padding_top
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomRotateImageBox:
|
||||
"""Rotate augmentation for segmentation based text recognition.
|
||||
|
||||
Args:
|
||||
min_angle (int): Minimum rotation angle for image and box.
|
||||
max_angle (int): Maximum rotation angle for image and box.
|
||||
box_type (str): Character box type, should be either
|
||||
'char_rects' or 'char_quads', with 'char_rects'
|
||||
for rectangle with ``xyxy`` style and 'char_quads'
|
||||
for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
|
||||
"""
|
||||
|
||||
def __init__(self, min_angle=-10, max_angle=10, box_type='char_quads'):
|
||||
assert box_type in ('char_rects', 'char_quads')
|
||||
|
||||
self.min_angle = min_angle
|
||||
self.max_angle = max_angle
|
||||
self.box_type = box_type
|
||||
|
||||
def __call__(self, results):
|
||||
in_img = results['img']
|
||||
in_chars = results['ann_info']['chars']
|
||||
in_boxes = results['ann_info'][self.box_type]
|
||||
|
||||
img_width, img_height = in_img.size
|
||||
rotate_center = [img_width / 2., img_height / 2.]
|
||||
|
||||
tan_temp_max_angle = rotate_center[1] / rotate_center[0]
|
||||
temp_max_angle = np.arctan(tan_temp_max_angle) * 180. / np.pi
|
||||
|
||||
random_angle = np.random.uniform(
|
||||
max(self.min_angle, -temp_max_angle),
|
||||
min(self.max_angle, temp_max_angle))
|
||||
random_angle_radian = random_angle * np.pi / 180.
|
||||
|
||||
img_box = shapely_box(0, 0, img_width, img_height)
|
||||
|
||||
out_img = TF.rotate(
|
||||
in_img,
|
||||
random_angle,
|
||||
resample=False,
|
||||
expand=False,
|
||||
center=rotate_center)
|
||||
|
||||
out_boxes, out_chars = self.rotate_bbox(in_boxes, in_chars,
|
||||
random_angle_radian,
|
||||
rotate_center, img_box)
|
||||
|
||||
results['img'] = out_img
|
||||
results['ann_info']['chars'] = out_chars
|
||||
results['ann_info'][self.box_type] = out_boxes
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def rotate_bbox(boxes, chars, angle, center, img_box):
|
||||
out_boxes = []
|
||||
out_chars = []
|
||||
for idx, bbox in enumerate(boxes):
|
||||
temp_bbox = []
|
||||
for i in range(len(bbox) // 2):
|
||||
point = [bbox[2 * i], bbox[2 * i + 1]]
|
||||
temp_bbox.append(
|
||||
RandomRotateImageBox.rotate_point(point, angle, center))
|
||||
poly_temp_bbox = Polygon(temp_bbox).buffer(0)
|
||||
if poly_temp_bbox.is_valid:
|
||||
if img_box.intersects(poly_temp_bbox) and (
|
||||
not img_box.touches(poly_temp_bbox)):
|
||||
temp_bbox_area = poly_temp_bbox.area
|
||||
|
||||
intersect_area = img_box.intersection(poly_temp_bbox).area
|
||||
intersect_ratio = intersect_area / temp_bbox_area
|
||||
|
||||
if intersect_ratio >= 0.7:
|
||||
out_box = []
|
||||
for p in temp_bbox:
|
||||
out_box.extend(p)
|
||||
out_boxes.append(out_box)
|
||||
out_chars.append(chars[idx])
|
||||
|
||||
return out_boxes, out_chars
|
||||
|
||||
@staticmethod
|
||||
def rotate_point(point, angle, center):
|
||||
cos_theta = math.cos(-angle)
|
||||
sin_theta = math.sin(-angle)
|
||||
c_x = center[0]
|
||||
c_y = center[1]
|
||||
new_x = (point[0] - c_x) * cos_theta - (point[1] -
|
||||
c_y) * sin_theta + c_x
|
||||
new_y = (point[0] - c_x) * sin_theta + (point[1] -
|
||||
c_y) * cos_theta + c_y
|
||||
|
||||
return [new_x, new_y]
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class OpencvToPil:
|
||||
"""Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb)."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, results):
|
||||
img = results['img'][..., ::-1]
|
||||
img = Image.fromarray(img)
|
||||
results['img'] = img
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class PilToOpencv:
|
||||
"""Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr)."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, results):
|
||||
img = np.asarray(results['img'])
|
||||
img = img[..., ::-1]
|
||||
results['img'] = img
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
@ -1,109 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.transforms import Compose
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class MultiRotateAugOCR:
|
||||
"""Test-time augmentation with multiple rotations in the case that
|
||||
img_height > img_width.
|
||||
|
||||
An example configuration is as follows:
|
||||
|
||||
.. code-block::
|
||||
|
||||
rotate_degrees=[0, 90, 270],
|
||||
transforms=[
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=32,
|
||||
min_width=32,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
||||
]),
|
||||
]
|
||||
|
||||
After MultiRotateAugOCR with above configuration, the results are wrapped
|
||||
into lists of the same length as follows:
|
||||
|
||||
.. code-block::
|
||||
|
||||
dict(
|
||||
img=[...],
|
||||
img_shape=[...]
|
||||
...
|
||||
)
|
||||
|
||||
Args:
|
||||
transforms (list[dict]): Transformation applied for each augmentation.
|
||||
rotate_degrees (list[int] | None): Degrees of anti-clockwise rotation.
|
||||
force_rotate (bool): If True, rotate image by 'rotate_degrees'
|
||||
while ignore image aspect ratio.
|
||||
"""
|
||||
|
||||
def __init__(self, transforms, rotate_degrees=None, force_rotate=False):
|
||||
self.transforms = Compose(transforms)
|
||||
self.force_rotate = force_rotate
|
||||
if rotate_degrees is not None:
|
||||
self.rotate_degrees = rotate_degrees if isinstance(
|
||||
rotate_degrees, list) else [rotate_degrees]
|
||||
assert mmcv.is_list_of(self.rotate_degrees, int)
|
||||
for degree in self.rotate_degrees:
|
||||
assert 0 <= degree < 360
|
||||
assert degree % 90 == 0
|
||||
if 0 not in self.rotate_degrees:
|
||||
self.rotate_degrees.append(0)
|
||||
else:
|
||||
self.rotate_degrees = [0]
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to apply test time augment transformation to results.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to be transformed.
|
||||
|
||||
Returns:
|
||||
dict[str: list]: The augmented data, where each value is wrapped
|
||||
into a list.
|
||||
"""
|
||||
img_shape = results['img_shape']
|
||||
ori_height, ori_width = img_shape[:2]
|
||||
if not self.force_rotate and ori_height <= ori_width:
|
||||
rotate_degrees = [0]
|
||||
else:
|
||||
rotate_degrees = self.rotate_degrees
|
||||
aug_data = []
|
||||
for degree in set(rotate_degrees):
|
||||
_results = results.copy()
|
||||
if degree == 0:
|
||||
pass
|
||||
elif degree == 90:
|
||||
_results['img'] = np.rot90(_results['img'], 1)
|
||||
elif degree == 180:
|
||||
_results['img'] = np.rot90(_results['img'], 2)
|
||||
elif degree == 270:
|
||||
_results['img'] = np.rot90(_results['img'], 3)
|
||||
data = self.transforms(_results)
|
||||
aug_data.append(data)
|
||||
# list of dict to dict of list
|
||||
aug_data_dict = {key: [] for key in aug_data[0]}
|
||||
for data in aug_data:
|
||||
for key, val in data.items():
|
||||
aug_data_dict[key].append(val)
|
||||
return aug_data_dict
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(transforms={self.transforms}, '
|
||||
repr_str += f'rotate_degrees={self.rotate_degrees})'
|
||||
return repr_str
|
19
mmocr/datasets/transforms/__init__.py
Normal file
19
mmocr/datasets/transforms/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
|
||||
from .loading import LoadKIEAnnotations, LoadOCRAnnotations
|
||||
from .ocr_transforms import RandomCrop, RandomRotate, Resize
|
||||
from .textdet_transforms import (BoundedScaleAspectJitter, FixInvalidPolygon,
|
||||
RandomFlip, ShortScaleAspectJitter,
|
||||
SourceImagePad, TextDetRandomCrop,
|
||||
TextDetRandomCropFlip)
|
||||
from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight
|
||||
from .wrappers import ImgAugWrapper, TorchVisionWrapper
|
||||
|
||||
__all__ = [
|
||||
'LoadOCRAnnotations', 'RandomRotate', 'ImgAugWrapper', 'SourceImagePad',
|
||||
'TextDetRandomCropFlip', 'PyramidRescale', 'TorchVisionWrapper', 'Resize',
|
||||
'RandomCrop', 'TextDetRandomCrop', 'RandomCrop', 'PackTextDetInputs',
|
||||
'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth',
|
||||
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
|
||||
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon'
|
||||
]
|
615
mmocr/datasets/transforms/ocr_transforms.py
Normal file
615
mmocr/datasets/transforms/ocr_transforms.py
Normal file
@ -0,0 +1,615 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.transforms import Resize as MMCV_Resize
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmcv.transforms.utils import avoid_cache_randomness, cache_randomness
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from mmocr.utils import (bbox2poly, crop_polygon, is_poly_inside_rect,
|
||||
poly2bbox, rescale_polygon)
|
||||
from .wrappers import ImgAugWrapper
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
@avoid_cache_randomness
|
||||
class RandomCrop(BaseTransform):
|
||||
"""Randomly crop images and make sure to contain at least one intact
|
||||
instance.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_polygons
|
||||
- gt_bboxes
|
||||
- gt_bboxes_labels
|
||||
- gt_ignored
|
||||
- gt_texts (optional)
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- gt_polygons
|
||||
- gt_bboxes
|
||||
- gt_bboxes_labels
|
||||
- gt_ignored
|
||||
- gt_texts (optional)
|
||||
|
||||
Args:
|
||||
min_side_ratio (float): The ratio of the shortest edge of the cropped
|
||||
image to the original image size.
|
||||
"""
|
||||
|
||||
def __init__(self, min_side_ratio: float = 0.4) -> None:
|
||||
if not 0. <= min_side_ratio <= 1.:
|
||||
raise ValueError('`min_side_ratio` should be in range [0, 1],')
|
||||
self.min_side_ratio = min_side_ratio
|
||||
|
||||
def _sample_valid_start_end(self, valid_array: np.ndarray, min_len: int,
|
||||
max_start_idx: int,
|
||||
min_end_idx: int) -> Tuple[int, int]:
|
||||
"""Sample a start and end idx on a given axis that contains at least
|
||||
one polygon. There should be at least one intact polygon bounded by
|
||||
max_start_idx and min_end_idx.
|
||||
|
||||
Args:
|
||||
valid_array (ndarray): A 0-1 mask 1D array indicating valid regions
|
||||
on the axis. 0 indicates text regions which are not allowed to
|
||||
be sampled from.
|
||||
min_len (int): Minimum distance between two start and end points.
|
||||
max_start_idx (int): The maximum start index.
|
||||
min_end_idx (int): The minimum end index.
|
||||
|
||||
Returns:
|
||||
tuple(int, int): Start and end index on a given axis, where
|
||||
0 <= start < max_start_idx and
|
||||
min_end_idx <= end < len(valid_array).
|
||||
"""
|
||||
assert isinstance(min_len, int)
|
||||
assert len(valid_array) > min_len
|
||||
|
||||
start_array = valid_array.copy()
|
||||
max_start_idx = min(len(start_array) - min_len, max_start_idx)
|
||||
start_array[max_start_idx:] = 0
|
||||
start_array[0] = 1
|
||||
diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
|
||||
region_starts = np.where(diff_array < 0)[0]
|
||||
region_ends = np.where(diff_array > 0)[0]
|
||||
region_ind = np.random.randint(0, len(region_starts))
|
||||
start = np.random.randint(region_starts[region_ind],
|
||||
region_ends[region_ind])
|
||||
|
||||
end_array = valid_array.copy()
|
||||
min_end_idx = max(start + min_len, min_end_idx)
|
||||
end_array[:min_end_idx] = 0
|
||||
end_array[-1] = 1
|
||||
diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
|
||||
region_starts = np.where(diff_array < 0)[0]
|
||||
region_ends = np.where(diff_array > 0)[0]
|
||||
region_ind = np.random.randint(0, len(region_starts))
|
||||
# Note that end index will never be region_ends[region_ind]
|
||||
# and therefore end index is always in range [0, w+1]
|
||||
end = np.random.randint(region_starts[region_ind],
|
||||
region_ends[region_ind])
|
||||
return start, end
|
||||
|
||||
def _sample_crop_box(self, img_size: Tuple[int, int],
|
||||
results: Dict) -> np.ndarray:
|
||||
"""Generate crop box which only contains intact polygon instances with
|
||||
the number >= 1.
|
||||
|
||||
Args:
|
||||
img_size (tuple(int, int)): The image size (h, w).
|
||||
results (dict): The results dict.
|
||||
|
||||
Returns:
|
||||
ndarray: Crop area in shape (4, ).
|
||||
"""
|
||||
assert isinstance(img_size, tuple)
|
||||
h, w = img_size[:2]
|
||||
|
||||
# Crop box can be represented by any integer numbers in
|
||||
# range [0, w] and [0, h]
|
||||
x_valid_array = np.ones(w + 1, dtype=np.int32)
|
||||
y_valid_array = np.ones(h + 1, dtype=np.int32)
|
||||
|
||||
polygons = results['gt_polygons']
|
||||
|
||||
# Randomly select a polygon that must be inside
|
||||
# the cropped region
|
||||
kept_poly_idx = np.random.randint(0, len(polygons))
|
||||
for i, polygon in enumerate(polygons):
|
||||
polygon = polygon.reshape((-1, 2))
|
||||
|
||||
clip_x = np.clip(polygon[:, 0], 0, w)
|
||||
clip_y = np.clip(polygon[:, 1], 0, h)
|
||||
min_x = np.floor(np.min(clip_x)).astype(np.int32)
|
||||
min_y = np.floor(np.min(clip_y)).astype(np.int32)
|
||||
max_x = np.ceil(np.max(clip_x)).astype(np.int32)
|
||||
max_y = np.ceil(np.max(clip_y)).astype(np.int32)
|
||||
|
||||
x_valid_array[min_x:max_x] = 0
|
||||
y_valid_array[min_y:max_y] = 0
|
||||
|
||||
if i == kept_poly_idx:
|
||||
max_x_start = min_x
|
||||
min_x_end = max_x
|
||||
max_y_start = min_y
|
||||
min_y_end = max_y
|
||||
|
||||
min_w = int(w * self.min_side_ratio)
|
||||
min_h = int(h * self.min_side_ratio)
|
||||
|
||||
x1, x2 = self._sample_valid_start_end(x_valid_array, min_w,
|
||||
max_x_start, min_x_end)
|
||||
y1, y2 = self._sample_valid_start_end(y_valid_array, min_h,
|
||||
max_y_start, min_y_end)
|
||||
|
||||
return np.array([x1, y1, x2, y2])
|
||||
|
||||
def _crop_img(self, img: np.ndarray, bbox: np.ndarray) -> np.ndarray:
|
||||
"""Crop image given a bbox region.
|
||||
Args:
|
||||
img (ndarray): Image.
|
||||
bbox (ndarray): Cropping region in shape (4, )
|
||||
|
||||
Returns:
|
||||
ndarray: Cropped image.
|
||||
"""
|
||||
assert img.ndim == 3
|
||||
h, w, _ = img.shape
|
||||
assert 0 <= bbox[1] < bbox[3] <= h
|
||||
assert 0 <= bbox[0] < bbox[2] <= w
|
||||
return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Applying random crop on results.
|
||||
Args:
|
||||
results (dict): Result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: The transformed data.
|
||||
"""
|
||||
if len(results['gt_polygons']) < 1:
|
||||
return results
|
||||
|
||||
crop_box = self._sample_crop_box(results['img'].shape, results)
|
||||
img = self._crop_img(results['img'], crop_box)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
crop_x = crop_box[0]
|
||||
crop_y = crop_box[1]
|
||||
crop_w = crop_box[2] - crop_box[0]
|
||||
crop_h = crop_box[3] - crop_box[1]
|
||||
|
||||
labels = results['gt_bboxes_labels']
|
||||
valid_labels = []
|
||||
ignored = results['gt_ignored']
|
||||
valid_ignored = []
|
||||
if 'gt_texts' in results:
|
||||
valid_texts = []
|
||||
texts = results['gt_texts']
|
||||
|
||||
polys = results['gt_polygons']
|
||||
valid_polys = []
|
||||
for idx, poly in enumerate(polys):
|
||||
poly = poly.reshape(-1, 2)
|
||||
poly = (poly - (crop_x, crop_y)).flatten()
|
||||
if is_poly_inside_rect(poly, [0, 0, crop_w, crop_h]):
|
||||
valid_polys.append(poly)
|
||||
valid_labels.append(labels[idx])
|
||||
valid_ignored.append(ignored[idx])
|
||||
if 'gt_texts' in results:
|
||||
valid_texts.append(texts[idx])
|
||||
results['gt_polygons'] = valid_polys
|
||||
results['gt_bboxes_labels'] = np.array(valid_labels, dtype=np.int64)
|
||||
results['gt_ignored'] = np.array(valid_ignored, dtype=bool)
|
||||
if 'gt_texts' in results:
|
||||
results['gt_texts'] = valid_texts
|
||||
valid_bboxes = [poly2bbox(poly) for poly in results['gt_polygons']]
|
||||
results['gt_bboxes'] = np.array(valid_bboxes).astype(
|
||||
np.float32).reshape(-1, 4)
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(min_side_ratio = {self.min_side_ratio})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomRotate(BaseTransform):
|
||||
"""Randomly rotate the image, boxes, and polygons. For recognition task,
|
||||
only the image will be rotated. If set ``use_canvas`` as True, the shape of
|
||||
rotated image might be modified based on the rotated angle size, otherwise,
|
||||
the image will keep the shape before rotation.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- gt_bboxes (optional)
|
||||
- gt_polygons (optional)
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape (optional)
|
||||
- gt_bboxes (optional)
|
||||
- gt_polygons (optional)
|
||||
|
||||
Added Keys:
|
||||
|
||||
- rotated_angle
|
||||
|
||||
Args:
|
||||
max_angle (int): The maximum rotation angle (can be bigger than 180 or
|
||||
a negative). Defaults to 10.
|
||||
pad_with_fixed_color (bool): The flag for whether to pad rotated
|
||||
image with fixed value. Defaults to False.
|
||||
pad_value (tuple[int, int, int]): The color value for padding rotated
|
||||
image. Defaults to (0, 0, 0).
|
||||
use_canvas (bool): Whether to create a canvas for rotated image.
|
||||
Defaults to False. If set true, the image shape may be modified.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_angle: int = 10,
|
||||
pad_with_fixed_color: bool = False,
|
||||
pad_value: Tuple[int, int, int] = (0, 0, 0),
|
||||
use_canvas: bool = False,
|
||||
) -> None:
|
||||
if not isinstance(max_angle, int):
|
||||
raise TypeError('`max_angle` should be an integer'
|
||||
f', but got {type(max_angle)} instead')
|
||||
if not isinstance(pad_with_fixed_color, bool):
|
||||
raise TypeError('`pad_with_fixed_color` should be a bool, '
|
||||
f'but got {type(pad_with_fixed_color)} instead')
|
||||
if not isinstance(pad_value, (list, tuple)):
|
||||
raise TypeError('`pad_value` should be a list or tuple, '
|
||||
f'but got {type(pad_value)} instead')
|
||||
if len(pad_value) != 3:
|
||||
raise ValueError('`pad_value` should contain three integers')
|
||||
if not isinstance(pad_value[0], int) or not isinstance(
|
||||
pad_value[1], int) or not isinstance(pad_value[2], int):
|
||||
raise ValueError('`pad_value` should contain three integers')
|
||||
|
||||
self.max_angle = max_angle
|
||||
self.pad_with_fixed_color = pad_with_fixed_color
|
||||
self.pad_value = pad_value
|
||||
self.use_canvas = use_canvas
|
||||
|
||||
@cache_randomness
|
||||
def _sample_angle(self, max_angle: int) -> float:
|
||||
"""Sampling a random angle for rotation.
|
||||
|
||||
Args:
|
||||
max_angle (int): Maximum rotation angle
|
||||
|
||||
Returns:
|
||||
float: The random angle used for rotation
|
||||
"""
|
||||
angle = np.random.random_sample() * 2 * max_angle - max_angle
|
||||
return angle
|
||||
|
||||
@staticmethod
|
||||
def _cal_canvas_size(ori_size: Tuple[int, int],
|
||||
degree: int) -> Tuple[int, int]:
|
||||
"""Calculate the canvas size.
|
||||
|
||||
Args:
|
||||
ori_size (Tuple[int, int]): The original image size (height, width)
|
||||
degree (int): The rotation angle
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The size of the canvas
|
||||
"""
|
||||
assert isinstance(ori_size, tuple)
|
||||
angle = degree * math.pi / 180.0
|
||||
h, w = ori_size[:2]
|
||||
|
||||
cos = math.cos(angle)
|
||||
sin = math.sin(angle)
|
||||
canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
|
||||
canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
|
||||
|
||||
canvas_size = (canvas_h, canvas_w)
|
||||
return canvas_size
|
||||
|
||||
@staticmethod
|
||||
def _rotate_points(center: Tuple[float, float],
|
||||
points: np.array,
|
||||
theta: float,
|
||||
center_shift: Tuple[int, int] = (0, 0)) -> np.array:
|
||||
"""Rotating a set of points according to the given theta.
|
||||
|
||||
Args:
|
||||
center (Tuple[float, float]): The coordinate of the canvas center
|
||||
points (np.array): A set of points needed to be rotated
|
||||
theta (float): Rotation angle
|
||||
center_shift (Tuple[int, int]): The shifting offset of the center
|
||||
coordinate
|
||||
|
||||
Returns:
|
||||
np.array: The rotated coordinates of the input points
|
||||
"""
|
||||
(center_x, center_y) = center
|
||||
center_y = -center_y
|
||||
x, y = points[::2], points[1::2]
|
||||
y = -y
|
||||
|
||||
theta = theta / 180 * math.pi
|
||||
cos = math.cos(theta)
|
||||
sin = math.sin(theta)
|
||||
|
||||
x = (x - center_x)
|
||||
y = (y - center_y)
|
||||
|
||||
_x = center_x + x * cos - y * sin + center_shift[0]
|
||||
_y = -(center_y + x * sin + y * cos) + center_shift[1]
|
||||
|
||||
points[::2], points[1::2] = _x, _y
|
||||
return points
|
||||
|
||||
def _rotate_img(self, results: Dict) -> Tuple[int, int]:
|
||||
"""Rotating the input image based on the given angle.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict containing the data to transform.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The shifting offset of the center point.
|
||||
"""
|
||||
if results.get('img', None) is not None:
|
||||
h = results['img'].shape[0]
|
||||
w = results['img'].shape[1]
|
||||
rotation_matrix = cv2.getRotationMatrix2D(
|
||||
(w / 2, h / 2), results['rotated_angle'], 1)
|
||||
|
||||
canvas_size = self._cal_canvas_size((h, w),
|
||||
results['rotated_angle'])
|
||||
center_shift = (int(
|
||||
(canvas_size[1] - w) / 2), int((canvas_size[0] - h) / 2))
|
||||
rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
|
||||
rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
|
||||
if self.pad_with_fixed_color:
|
||||
rotated_img = cv2.warpAffine(
|
||||
results['img'],
|
||||
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
||||
flags=cv2.INTER_NEAREST,
|
||||
borderValue=self.pad_value)
|
||||
else:
|
||||
mask = np.zeros_like(results['img'])
|
||||
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
|
||||
np.random.randint(0, w * 7 // 8))
|
||||
img_cut = results['img'][h_ind:(h_ind + h // 9),
|
||||
w_ind:(w_ind + w // 9)]
|
||||
img_cut = mmcv.imresize(img_cut,
|
||||
(canvas_size[1], canvas_size[0]))
|
||||
mask = cv2.warpAffine(
|
||||
mask,
|
||||
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
||||
borderValue=[1, 1, 1])
|
||||
rotated_img = cv2.warpAffine(
|
||||
results['img'],
|
||||
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
||||
borderValue=[0, 0, 0])
|
||||
rotated_img = rotated_img + img_cut * mask
|
||||
|
||||
results['img'] = rotated_img
|
||||
else:
|
||||
raise ValueError('`img` is not found in results')
|
||||
|
||||
return center_shift
|
||||
|
||||
def _rotate_bboxes(self, results: Dict, center_shift: Tuple[int,
|
||||
int]) -> None:
|
||||
"""Rotating the bounding boxes based on the given angle.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict containing the data to transform.
|
||||
center_shift (Tuple[int, int]): The shifting offset of the
|
||||
center point
|
||||
"""
|
||||
if results.get('gt_bboxes', None) is not None:
|
||||
height, width = results['img_shape']
|
||||
box_list = []
|
||||
for box in results['gt_bboxes']:
|
||||
rotated_box = self._rotate_points((width / 2, height / 2),
|
||||
bbox2poly(box),
|
||||
results['rotated_angle'],
|
||||
center_shift)
|
||||
rotated_box = poly2bbox(rotated_box)
|
||||
box_list.append(rotated_box)
|
||||
|
||||
results['gt_bboxes'] = np.array(
|
||||
box_list, dtype=np.float32).reshape(-1, 4)
|
||||
|
||||
def _rotate_polygons(self, results: Dict,
|
||||
center_shift: Tuple[int, int]) -> None:
|
||||
"""Rotating the polygons based on the given angle.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict containing the data to transform.
|
||||
center_shift (Tuple[int, int]): The shifting offset of the
|
||||
center point
|
||||
"""
|
||||
if results.get('gt_polygons', None) is not None:
|
||||
height, width = results['img_shape']
|
||||
polygon_list = []
|
||||
for poly in results['gt_polygons']:
|
||||
rotated_poly = self._rotate_points(
|
||||
(width / 2, height / 2), poly, results['rotated_angle'],
|
||||
center_shift)
|
||||
polygon_list.append(rotated_poly)
|
||||
results['gt_polygons'] = polygon_list
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Applying random rotate on results.
|
||||
|
||||
Args:
|
||||
results (Dict): Result dict containing the data to transform.
|
||||
center_shift (Tuple[int, int]): The shifting offset of the
|
||||
center point
|
||||
|
||||
Returns:
|
||||
dict: The transformed data
|
||||
"""
|
||||
# TODO rotate char_quads & char_rects for SegOCR
|
||||
if self.use_canvas:
|
||||
results['rotated_angle'] = self._sample_angle(self.max_angle)
|
||||
# rotate image
|
||||
center_shift = self._rotate_img(results)
|
||||
# rotate gt_bboxes
|
||||
self._rotate_bboxes(results, center_shift)
|
||||
# rotate gt_polygons
|
||||
self._rotate_polygons(results, center_shift)
|
||||
|
||||
results['img_shape'] = (results['img'].shape[0],
|
||||
results['img'].shape[1])
|
||||
else:
|
||||
args = [
|
||||
dict(
|
||||
cls='Affine',
|
||||
rotate=[-self.max_angle, self.max_angle],
|
||||
backend='cv2',
|
||||
order=0) # order=0 -> cv2.INTER_NEAREST
|
||||
]
|
||||
imgaug_transform = ImgAugWrapper(args)
|
||||
results = imgaug_transform(results)
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(max_angle = {self.max_angle}'
|
||||
repr_str += f', pad_with_fixed_color = {self.pad_with_fixed_color}'
|
||||
repr_str += f', pad_value = {self.pad_value}'
|
||||
repr_str += f', use_canvas = {self.use_canvas})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Resize(MMCV_Resize):
|
||||
"""Resize image & bboxes & polygons.
|
||||
|
||||
This transform resizes the input image according to ``scale`` or
|
||||
``scale_factor``. Bboxes and polygons are then resized with the same
|
||||
scale factor. if ``scale`` and ``scale_factor`` are both set, it will use
|
||||
``scale`` to resize.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- gt_bboxes
|
||||
- gt_polygons
|
||||
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- gt_bboxes
|
||||
- gt_polygons
|
||||
|
||||
Added Keys:
|
||||
|
||||
- scale
|
||||
- scale_factor
|
||||
- keep_ratio
|
||||
|
||||
Args:
|
||||
scale (int or tuple): Image scales for resizing. Defaults to None.
|
||||
scale_factor (float or tuple[float, float]): Scale factors for
|
||||
resizing. It's either a factor applicable to both dimensions or
|
||||
in the form of (scale_w, scale_h). Defaults to None.
|
||||
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
|
||||
image. Defaults to False.
|
||||
clip_object_border (bool): Whether to clip the objects outside the
|
||||
border of the image. Defaults to True.
|
||||
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
|
||||
These two backends generates slightly different results. Defaults
|
||||
to 'cv2'.
|
||||
interpolation (str): Interpolation method, accepted values are
|
||||
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
|
||||
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
|
||||
to 'bilinear'.
|
||||
"""
|
||||
|
||||
def _resize_img(self, results: dict) -> None:
|
||||
"""Resize images with ``results['scale']``.
|
||||
|
||||
If no image is provided, only resize ``results['img_shape']``.
|
||||
"""
|
||||
if results.get('img', None) is not None:
|
||||
return super()._resize_img(results)
|
||||
h, w = results['img_shape']
|
||||
if self.keep_ratio:
|
||||
new_w, new_h = mmcv.rescale_size((w, h),
|
||||
results['scale'],
|
||||
return_scale=False)
|
||||
else:
|
||||
new_w, new_h = results['scale']
|
||||
w_scale = new_w / w
|
||||
h_scale = new_h / h
|
||||
results['img_shape'] = (new_h, new_w)
|
||||
results['scale'] = (new_w, new_h)
|
||||
results['scale_factor'] = (w_scale, h_scale)
|
||||
results['keep_ratio'] = self.keep_ratio
|
||||
|
||||
def _resize_bboxes(self, results: dict) -> None:
|
||||
"""Resize bounding boxes."""
|
||||
super()._resize_bboxes(results)
|
||||
if results.get('gt_bboxes', None) is not None:
|
||||
results['gt_bboxes'] = results['gt_bboxes'].astype(np.float32)
|
||||
|
||||
def _resize_polygons(self, results: dict) -> None:
|
||||
"""Resize polygons with ``results['scale_factor']``."""
|
||||
if results.get('gt_polygons', None) is not None:
|
||||
polygons = results['gt_polygons']
|
||||
polygons_resize = []
|
||||
for idx, polygon in enumerate(polygons):
|
||||
polygon = rescale_polygon(polygon, results['scale_factor'])
|
||||
if self.clip_object_border:
|
||||
crop_bbox = np.array([
|
||||
0, 0, results['img_shape'][1], results['img_shape'][0]
|
||||
])
|
||||
polygon = crop_polygon(polygon, crop_bbox)
|
||||
if polygon is not None:
|
||||
polygons_resize.append(polygon.astype(np.float32))
|
||||
else:
|
||||
polygons_resize.append(
|
||||
np.zeros_like(polygons[idx], dtype=np.float32))
|
||||
results['gt_polygons'] = polygons_resize
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to resize images, bounding boxes and polygons.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Resized results, 'img', 'gt_bboxes', 'gt_polygons',
|
||||
'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys
|
||||
are updated in result dict.
|
||||
"""
|
||||
results = super().transform(results)
|
||||
self._resize_polygons(results)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(scale={self.scale}, '
|
||||
repr_str += f'scale_factor={self.scale_factor}, '
|
||||
repr_str += f'keep_ratio={self.keep_ratio}, '
|
||||
repr_str += f'clip_object_border={self.clip_object_border}), '
|
||||
repr_str += f'backend={self.backend}), '
|
||||
repr_str += f'interpolation={self.interpolation})'
|
||||
return repr_str
|
File diff suppressed because it is too large
Load Diff
252
mmocr/datasets/transforms/textrecog_transforms.py
Normal file
252
mmocr/datasets/transforms/textrecog_transforms.py
Normal file
@ -0,0 +1,252 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmcv.transforms.utils import cache_randomness
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class PyramidRescale(BaseTransform):
|
||||
"""Resize the image to the base shape, downsample it with gaussian pyramid,
|
||||
and rescale it back to original size.
|
||||
|
||||
Adapted from https://github.com/FangShancheng/ABINet.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img (ndarray)
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img (ndarray)
|
||||
|
||||
Args:
|
||||
factor (int): The decay factor from base size, or the number of
|
||||
downsampling operations from the base layer.
|
||||
base_shape (tuple[int, int]): The shape (width, height) of the base
|
||||
layer of the pyramid.
|
||||
randomize_factor (bool): If True, the final factor would be a random
|
||||
integer in [0, factor].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
factor: int = 4,
|
||||
base_shape: Tuple[int, int] = (128, 512),
|
||||
randomize_factor: bool = True) -> None:
|
||||
if not isinstance(factor, int):
|
||||
raise TypeError('`factor` should be an integer, '
|
||||
f'but got {type(factor)} instead')
|
||||
if not isinstance(base_shape, (list, tuple)):
|
||||
raise TypeError('`base_shape` should be a list or tuple, '
|
||||
f'but got {type(base_shape)} instead')
|
||||
if not len(base_shape) == 2:
|
||||
raise ValueError('`base_shape` should contain two integers')
|
||||
if not isinstance(base_shape[0], int) or not isinstance(
|
||||
base_shape[1], int):
|
||||
raise ValueError('`base_shape` should contain two integers')
|
||||
if not isinstance(randomize_factor, bool):
|
||||
raise TypeError('`randomize_factor` should be a bool, '
|
||||
f'but got {type(randomize_factor)} instead')
|
||||
|
||||
self.factor = factor
|
||||
self.randomize_factor = randomize_factor
|
||||
self.base_w, self.base_h = base_shape
|
||||
|
||||
@cache_randomness
|
||||
def get_random_factor(self) -> float:
|
||||
"""Get the randomized factor.
|
||||
|
||||
Returns:
|
||||
float: The randomized factor.
|
||||
"""
|
||||
return np.random.randint(0, self.factor + 1)
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Applying pyramid rescale on results.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict containing the data to transform.
|
||||
|
||||
Returns:
|
||||
Dict: The transformed data.
|
||||
"""
|
||||
|
||||
assert 'img' in results, '`img` is not found in results'
|
||||
if self.randomize_factor:
|
||||
self.factor = self.get_random_factor()
|
||||
if self.factor == 0:
|
||||
return results
|
||||
img = results['img']
|
||||
src_h, src_w = img.shape[:2]
|
||||
scale_img = mmcv.imresize(img, (self.base_w, self.base_h))
|
||||
for _ in range(self.factor):
|
||||
scale_img = cv2.pyrDown(scale_img)
|
||||
scale_img = mmcv.imresize(scale_img, (src_w, src_h))
|
||||
results['img'] = scale_img
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(factor = {self.factor}'
|
||||
repr_str += f', randomize_factor = {self.randomize_factor}'
|
||||
repr_str += f', base_w = {self.base_w}'
|
||||
repr_str += f', base_h = {self.base_h})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RescaleToHeight(BaseTransform):
|
||||
"""Rescale the image to the height according to setting and keep the aspect
|
||||
ratio unchanged if possible. However, if any of ``min_width``,
|
||||
``max_width`` or ``width_divisor`` are specified, aspect ratio may still be
|
||||
changed to ensure the width meets these constraints.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
||||
Added Keys:
|
||||
|
||||
- scale
|
||||
- scale_factor
|
||||
- keep_ratio
|
||||
|
||||
Args:
|
||||
height (int): Height of rescaled image.
|
||||
min_width (int, optional): Minimum width of rescaled image. Defaults
|
||||
to None.
|
||||
max_width (int, optional): Maximum width of rescaled image. Defaults
|
||||
to None.
|
||||
width_divisor (int): The divisor of width size. Defaults to 1.
|
||||
resize_cfg (dict): (dict): Config to construct the Resize transform.
|
||||
Refer to ``Resize`` for detail. Defaults to
|
||||
``dict(type='Resize')``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
height: int,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
width_divisor: int = 1,
|
||||
resize_cfg: dict = dict(type='Resize')) -> None:
|
||||
super().__init__()
|
||||
assert isinstance(height, int)
|
||||
assert isinstance(width_divisor, int)
|
||||
if min_width is not None:
|
||||
assert isinstance(min_width, int)
|
||||
if max_width is not None:
|
||||
assert isinstance(max_width, int)
|
||||
self.width_divisor = width_divisor
|
||||
self.height = height
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.resize_cfg = resize_cfg
|
||||
_resize_cfg = self.resize_cfg.copy()
|
||||
_resize_cfg.update(dict(scale=0))
|
||||
self.resize = TRANSFORMS.build(_resize_cfg)
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Transform function to resize images, bounding boxes and polygons.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Resized results.
|
||||
"""
|
||||
ori_height, ori_width = results['img'].shape[:2]
|
||||
new_width = math.ceil(float(self.height) / ori_height * ori_width)
|
||||
if self.min_width is not None:
|
||||
new_width = max(self.min_width, new_width)
|
||||
if self.max_width is not None:
|
||||
new_width = min(self.max_width, new_width)
|
||||
|
||||
if new_width % self.width_divisor != 0:
|
||||
new_width = round(
|
||||
new_width / self.width_divisor) * self.width_divisor
|
||||
# TODO replace up code after testing precision.
|
||||
# new_width = math.ceil(
|
||||
# new_width / self.width_divisor) * self.width_divisor
|
||||
scale = (new_width, self.height)
|
||||
self.resize.scale = scale
|
||||
results = self.resize(results)
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(height={self.height}, '
|
||||
repr_str += f'min_width={self.min_width}, '
|
||||
repr_str += f'max_width={self.max_width}, '
|
||||
repr_str += f'width_divisor={self.width_divisor}, '
|
||||
repr_str += f'resize_cfg={self.resize_cfg})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class PadToWidth(BaseTransform):
|
||||
"""Only pad the image's width.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
||||
Added Keys:
|
||||
|
||||
- pad_shape
|
||||
- pad_fixed_size
|
||||
- pad_size_divisor
|
||||
- valid_ratio
|
||||
|
||||
Args:
|
||||
width (int): Target width of padded image. Defaults to None.
|
||||
pad_cfg (dict): Config to construct the Resize transform. Refer to
|
||||
``Pad`` for detail. Defaults to ``dict(type='Pad')``.
|
||||
"""
|
||||
|
||||
def __init__(self, width: int, pad_cfg: dict = dict(type='Pad')) -> None:
|
||||
super().__init__()
|
||||
assert isinstance(width, int)
|
||||
self.width = width
|
||||
self.pad_cfg = pad_cfg
|
||||
_pad_cfg = self.pad_cfg.copy()
|
||||
_pad_cfg.update(dict(size=0))
|
||||
self.pad = TRANSFORMS.build(_pad_cfg)
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Call function to pad images.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Updated result dict.
|
||||
"""
|
||||
ori_height, ori_width = results['img'].shape[:2]
|
||||
valid_ratio = min(1.0, 1.0 * ori_width / self.width)
|
||||
size = (self.width, ori_height)
|
||||
self.pad.size = size
|
||||
results = self.pad(results)
|
||||
results['valid_ratio'] = valid_ratio
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(width={self.width}, '
|
||||
repr_str += f'pad_cfg={self.pad_cfg})'
|
||||
return repr_str
|
@ -13,7 +13,7 @@ from mmocr.utils import poly2bbox
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ImgAug(BaseTransform):
|
||||
class ImgAugWrapper(BaseTransform):
|
||||
"""A wrapper around imgaug https://github.com/aleju/imgaug.
|
||||
|
||||
Find available augmenters at
|
||||
@ -167,7 +167,7 @@ class ImgAug(BaseTransform):
|
||||
return new_polys, removed_poly_inds
|
||||
|
||||
def _build_augmentation(self, args, root=True):
|
||||
"""Build ImgAug augmentations.
|
||||
"""Build ImgAugWrapper augmentations.
|
||||
|
||||
Args:
|
||||
args (dict): Arguments to be passed to imgaug.
|
@ -143,10 +143,10 @@ def eval_hmean_ic13(det_boxes,
|
||||
gt_point = np.array(gt_points[gt_id])
|
||||
det_point = np.array(pred_points[pred_id])
|
||||
|
||||
norm_dist = utils.box_center_distance(
|
||||
norm_dist = utils.bbox_center_distance(
|
||||
det_point, gt_point)
|
||||
norm_dist /= utils.box_diag(
|
||||
det_point) + utils.box_diag(gt_point)
|
||||
norm_dist /= utils.bbox_diag(
|
||||
det_point) + utils.bbox_diag(gt_point)
|
||||
norm_dist *= 2.0
|
||||
|
||||
if norm_dist < center_dist_thr:
|
||||
|
@ -1,8 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
from .bbox_utils import (bbox2poly, bezier_to_polygon, box_center_distance,
|
||||
box_diag, is_on_same_line, rescale_bboxes,
|
||||
from .bbox_utils import (bbox2poly, bbox_center_distance, bbox_diag,
|
||||
bezier_to_polygon, is_on_same_line, rescale_bboxes,
|
||||
sort_points, sort_vertex, sort_vertex8,
|
||||
stitch_boxes_into_lines)
|
||||
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type,
|
||||
@ -13,6 +13,7 @@ from .evaluation_utils import (compute_hmean, filter_2dlist_result,
|
||||
many2one_match_ic13, one2one_match_ic13,
|
||||
select_top_boundary)
|
||||
from .fileio import list_from_file, list_to_file
|
||||
from .img_utils import crop_img, warp_img
|
||||
from .logger import get_root_logger
|
||||
from .mask_utils import fill_hole
|
||||
from .model import revert_sync_batchnorm
|
||||
@ -36,9 +37,9 @@ __all__ = [
|
||||
'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_inside_rect',
|
||||
'poly2bbox', 'poly_intersection', 'poly_iou', 'poly_make_valid',
|
||||
'poly_union', 'poly2shapely', 'polys2shapely', 'register_all_modules',
|
||||
'offset_polygon', 'sort_vertex8', 'sort_vertex', 'box_center_distance',
|
||||
'box_diag', 'compute_hmean', 'filter_2dlist_result', 'many2one_match_ic13',
|
||||
'one2one_match_ic13', 'select_top_boundary', 'boundary_iou',
|
||||
'point_distance', 'points_center', 'fill_hole', 'LineJsonParser',
|
||||
'LineStrParser', 'shapely2poly'
|
||||
'offset_polygon', 'sort_vertex8', 'sort_vertex', 'bbox_center_distance',
|
||||
'bbox_diag', 'compute_hmean', 'filter_2dlist_result',
|
||||
'many2one_match_ic13', 'one2one_match_ic13', 'select_top_boundary',
|
||||
'boundary_iou', 'point_distance', 'points_center', 'fill_hole',
|
||||
'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img', 'warp_img'
|
||||
]
|
||||
|
@ -4,6 +4,7 @@ from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
from shapely.geometry import LineString, Point
|
||||
|
||||
from mmocr.utils.check_argument import is_2dlist, is_type_list
|
||||
from mmocr.utils.point_utils import point_distance, points_center
|
||||
@ -329,16 +330,47 @@ def sort_vertex8(points):
|
||||
return sorted_box
|
||||
|
||||
|
||||
def box_center_distance(b1, b2):
|
||||
def bbox_center_distance(b1, b2):
|
||||
# TODO typehints & docstring & test
|
||||
assert isinstance(b1, np.ndarray)
|
||||
assert isinstance(b2, np.ndarray)
|
||||
return point_distance(points_center(b1), points_center(b2))
|
||||
|
||||
|
||||
def box_diag(box):
|
||||
def bbox_diag(box):
|
||||
# TODO typehints & docstring & test
|
||||
assert isinstance(box, np.ndarray)
|
||||
assert box.size == 8
|
||||
|
||||
return point_distance(box[0:2], box[4:6])
|
||||
|
||||
|
||||
def bbox_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1):
|
||||
"""Jitter on the coordinates of bounding box.
|
||||
|
||||
Args:
|
||||
points_x (list[float | int]): List of y for four vertices.
|
||||
points_y (list[float | int]): List of x for four vertices.
|
||||
jitter_ratio_x (float): Horizontal jitter ratio relative to the height.
|
||||
jitter_ratio_y (float): Vertical jitter ratio relative to the height.
|
||||
"""
|
||||
assert len(points_x) == 4
|
||||
assert len(points_y) == 4
|
||||
assert isinstance(jitter_ratio_x, float)
|
||||
assert isinstance(jitter_ratio_y, float)
|
||||
assert 0 <= jitter_ratio_x < 1
|
||||
assert 0 <= jitter_ratio_y < 1
|
||||
|
||||
points = [Point(points_x[i], points_y[i]) for i in range(4)]
|
||||
line_list = [
|
||||
LineString([points[i], points[i + 1 if i < 3 else 0]])
|
||||
for i in range(4)
|
||||
]
|
||||
|
||||
tmp_h = max(line_list[1].length, line_list[3].length)
|
||||
|
||||
for i in range(4):
|
||||
jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h
|
||||
jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h
|
||||
points_x[i] += jitter_pixel_x
|
||||
points_y[i] += jitter_pixel_y
|
||||
|
@ -1,64 +1,40 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import cv2
|
||||
import numpy as np
|
||||
from mmcv.utils import is_seq_of
|
||||
from shapely.geometry import LineString, Point
|
||||
|
||||
import mmocr.utils as utils
|
||||
|
||||
|
||||
def box_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1):
|
||||
"""Jitter on the coordinates of bounding box.
|
||||
|
||||
Args:
|
||||
points_x (list[float | int]): List of y for four vertices.
|
||||
points_y (list[float | int]): List of x for four vertices.
|
||||
jitter_ratio_x (float): Horizontal jitter ratio relative to the height.
|
||||
jitter_ratio_y (float): Vertical jitter ratio relative to the height.
|
||||
"""
|
||||
assert len(points_x) == 4
|
||||
assert len(points_y) == 4
|
||||
assert isinstance(jitter_ratio_x, float)
|
||||
assert isinstance(jitter_ratio_y, float)
|
||||
assert 0 <= jitter_ratio_x < 1
|
||||
assert 0 <= jitter_ratio_y < 1
|
||||
|
||||
points = [Point(points_x[i], points_y[i]) for i in range(4)]
|
||||
line_list = [
|
||||
LineString([points[i], points[i + 1 if i < 3 else 0]])
|
||||
for i in range(4)
|
||||
]
|
||||
|
||||
tmp_h = max(line_list[1].length, line_list[3].length)
|
||||
|
||||
for i in range(4):
|
||||
jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h
|
||||
jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h
|
||||
points_x[i] += jitter_pixel_x
|
||||
points_y[i] += jitter_pixel_y
|
||||
from .bbox_utils import bbox_jitter, sort_vertex
|
||||
|
||||
|
||||
def warp_img(src_img,
|
||||
box,
|
||||
jitter_flag=False,
|
||||
jitter=False,
|
||||
jitter_ratio_x=0.5,
|
||||
jitter_ratio_y=0.1):
|
||||
"""Crop box area from image using opencv warpPerspective w/o box jitter.
|
||||
"""Crop box area from image using opencv warpPerspective.
|
||||
|
||||
Args:
|
||||
src_img (np.array): Image before cropping.
|
||||
box (list[float | int]): Coordinates of quadrangle.
|
||||
jitter (bool): Whether to jitter the box.
|
||||
jitter_ratio_x (float): Horizontal jitter ratio relative to the height.
|
||||
jitter_ratio_y (float): Vertical jitter ratio relative to the height.
|
||||
|
||||
Returns:
|
||||
np.array: The warped image.
|
||||
"""
|
||||
assert utils.is_type_list(box, (float, int))
|
||||
assert is_seq_of(box, (float, int))
|
||||
assert len(box) == 8
|
||||
|
||||
h, w = src_img.shape[:2]
|
||||
points_x = [min(max(x, 0), w) for x in box[0:8:2]]
|
||||
points_y = [min(max(y, 0), h) for y in box[1:9:2]]
|
||||
|
||||
points_x, points_y = utils.sort_vertex(points_x, points_y)
|
||||
points_x, points_y = sort_vertex(points_x, points_y)
|
||||
|
||||
if jitter_flag:
|
||||
box_jitter(
|
||||
if jitter:
|
||||
bbox_jitter(
|
||||
points_x,
|
||||
points_y,
|
||||
jitter_ratio_x=jitter_ratio_x,
|
||||
@ -84,17 +60,24 @@ def warp_img(src_img,
|
||||
|
||||
|
||||
def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
|
||||
"""Crop text region with their bounding box.
|
||||
"""Crop text region given the bounding box which might be slightly padded.
|
||||
The bounding box is assumed to be a quadrangle and tightly bound the text
|
||||
region.
|
||||
|
||||
Args:
|
||||
src_img (np.array): The original image.
|
||||
box (list[float | int]): Points of quadrangle.
|
||||
long_edge_pad_ratio (float): Box pad ratio for long edge
|
||||
corresponding to font size.
|
||||
short_edge_pad_ratio (float): Box pad ratio for short edge
|
||||
corresponding to font size.
|
||||
long_edge_pad_ratio (float): The ratio of padding to the long edge. The
|
||||
padding will be the length of the short edge * long_edge_pad_ratio.
|
||||
Defaults to 0.4.
|
||||
short_edge_pad_ratio (float): The ratio of padding to the short edge.
|
||||
The padding will be the length of the long edge *
|
||||
short_edge_pad_ratio. Defaults to 0.2.
|
||||
|
||||
Returns:
|
||||
np.array: The cropped image.
|
||||
"""
|
||||
assert utils.is_type_list(box, (float, int))
|
||||
assert is_seq_of(box, (float, int))
|
||||
assert len(box) == 8
|
||||
assert 0. <= long_edge_pad_ratio < 1.0
|
||||
assert 0. <= short_edge_pad_ratio < 1.0
|
||||
@ -105,14 +88,14 @@ def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
|
||||
|
||||
box_width = np.max(points_x) - np.min(points_x)
|
||||
box_height = np.max(points_y) - np.min(points_y)
|
||||
font_size = min(box_height, box_width)
|
||||
shorter_size = min(box_height, box_width)
|
||||
|
||||
if box_height < box_width:
|
||||
horizontal_pad = long_edge_pad_ratio * font_size
|
||||
vertical_pad = short_edge_pad_ratio * font_size
|
||||
horizontal_pad = long_edge_pad_ratio * shorter_size
|
||||
vertical_pad = short_edge_pad_ratio * shorter_size
|
||||
else:
|
||||
horizontal_pad = short_edge_pad_ratio * font_size
|
||||
vertical_pad = long_edge_pad_ratio * font_size
|
||||
horizontal_pad = short_edge_pad_ratio * shorter_size
|
||||
vertical_pad = long_edge_pad_ratio * shorter_size
|
||||
|
||||
left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w)
|
||||
top = np.clip(int(np.min(points_y) - vertical_pad), 0, h)
|
@ -22,13 +22,13 @@ except ImportError:
|
||||
from mmocr.apis import init_detector
|
||||
from mmocr.apis.inference import model_inference
|
||||
from mmocr.core.visualize import det_recog_show_result
|
||||
from mmocr.datasets.kie_dataset import KIEDataset
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.datasets import WildReceiptDataset
|
||||
from mmocr.models.textdet.detectors import TextDetectorMixin
|
||||
from mmocr.models.textrecog.recognizers import BaseRecognizer
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import is_type_list, stitch_boxes_into_lines
|
||||
from mmocr.utils.fileio import list_from_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
from mmocr.utils.model import revert_sync_batchnorm
|
||||
|
||||
|
||||
@ -681,7 +681,7 @@ class MMOCR:
|
||||
bboxes_list = [res['boundary_result'] for res in det_result]
|
||||
|
||||
if kie_model:
|
||||
kie_dataset = KIEDataset(
|
||||
kie_dataset = WildReceiptDataset(
|
||||
dict_file=kie_model.cfg.data.test.dict_file)
|
||||
|
||||
# For each bounding box, the image is cropped and
|
||||
|
@ -1,128 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import math
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from mmocr.datasets.ocr_seg_dataset import OCRSegDataset
|
||||
|
||||
|
||||
def _create_dummy_ann_file(ann_file):
|
||||
ann_info1 = {
|
||||
'file_name':
|
||||
'sample1.png',
|
||||
'annotations': [{
|
||||
'char_text':
|
||||
'F',
|
||||
'char_box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]
|
||||
}, {
|
||||
'char_text':
|
||||
'r',
|
||||
'char_box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0]
|
||||
}, {
|
||||
'char_text':
|
||||
'o',
|
||||
'char_box': [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0]
|
||||
}, {
|
||||
'char_text':
|
||||
'm',
|
||||
'char_box': [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0]
|
||||
}, {
|
||||
'char_text':
|
||||
':',
|
||||
'char_box': [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0]
|
||||
}],
|
||||
'text':
|
||||
'From:'
|
||||
}
|
||||
ann_info2 = {
|
||||
'file_name':
|
||||
'sample2.png',
|
||||
'annotations': [{
|
||||
'char_text': 'o',
|
||||
'char_box': [0.0, 5.0, 7.0, 5.0, 9.0, 15.0, 2.0, 15.0]
|
||||
}, {
|
||||
'char_text':
|
||||
'u',
|
||||
'char_box': [7.0, 4.0, 14.0, 4.0, 18.0, 18.0, 11.0, 18.0]
|
||||
}, {
|
||||
'char_text':
|
||||
't',
|
||||
'char_box': [13.0, 1.0, 19.0, 2.0, 24.0, 18.0, 17.0, 18.0]
|
||||
}],
|
||||
'text':
|
||||
'out'
|
||||
}
|
||||
|
||||
with open(ann_file, 'w') as fw:
|
||||
for ann_info in [ann_info1, ann_info2]:
|
||||
fw.write(json.dumps(ann_info) + '\n')
|
||||
|
||||
return ann_info1, ann_info2
|
||||
|
||||
|
||||
def _create_dummy_loader():
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineJsonParser', keys=['file_name', 'text', 'annotations']))
|
||||
return loader
|
||||
|
||||
|
||||
def test_ocr_seg_dataset():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
|
||||
ann_info1, ann_info2 = _create_dummy_ann_file(ann_file)
|
||||
|
||||
# test initialization
|
||||
loader = _create_dummy_loader()
|
||||
dataset = OCRSegDataset(ann_file, loader, pipeline=[])
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
# test pre_pipeline
|
||||
img_info = dataset.data_infos[0]
|
||||
results = dict(img_info=img_info)
|
||||
dataset.pre_pipeline(results)
|
||||
assert results['img_prefix'] == dataset.img_prefix
|
||||
|
||||
# test _parse_anno_info
|
||||
annos = ann_info1['annotations']
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._parse_anno_info(annos[0])
|
||||
annos2 = ann_info2['annotations']
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._parse_anno_info([{'char_text': 'i'}])
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._parse_anno_info([{'char_box': [1, 2, 3, 4, 5, 6, 7, 8]}])
|
||||
annos2[0]['char_box'] = [1, 2, 3]
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._parse_anno_info(annos2)
|
||||
|
||||
return_anno = dataset._parse_anno_info(annos)
|
||||
assert return_anno['chars'] == ['F', 'r', 'o', 'm', ':']
|
||||
assert len(return_anno['char_rects']) == 5
|
||||
|
||||
# test prepare_train_img
|
||||
expect_results = {
|
||||
'img_info': {
|
||||
'filename': 'sample1.png'
|
||||
},
|
||||
'img_prefix': '',
|
||||
'ann_info': return_anno
|
||||
}
|
||||
data = dataset.prepare_train_img(0)
|
||||
assert data == expect_results
|
||||
|
||||
# test evluation
|
||||
metric = 'acc'
|
||||
results = [{'text': 'From:'}, {'text': 'ou'}]
|
||||
eval_res = dataset.evaluate(results, metric)
|
||||
|
||||
assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4)
|
||||
assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4)
|
||||
assert math.isclose(eval_res['char_recall'], 0.857, abs_tol=1e-4)
|
@ -1,94 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets
|
||||
|
||||
|
||||
def _create_dummy_dict_file(dict_file):
|
||||
chars = list('0123456789')
|
||||
with open(dict_file, 'w') as fw:
|
||||
for char in chars:
|
||||
fw.write(char + '\n')
|
||||
|
||||
|
||||
def test_ocr_segm_targets():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy dict file
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
_create_dummy_dict_file(dict_file)
|
||||
# dummy label convertor
|
||||
label_convertor = dict(
|
||||
type='SegConvertor',
|
||||
dict_file=dict_file,
|
||||
with_unknown=True,
|
||||
lower=True)
|
||||
# test init
|
||||
with pytest.raises(AssertionError):
|
||||
OCRSegTargets(None, 0.5, 0.5)
|
||||
with pytest.raises(AssertionError):
|
||||
OCRSegTargets(label_convertor, '1by2', 0.5)
|
||||
with pytest.raises(AssertionError):
|
||||
OCRSegTargets(label_convertor, 0.5, 2)
|
||||
|
||||
ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5)
|
||||
# test generate kernels
|
||||
img_size = (8, 8)
|
||||
pad_size = (8, 10)
|
||||
char_boxes = [[2, 2, 6, 6]]
|
||||
char_idxs = [2]
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5,
|
||||
True)
|
||||
with pytest.raises(AssertionError):
|
||||
ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6],
|
||||
char_idxs, 0.5, True)
|
||||
with pytest.raises(AssertionError):
|
||||
ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5,
|
||||
True)
|
||||
|
||||
attn_tgt = ocr_seg_tgt.generate_kernels(
|
||||
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True)
|
||||
expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
|
||||
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
|
||||
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
|
||||
assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32))
|
||||
|
||||
segm_tgt = ocr_seg_tgt.generate_kernels(
|
||||
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False)
|
||||
expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
|
||||
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
|
||||
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
|
||||
assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32))
|
||||
|
||||
# test __call__
|
||||
results = {}
|
||||
results['img_shape'] = (4, 4, 3)
|
||||
results['resize_shape'] = (8, 8, 3)
|
||||
results['pad_shape'] = (8, 10)
|
||||
results['ann_info'] = {}
|
||||
results['ann_info']['char_rects'] = [[1, 1, 3, 3]]
|
||||
results['ann_info']['chars'] = ['1']
|
||||
|
||||
results = ocr_seg_tgt(results)
|
||||
assert results['mask_fields'] == ['gt_kernels']
|
||||
assert np.allclose(results['gt_kernels'].masks[0],
|
||||
np.array(expect_attn_tgt, dtype=np.int32))
|
||||
assert np.allclose(results['gt_kernels'].masks[1],
|
||||
np.array(expect_segm_tgt, dtype=np.int32))
|
||||
|
||||
tmp_dir.cleanup()
|
@ -1,141 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import unittest.mock as mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
|
||||
import mmocr.datasets.pipelines.ocr_transforms as transforms
|
||||
|
||||
|
||||
def test_resize_ocr():
|
||||
input_img = np.ones((64, 256, 3), dtype=np.uint8)
|
||||
|
||||
rci = transforms.ResizeOCR(
|
||||
32, min_width=32, max_width=160, keep_aspect_ratio=True)
|
||||
results = {'img_shape': input_img.shape, 'img': input_img}
|
||||
|
||||
# test call
|
||||
results = rci(results)
|
||||
assert np.allclose([32, 160, 3], results['pad_shape'])
|
||||
assert np.allclose([32, 160, 3], results['img'].shape)
|
||||
assert 'valid_ratio' in results
|
||||
assert math.isclose(results['valid_ratio'], 0.8)
|
||||
assert math.isclose(np.sum(results['img'][:, 129:, :]), 0)
|
||||
|
||||
rci = transforms.ResizeOCR(
|
||||
32, min_width=32, max_width=160, keep_aspect_ratio=False)
|
||||
results = {'img_shape': input_img.shape, 'img': input_img}
|
||||
results = rci(results)
|
||||
assert math.isclose(results['valid_ratio'], 1)
|
||||
|
||||
|
||||
def test_to_tensor():
|
||||
input_img = np.ones((64, 256, 3), dtype=np.uint8)
|
||||
|
||||
expect_output = TF.to_tensor(input_img)
|
||||
rci = transforms.ToTensorOCR()
|
||||
|
||||
results = {'img': input_img}
|
||||
results = rci(results)
|
||||
|
||||
assert np.allclose(results['img'].numpy(), expect_output.numpy())
|
||||
|
||||
|
||||
def test_normalize():
|
||||
inputs = torch.zeros(3, 10, 10)
|
||||
|
||||
expect_output = torch.ones_like(inputs) * (-1)
|
||||
rci = transforms.NormalizeOCR(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
results = {'img': inputs}
|
||||
results = rci(results)
|
||||
|
||||
assert np.allclose(results['img'].numpy(), expect_output.numpy())
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.random' % __name__)
|
||||
def test_online_crop(mock_random):
|
||||
kwargs = dict(
|
||||
box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
|
||||
jitter_prob=0.5,
|
||||
max_jitter_ratio_x=0.05,
|
||||
max_jitter_ratio_y=0.02)
|
||||
|
||||
mock_random.side_effect = [0.1, 1, 1, 1]
|
||||
|
||||
src_img = np.ones((100, 100, 3), dtype=np.uint8)
|
||||
results = {
|
||||
'img': src_img,
|
||||
'img_info': {
|
||||
'x1': '20',
|
||||
'y1': '20',
|
||||
'x2': '40',
|
||||
'y2': '20',
|
||||
'x3': '40',
|
||||
'y3': '40',
|
||||
'x4': '20',
|
||||
'y4': '40'
|
||||
}
|
||||
}
|
||||
|
||||
rci = transforms.OnlineCropOCR(**kwargs)
|
||||
|
||||
results = rci(results)
|
||||
|
||||
assert np.allclose(results['img_shape'], [20, 20, 3])
|
||||
|
||||
# test not crop
|
||||
mock_random.side_effect = [0.1, 1, 1, 1]
|
||||
results['img_info'] = {}
|
||||
results['img'] = src_img
|
||||
|
||||
results = rci(results)
|
||||
assert np.allclose(results['img'].shape, [100, 100, 3])
|
||||
|
||||
|
||||
def test_fancy_pca():
|
||||
input_tensor = torch.rand(3, 32, 100)
|
||||
|
||||
rci = transforms.FancyPCA()
|
||||
|
||||
results = {'img': input_tensor}
|
||||
results = rci(results)
|
||||
|
||||
assert results['img'].shape == torch.Size([3, 32, 100])
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.uniform' % __name__)
|
||||
def test_random_padding(mock_random):
|
||||
kwargs = dict(max_ratio=[0.0, 0.0, 0.0, 0.0], box_type=None)
|
||||
|
||||
mock_random.side_effect = [1, 1, 1, 1]
|
||||
|
||||
src_img = np.ones((32, 100, 3), dtype=np.uint8)
|
||||
results = {'img': src_img, 'img_shape': (32, 100, 3)}
|
||||
|
||||
rci = transforms.RandomPaddingOCR(**kwargs)
|
||||
|
||||
results = rci(results)
|
||||
print(results['img'].shape)
|
||||
assert np.allclose(results['img_shape'], [96, 300, 3])
|
||||
|
||||
|
||||
def test_opencv2pil():
|
||||
src_img = np.ones((32, 100, 3), dtype=np.uint8)
|
||||
results = {'img': src_img}
|
||||
rci = transforms.OpencvToPil()
|
||||
|
||||
results = rci(results)
|
||||
assert np.allclose(results['img'].size, (100, 32))
|
||||
|
||||
|
||||
def test_pil2opencv():
|
||||
src_img = Image.new('RGB', (100, 32), color=(255, 255, 255))
|
||||
results = {'img': src_img}
|
||||
rci = transforms.PilToOpencv()
|
||||
|
||||
results = rci(results)
|
||||
assert np.allclose(results['img'].shape, (32, 100, 3))
|
@ -1,33 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.pipelines.textdet_targets.dbnet_targets import DBNetTargets
|
||||
|
||||
|
||||
def test_invalid_polys():
|
||||
|
||||
dbtarget = DBNetTargets()
|
||||
|
||||
poly = np.array([[256.1229216, 347.17471155], [257.63126133, 347.0069367],
|
||||
[257.70317729, 347.65337423],
|
||||
[256.19488113, 347.82114909]])
|
||||
|
||||
assert dbtarget.invalid_polygon(poly)
|
||||
|
||||
poly = np.array([[570.34735492,
|
||||
335.00214526], [570.99778839, 335.00327318],
|
||||
[569.69077318, 338.47009908],
|
||||
[569.04038393, 338.46894904]])
|
||||
assert dbtarget.invalid_polygon(poly)
|
||||
|
||||
poly = np.array([[481.18343777,
|
||||
305.03190065], [479.88478587, 305.10684512],
|
||||
[479.90976971, 305.53968843], [480.99197962,
|
||||
305.4772347]])
|
||||
assert dbtarget.invalid_polygon(poly)
|
||||
|
||||
poly = np.array([[0, 0], [2, 0], [2, 2], [0, 2]])
|
||||
assert dbtarget.invalid_polygon(poly)
|
||||
|
||||
poly = np.array([[0, 0], [10, 0], [10, 10], [0, 10]])
|
||||
assert not dbtarget.invalid_polygon(poly)
|
@ -5,8 +5,8 @@ from unittest import TestCase
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmocr.datasets.pipelines import (PackKIEInputs, PackTextDetInputs,
|
||||
PackTextRecogInputs)
|
||||
from mmocr.datasets.transforms import (PackKIEInputs, PackTextDetInputs,
|
||||
PackTextRecogInputs)
|
||||
|
||||
|
||||
class TestPackTextDetInputs(TestCase):
|
@ -4,7 +4,7 @@ from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.pipelines import LoadKIEAnnotations, LoadOCRAnnotations
|
||||
from mmocr.datasets.transforms import LoadKIEAnnotations, LoadOCRAnnotations
|
||||
|
||||
|
||||
class TestLoadOCRAnnotations(TestCase):
|
208
tests/test_datasets/test_transforms/test_ocr_transforms.py
Normal file
208
tests/test_datasets/test_transforms/test_ocr_transforms.py
Normal file
@ -0,0 +1,208 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.transforms import RandomCrop, RandomRotate, Resize
|
||||
|
||||
|
||||
class TestRandomCrop(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
img = np.zeros((30, 30, 3))
|
||||
gt_polygons = [
|
||||
np.array([5., 5., 25., 5., 25., 10., 5., 10.]),
|
||||
np.array([5., 20., 25., 20., 25., 25., 5., 25.])
|
||||
]
|
||||
gt_bboxes = np.array([[5, 5, 25, 10], [5, 20, 25, 25]])
|
||||
labels = np.array([0, 1])
|
||||
gt_ignored = np.array([True, False], dtype=bool)
|
||||
texts = ['text1', 'text2']
|
||||
self.data_info = dict(
|
||||
img=img,
|
||||
gt_polygons=gt_polygons,
|
||||
gt_bboxes=gt_bboxes,
|
||||
gt_bboxes_labels=labels,
|
||||
gt_ignored=gt_ignored,
|
||||
gt_texts=texts)
|
||||
|
||||
@mock.patch('mmocr.datasets.transforms.ocr_transforms.np.random.randint')
|
||||
def test_sample_crop_box(self, mock_randint):
|
||||
|
||||
def rand_min(low, high):
|
||||
return low
|
||||
|
||||
trans = RandomCrop(min_side_ratio=0.3)
|
||||
mock_randint.side_effect = rand_min
|
||||
crop_box = trans._sample_crop_box((30, 30), self.data_info.copy())
|
||||
assert np.allclose(np.array(crop_box), np.array([0, 0, 25, 10]))
|
||||
|
||||
def rand_max(low, high):
|
||||
return high - 1
|
||||
|
||||
mock_randint.side_effect = rand_max
|
||||
crop_box = trans._sample_crop_box((30, 30), self.data_info.copy())
|
||||
assert np.allclose(np.array(crop_box), np.array([4, 19, 30, 30]))
|
||||
|
||||
@mock.patch('mmocr.datasets.transforms.ocr_transforms.np.random.randint')
|
||||
def test_transform(self, mock_randint):
|
||||
|
||||
def rand_min(low, high):
|
||||
return low
|
||||
|
||||
# mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 0, 0, 15]
|
||||
mock_randint.side_effect = rand_min
|
||||
trans = RandomCrop(min_side_ratio=0.3)
|
||||
polygon_target = np.array([5., 5., 25., 5., 25., 10., 5., 10.])
|
||||
bbox_target = np.array([[5., 5., 25., 10.]])
|
||||
results = trans(self.data_info)
|
||||
|
||||
self.assertEqual(results['img'].shape, (10, 25, 3))
|
||||
self.assertEqual(results['img_shape'], (10, 25))
|
||||
self.assertTrue(np.allclose(results['gt_bboxes'], bbox_target))
|
||||
self.assertEqual(results['gt_bboxes'].shape, (1, 4))
|
||||
self.assertEqual(len(results['gt_polygons']), 1)
|
||||
self.assertTrue(np.allclose(results['gt_polygons'][0], polygon_target))
|
||||
self.assertEqual(results['gt_bboxes_labels'][0], 0)
|
||||
self.assertEqual(results['gt_ignored'][0], True)
|
||||
self.assertEqual(results['gt_texts'][0], 'text1')
|
||||
|
||||
def rand_max(low, high):
|
||||
return high - 1
|
||||
|
||||
mock_randint.side_effect = rand_max
|
||||
trans = RandomCrop(min_side_ratio=0.3)
|
||||
polygon_target = np.array([1, 1, 21, 1, 21, 6, 1, 6])
|
||||
bbox_target = np.array([[1, 1, 21, 6]])
|
||||
results = trans(self.data_info)
|
||||
|
||||
self.assertEqual(results['img'].shape, (6, 21, 3))
|
||||
self.assertEqual(results['img_shape'], (6, 21))
|
||||
self.assertTrue(np.allclose(results['gt_bboxes'], bbox_target))
|
||||
self.assertEqual(results['gt_bboxes'].shape, (1, 4))
|
||||
self.assertEqual(len(results['gt_polygons']), 1)
|
||||
self.assertTrue(np.allclose(results['gt_polygons'][0], polygon_target))
|
||||
self.assertEqual(results['gt_bboxes_labels'][0], 0)
|
||||
self.assertTrue(results['gt_ignored'][0])
|
||||
self.assertEqual(results['gt_texts'][0], 'text1')
|
||||
|
||||
def test_repr(self):
|
||||
transform = RandomCrop(min_side_ratio=0.4)
|
||||
self.assertEqual(repr(transform), ('RandomCrop(min_side_ratio = 0.4)'))
|
||||
|
||||
|
||||
class TestRandomRotate(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
img = np.random.random((5, 5))
|
||||
self.data_info1 = dict(img=img.copy(), img_shape=img.shape[:2])
|
||||
self.data_info2 = dict(
|
||||
img=np.random.random((30, 30, 3)),
|
||||
gt_bboxes=np.array([[10, 10, 20, 20], [5, 5, 10, 10]]),
|
||||
img_shape=(30, 30))
|
||||
self.data_info3 = dict(
|
||||
img=np.random.random((30, 30, 3)),
|
||||
gt_polygons=[np.array([10., 10., 20., 10., 20., 20., 10., 20.])],
|
||||
img_shape=(30, 30))
|
||||
|
||||
def test_init(self):
|
||||
# max angle is float
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'`max_angle` should be an integer'):
|
||||
RandomRotate(max_angle=16.8)
|
||||
# invalid pad value
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, '`pad_value` should contain three integers'):
|
||||
RandomRotate(pad_value=[16.8, 0.1])
|
||||
|
||||
def test_transform(self):
|
||||
self._test_recog()
|
||||
self._test_bboxes()
|
||||
self._test_polygons()
|
||||
|
||||
def _test_recog(self):
|
||||
# test random rotate for recognition (image only) input
|
||||
transform = RandomRotate(max_angle=10)
|
||||
results = transform(copy.deepcopy(self.data_info1))
|
||||
self.assertTrue(np.allclose(results['img'], self.data_info1['img']))
|
||||
|
||||
@mock.patch(
|
||||
'mmocr.datasets.transforms.ocr_transforms.np.random.random_sample')
|
||||
def _test_bboxes(self, mock_sample):
|
||||
# test random rotate for bboxes
|
||||
# returns 1. for random_sample() in _sample_angle(), i.e., angle = 90
|
||||
mock_sample.side_effect = [1.]
|
||||
transform = RandomRotate(max_angle=90, use_canvas=True)
|
||||
results = transform(copy.deepcopy(self.data_info2))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_bboxes'][0], np.array([10, 10, 20, 20])))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_bboxes'][1], np.array([5, 20, 10, 25])))
|
||||
self.assertEqual(results['img'].shape, self.data_info2['img'].shape)
|
||||
|
||||
@mock.patch(
|
||||
'mmocr.datasets.transforms.ocr_transforms.np.random.random_sample')
|
||||
def _test_polygons(self, mock_sample):
|
||||
# test random rotate for polygons
|
||||
# returns 1. for random_sample() in _sample_angle(), i.e., angle = 90
|
||||
mock_sample.side_effect = [1.]
|
||||
transform = RandomRotate(max_angle=90, use_canvas=True)
|
||||
results = transform(copy.deepcopy(self.data_info3))
|
||||
self.assertTrue(
|
||||
np.allclose(results['gt_polygons'][0],
|
||||
np.array([10., 20., 10., 10., 20., 10., 20., 20.])))
|
||||
self.assertEqual(results['img'].shape, self.data_info3['img'].shape)
|
||||
|
||||
def test_repr(self):
|
||||
transform = RandomRotate(
|
||||
max_angle=10,
|
||||
pad_with_fixed_color=False,
|
||||
pad_value=(0, 0, 0),
|
||||
use_canvas=False)
|
||||
self.assertEqual(
|
||||
repr(transform),
|
||||
('RandomRotate(max_angle = 10, '
|
||||
'pad_with_fixed_color = False, pad_value = (0, 0, 0), '
|
||||
'use_canvas = False)'))
|
||||
|
||||
|
||||
class TestResize(unittest.TestCase):
|
||||
|
||||
def test_resize_wo_img(self):
|
||||
# keep_ratio = True
|
||||
dummy_result = dict(img_shape=(10, 20))
|
||||
resize = Resize(scale=(40, 30), keep_ratio=True)
|
||||
result = resize(dummy_result)
|
||||
self.assertEqual(result['img_shape'], (20, 40))
|
||||
self.assertEqual(result['scale'], (40, 20))
|
||||
self.assertEqual(result['scale_factor'], (2., 2.))
|
||||
self.assertEqual(result['keep_ratio'], True)
|
||||
|
||||
# keep_ratio = False
|
||||
dummy_result = dict(img_shape=(10, 20))
|
||||
resize = Resize(scale=(40, 30), keep_ratio=False)
|
||||
result = resize(dummy_result)
|
||||
self.assertEqual(result['img_shape'], (30, 40))
|
||||
self.assertEqual(result['scale'], (40, 30))
|
||||
self.assertEqual(result['scale_factor'], (
|
||||
2.,
|
||||
3.,
|
||||
))
|
||||
self.assertEqual(result['keep_ratio'], False)
|
||||
|
||||
def test_resize_bbox(self):
|
||||
# keep_ratio = True
|
||||
dummy_result = dict(
|
||||
img_shape=(10, 20),
|
||||
gt_bboxes=np.array([[0, 0, 1, 1]], dtype=np.float32))
|
||||
resize = Resize(scale=(40, 30))
|
||||
result = resize(dummy_result)
|
||||
self.assertEqual(result['gt_bboxes'].dtype, np.float32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
t = TestRandomCrop()
|
||||
t.test_sample_crop_box()
|
||||
t.test_transform()
|
File diff suppressed because it is too large
Load Diff
127
tests/test_datasets/test_transforms/test_textrecog_transforms.py
Normal file
127
tests/test_datasets/test_transforms/test_textrecog_transforms.py
Normal file
@ -0,0 +1,127 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.transforms import (PadToWidth, PyramidRescale,
|
||||
RescaleToHeight)
|
||||
|
||||
|
||||
class TestPadToWidth(unittest.TestCase):
|
||||
|
||||
def test_pad_to_width(self):
|
||||
data_info = dict(img=np.random.random((16, 25, 3)))
|
||||
# test size and size_divisor are both set
|
||||
with self.assertRaises(AssertionError):
|
||||
PadToWidth(width=10.5)
|
||||
|
||||
transform = PadToWidth(width=100)
|
||||
results = transform(copy.deepcopy(data_info))
|
||||
self.assertTupleEqual(results['img'].shape[:2], (16, 100))
|
||||
self.assertEqual(results['valid_ratio'], 25 / 100)
|
||||
|
||||
def test_repr(self):
|
||||
transform = PadToWidth(width=100)
|
||||
self.assertEqual(
|
||||
repr(transform),
|
||||
("PadToWidth(width=100, pad_cfg={'type': 'Pad'})"))
|
||||
|
||||
|
||||
class TestPyramidRescale(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.data_info = dict(img=np.random.random((128, 100, 3)))
|
||||
|
||||
def test_init(self):
|
||||
# factor is int
|
||||
transform = PyramidRescale(factor=4, randomize_factor=False)
|
||||
self.assertEqual(transform.factor, 4)
|
||||
# factor is float
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'`factor` should be an integer'):
|
||||
PyramidRescale(factor=4.0)
|
||||
# invalid base_shape
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'`base_shape` should be a list or tuple'):
|
||||
PyramidRescale(base_shape=128)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, '`base_shape` should contain two integers'):
|
||||
PyramidRescale(base_shape=(128, ))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, '`base_shape` should contain two integers'):
|
||||
PyramidRescale(base_shape=(128.0, 2.0))
|
||||
# invalid randomize_factor
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'`randomize_factor` should be a bool'):
|
||||
PyramidRescale(randomize_factor=None)
|
||||
|
||||
def test_transform(self):
|
||||
# test if the rescale keeps the original size
|
||||
transform = PyramidRescale()
|
||||
results = transform(copy.deepcopy(self.data_info))
|
||||
self.assertEqual(results['img'].shape, (128, 100, 3))
|
||||
# test factor = 0
|
||||
transform = PyramidRescale(factor=0, randomize_factor=False)
|
||||
results = transform(copy.deepcopy(self.data_info))
|
||||
self.assertTrue(np.all(results['img'] == self.data_info['img']))
|
||||
|
||||
def test_repr(self):
|
||||
transform = PyramidRescale(
|
||||
factor=4, base_shape=(128, 512), randomize_factor=False)
|
||||
self.assertEqual(
|
||||
repr(transform),
|
||||
('PyramidRescale(factor = 4, randomize_factor = False, '
|
||||
'base_w = 128, base_h = 512)'))
|
||||
|
||||
|
||||
class TestRescaleToHeight(unittest.TestCase):
|
||||
|
||||
def test_rescale_height(self):
|
||||
data_info = dict(
|
||||
img=np.random.random((16, 25, 3)),
|
||||
gt_seg_map=np.random.random((16, 25, 3)),
|
||||
gt_bboxes=np.array([[0, 0, 10, 10]]),
|
||||
gt_keypoints=np.array([[[10, 10, 1]]]))
|
||||
with self.assertRaises(AssertionError):
|
||||
RescaleToHeight(height=20.9)
|
||||
with self.assertRaises(AssertionError):
|
||||
RescaleToHeight(height=20, min_width=20.9)
|
||||
with self.assertRaises(AssertionError):
|
||||
RescaleToHeight(height=20, max_width=20.9)
|
||||
with self.assertRaises(AssertionError):
|
||||
RescaleToHeight(height=20, width_divisor=0.5)
|
||||
transform = RescaleToHeight(height=32)
|
||||
results = transform(copy.deepcopy(data_info))
|
||||
self.assertTupleEqual(results['img'].shape[:2], (32, 50))
|
||||
self.assertTupleEqual(results['scale'], (50, 32))
|
||||
self.assertTupleEqual(results['scale_factor'], (50 / 25, 32 / 16))
|
||||
|
||||
# test min_width
|
||||
transform = RescaleToHeight(height=32, min_width=60)
|
||||
results = transform(copy.deepcopy(data_info))
|
||||
self.assertTupleEqual(results['img'].shape[:2], (32, 60))
|
||||
self.assertTupleEqual(results['scale'], (60, 32))
|
||||
self.assertTupleEqual(results['scale_factor'], (60 / 25, 32 / 16))
|
||||
|
||||
# test max_width
|
||||
transform = RescaleToHeight(height=32, max_width=45)
|
||||
results = transform(copy.deepcopy(data_info))
|
||||
self.assertTupleEqual(results['img'].shape[:2], (32, 45))
|
||||
self.assertTupleEqual(results['scale'], (45, 32))
|
||||
self.assertTupleEqual(results['scale_factor'], (45 / 25, 32 / 16))
|
||||
|
||||
# test width_divisor
|
||||
transform = RescaleToHeight(height=32, width_divisor=4)
|
||||
results = transform(copy.deepcopy(data_info))
|
||||
self.assertTupleEqual(results['img'].shape[:2], (32, 48))
|
||||
self.assertTupleEqual(results['scale'], (48, 32))
|
||||
self.assertTupleEqual(results['scale_factor'], (48 / 25, 32 / 16))
|
||||
|
||||
def test_repr(self):
|
||||
transform = RescaleToHeight(height=32)
|
||||
self.assertEqual(
|
||||
repr(transform), ('RescaleToHeight(height=32, '
|
||||
'min_width=None, max_width=None, '
|
||||
'width_divisor=1, '
|
||||
"resize_cfg={'type': 'Resize'})"))
|
@ -6,16 +6,16 @@ from typing import Dict, List, Optional
|
||||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from mmocr.datasets.pipelines import ImgAug, TorchVisionWrapper
|
||||
from mmocr.datasets.transforms import ImgAugWrapper, TorchVisionWrapper
|
||||
|
||||
|
||||
class TestImgAug(unittest.TestCase):
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
ImgAug(args=[])
|
||||
ImgAugWrapper(args=[])
|
||||
with self.assertRaises(AssertionError):
|
||||
ImgAug(args=['test'])
|
||||
ImgAugWrapper(args=['test'])
|
||||
|
||||
def _create_dummy_data(self):
|
||||
img = np.random.rand(50, 50, 3)
|
||||
@ -61,7 +61,7 @@ class TestImgAug(unittest.TestCase):
|
||||
def test_transform(self):
|
||||
|
||||
# Test empty transform
|
||||
imgaug_transform = ImgAug()
|
||||
imgaug_transform = ImgAugWrapper()
|
||||
results = self._create_dummy_data()
|
||||
origin_results = copy.deepcopy(results)
|
||||
results = imgaug_transform(results)
|
||||
@ -72,7 +72,7 @@ class TestImgAug(unittest.TestCase):
|
||||
origin_results['gt_texts'])
|
||||
|
||||
args = [dict(cls='Affine', translate_px=dict(x=-10, y=-10))]
|
||||
imgaug_transform = ImgAug(args)
|
||||
imgaug_transform = ImgAugWrapper(args)
|
||||
results = self._create_dummy_data()
|
||||
results = imgaug_transform(results)
|
||||
|
||||
@ -99,7 +99,7 @@ class TestImgAug(unittest.TestCase):
|
||||
label_target = np.array([0], dtype=np.int64)
|
||||
ignored = np.array([False], dtype=bool)
|
||||
texts = ['text1']
|
||||
imgaug_transform = ImgAug(args)
|
||||
imgaug_transform = ImgAugWrapper(args)
|
||||
results = self._create_dummy_data()
|
||||
results = imgaug_transform(results)
|
||||
self.assert_result_equal(results, poly_target, box_target,
|
||||
@ -111,7 +111,7 @@ class TestImgAug(unittest.TestCase):
|
||||
# When some transforms result in empty polygons
|
||||
args = [dict(cls='Affine', translate_px=dict(x=100, y=100))]
|
||||
results = self._create_dummy_data()
|
||||
invalid_transform = ImgAug(args)
|
||||
invalid_transform = ImgAugWrapper(args)
|
||||
results = invalid_transform(results)
|
||||
self.assertIsNone(results)
|
||||
|
||||
@ -131,11 +131,12 @@ class TestImgAug(unittest.TestCase):
|
||||
|
||||
def test_repr(self):
|
||||
args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]]
|
||||
transform = ImgAug(args)
|
||||
transform = ImgAugWrapper(args)
|
||||
print(repr(transform))
|
||||
self.assertEqual(
|
||||
repr(transform),
|
||||
("ImgAug(args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]])"))
|
||||
("ImgAugWrapper(args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]])"
|
||||
))
|
||||
|
||||
|
||||
class TestTorchVisionWrapper(unittest.TestCase):
|
@ -8,8 +8,8 @@ from functools import partial
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img, warp_img
|
||||
from mmocr.utils import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img, warp_img
|
||||
|
||||
|
||||
def parse_labelme_json(json_file,
|
||||
|
@ -6,8 +6,8 @@ import os.path as osp
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir):
|
||||
|
@ -7,8 +7,8 @@ import os.path as osp
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir):
|
||||
|
@ -7,8 +7,8 @@ import os.path as osp
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir):
|
||||
|
@ -7,8 +7,8 @@ import xml.etree.ElementTree as ET
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir):
|
||||
|
@ -8,8 +8,8 @@ import os.path as osp
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -8,8 +8,8 @@ import xml.etree.ElementTree as ET
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir, ratio):
|
||||
|
@ -9,8 +9,8 @@ import cv2
|
||||
import mmcv
|
||||
from PIL import Image
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir, ratio):
|
||||
|
@ -6,8 +6,8 @@ import os.path as osp
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir, split_info):
|
||||
|
@ -7,8 +7,8 @@ import os.path as osp
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir, ratio):
|
||||
|
@ -7,8 +7,8 @@ import os.path as osp
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir, ratio):
|
||||
|
@ -7,8 +7,8 @@ import os.path as osp
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir):
|
||||
|
@ -11,8 +11,8 @@ import scipy.io as scio
|
||||
import yaml
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir):
|
||||
|
@ -6,8 +6,8 @@ import os.path as osp
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir):
|
||||
|
Loading…
x
Reference in New Issue
Block a user