diff --git a/configs/textdet/fcenet/README.MD b/configs/textdet/fcenet/README.MD new file mode 100644 index 00000000..5e111a33 --- /dev/null +++ b/configs/textdet/fcenet/README.MD @@ -0,0 +1,22 @@ +# Fourier Contour Embedding for Arbitrary-Shaped Text Detection + +## Introduction + +[ALGORITHM] + +```bibtex +@InProceedings{zhu2021fourier, + title={Fourier Contour Embedding for Arbitrary-Shaped Text Detection}, + author={Yiqin Zhu and Jianyong Chen and Lingyu Liang and Zhanghui Kuang and Lianwen Jin and Wayne Zhang}, + year={2021}, + booktitle = {CVPR} + } +``` + +## Results and models + +### CTW1500 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :--------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [FCENet](/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1500 |(736, 1080)| 0.828 | 0.875 | 0.851 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/20210511_181328.log.json) | diff --git a/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py b/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py new file mode 100644 index 00000000..5c5e048d --- /dev/null +++ b/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py @@ -0,0 +1,134 @@ +fourier_degree = 5 +model = dict( + type='FCENet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + dcn=dict(type='DCNv2', deform_groups=2, fallback_on_stride=False), + stage_with_dcn=(False, True, True, True)), + neck=dict( + type='FPN', + in_channels=[512, 1024, 2048], + out_channels=256, + add_extra_convs=True, + extra_convs_on_inputs=False, # use P5 + num_outs=3, + relu_before_extra_convs=True, + act_cfg=None), + bbox_head=dict( + type='FCEHead', + in_channels=256, + scales=(8, 16, 32), + loss=dict(type='FCELoss'), + fourier_degree=fourier_degree, + )) + +train_cfg = None +test_cfg = None + +dataset_type = 'IcdarDataset' +data_root = 'data/ctw1500/' + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict( + type='ColorJitter', + brightness=32.0 / 255, + saturation=0.5, + contrast=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='RandomScaling', size=800, scale=(3. / 4, 5. / 2)), + dict( + type='RandomCropFlip', crop_ratio=0.5, iter_num=1, min_area_ratio=0.2), + dict( + type='RandomCropPolyInstances', + instance_key='gt_masks', + crop_ratio=0.8, + min_side_ratio=0.3), + dict( + type='RandomRotatePolyInstances', + rotate_ratio=0.5, + max_angle=30, + pad_with_fixed_color=False), + dict(type='SquareResizePad', target_size=800, pad_ratio=0.6), + dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), + dict(type='Pad', size_divisor=32), + dict( + type='FCENetTargets', + fourier_degree=fourier_degree, + level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0))), + dict( + type='CustomFormatBundle', + keys=['p3_maps', 'p4_maps', 'p5_maps'], + visualize=dict(flag=False, boundary_key=None)), + dict(type='Collect', keys=['img', 'p3_maps', 'p4_maps', 'p5_maps']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1080, 736), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(1280, 800), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=6, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) +evaluation = dict(interval=5, metric='hmean-iou') + +# optimizer +optimizer = dict(type='SGD', lr=1e-3, momentum=0.90, weight_decay=5e-4) +optimizer_config = dict(grad_clip=None) +lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True) +total_epochs = 1500 + +checkpoint_config = dict(interval=5) +# yapf:disable +log_config = dict( + interval=20, + hooks=[ + dict(type='TextLoggerHook') + + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmocr/datasets/__init__.py b/mmocr/datasets/__init__.py index 7e3fa16b..87baba26 100644 --- a/mmocr/datasets/__init__.py +++ b/mmocr/datasets/__init__.py @@ -5,7 +5,7 @@ from .icdar_dataset import IcdarDataset from .kie_dataset import KIEDataset from .ocr_dataset import OCRDataset from .ocr_seg_dataset import OCRSegDataset -from .pipelines import CustomFormatBundle, DBNetTargets +from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets from .text_det_dataset import TextDetDataset from .utils import * # NOQA @@ -13,7 +13,7 @@ from .utils import * # NOQA __all__ = [ 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset', 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle', - 'DBNetTargets', 'OCRSegDataset', 'KIEDataset' + 'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets' ] __all__ += utils.__all__ diff --git a/mmocr/datasets/pipelines/__init__.py b/mmocr/datasets/pipelines/__init__.py index cabdd94e..7fa4c924 100644 --- a/mmocr/datasets/pipelines/__init__.py +++ b/mmocr/datasets/pipelines/__init__.py @@ -8,10 +8,11 @@ from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR, OpencvToPil, PilToOpencv, RandomPaddingOCR, RandomRotateImageBox, ResizeOCR, ToTensorOCR) from .test_time_aug import MultiRotateAugOCR -from .textdet_targets import DBNetTargets, PANetTargets, TextSnakeTargets -from .transforms import (ColorJitter, RandomCropInstances, +from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets, + TextSnakeTargets) +from .transforms import (ColorJitter, RandomCropFlip, RandomCropInstances, RandomCropPolyInstances, RandomRotatePolyInstances, - RandomRotateTextDet, ScaleAspectJitter, + RandomRotateTextDet, RandomScaling, ScaleAspectJitter, SquareResizePad) __all__ = [ @@ -22,5 +23,6 @@ __all__ = [ 'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR', 'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil', 'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets', - 'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8' + 'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets', + 'RandomScaling', 'RandomCropFlip' ] diff --git a/mmocr/datasets/pipelines/textdet_targets/__init__.py b/mmocr/datasets/pipelines/textdet_targets/__init__.py index 1565e924..b5313ffb 100644 --- a/mmocr/datasets/pipelines/textdet_targets/__init__.py +++ b/mmocr/datasets/pipelines/textdet_targets/__init__.py @@ -1,10 +1,11 @@ from .base_textdet_targets import BaseTextDetTargets from .dbnet_targets import DBNetTargets +from .fcenet_targets import FCENetTargets from .panet_targets import PANetTargets from .psenet_targets import PSENetTargets from .textsnake_targets import TextSnakeTargets __all__ = [ 'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets', - 'TextSnakeTargets' + 'FCENetTargets', 'TextSnakeTargets' ] diff --git a/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py b/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py new file mode 100644 index 00000000..2feea45d --- /dev/null +++ b/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py @@ -0,0 +1,370 @@ +import cv2 +import numpy as np +from numpy.linalg import norm + +import mmocr.utils.check_argument as check_argument +from mmdet.datasets.builder import PIPELINES +from .textsnake_targets import TextSnakeTargets + + +@PIPELINES.register_module() +class FCENetTargets(TextSnakeTargets): + """Generate the ground truth targets of FCENet: Fourier Contour Embedding + for Arbitrary-Shaped Text Detection. + + [https://arxiv.org/abs/2104.10442] + + Args: + fourier_degree (int): The maximum Fourier transform degree k. + resample_step (float): The step size for resampling the text center + line (TCL). It's better not to exceed half of the minimum width. + center_region_shrink_ratio (float): The shrink ratio of text center + region. + level_size_divisors (tuple(int)): The downsample ratio on each level. + level_proportion_range (tuple(tuple(int))): The range of text sizes + assigned to each level. + """ + + def __init__(self, + fourier_degree=5, + resample_step=4.0, + center_region_shrink_ratio=0.3, + level_size_divisors=(8, 16, 32), + level_proportion_range=((0, 0.4), (0.3, 0.7), (0.6, 1.0))): + + super().__init__() + assert isinstance(level_size_divisors, tuple) + assert isinstance(level_proportion_range, tuple) + assert len(level_size_divisors) == len(level_proportion_range) + self.fourier_degree = fourier_degree + self.resample_step = resample_step + self.center_region_shrink_ratio = center_region_shrink_ratio + self.level_size_divisors = level_size_divisors + self.level_proportion_range = level_proportion_range + + def generate_center_region_mask(self, img_size, text_polys): + """Generate text center region mask. + + Args: + img_size (tuple): The image size of (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + center_region_mask (ndarray): The text center region mask. + """ + + assert isinstance(img_size, tuple) + assert check_argument.is_2dlist(text_polys) + + h, w = img_size + + center_region_mask = np.zeros((h, w), np.uint8) + + center_region_boxes = [] + for poly in text_polys: + assert len(poly) == 1 + polygon_points = poly[0].reshape(-1, 2) + _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points) + resampled_top_line, resampled_bot_line = self.resample_sidelines( + top_line, bot_line, self.resample_step) + resampled_bot_line = resampled_bot_line[::-1] + center_line = (resampled_top_line + resampled_bot_line) / 2 + + line_head_shrink_len = norm(resampled_top_line[0] - + resampled_bot_line[0]) / 4.0 + line_tail_shrink_len = norm(resampled_top_line[-1] - + resampled_bot_line[-1]) / 4.0 + head_shrink_num = int(line_head_shrink_len // self.resample_step) + tail_shrink_num = int(line_tail_shrink_len // self.resample_step) + if len(center_line) > head_shrink_num + tail_shrink_num + 2: + center_line = center_line[head_shrink_num:len(center_line) - + tail_shrink_num] + resampled_top_line = resampled_top_line[ + head_shrink_num:len(resampled_top_line) - tail_shrink_num] + resampled_bot_line = resampled_bot_line[ + head_shrink_num:len(resampled_bot_line) - tail_shrink_num] + + for i in range(0, len(center_line) - 1): + tl = center_line[i] + (resampled_top_line[i] - center_line[i] + ) * self.center_region_shrink_ratio + tr = center_line[i + 1] + ( + resampled_top_line[i + 1] - + center_line[i + 1]) * self.center_region_shrink_ratio + br = center_line[i + 1] + ( + resampled_bot_line[i + 1] - + center_line[i + 1]) * self.center_region_shrink_ratio + bl = center_line[i] + (resampled_bot_line[i] - center_line[i] + ) * self.center_region_shrink_ratio + current_center_box = np.vstack([tl, tr, br, + bl]).astype(np.int32) + center_region_boxes.append(current_center_box) + + cv2.fillPoly(center_region_mask, center_region_boxes, 1) + return center_region_mask + + def resample_polygon(self, polygon, n=400): + """Resample one polygon with n points on its boundary. + + Args: + polygon (list[float]): The input polygon. + n (int): The number of resampled points. + Returns: + resampled_polygon (list[float]): The resampled polygon. + """ + length = [] + + for i in range(len(polygon)): + p1 = polygon[i] + if i == len(polygon) - 1: + p2 = polygon[0] + else: + p2 = polygon[i + 1] + length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5) + + total_length = sum(length) + n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n + n_on_each_line = n_on_each_line.astype(np.int32) + new_polygon = [] + + for i in range(len(polygon)): + num = n_on_each_line[i] + p1 = polygon[i] + if i == len(polygon) - 1: + p2 = polygon[0] + else: + p2 = polygon[i + 1] + + if num == 0: + continue + + dxdy = (p2 - p1) / num + for j in range(num): + point = p1 + dxdy * j + new_polygon.append(point) + + return np.array(new_polygon) + + def normalize_polygon(self, polygon): + """Normalize one polygon so that its start point is at right most. + + Args: + polygon (list[float]): The origin polygon. + Returns: + new_polygon (lost[float]): The polygon with start point at right. + """ + temp_polygon = polygon - polygon.mean(axis=0) + x = np.abs(temp_polygon[:, 0]) + y = temp_polygon[:, 1] + index_x = np.argsort(x) + index_y = np.argmin(y[index_x[:8]]) + index = index_x[index_y] + new_polygon = np.concatenate([polygon[index:], polygon[:index]]) + return new_polygon + + def fourier_transform(self, polygon, fourier_degree): + """Perform Fourier transformation to generate Fourier coefficients ck + from polygon. + + Args: + polygon (ndarray): An input polygon. + fourier_degree (int): The maximum Fourier degree K. + Returns: + c (ndarray(complex)): Fourier coefficients. + """ + points = polygon[:, 0] + polygon[:, 1] * 1j + n = len(points) + t = np.multiply([i / n for i in range(n)], -2 * np.pi * 1j) + + e = complex(np.e) + c = np.zeros((2 * fourier_degree + 1, ), dtype='complex') + + for i in range(-fourier_degree, fourier_degree + 1): + c[i + fourier_degree] = np.sum(points * np.power(e, i * t)) / n + + return c + + def clockwise(self, c, fourier_degree): + """Make sure the polygon reconstructed from Fourier coefficients c in + the clockwise direction. + + Args: + polygon (list[float]): The origin polygon. + Returns: + new_polygon (lost[float]): The polygon in clockwise point order. + """ + if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]): + return c + elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]): + return c[::-1] + else: + if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]): + return c + else: + return c[::-1] + + def cal_fourier_signature(self, polygon, fourier_degree): + """Calculate Fourier signature from input polygon. + + Args: + polygon (ndarray): The input polygon. + fourier_degree (int): The maximum Fourier degree K. + Returns: + fourier_signature (ndarray): An array shaped (2k+1, 2) containing + real part and image part of 2k+1 Fourier coefficients. + """ + resampled_polygon = self.resample_polygon(polygon) + resampled_polygon = self.normalize_polygon(resampled_polygon) + + fourier_coeff = self.fourier_transform(resampled_polygon, + fourier_degree) + fourier_coeff = self.clockwise(fourier_coeff, fourier_degree) + + real_part = np.real(fourier_coeff).reshape((-1, 1)) + image_part = np.imag(fourier_coeff).reshape((-1, 1)) + fourier_signature = np.hstack([real_part, image_part]) + + return fourier_signature + + def generate_fourier_maps(self, img_size, text_polys): + """Generate Fourier coefficient maps. + + Args: + img_size (tuple): The image size of (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + fourier_real_map (ndarray): The Fourier coefficient real part maps. + fourier_image_map (ndarray): The Fourier coefficient image part + maps. + """ + + assert isinstance(img_size, tuple) + assert check_argument.is_2dlist(text_polys) + + h, w = img_size + k = self.fourier_degree + real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32) + imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32) + + for poly in text_polys: + assert len(poly) == 1 + text_instance = [[poly[0][i], poly[0][i + 1]] + for i in range(0, len(poly[0]), 2)] + mask = np.zeros((h, w), dtype=np.uint8) + polygon = np.array(text_instance).reshape((1, -1, 2)) + cv2.fillPoly(mask, polygon.astype(np.int32), 1) + fourier_coeff = self.cal_fourier_signature(polygon[0], k) + for i in range(-k, k + 1): + if i != 0: + real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + ( + 1 - mask) * real_map[i + k, :, :] + imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + ( + 1 - mask) * imag_map[i + k, :, :] + else: + yx = np.argwhere(mask > 0.5) + k_ind = np.ones((len(yx)), dtype=np.int64) * k + y, x = yx[:, 0], yx[:, 1] + real_map[k_ind, y, x] = fourier_coeff[k, 0] - x + imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y + + return real_map, imag_map + + def generate_level_targets(self, img_size, text_polys, ignore_polys): + """Generate ground truth target on each level. + + Args: + img_size (list[int]): Shape of input image. + text_polys (list[list[ndarray]]): A list of ground truth polygons. + ignore_polys (list[list[ndarray]]): A list of ignored polygons. + Returns: + level_maps (list(ndarray)): A list of ground target on each level. + """ + h, w = img_size + lv_size_divs = self.level_size_divisors + lv_proportion_range = self.level_proportion_range + lv_text_polys = [[] for i in range(len(lv_size_divs))] + lv_ignore_polys = [[] for i in range(len(lv_size_divs))] + level_maps = [] + for poly in text_polys: + assert len(poly) == 1 + text_instance = [[poly[0][i], poly[0][i + 1]] + for i in range(0, len(poly[0]), 2)] + polygon = np.array(text_instance, dtype=np.int).reshape((1, -1, 2)) + _, _, box_w, box_h = cv2.boundingRect(polygon) + proportion = max(box_h, box_w) / (h + 1e-8) + + for ind, proportion_range in enumerate(lv_proportion_range): + if proportion_range[0] < proportion < proportion_range[1]: + lv_text_polys[ind].append([poly[0] / lv_size_divs[ind]]) + + for ignore_poly in ignore_polys: + assert len(ignore_poly) == 1 + text_instance = [[ignore_poly[0][i], ignore_poly[0][i + 1]] + for i in range(0, len(ignore_poly[0]), 2)] + polygon = np.array(text_instance, dtype=np.int).reshape((1, -1, 2)) + _, _, box_w, box_h = cv2.boundingRect(polygon) + proportion = max(box_h, box_w) / (h + 1e-8) + + for ind, proportion_range in enumerate(lv_proportion_range): + if proportion_range[0] < proportion < proportion_range[1]: + lv_text_polys[ind].append( + [ignore_poly[0] / lv_size_divs[ind]]) + + for ind, size_divisor in enumerate(lv_size_divs): + current_level_maps = [] + level_img_size = (h // size_divisor, w // size_divisor) + + text_region = self.generate_text_region_mask( + level_img_size, lv_text_polys[ind]) + text_region = np.expand_dims(text_region, axis=0) + current_level_maps.append(text_region) + + center_region = self.generate_center_region_mask( + level_img_size, lv_text_polys[ind]) + center_region = np.expand_dims(center_region, axis=0) + current_level_maps.append(center_region) + + effective_mask = self.generate_effective_mask( + level_img_size, lv_ignore_polys[ind]) + effective_mask = np.expand_dims(effective_mask, axis=0) + current_level_maps.append(effective_mask) + + fourier_real_map, fourier_image_maps = self.generate_fourier_maps( + level_img_size, lv_text_polys[ind]) + current_level_maps.append(fourier_real_map) + current_level_maps.append(fourier_image_maps) + + level_maps.append(np.concatenate(current_level_maps)) + + return level_maps + + def generate_targets(self, results): + """Generate the ground truth targets for FCENet. + + Args: + results (dict): The input result dictionary. + + Returns: + results (dict): The output result dictionary. + """ + + assert isinstance(results, dict) + + polygon_masks = results['gt_masks'].masks + polygon_masks_ignore = results['gt_masks_ignore'].masks + + h, w, _ = results['img_shape'] + + level_maps = self.generate_level_targets((h, w), polygon_masks, + polygon_masks_ignore) + + results['mask_fields'].clear() # rm gt_masks encoded by polygons + mapping = { + 'p3_maps': level_maps[0], + 'p4_maps': level_maps[1], + 'p5_maps': level_maps[2] + } + for key, value in mapping.items(): + results[key] = value + + return results diff --git a/mmocr/datasets/pipelines/transforms.py b/mmocr/datasets/pipelines/transforms.py index be01d659..72bf0cae 100644 --- a/mmocr/datasets/pipelines/transforms.py +++ b/mmocr/datasets/pipelines/transforms.py @@ -2,6 +2,7 @@ import math import cv2 import numpy as np +import Polygon as plg import torchvision.transforms as transforms from PIL import Image @@ -731,3 +732,231 @@ class SquareResizePad: def __repr__(self): repr_str = self.__class__.__name__ return repr_str + + +@PIPELINES.register_module() +class RandomScaling: + + def __init__(self, size=800, scale=(3. / 4, 5. / 2)): + """Random scale the image while keeping aspect. + + Args: + size (int) : Base size before scaling. + scale (tuple(float)) : The range of scaling. + """ + assert isinstance(size, int) + assert isinstance(scale, float) or isinstance(scale, tuple) + self.size = size + self.scale = scale if isinstance(scale, tuple) \ + else (1 - scale, 1 + scale) + + def __call__(self, results): + image = results['img'] + h, w, _ = results['img_shape'] + + aspect_ratio = np.random.uniform(min(self.scale), max(self.scale)) + scales = self.size * 1.0 / max(h, w) * aspect_ratio + scales = np.array([scales, scales]) + out_size = (int(h * scales[1]), int(w * scales[0])) + image = cv2.resize(image, out_size[::-1]) + + results['img'] = image + results['img_shape'] = image.shape + + for key in results.get('mask_fields', []): + if len(results[key].masks) == 0: + continue + results[key] = results[key].resize(out_size) + + return results + + +@PIPELINES.register_module() +class RandomCropFlip: + + def __init__(self, crop_ratio=0.5, iter_num=1, min_area_ratio=0.2): + """Random crop and flip a patch of the image. + + Args: + crop_ratio (float): The ratio of cropping. + iter_num (int): Number of operations. + min_area_ratio (float): Minimal area ratio between cropped patch + and original image. + """ + assert isinstance(crop_ratio, float) + assert isinstance(iter_num, int) + assert isinstance(min_area_ratio, float) + + self.scale = 10 + self.epsilon = 1e-2 + self.crop_ratio = crop_ratio + self.iter_num = iter_num + self.min_area_ratio = min_area_ratio + + def __call__(self, results): + for i in range(self.iter_num): + results = self.random_crop_flip(results) + return results + + def random_crop_flip(self, results): + image = results['img'] + polygons = results['gt_masks'].masks + ignore_polygons = results['gt_masks_ignore'].masks + all_polygons = polygons + ignore_polygons + if len(polygons) == 0: + return results + + if np.random.random() >= self.crop_ratio: + return results + + h_axis, w_axis = self.crop_target(image, all_polygons, self.scale) + if len(h_axis) == 0 or len(w_axis) == 0: + return results + + attempt = 0 + h, w, _ = results['img_shape'] + area = h * w + pad_h = h // self.scale + pad_w = w // self.scale + while attempt < 10: + attempt += 1 + polys_keep = [] + polys_new = [] + ign_polys_keep = [] + ign_polys_new = [] + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio: + # area too small + continue + + pts = np.stack([[xmin, xmax, xmax, xmin], + [ymin, ymin, ymax, ymax]]).T.astype(np.int32) + pp = plg.Polygon(pts) + fail_flag = False + for polygon in polygons: + ppi = plg.Polygon(polygon[0].reshape(-1, 2)) + ppiou, _ = eval_utils.poly_intersection(ppi, pp) + if np.abs(ppiou - float(ppi.area())) > self.epsilon and \ + np.abs(ppiou) > self.epsilon: + fail_flag = True + break + elif np.abs(ppiou - float(ppi.area())) < self.epsilon: + polys_new.append(polygon) + else: + polys_keep.append(polygon) + + for polygon in ignore_polygons: + ppi = plg.Polygon(polygon[0].reshape(-1, 2)) + ppiou, _ = eval_utils.poly_intersection(ppi, pp) + if np.abs(ppiou - float(ppi.area())) > self.epsilon and \ + np.abs(ppiou) > self.epsilon: + fail_flag = True + break + elif np.abs(ppiou - float(ppi.area())) < self.epsilon: + ign_polys_new.append(polygon) + else: + ign_polys_keep.append(polygon) + + if fail_flag: + continue + else: + break + + cropped = image[ymin:ymax, xmin:xmax, :] + select_type = np.random.randint(3) + if select_type == 0: + img = np.ascontiguousarray(cropped[:, ::-1]) + elif select_type == 1: + img = np.ascontiguousarray(cropped[::-1, :]) + else: + img = np.ascontiguousarray(cropped[::-1, ::-1]) + image[ymin:ymax, xmin:xmax, :] = img + results['img'] = image + + if len(polys_new) + len(ign_polys_new) != 0: + height, width, _ = cropped.shape + if select_type == 0: + for idx, polygon in enumerate(polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + polys_new[idx] = [poly.reshape(-1, )] + for idx, polygon in enumerate(ign_polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + ign_polys_new[idx] = [poly.reshape(-1, )] + elif select_type == 1: + for idx, polygon in enumerate(polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 1] = height - poly[:, 1] + 2 * ymin + polys_new[idx] = [poly.reshape(-1, )] + for idx, polygon in enumerate(ign_polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 1] = height - poly[:, 1] + 2 * ymin + ign_polys_new[idx] = [poly.reshape(-1, )] + else: + for idx, polygon in enumerate(polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + poly[:, 1] = height - poly[:, 1] + 2 * ymin + polys_new[idx] = [poly.reshape(-1, )] + for idx, polygon in enumerate(ign_polys_new): + poly = polygon[0].reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + poly[:, 1] = height - poly[:, 1] + 2 * ymin + ign_polys_new[idx] = [poly.reshape(-1, )] + polygons = polys_keep + polys_new + ignore_polygons = ign_polys_keep + ign_polys_new + results['gt_masks'] = PolygonMasks(polygons, *(image.shape[:2])) + results['gt_masks_ignore'] = PolygonMasks(ignore_polygons, + *(image.shape[:2])) + + return results + + def crop_target(self, image, all_polys, scale): + """Generate crop target and make sure not to crop the polygon + instances. + + Args: + image (ndarray): The image waited to be crop. + all_polys (list[list[ndarray]]): All polygons including ground + truth polygons and ground truth ignored polygons. + scale (int): A scale factor to control crop range. + Returns: + h_axis (ndarray): Vertical cropping range. + w_axis (ndarray): Horizontal cropping range. + """ + h, w, _ = image.shape + pad_h = h // scale + pad_w = w // scale + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + + text_polys = [] + for polygon in all_polys: + rect = cv2.minAreaRect(polygon[0].astype(np.int32).reshape(-1, 2)) + box = cv2.boxPoints(rect) + box = np.int0(box) + text_polys.append([box[0], box[1], box[2], box[3]]) + + polys = np.array(text_polys, dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) # ε››θˆδΊ”ε…₯ + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w:maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h:maxy + pad_h] = 1 + + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + return h_axis, w_axis diff --git a/mmocr/models/textdet/dense_heads/__init__.py b/mmocr/models/textdet/dense_heads/__init__.py index 8f227b80..ebce726f 100644 --- a/mmocr/models/textdet/dense_heads/__init__.py +++ b/mmocr/models/textdet/dense_heads/__init__.py @@ -1,7 +1,10 @@ from .db_head import DBHead +from .fce_head import FCEHead from .head_mixin import HeadMixin from .pan_head import PANHead from .pse_head import PSEHead from .textsnake_head import TextSnakeHead -__all__ = ['PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'TextSnakeHead'] +__all__ = [ + 'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'HeadMixin', 'TextSnakeHead' +] diff --git a/mmocr/models/textdet/dense_heads/fce_head.py b/mmocr/models/textdet/dense_heads/fce_head.py new file mode 100644 index 00000000..d1719880 --- /dev/null +++ b/mmocr/models/textdet/dense_heads/fce_head.py @@ -0,0 +1,134 @@ +import torch.nn as nn +from mmcv.cnn import normal_init + +from mmdet.core import multi_apply +from mmdet.models.builder import HEADS, build_loss +from mmocr.models.textdet.postprocess import decode +from ..postprocess.wrapper import poly_nms +from .head_mixin import HeadMixin + + +@HEADS.register_module() +class FCEHead(HeadMixin, nn.Module): + """The class for implementing FCENet head. + FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text + Detection. + + [https://arxiv.org/abs/2104.10442] + + Args: + in_channels (int): The number of input channels. + scales (list[int]) : The scale of each layer. + fourier_degree (int) : The maximum Fourier transform degree k. + sample_num (int) : The sampling points number of regression + loss. If it is too small, FCEnet tends to be overfitting. + score_thresh (float) : The threshold to filter out the final + candidates. + nms_thresh (float) : The threshold of nms. + alpha (float) : The parameter to calculate final scores. Score_{final} + = (Score_{text region} ^ alpha) + * (Score{text center region} ^ beta) + beta (float) :The parameter to calculate final scores. + """ + + def __init__(self, + in_channels, + scales, + fourier_degree=5, + sample_num=50, + reconstr_points=50, + decoding_type='fcenet', + loss=dict(type='FCELoss'), + score_thresh=0.3, + nms_thresh=0.1, + alpha=1.0, + beta=1.0, + train_cfg=None, + test_cfg=None): + + super().__init__() + assert isinstance(in_channels, int) + + self.downsample_ratio = 1.0 + self.in_channels = in_channels + self.scales = scales + self.fourier_degree = fourier_degree + self.sample_num = sample_num + self.reconstr_points = reconstr_points + loss['fourier_degree'] = fourier_degree + loss['sample_num'] = sample_num + self.decoding_type = decoding_type + self.loss_module = build_loss(loss) + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.alpha = alpha + self.beta = beta + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.out_channels_cls = 4 + self.out_channels_reg = (2 * self.fourier_degree + 1) * 2 + + self.out_conv_cls = nn.Conv2d( + self.in_channels, + self.out_channels_cls, + kernel_size=3, + stride=1, + padding=1) + self.out_conv_reg = nn.Conv2d( + self.in_channels, + self.out_channels_reg, + kernel_size=3, + stride=1, + padding=1) + + self.init_weights() + + def init_weights(self): + normal_init(self.out_conv_cls, mean=0, std=0.01) + normal_init(self.out_conv_reg, mean=0, std=0.01) + + def forward(self, feats): + cls_res, reg_res = multi_apply(self.forward_single, feats) + level_num = len(cls_res) + preds = [[cls_res[i], reg_res[i]] for i in range(level_num)] + return preds + + def forward_single(self, x): + cls_predict = self.out_conv_cls(x) + reg_predict = self.out_conv_reg(x) + return cls_predict, reg_predict + + def get_boundary(self, score_maps, img_metas, rescale): + assert len(score_maps) == len(self.scales) + + boundaries = [] + for idx, score_map in enumerate(score_maps): + scale = self.scales[idx] + boundaries = boundaries + self._get_boundary_single( + score_map, scale) + + # nms + boundaries = poly_nms(boundaries, self.nms_thresh) + + if rescale: + boundaries = self.resize_boundary( + boundaries, 1.0 / img_metas[0]['scale_factor']) + + results = dict(boundary_result=boundaries) + return results + + def _get_boundary_single(self, score_map, scale): + assert len(score_map) == 2 + assert score_map[1].shape[1] == 4 * self.fourier_degree + 2 + + return decode( + decoding_type=self.decoding_type, + preds=score_map, + fourier_degree=self.fourier_degree, + reconstr_points=self.reconstr_points, + scale=scale, + alpha=self.alpha, + beta=self.beta, + text_repr_type='poly', + score_thresh=self.score_thresh, + nms_thresh=self.nms_thresh) diff --git a/mmocr/models/textdet/detectors/__init__.py b/mmocr/models/textdet/detectors/__init__.py index 6aab9c73..6612bf54 100644 --- a/mmocr/models/textdet/detectors/__init__.py +++ b/mmocr/models/textdet/detectors/__init__.py @@ -1,4 +1,5 @@ from .dbnet import DBNet +from .fcenet import FCENet from .ocr_mask_rcnn import OCRMaskRCNN from .panet import PANet from .psenet import PSENet @@ -8,5 +9,5 @@ from .textsnake import TextSnake __all__ = [ 'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet', - 'PANet', 'PSENet', 'TextSnake' + 'PANet', 'PSENet', 'TextSnake', 'FCENet' ] diff --git a/mmocr/models/textdet/detectors/fcenet.py b/mmocr/models/textdet/detectors/fcenet.py new file mode 100644 index 00000000..597622ac --- /dev/null +++ b/mmocr/models/textdet/detectors/fcenet.py @@ -0,0 +1,32 @@ +from mmdet.models.builder import DETECTORS +from .single_stage_text_detector import SingleStageTextDetector +from .text_detector_mixin import TextDetectorMixin + + +@DETECTORS.register_module() +class FCENet(TextDetectorMixin, SingleStageTextDetector): + """The class for implementing FCENet text detector + FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text + Detection + + [https://arxiv.org/abs/2104.10442] + """ + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + show_score=False): + SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, + train_cfg, test_cfg, pretrained) + TextDetectorMixin.__init__(self, show_score) + + def simple_test(self, img, img_metas, rescale=False): + x = self.extract_feat(img) + outs = self.bbox_head(x) + boundaries = self.bbox_head.get_boundary(outs, img_metas, rescale) + + return [boundaries] diff --git a/mmocr/models/textdet/losses/__init__.py b/mmocr/models/textdet/losses/__init__.py index eaa4d9cd..d25218a3 100644 --- a/mmocr/models/textdet/losses/__init__.py +++ b/mmocr/models/textdet/losses/__init__.py @@ -1,6 +1,7 @@ from .db_loss import DBLoss +from .fce_loss import FCELoss from .pan_loss import PANLoss from .pse_loss import PSELoss from .textsnake_loss import TextSnakeLoss -__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss'] +__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss', 'FCELoss'] diff --git a/mmocr/models/textdet/losses/fce_loss.py b/mmocr/models/textdet/losses/fce_loss.py new file mode 100644 index 00000000..9d3f125d --- /dev/null +++ b/mmocr/models/textdet/losses/fce_loss.py @@ -0,0 +1,194 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from mmdet.core import multi_apply +from mmdet.models.builder import LOSSES + + +@LOSSES.register_module() +class FCELoss(nn.Module): + """The class for implementing FCENet loss + FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped + Text Detection + + [https://arxiv.org/abs/2104.10442] + + Args: + fourier_degree (int) : The maximum Fourier transform degree k. + sample_num (int) : The sampling points number of regression + loss. If it is too small, fcenet tends to be overfitting. + ohem_ratio (float): the negative/positive ratio in OHEM. + """ + + def __init__(self, fourier_degree, sample_num, ohem_ratio=3.): + super().__init__() + self.fourier_degree = fourier_degree + self.sample_num = sample_num + self.ohem_ratio = ohem_ratio + + def forward(self, preds, _, p3_maps, p4_maps, p5_maps): + assert isinstance(preds, list) + assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\ + 'fourier degree not equal in FCEhead and FCEtarget' + + device = preds[0][0].device + # to tensor + gts = [p3_maps, p4_maps, p5_maps] + for idx, maps in enumerate(gts): + gts[idx] = torch.from_numpy(np.stack(maps)).float().to(device) + + losses = multi_apply(self.forward_single, preds, gts) + + loss_tr = torch.tensor(0., device=device).float() + loss_tcl = torch.tensor(0., device=device).float() + loss_reg_x = torch.tensor(0., device=device).float() + loss_reg_y = torch.tensor(0., device=device).float() + + for idx, loss in enumerate(losses): + if idx == 0: + loss_tr += sum(loss) + elif idx == 1: + loss_tcl += sum(loss) + elif idx == 2: + loss_reg_x += sum(loss) + else: + loss_reg_y += sum(loss) + + results = dict( + loss_text=loss_tr, + loss_center=loss_tcl, + loss_reg_x=loss_reg_x, + loss_reg_y=loss_reg_y, + ) + + return results + + def forward_single(self, pred, gt): + cls_pred, reg_pred = pred[0], pred[1] + + tr_pred = cls_pred[:, :2, :, :].permute(0, 2, 3, 1)\ + .contiguous().view(-1, 2) + tcl_pred = cls_pred[:, 2:, :, :].permute(0, 2, 3, 1)\ + .contiguous().view(-1, 2) + x_pred = reg_pred[:, 0:2 * self.fourier_degree + 1, :, :]\ + .permute(0, 2, 3, 1).contiguous().view( + -1, 2 * self.fourier_degree + 1) + y_pred = reg_pred[:, + 2 * self.fourier_degree + 1:4 * self.fourier_degree + + 2, :, :].permute(0, 2, 3, 1).contiguous().view( + -1, 2 * self.fourier_degree + 1) + + tr_mask = gt[:, :1, :, :].permute(0, 2, 3, 1).contiguous().view(-1) + tcl_mask = gt[:, 1:2, :, :].permute(0, 2, 3, 1).contiguous().view(-1) + train_mask = gt[:, 2:3, :, :].permute(0, 2, 3, 1).contiguous().view(-1) + x_map = gt[:, 3:4 + 2 * self.fourier_degree, :, :].permute( + 0, 2, 3, 1).contiguous().view(-1, 2 * self.fourier_degree + 1) + y_map = gt[:, 4 + 2 * self.fourier_degree:, :, :].permute( + 0, 2, 3, 1).contiguous().view(-1, 2 * self.fourier_degree + 1) + + tr_train_mask = train_mask * tr_mask + device = x_map.device + # tr loss + loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long()) + + # tcl loss + loss_tcl = torch.tensor(0.).float().to(device) + tr_neg_mask = 1 - tr_train_mask + if tr_train_mask.sum().item() > 0: + loss_tcl_pos = F.cross_entropy( + tcl_pred[tr_train_mask.bool()], + tcl_mask[tr_train_mask.bool()].long()) + loss_tcl_neg = F.cross_entropy(tcl_pred[tr_neg_mask.bool()], + tcl_mask[tr_neg_mask.bool()].long()) + loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg + + # regression loss + loss_reg_x = torch.tensor(0.).float().to(device) + loss_reg_y = torch.tensor(0.).float().to(device) + if tr_train_mask.sum().item() > 0: + weight = (tr_mask[tr_train_mask.bool()].float() + + tcl_mask[tr_train_mask.bool()].float()) / 2 + weight = weight.contiguous().view(-1, 1) + + ft_x, ft_y = self.fourier2poly(x_map, y_map) + ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred) + + loss_reg_x = torch.mean(weight * F.smooth_l1_loss( + ft_x_pre[tr_train_mask.bool()], + ft_x[tr_train_mask.bool()], + reduction='none')) + loss_reg_y = torch.mean(weight * F.smooth_l1_loss( + ft_y_pre[tr_train_mask.bool()], + ft_y[tr_train_mask.bool()], + reduction='none')) + + return loss_tr, loss_tcl, loss_reg_x, loss_reg_y + + def ohem(self, predict, target, train_mask): + pos = (target * train_mask).bool() + neg = ((1 - target) * train_mask).bool() + + n_pos = pos.float().sum() + + if n_pos.item() > 0: + loss_pos = F.cross_entropy( + predict[pos], target[pos], reduction='sum') + loss_neg = F.cross_entropy( + predict[neg], target[neg], reduction='none') + n_neg = min( + int(neg.float().sum().item()), + int(self.ohem_ratio * n_pos.float())) + else: + loss_pos = torch.tensor(0.) + loss_neg = F.cross_entropy( + predict[neg], target[neg], reduction='none') + n_neg = 100 + if len(loss_neg) > n_neg: + loss_neg, _ = torch.topk(loss_neg, n_neg) + + return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float() + + def fourier2poly(self, real_maps, imag_maps): + """Transform Fourier coefficient maps to polygon maps. + + Args: + real_maps (tensor): A map composed of the real parts of the + Fourier coefficients, whose shape is (-1, 2k+1) + imag_maps (tensor):A map composed of the imag parts of the + Fourier coefficients, whose shape is (-1, 2k+1) + + Returns + x_maps (tensor): A map composed of the x value of the polygon + represented by n sample points (xn, yn), whose shape is (-1, n) + y_maps (tensor): A map composed of the y value of the polygon + represented by n sample points (xn, yn), whose shape is (-1, n) + """ + + device = real_maps.device + + k_vect = torch.arange( + -self.fourier_degree, + self.fourier_degree + 1, + dtype=torch.float, + device=device).view(-1, 1) + i_vect = torch.arange( + 0, self.sample_num, dtype=torch.float, device=device).view(1, -1) + + transform_matrix = 2 * np.pi / self.sample_num * torch.mm( + k_vect, i_vect) + + x1 = torch.einsum('ak, kn-> an', real_maps, + torch.cos(transform_matrix)) + x2 = torch.einsum('ak, kn-> an', imag_maps, + torch.sin(transform_matrix)) + y1 = torch.einsum('ak, kn-> an', real_maps, + torch.sin(transform_matrix)) + y2 = torch.einsum('ak, kn-> an', imag_maps, + torch.cos(transform_matrix)) + + x_maps = x1 - x2 + y_maps = y1 + y2 + + return x_maps, y_maps diff --git a/mmocr/models/textdet/postprocess/wrapper.py b/mmocr/models/textdet/postprocess/wrapper.py index a794a7db..4347fbc5 100644 --- a/mmocr/models/textdet/postprocess/wrapper.py +++ b/mmocr/models/textdet/postprocess/wrapper.py @@ -7,6 +7,7 @@ from shapely.geometry import Polygon from skimage.morphology import skeletonize from mmocr.core import points2boundary +from mmocr.core.evaluation.utils import boundary_iou def filter_instance(area, confidence, min_area, min_confidence): @@ -24,6 +25,8 @@ def decode( return db_decode(**kwargs) if decoding_type == 'textsnake': return textsnake_decode(**kwargs) + if decoding_type == 'fcenet': + return fcenet_decode(**kwargs) raise NotImplementedError @@ -391,3 +394,177 @@ def textsnake_decode(preds, boundaries.append(boundary + [score]) return boundaries + + +def fcenet_decode( + preds, + fourier_degree, + reconstr_points, + scale, + alpha=1.0, + beta=2.0, + text_repr_type='poly', + score_thresh=0.8, + nms_thresh=0.1, +): + """Decoding predictions of FCENet to instances. + + Args: + preds (list(Tensor)): The head output tensors. + fourier_degree (int): The maximum Fourier transform degree k. + reconstr_points (int): The points number of the polygon reconstructed + from predicted Fourier coefficients. + scale (int): The downsample scale of the prediction. + alpha (float) : The parameter to calculate final scores. Score_{final} + = (Score_{text region} ^ alpha) + * (Score_{text center region}^ beta) + beta (float) : The parameter to calculate final score. + text_repr_type (str): Boundary encoding type 'poly' or 'quad'. + score_thresh (float) : The threshold used to filter out the final + candidates. + nms_thresh (float) : The threshold of nms. + + Returns: + boundaries (list[list[float]]): The instance boundary and confidence + list. + """ + assert isinstance(preds, list) + assert len(preds) == 2 + assert text_repr_type == 'poly' + + cls_pred = preds[0][0] + tr_pred = cls_pred[0:2].softmax(dim=0).data.cpu().numpy() + tcl_pred = cls_pred[2:].softmax(dim=0).data.cpu().numpy() + + reg_pred = preds[1][0].permute(1, 2, 0).data.cpu().numpy() + x_pred = reg_pred[:, :, :2 * fourier_degree + 1] + y_pred = reg_pred[:, :, 2 * fourier_degree + 1:] + + score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta) + tr_pred_mask = (score_pred) > score_thresh + tr_mask = fill_hole(tr_pred_mask) + + tr_contours, _ = cv2.findContours( + tr_mask.astype(np.uint8), cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE) # opencv4 + + mask = np.zeros_like(tr_mask) + exp_matrix = generate_exp_matrix(reconstr_points, fourier_degree) + boundaries = [] + for cont in tr_contours: + deal_map = mask.copy().astype(np.int8) + cv2.drawContours(deal_map, [cont], -1, 1, -1) + + text_map = score_pred * deal_map + polygons = contour_transfor_inv(fourier_degree, x_pred, y_pred, + text_map, exp_matrix, scale) + polygons = poly_nms(polygons, nms_thresh) + boundaries = boundaries + polygons + + boundaries = poly_nms(boundaries, nms_thresh) + return boundaries + + +def poly_nms(polygons, threshold): + assert isinstance(polygons, list) + + polygons = np.array(sorted(polygons, key=lambda x: x[-1])) + + keep_poly = [] + index = [i for i in range(polygons.shape[0])] + + while len(index) > 0: + keep_poly.append(polygons[index[-1]].tolist()) + A = polygons[index[-1]][:-1] + index = np.delete(index, -1) + + iou_list = np.zeros((len(index), )) + for i in range(len(index)): + B = polygons[index[i]][:-1] + + iou_list[i] = boundary_iou(A, B) + remove_index = np.where(iou_list > threshold) + index = np.delete(index, remove_index) + + return keep_poly + + +def contour_transfor_inv(fourier_degree, x_pred, y_pred, score_map, exp_matrix, + scale): + """Reconstruct polygon from predicts. + + Args: + fourier_degree (int): The maximum Fourier degree K. + x_pred (ndarray): The real part of predicted Fourier coefficients. + y_pred (ndarray): The image part of predicted Fourier coefficients. + score_map (ndarray): The final score of candidates. + exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt, and shape + is (2k+1, n') where n' is reconstructed point number. See Eq.2 + in paper. + scale (int): The down-sample scale. + Returns: + Polygons (list): The reconstructed polygons and scores. + """ + mask = score_map > 0 + + xy_text = np.argwhere(mask) + dxy = xy_text[:, 1] + xy_text[:, 0] * 1j + + x = x_pred[mask] + y = y_pred[mask] + + c = x + y * 1j + c[:, fourier_degree] = c[:, fourier_degree] + dxy + c *= scale + + polygons = fourier_inverse_matrix(c, exp_matrix=exp_matrix) + score = score_map[mask].reshape(-1, 1) + return np.hstack((polygons, score)).tolist() + + +def fourier_inverse_matrix(fourier_coeff, exp_matrix): + """ Inverse Fourier transform + Args: + fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1), with + n and k being candidates number and Fourier degree respectively. + exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt and shape + is (2k+1, n') where n' is reconstructed point number. + See Eq.2 in paper. + Returns: + Polygons (ndarray): The reconstructed polygons shaped (n, n') + """ + + assert type(fourier_coeff) == np.ndarray + assert fourier_coeff.shape[1] == exp_matrix.shape[0] + + n = exp_matrix.shape[1] + polygons = np.zeros((fourier_coeff.shape[0], n, 2)) + + points = np.matmul(fourier_coeff, exp_matrix) + p_x = np.real(points) + p_y = np.imag(points) + polygons[:, :, 0] = p_x + polygons[:, :, 1] = p_y + return polygons.astype('int32').reshape(polygons.shape[0], -1) + + +def generate_exp_matrix(point_num, fourier_degree): + """ Generate a matrix of e^x, where x = 2pi x ikt. See Eq.2 in paper. + Args: + point_num (int): Number of reconstruct points of polygon + fourier_degree (int): Maximum Fourier degree k + Returns: + exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt and + shape is (2k+1, n') where n' is reconstructed point number. + """ + e = complex(np.e) + exp_matrix = np.zeros([2 * fourier_degree + 1, point_num], dtype='complex') + + temp = np.zeros([point_num], dtype='complex') + for i in range(point_num): + temp[i] = 2 * np.pi * 1j / point_num * i + + for i in range(2 * fourier_degree + 1): + exp_matrix[i, :] = temp * (i - fourier_degree) + + return np.power(e, exp_matrix) diff --git a/tests/test_dataset/test_textdet_targets.py b/tests/test_dataset/test_textdet_targets.py index 6fa53984..c26291c4 100644 --- a/tests/test_dataset/test_textdet_targets.py +++ b/tests/test_dataset/test_textdet_targets.py @@ -218,3 +218,27 @@ def test_gen_textsnake_targets(mock_show_feature): assert 'gt_sin_map' in output.keys() assert 'gt_cos_map' in output.keys() mock_show_feature.assert_called_once() + + +def test_fcenet_generate_targets(): + fourier_degree = 5 + target_generator = textdet_targets.FCENetTargets( + fourier_degree=fourier_degree) + + h, w, c = (64, 64, 3) + text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])], + [np.array([20, 0, 30, 0, 30, 10, 20, 10])]] + text_polys_ignore = [[np.array([0, 0, 15, 0, 15, 10, 0, 10])]] + + results = {} + results['mask_fields'] = [] + results['img_shape'] = (h, w, c) + results['gt_masks_ignore'] = PolygonMasks(text_polys_ignore, h, w) + results['gt_masks'] = PolygonMasks(text_polys, h, w) + results['gt_bboxes'] = np.array([[0, 0, 10, 10], [20, 0, 30, 10]]) + results['gt_labels'] = np.array([0, 1]) + + target_generator.generate_targets(results) + assert 'p3_maps' in results.keys() + assert 'p4_maps' in results.keys() + assert 'p5_maps' in results.keys() diff --git a/tests/test_dataset/test_transforms.py b/tests/test_dataset/test_transforms.py index a25f308a..c2d38c50 100644 --- a/tests/test_dataset/test_transforms.py +++ b/tests/test_dataset/test_transforms.py @@ -166,6 +166,72 @@ def test_affine_jitter(): assert np.allclose(np.array(output1), output2['img']) +def test_random_scale(): + h, w, c = 100, 100, 3 + img = np.ones((h, w, c), dtype=np.uint8) + results = {'img': img, 'img_shape': (h, w, c)} + + polygon = np.array([0., 0., 0., 10., 10., 10., 10., 0.]) + + results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2])) + results['mask_fields'] = ['gt_masks'] + + size = 100 + scale = (2., 2.) + random_scaler = transforms.RandomScaling(size=size, scale=scale) + + results = random_scaler(results) + + out_img = results['img'] + out_poly = results['gt_masks'].masks[0][0] + gt_poly = polygon * 2 + + assert np.allclose(out_img.shape, (2 * h, 2 * w, c)) + assert np.allclose(out_poly, gt_poly) + + +@mock.patch('%s.transforms.np.random.randint' % __name__) +def test_random_crop_flip(mock_randint): + img = np.ones((10, 10, 3), dtype=np.uint8) + img[0, 0, :] = 0 + results = {'img': img, 'img_shape': img.shape} + + polygon = np.array([0., 0., 0., 10., 10., 10., 10., 0.]) + + results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2])) + results['gt_masks_ignore'] = PolygonMasks([], *(img.shape[:2])) + results['mask_fields'] = ['gt_masks', 'gt_masks_ignore'] + + crop_ratio = 1.1 + iter_num = 3 + random_crop_fliper = transforms.RandomCropFlip( + crop_ratio=crop_ratio, iter_num=iter_num) + + # test crop_target + scale = 10 + all_polys = results['gt_masks'].masks + h_axis, w_axis = random_crop_fliper.crop_target(img, all_polys, scale) + + assert np.allclose(h_axis, (0, 11)) + assert np.allclose(w_axis, (0, 11)) + + # test __call__ + polygon = np.array([1., 1., 1., 9., 9., 9., 9., 1.]) + results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2])) + results['gt_masks_ignore'] = PolygonMasks([[polygon]], *(img.shape[:2])) + + mock_randint.side_effect = [0, 1, 2] + results = random_crop_fliper(results) + + out_img = results['img'] + out_poly = results['gt_masks'].masks[0][0] + gt_img = img + gt_poly = polygon + + assert np.allclose(out_img, gt_img) + assert np.allclose(out_poly, gt_poly) + + @mock.patch('%s.transforms.np.random.random_sample' % __name__) @mock.patch('%s.transforms.np.random.randint' % __name__) def test_random_crop_poly_instances(mock_randint, mock_sample): diff --git a/tests/test_models/test_detector.py b/tests/test_models/test_detector.py index 87d9d0b2..f98c0f73 100644 --- a/tests/test_models/test_detector.py +++ b/tests/test_models/test_detector.py @@ -372,3 +372,60 @@ def test_textsnake(cfg_file): results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} img = np.random.rand(5, 5) detector.show_result(img, results) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +@pytest.mark.parametrize( + 'cfg_file', ['textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py']) +def test_fcenet(cfg_file): + model = _get_detector_cfg(cfg_file) + model['pretrained'] = None + model['backbone']['norm_cfg']['type'] = 'BN' + + from mmocr.models import build_detector + detector = build_detector(model) + detector = detector.cuda() + + fourier_degree = 5 + input_shape = (1, 3, 256, 256) + (n, c, h, w) = input_shape + + imgs = torch.randn(n, c, h, w).float().cuda() + img_metas = [{ + 'img_shape': (h, w, c), + 'ori_shape': (h, w, c), + 'pad_shape': (h, w, c), + 'filename': '.png', + 'scale_factor': np.array([1, 1, 1, 1]), + 'flip': False, + } for _ in range(n)] + + p3_maps = [] + p4_maps = [] + p5_maps = [] + for _ in range(n): + p3_maps.append( + np.random.random((5 + 4 * fourier_degree, h // 8, w // 8))) + p4_maps.append( + np.random.random((5 + 4 * fourier_degree, h // 16, w // 16))) + p5_maps.append( + np.random.random((5 + 4 * fourier_degree, h // 32, w // 32))) + + # Test forward train + losses = detector.forward( + imgs, img_metas, p3_maps=p3_maps, p4_maps=p4_maps, p5_maps=p5_maps) + assert isinstance(losses, dict) + + # Test forward test + with torch.no_grad(): + img_list = [g[None, :] for g in imgs] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + return_loss=False) + batch_results.append(result) + + # Test show result + results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} + img = np.random.rand(5, 5) + detector.show_result(img, results) diff --git a/tests/test_models/test_loss.py b/tests/test_models/test_loss.py index 88d6522c..598828ec 100644 --- a/tests/test_models/test_loss.py +++ b/tests/test_models/test_loss.py @@ -31,3 +31,42 @@ def test_textsnakeloss(): bce_loss = textsnakeloss.balanced_bce_loss(pred, target, mask).item() assert np.allclose(bce_loss, 0) + + +def test_fcenetloss(): + k = 5 + fcenetloss = losses.FCELoss(fourier_degree=k, sample_num=10) + + input_shape = (1, 3, 64, 64) + (n, c, h, w) = input_shape + + # test ohem + pred = torch.ones((200, 2), dtype=torch.float) + target = torch.ones((200, ), dtype=torch.long) + target[20:] = 0 + mask = torch.ones((200, ), dtype=torch.long) + + ohem_loss1 = fcenetloss.ohem(pred, target, mask) + ohem_loss2 = fcenetloss.ohem(pred, target, 1 - mask) + assert isinstance(ohem_loss1, torch.Tensor) + assert isinstance(ohem_loss2, torch.Tensor) + + # test forward + preds = [] + for i in range(n): + scale = 8 * 2**i + pred = [] + pred.append(torch.rand(n, 4, h // scale, w // scale)) + pred.append(torch.rand(n, 4 * k + 2, h // scale, w // scale)) + preds.append(pred) + + p3_maps = [] + p4_maps = [] + p5_maps = [] + for _ in range(n): + p3_maps.append(np.random.random((5 + 4 * k, h // 8, w // 8))) + p4_maps.append(np.random.random((5 + 4 * k, h // 16, w // 16))) + p5_maps.append(np.random.random((5 + 4 * k, h // 32, w // 32))) + + loss = fcenetloss(preds, 0, p3_maps, p4_maps, p5_maps) + assert isinstance(loss, dict) diff --git a/tests/test_utils/test_wrapper.py b/tests/test_utils/test_wrapper.py index b8083e03..92ad3f74 100644 --- a/tests/test_utils/test_wrapper.py +++ b/tests/test_utils/test_wrapper.py @@ -12,3 +12,17 @@ def test_db_boxes_from_bitmaps(): boxes = db_decode(preds, text_repr_type='quad', min_text_width=0) assert len(boxes) == 1 + + +def test_fcenet_decode(): + from mmocr.models.textdet.postprocess.wrapper import fcenet_decode + + k = 5 + preds = [] + preds.append(torch.randn(1, 4, 40, 40)) + preds.append(torch.randn(1, 4 * k + 2, 40, 40)) + + boundaries = fcenet_decode( + preds=preds, fourier_degree=k, reconstr_points=50, scale=1) + + assert isinstance(boundaries, list)