add fcenet (#133)

* add fcenet

* fix linting and code style

* fcenet finetune

* Update transforms.py

* Update fcenet_r50dcnv2_fpn_1500e_ctw1500.py

* Update fcenet_targets.py

* Update fce_loss.py

* fix

* add readme

* fix config

* Update fcenet_r50dcnv2_fpn_1500e_ctw1500.py

* fix

* fix readme

* fix readme

* Update test_loss.py

Co-authored-by: Hongbin Sun <hongbin306@gmail.com>
This commit is contained in:
Zyq-scut 2021-05-14 21:37:04 +08:00 committed by GitHub
parent 1240169f18
commit cbdd98a1e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1510 additions and 10 deletions

View File

@ -0,0 +1,22 @@
# Fourier Contour Embedding for Arbitrary-Shaped Text Detection
## Introduction
[ALGORITHM]
```bibtex
@InProceedings{zhu2021fourier,
title={Fourier Contour Embedding for Arbitrary-Shaped Text Detection},
author={Yiqin Zhu and Jianyong Chen and Lingyu Liang and Zhanghui Kuang and Lianwen Jin and Wayne Zhang},
year={2021},
booktitle = {CVPR}
}
```
## Results and models
### CTW1500
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
| :--------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [FCENet](/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1500 |(736, 1080)| 0.828 | 0.875 | 0.851 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/20210511_181328.log.json) |

View File

@ -0,0 +1,134 @@
fourier_degree = 5
model = dict(
type='FCENet',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
dcn=dict(type='DCNv2', deform_groups=2, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
in_channels=[512, 1024, 2048],
out_channels=256,
add_extra_convs=True,
extra_convs_on_inputs=False, # use P5
num_outs=3,
relu_before_extra_convs=True,
act_cfg=None),
bbox_head=dict(
type='FCEHead',
in_channels=256,
scales=(8, 16, 32),
loss=dict(type='FCELoss'),
fourier_degree=fourier_degree,
))
train_cfg = None
test_cfg = None
dataset_type = 'IcdarDataset'
data_root = 'data/ctw1500/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='LoadTextAnnotations',
with_bbox=True,
with_mask=True,
poly2mask=False),
dict(
type='ColorJitter',
brightness=32.0 / 255,
saturation=0.5,
contrast=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='RandomScaling', size=800, scale=(3. / 4, 5. / 2)),
dict(
type='RandomCropFlip', crop_ratio=0.5, iter_num=1, min_area_ratio=0.2),
dict(
type='RandomCropPolyInstances',
instance_key='gt_masks',
crop_ratio=0.8,
min_side_ratio=0.3),
dict(
type='RandomRotatePolyInstances',
rotate_ratio=0.5,
max_angle=30,
pad_with_fixed_color=False),
dict(type='SquareResizePad', target_size=800, pad_ratio=0.6),
dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
dict(type='Pad', size_divisor=32),
dict(
type='FCENetTargets',
fourier_degree=fourier_degree,
level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0))),
dict(
type='CustomFormatBundle',
keys=['p3_maps', 'p4_maps', 'p5_maps'],
visualize=dict(flag=False, boundary_key=None)),
dict(type='Collect', keys=['img', 'p3_maps', 'p4_maps', 'p5_maps'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1080, 736),
flip=False,
transforms=[
dict(type='Resize', img_scale=(1280, 800), keep_ratio=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=6,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
img_prefix=data_root + '/imgs',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline))
evaluation = dict(interval=5, metric='hmean-iou')
# optimizer
optimizer = dict(type='SGD', lr=1e-3, momentum=0.90, weight_decay=5e-4)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True)
total_epochs = 1500
checkpoint_config = dict(interval=5)
# yapf:disable
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook')
])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]

View File

@ -5,7 +5,7 @@ from .icdar_dataset import IcdarDataset
from .kie_dataset import KIEDataset from .kie_dataset import KIEDataset
from .ocr_dataset import OCRDataset from .ocr_dataset import OCRDataset
from .ocr_seg_dataset import OCRSegDataset from .ocr_seg_dataset import OCRSegDataset
from .pipelines import CustomFormatBundle, DBNetTargets from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets
from .text_det_dataset import TextDetDataset from .text_det_dataset import TextDetDataset
from .utils import * # NOQA from .utils import * # NOQA
@ -13,7 +13,7 @@ from .utils import * # NOQA
__all__ = [ __all__ = [
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset', 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle', 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
'DBNetTargets', 'OCRSegDataset', 'KIEDataset' 'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets'
] ]
__all__ += utils.__all__ __all__ += utils.__all__

View File

@ -8,10 +8,11 @@ from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
OpencvToPil, PilToOpencv, RandomPaddingOCR, OpencvToPil, PilToOpencv, RandomPaddingOCR,
RandomRotateImageBox, ResizeOCR, ToTensorOCR) RandomRotateImageBox, ResizeOCR, ToTensorOCR)
from .test_time_aug import MultiRotateAugOCR from .test_time_aug import MultiRotateAugOCR
from .textdet_targets import DBNetTargets, PANetTargets, TextSnakeTargets from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
from .transforms import (ColorJitter, RandomCropInstances, TextSnakeTargets)
from .transforms import (ColorJitter, RandomCropFlip, RandomCropInstances,
RandomCropPolyInstances, RandomRotatePolyInstances, RandomCropPolyInstances, RandomRotatePolyInstances,
RandomRotateTextDet, ScaleAspectJitter, RandomRotateTextDet, RandomScaling, ScaleAspectJitter,
SquareResizePad) SquareResizePad)
__all__ = [ __all__ = [
@ -22,5 +23,6 @@ __all__ = [
'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR', 'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR',
'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil', 'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil',
'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets', 'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8' 'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets',
'RandomScaling', 'RandomCropFlip'
] ]

View File

@ -1,10 +1,11 @@
from .base_textdet_targets import BaseTextDetTargets from .base_textdet_targets import BaseTextDetTargets
from .dbnet_targets import DBNetTargets from .dbnet_targets import DBNetTargets
from .fcenet_targets import FCENetTargets
from .panet_targets import PANetTargets from .panet_targets import PANetTargets
from .psenet_targets import PSENetTargets from .psenet_targets import PSENetTargets
from .textsnake_targets import TextSnakeTargets from .textsnake_targets import TextSnakeTargets
__all__ = [ __all__ = [
'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets', 'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets',
'TextSnakeTargets' 'FCENetTargets', 'TextSnakeTargets'
] ]

View File

@ -0,0 +1,370 @@
import cv2
import numpy as np
from numpy.linalg import norm
import mmocr.utils.check_argument as check_argument
from mmdet.datasets.builder import PIPELINES
from .textsnake_targets import TextSnakeTargets
@PIPELINES.register_module()
class FCENetTargets(TextSnakeTargets):
"""Generate the ground truth targets of FCENet: Fourier Contour Embedding
for Arbitrary-Shaped Text Detection.
[https://arxiv.org/abs/2104.10442]
Args:
fourier_degree (int): The maximum Fourier transform degree k.
resample_step (float): The step size for resampling the text center
line (TCL). It's better not to exceed half of the minimum width.
center_region_shrink_ratio (float): The shrink ratio of text center
region.
level_size_divisors (tuple(int)): The downsample ratio on each level.
level_proportion_range (tuple(tuple(int))): The range of text sizes
assigned to each level.
"""
def __init__(self,
fourier_degree=5,
resample_step=4.0,
center_region_shrink_ratio=0.3,
level_size_divisors=(8, 16, 32),
level_proportion_range=((0, 0.4), (0.3, 0.7), (0.6, 1.0))):
super().__init__()
assert isinstance(level_size_divisors, tuple)
assert isinstance(level_proportion_range, tuple)
assert len(level_size_divisors) == len(level_proportion_range)
self.fourier_degree = fourier_degree
self.resample_step = resample_step
self.center_region_shrink_ratio = center_region_shrink_ratio
self.level_size_divisors = level_size_divisors
self.level_proportion_range = level_proportion_range
def generate_center_region_mask(self, img_size, text_polys):
"""Generate text center region mask.
Args:
img_size (tuple): The image size of (height, width).
text_polys (list[list[ndarray]]): The list of text polygons.
Returns:
center_region_mask (ndarray): The text center region mask.
"""
assert isinstance(img_size, tuple)
assert check_argument.is_2dlist(text_polys)
h, w = img_size
center_region_mask = np.zeros((h, w), np.uint8)
center_region_boxes = []
for poly in text_polys:
assert len(poly) == 1
polygon_points = poly[0].reshape(-1, 2)
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
resampled_top_line, resampled_bot_line = self.resample_sidelines(
top_line, bot_line, self.resample_step)
resampled_bot_line = resampled_bot_line[::-1]
center_line = (resampled_top_line + resampled_bot_line) / 2
line_head_shrink_len = norm(resampled_top_line[0] -
resampled_bot_line[0]) / 4.0
line_tail_shrink_len = norm(resampled_top_line[-1] -
resampled_bot_line[-1]) / 4.0
head_shrink_num = int(line_head_shrink_len // self.resample_step)
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
center_line = center_line[head_shrink_num:len(center_line) -
tail_shrink_num]
resampled_top_line = resampled_top_line[
head_shrink_num:len(resampled_top_line) - tail_shrink_num]
resampled_bot_line = resampled_bot_line[
head_shrink_num:len(resampled_bot_line) - tail_shrink_num]
for i in range(0, len(center_line) - 1):
tl = center_line[i] + (resampled_top_line[i] - center_line[i]
) * self.center_region_shrink_ratio
tr = center_line[i + 1] + (
resampled_top_line[i + 1] -
center_line[i + 1]) * self.center_region_shrink_ratio
br = center_line[i + 1] + (
resampled_bot_line[i + 1] -
center_line[i + 1]) * self.center_region_shrink_ratio
bl = center_line[i] + (resampled_bot_line[i] - center_line[i]
) * self.center_region_shrink_ratio
current_center_box = np.vstack([tl, tr, br,
bl]).astype(np.int32)
center_region_boxes.append(current_center_box)
cv2.fillPoly(center_region_mask, center_region_boxes, 1)
return center_region_mask
def resample_polygon(self, polygon, n=400):
"""Resample one polygon with n points on its boundary.
Args:
polygon (list[float]): The input polygon.
n (int): The number of resampled points.
Returns:
resampled_polygon (list[float]): The resampled polygon.
"""
length = []
for i in range(len(polygon)):
p1 = polygon[i]
if i == len(polygon) - 1:
p2 = polygon[0]
else:
p2 = polygon[i + 1]
length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5)
total_length = sum(length)
n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
n_on_each_line = n_on_each_line.astype(np.int32)
new_polygon = []
for i in range(len(polygon)):
num = n_on_each_line[i]
p1 = polygon[i]
if i == len(polygon) - 1:
p2 = polygon[0]
else:
p2 = polygon[i + 1]
if num == 0:
continue
dxdy = (p2 - p1) / num
for j in range(num):
point = p1 + dxdy * j
new_polygon.append(point)
return np.array(new_polygon)
def normalize_polygon(self, polygon):
"""Normalize one polygon so that its start point is at right most.
Args:
polygon (list[float]): The origin polygon.
Returns:
new_polygon (lost[float]): The polygon with start point at right.
"""
temp_polygon = polygon - polygon.mean(axis=0)
x = np.abs(temp_polygon[:, 0])
y = temp_polygon[:, 1]
index_x = np.argsort(x)
index_y = np.argmin(y[index_x[:8]])
index = index_x[index_y]
new_polygon = np.concatenate([polygon[index:], polygon[:index]])
return new_polygon
def fourier_transform(self, polygon, fourier_degree):
"""Perform Fourier transformation to generate Fourier coefficients ck
from polygon.
Args:
polygon (ndarray): An input polygon.
fourier_degree (int): The maximum Fourier degree K.
Returns:
c (ndarray(complex)): Fourier coefficients.
"""
points = polygon[:, 0] + polygon[:, 1] * 1j
n = len(points)
t = np.multiply([i / n for i in range(n)], -2 * np.pi * 1j)
e = complex(np.e)
c = np.zeros((2 * fourier_degree + 1, ), dtype='complex')
for i in range(-fourier_degree, fourier_degree + 1):
c[i + fourier_degree] = np.sum(points * np.power(e, i * t)) / n
return c
def clockwise(self, c, fourier_degree):
"""Make sure the polygon reconstructed from Fourier coefficients c in
the clockwise direction.
Args:
polygon (list[float]): The origin polygon.
Returns:
new_polygon (lost[float]): The polygon in clockwise point order.
"""
if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
return c
elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
return c[::-1]
else:
if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
return c
else:
return c[::-1]
def cal_fourier_signature(self, polygon, fourier_degree):
"""Calculate Fourier signature from input polygon.
Args:
polygon (ndarray): The input polygon.
fourier_degree (int): The maximum Fourier degree K.
Returns:
fourier_signature (ndarray): An array shaped (2k+1, 2) containing
real part and image part of 2k+1 Fourier coefficients.
"""
resampled_polygon = self.resample_polygon(polygon)
resampled_polygon = self.normalize_polygon(resampled_polygon)
fourier_coeff = self.fourier_transform(resampled_polygon,
fourier_degree)
fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
real_part = np.real(fourier_coeff).reshape((-1, 1))
image_part = np.imag(fourier_coeff).reshape((-1, 1))
fourier_signature = np.hstack([real_part, image_part])
return fourier_signature
def generate_fourier_maps(self, img_size, text_polys):
"""Generate Fourier coefficient maps.
Args:
img_size (tuple): The image size of (height, width).
text_polys (list[list[ndarray]]): The list of text polygons.
Returns:
fourier_real_map (ndarray): The Fourier coefficient real part maps.
fourier_image_map (ndarray): The Fourier coefficient image part
maps.
"""
assert isinstance(img_size, tuple)
assert check_argument.is_2dlist(text_polys)
h, w = img_size
k = self.fourier_degree
real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
for poly in text_polys:
assert len(poly) == 1
text_instance = [[poly[0][i], poly[0][i + 1]]
for i in range(0, len(poly[0]), 2)]
mask = np.zeros((h, w), dtype=np.uint8)
polygon = np.array(text_instance).reshape((1, -1, 2))
cv2.fillPoly(mask, polygon.astype(np.int32), 1)
fourier_coeff = self.cal_fourier_signature(polygon[0], k)
for i in range(-k, k + 1):
if i != 0:
real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + (
1 - mask) * real_map[i + k, :, :]
imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + (
1 - mask) * imag_map[i + k, :, :]
else:
yx = np.argwhere(mask > 0.5)
k_ind = np.ones((len(yx)), dtype=np.int64) * k
y, x = yx[:, 0], yx[:, 1]
real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
return real_map, imag_map
def generate_level_targets(self, img_size, text_polys, ignore_polys):
"""Generate ground truth target on each level.
Args:
img_size (list[int]): Shape of input image.
text_polys (list[list[ndarray]]): A list of ground truth polygons.
ignore_polys (list[list[ndarray]]): A list of ignored polygons.
Returns:
level_maps (list(ndarray)): A list of ground target on each level.
"""
h, w = img_size
lv_size_divs = self.level_size_divisors
lv_proportion_range = self.level_proportion_range
lv_text_polys = [[] for i in range(len(lv_size_divs))]
lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
level_maps = []
for poly in text_polys:
assert len(poly) == 1
text_instance = [[poly[0][i], poly[0][i + 1]]
for i in range(0, len(poly[0]), 2)]
polygon = np.array(text_instance, dtype=np.int).reshape((1, -1, 2))
_, _, box_w, box_h = cv2.boundingRect(polygon)
proportion = max(box_h, box_w) / (h + 1e-8)
for ind, proportion_range in enumerate(lv_proportion_range):
if proportion_range[0] < proportion < proportion_range[1]:
lv_text_polys[ind].append([poly[0] / lv_size_divs[ind]])
for ignore_poly in ignore_polys:
assert len(ignore_poly) == 1
text_instance = [[ignore_poly[0][i], ignore_poly[0][i + 1]]
for i in range(0, len(ignore_poly[0]), 2)]
polygon = np.array(text_instance, dtype=np.int).reshape((1, -1, 2))
_, _, box_w, box_h = cv2.boundingRect(polygon)
proportion = max(box_h, box_w) / (h + 1e-8)
for ind, proportion_range in enumerate(lv_proportion_range):
if proportion_range[0] < proportion < proportion_range[1]:
lv_text_polys[ind].append(
[ignore_poly[0] / lv_size_divs[ind]])
for ind, size_divisor in enumerate(lv_size_divs):
current_level_maps = []
level_img_size = (h // size_divisor, w // size_divisor)
text_region = self.generate_text_region_mask(
level_img_size, lv_text_polys[ind])
text_region = np.expand_dims(text_region, axis=0)
current_level_maps.append(text_region)
center_region = self.generate_center_region_mask(
level_img_size, lv_text_polys[ind])
center_region = np.expand_dims(center_region, axis=0)
current_level_maps.append(center_region)
effective_mask = self.generate_effective_mask(
level_img_size, lv_ignore_polys[ind])
effective_mask = np.expand_dims(effective_mask, axis=0)
current_level_maps.append(effective_mask)
fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
level_img_size, lv_text_polys[ind])
current_level_maps.append(fourier_real_map)
current_level_maps.append(fourier_image_maps)
level_maps.append(np.concatenate(current_level_maps))
return level_maps
def generate_targets(self, results):
"""Generate the ground truth targets for FCENet.
Args:
results (dict): The input result dictionary.
Returns:
results (dict): The output result dictionary.
"""
assert isinstance(results, dict)
polygon_masks = results['gt_masks'].masks
polygon_masks_ignore = results['gt_masks_ignore'].masks
h, w, _ = results['img_shape']
level_maps = self.generate_level_targets((h, w), polygon_masks,
polygon_masks_ignore)
results['mask_fields'].clear() # rm gt_masks encoded by polygons
mapping = {
'p3_maps': level_maps[0],
'p4_maps': level_maps[1],
'p5_maps': level_maps[2]
}
for key, value in mapping.items():
results[key] = value
return results

View File

@ -2,6 +2,7 @@ import math
import cv2 import cv2
import numpy as np import numpy as np
import Polygon as plg
import torchvision.transforms as transforms import torchvision.transforms as transforms
from PIL import Image from PIL import Image
@ -731,3 +732,231 @@ class SquareResizePad:
def __repr__(self): def __repr__(self):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
return repr_str return repr_str
@PIPELINES.register_module()
class RandomScaling:
def __init__(self, size=800, scale=(3. / 4, 5. / 2)):
"""Random scale the image while keeping aspect.
Args:
size (int) : Base size before scaling.
scale (tuple(float)) : The range of scaling.
"""
assert isinstance(size, int)
assert isinstance(scale, float) or isinstance(scale, tuple)
self.size = size
self.scale = scale if isinstance(scale, tuple) \
else (1 - scale, 1 + scale)
def __call__(self, results):
image = results['img']
h, w, _ = results['img_shape']
aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
scales = self.size * 1.0 / max(h, w) * aspect_ratio
scales = np.array([scales, scales])
out_size = (int(h * scales[1]), int(w * scales[0]))
image = cv2.resize(image, out_size[::-1])
results['img'] = image
results['img_shape'] = image.shape
for key in results.get('mask_fields', []):
if len(results[key].masks) == 0:
continue
results[key] = results[key].resize(out_size)
return results
@PIPELINES.register_module()
class RandomCropFlip:
def __init__(self, crop_ratio=0.5, iter_num=1, min_area_ratio=0.2):
"""Random crop and flip a patch of the image.
Args:
crop_ratio (float): The ratio of cropping.
iter_num (int): Number of operations.
min_area_ratio (float): Minimal area ratio between cropped patch
and original image.
"""
assert isinstance(crop_ratio, float)
assert isinstance(iter_num, int)
assert isinstance(min_area_ratio, float)
self.scale = 10
self.epsilon = 1e-2
self.crop_ratio = crop_ratio
self.iter_num = iter_num
self.min_area_ratio = min_area_ratio
def __call__(self, results):
for i in range(self.iter_num):
results = self.random_crop_flip(results)
return results
def random_crop_flip(self, results):
image = results['img']
polygons = results['gt_masks'].masks
ignore_polygons = results['gt_masks_ignore'].masks
all_polygons = polygons + ignore_polygons
if len(polygons) == 0:
return results
if np.random.random() >= self.crop_ratio:
return results
h_axis, w_axis = self.crop_target(image, all_polygons, self.scale)
if len(h_axis) == 0 or len(w_axis) == 0:
return results
attempt = 0
h, w, _ = results['img_shape']
area = h * w
pad_h = h // self.scale
pad_w = w // self.scale
while attempt < 10:
attempt += 1
polys_keep = []
polys_new = []
ign_polys_keep = []
ign_polys_new = []
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w - 1)
xmax = np.clip(xmax, 0, w - 1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
# area too small
continue
pts = np.stack([[xmin, xmax, xmax, xmin],
[ymin, ymin, ymax, ymax]]).T.astype(np.int32)
pp = plg.Polygon(pts)
fail_flag = False
for polygon in polygons:
ppi = plg.Polygon(polygon[0].reshape(-1, 2))
ppiou, _ = eval_utils.poly_intersection(ppi, pp)
if np.abs(ppiou - float(ppi.area())) > self.epsilon and \
np.abs(ppiou) > self.epsilon:
fail_flag = True
break
elif np.abs(ppiou - float(ppi.area())) < self.epsilon:
polys_new.append(polygon)
else:
polys_keep.append(polygon)
for polygon in ignore_polygons:
ppi = plg.Polygon(polygon[0].reshape(-1, 2))
ppiou, _ = eval_utils.poly_intersection(ppi, pp)
if np.abs(ppiou - float(ppi.area())) > self.epsilon and \
np.abs(ppiou) > self.epsilon:
fail_flag = True
break
elif np.abs(ppiou - float(ppi.area())) < self.epsilon:
ign_polys_new.append(polygon)
else:
ign_polys_keep.append(polygon)
if fail_flag:
continue
else:
break
cropped = image[ymin:ymax, xmin:xmax, :]
select_type = np.random.randint(3)
if select_type == 0:
img = np.ascontiguousarray(cropped[:, ::-1])
elif select_type == 1:
img = np.ascontiguousarray(cropped[::-1, :])
else:
img = np.ascontiguousarray(cropped[::-1, ::-1])
image[ymin:ymax, xmin:xmax, :] = img
results['img'] = image
if len(polys_new) + len(ign_polys_new) != 0:
height, width, _ = cropped.shape
if select_type == 0:
for idx, polygon in enumerate(polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 0] = width - poly[:, 0] + 2 * xmin
polys_new[idx] = [poly.reshape(-1, )]
for idx, polygon in enumerate(ign_polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 0] = width - poly[:, 0] + 2 * xmin
ign_polys_new[idx] = [poly.reshape(-1, )]
elif select_type == 1:
for idx, polygon in enumerate(polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 1] = height - poly[:, 1] + 2 * ymin
polys_new[idx] = [poly.reshape(-1, )]
for idx, polygon in enumerate(ign_polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 1] = height - poly[:, 1] + 2 * ymin
ign_polys_new[idx] = [poly.reshape(-1, )]
else:
for idx, polygon in enumerate(polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 0] = width - poly[:, 0] + 2 * xmin
poly[:, 1] = height - poly[:, 1] + 2 * ymin
polys_new[idx] = [poly.reshape(-1, )]
for idx, polygon in enumerate(ign_polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 0] = width - poly[:, 0] + 2 * xmin
poly[:, 1] = height - poly[:, 1] + 2 * ymin
ign_polys_new[idx] = [poly.reshape(-1, )]
polygons = polys_keep + polys_new
ignore_polygons = ign_polys_keep + ign_polys_new
results['gt_masks'] = PolygonMasks(polygons, *(image.shape[:2]))
results['gt_masks_ignore'] = PolygonMasks(ignore_polygons,
*(image.shape[:2]))
return results
def crop_target(self, image, all_polys, scale):
"""Generate crop target and make sure not to crop the polygon
instances.
Args:
image (ndarray): The image waited to be crop.
all_polys (list[list[ndarray]]): All polygons including ground
truth polygons and ground truth ignored polygons.
scale (int): A scale factor to control crop range.
Returns:
h_axis (ndarray): Vertical cropping range.
w_axis (ndarray): Horizontal cropping range.
"""
h, w, _ = image.shape
pad_h = h // scale
pad_w = w // scale
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
text_polys = []
for polygon in all_polys:
rect = cv2.minAreaRect(polygon[0].astype(np.int32).reshape(-1, 2))
box = cv2.boxPoints(rect)
box = np.int0(box)
text_polys.append([box[0], box[1], box[2], box[3]])
polys = np.array(text_polys, dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32) # 四舍五入
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx + pad_w:maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny + pad_h:maxy + pad_h] = 1
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
return h_axis, w_axis

View File

@ -1,7 +1,10 @@
from .db_head import DBHead from .db_head import DBHead
from .fce_head import FCEHead
from .head_mixin import HeadMixin from .head_mixin import HeadMixin
from .pan_head import PANHead from .pan_head import PANHead
from .pse_head import PSEHead from .pse_head import PSEHead
from .textsnake_head import TextSnakeHead from .textsnake_head import TextSnakeHead
__all__ = ['PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'TextSnakeHead'] __all__ = [
'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'HeadMixin', 'TextSnakeHead'
]

View File

@ -0,0 +1,134 @@
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import multi_apply
from mmdet.models.builder import HEADS, build_loss
from mmocr.models.textdet.postprocess import decode
from ..postprocess.wrapper import poly_nms
from .head_mixin import HeadMixin
@HEADS.register_module()
class FCEHead(HeadMixin, nn.Module):
"""The class for implementing FCENet head.
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
Detection.
[https://arxiv.org/abs/2104.10442]
Args:
in_channels (int): The number of input channels.
scales (list[int]) : The scale of each layer.
fourier_degree (int) : The maximum Fourier transform degree k.
sample_num (int) : The sampling points number of regression
loss. If it is too small, FCEnet tends to be overfitting.
score_thresh (float) : The threshold to filter out the final
candidates.
nms_thresh (float) : The threshold of nms.
alpha (float) : The parameter to calculate final scores. Score_{final}
= (Score_{text region} ^ alpha)
* (Score{text center region} ^ beta)
beta (float) :The parameter to calculate final scores.
"""
def __init__(self,
in_channels,
scales,
fourier_degree=5,
sample_num=50,
reconstr_points=50,
decoding_type='fcenet',
loss=dict(type='FCELoss'),
score_thresh=0.3,
nms_thresh=0.1,
alpha=1.0,
beta=1.0,
train_cfg=None,
test_cfg=None):
super().__init__()
assert isinstance(in_channels, int)
self.downsample_ratio = 1.0
self.in_channels = in_channels
self.scales = scales
self.fourier_degree = fourier_degree
self.sample_num = sample_num
self.reconstr_points = reconstr_points
loss['fourier_degree'] = fourier_degree
loss['sample_num'] = sample_num
self.decoding_type = decoding_type
self.loss_module = build_loss(loss)
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.alpha = alpha
self.beta = beta
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.out_channels_cls = 4
self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
self.out_conv_cls = nn.Conv2d(
self.in_channels,
self.out_channels_cls,
kernel_size=3,
stride=1,
padding=1)
self.out_conv_reg = nn.Conv2d(
self.in_channels,
self.out_channels_reg,
kernel_size=3,
stride=1,
padding=1)
self.init_weights()
def init_weights(self):
normal_init(self.out_conv_cls, mean=0, std=0.01)
normal_init(self.out_conv_reg, mean=0, std=0.01)
def forward(self, feats):
cls_res, reg_res = multi_apply(self.forward_single, feats)
level_num = len(cls_res)
preds = [[cls_res[i], reg_res[i]] for i in range(level_num)]
return preds
def forward_single(self, x):
cls_predict = self.out_conv_cls(x)
reg_predict = self.out_conv_reg(x)
return cls_predict, reg_predict
def get_boundary(self, score_maps, img_metas, rescale):
assert len(score_maps) == len(self.scales)
boundaries = []
for idx, score_map in enumerate(score_maps):
scale = self.scales[idx]
boundaries = boundaries + self._get_boundary_single(
score_map, scale)
# nms
boundaries = poly_nms(boundaries, self.nms_thresh)
if rescale:
boundaries = self.resize_boundary(
boundaries, 1.0 / img_metas[0]['scale_factor'])
results = dict(boundary_result=boundaries)
return results
def _get_boundary_single(self, score_map, scale):
assert len(score_map) == 2
assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
return decode(
decoding_type=self.decoding_type,
preds=score_map,
fourier_degree=self.fourier_degree,
reconstr_points=self.reconstr_points,
scale=scale,
alpha=self.alpha,
beta=self.beta,
text_repr_type='poly',
score_thresh=self.score_thresh,
nms_thresh=self.nms_thresh)

View File

@ -1,4 +1,5 @@
from .dbnet import DBNet from .dbnet import DBNet
from .fcenet import FCENet
from .ocr_mask_rcnn import OCRMaskRCNN from .ocr_mask_rcnn import OCRMaskRCNN
from .panet import PANet from .panet import PANet
from .psenet import PSENet from .psenet import PSENet
@ -8,5 +9,5 @@ from .textsnake import TextSnake
__all__ = [ __all__ = [
'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet', 'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',
'PANet', 'PSENet', 'TextSnake' 'PANet', 'PSENet', 'TextSnake', 'FCENet'
] ]

View File

@ -0,0 +1,32 @@
from mmdet.models.builder import DETECTORS
from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module()
class FCENet(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing FCENet text detector
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
Detection
[https://arxiv.org/abs/2104.10442]
"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None,
show_score=False):
SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
train_cfg, test_cfg, pretrained)
TextDetectorMixin.__init__(self, show_score)
def simple_test(self, img, img_metas, rescale=False):
x = self.extract_feat(img)
outs = self.bbox_head(x)
boundaries = self.bbox_head.get_boundary(outs, img_metas, rescale)
return [boundaries]

View File

@ -1,6 +1,7 @@
from .db_loss import DBLoss from .db_loss import DBLoss
from .fce_loss import FCELoss
from .pan_loss import PANLoss from .pan_loss import PANLoss
from .pse_loss import PSELoss from .pse_loss import PSELoss
from .textsnake_loss import TextSnakeLoss from .textsnake_loss import TextSnakeLoss
__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss'] __all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss', 'FCELoss']

View File

@ -0,0 +1,194 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from mmdet.core import multi_apply
from mmdet.models.builder import LOSSES
@LOSSES.register_module()
class FCELoss(nn.Module):
"""The class for implementing FCENet loss
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped
Text Detection
[https://arxiv.org/abs/2104.10442]
Args:
fourier_degree (int) : The maximum Fourier transform degree k.
sample_num (int) : The sampling points number of regression
loss. If it is too small, fcenet tends to be overfitting.
ohem_ratio (float): the negative/positive ratio in OHEM.
"""
def __init__(self, fourier_degree, sample_num, ohem_ratio=3.):
super().__init__()
self.fourier_degree = fourier_degree
self.sample_num = sample_num
self.ohem_ratio = ohem_ratio
def forward(self, preds, _, p3_maps, p4_maps, p5_maps):
assert isinstance(preds, list)
assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
'fourier degree not equal in FCEhead and FCEtarget'
device = preds[0][0].device
# to tensor
gts = [p3_maps, p4_maps, p5_maps]
for idx, maps in enumerate(gts):
gts[idx] = torch.from_numpy(np.stack(maps)).float().to(device)
losses = multi_apply(self.forward_single, preds, gts)
loss_tr = torch.tensor(0., device=device).float()
loss_tcl = torch.tensor(0., device=device).float()
loss_reg_x = torch.tensor(0., device=device).float()
loss_reg_y = torch.tensor(0., device=device).float()
for idx, loss in enumerate(losses):
if idx == 0:
loss_tr += sum(loss)
elif idx == 1:
loss_tcl += sum(loss)
elif idx == 2:
loss_reg_x += sum(loss)
else:
loss_reg_y += sum(loss)
results = dict(
loss_text=loss_tr,
loss_center=loss_tcl,
loss_reg_x=loss_reg_x,
loss_reg_y=loss_reg_y,
)
return results
def forward_single(self, pred, gt):
cls_pred, reg_pred = pred[0], pred[1]
tr_pred = cls_pred[:, :2, :, :].permute(0, 2, 3, 1)\
.contiguous().view(-1, 2)
tcl_pred = cls_pred[:, 2:, :, :].permute(0, 2, 3, 1)\
.contiguous().view(-1, 2)
x_pred = reg_pred[:, 0:2 * self.fourier_degree + 1, :, :]\
.permute(0, 2, 3, 1).contiguous().view(
-1, 2 * self.fourier_degree + 1)
y_pred = reg_pred[:,
2 * self.fourier_degree + 1:4 * self.fourier_degree +
2, :, :].permute(0, 2, 3, 1).contiguous().view(
-1, 2 * self.fourier_degree + 1)
tr_mask = gt[:, :1, :, :].permute(0, 2, 3, 1).contiguous().view(-1)
tcl_mask = gt[:, 1:2, :, :].permute(0, 2, 3, 1).contiguous().view(-1)
train_mask = gt[:, 2:3, :, :].permute(0, 2, 3, 1).contiguous().view(-1)
x_map = gt[:, 3:4 + 2 * self.fourier_degree, :, :].permute(
0, 2, 3, 1).contiguous().view(-1, 2 * self.fourier_degree + 1)
y_map = gt[:, 4 + 2 * self.fourier_degree:, :, :].permute(
0, 2, 3, 1).contiguous().view(-1, 2 * self.fourier_degree + 1)
tr_train_mask = train_mask * tr_mask
device = x_map.device
# tr loss
loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long())
# tcl loss
loss_tcl = torch.tensor(0.).float().to(device)
tr_neg_mask = 1 - tr_train_mask
if tr_train_mask.sum().item() > 0:
loss_tcl_pos = F.cross_entropy(
tcl_pred[tr_train_mask.bool()],
tcl_mask[tr_train_mask.bool()].long())
loss_tcl_neg = F.cross_entropy(tcl_pred[tr_neg_mask.bool()],
tcl_mask[tr_neg_mask.bool()].long())
loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
# regression loss
loss_reg_x = torch.tensor(0.).float().to(device)
loss_reg_y = torch.tensor(0.).float().to(device)
if tr_train_mask.sum().item() > 0:
weight = (tr_mask[tr_train_mask.bool()].float() +
tcl_mask[tr_train_mask.bool()].float()) / 2
weight = weight.contiguous().view(-1, 1)
ft_x, ft_y = self.fourier2poly(x_map, y_map)
ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
loss_reg_x = torch.mean(weight * F.smooth_l1_loss(
ft_x_pre[tr_train_mask.bool()],
ft_x[tr_train_mask.bool()],
reduction='none'))
loss_reg_y = torch.mean(weight * F.smooth_l1_loss(
ft_y_pre[tr_train_mask.bool()],
ft_y[tr_train_mask.bool()],
reduction='none'))
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
def ohem(self, predict, target, train_mask):
pos = (target * train_mask).bool()
neg = ((1 - target) * train_mask).bool()
n_pos = pos.float().sum()
if n_pos.item() > 0:
loss_pos = F.cross_entropy(
predict[pos], target[pos], reduction='sum')
loss_neg = F.cross_entropy(
predict[neg], target[neg], reduction='none')
n_neg = min(
int(neg.float().sum().item()),
int(self.ohem_ratio * n_pos.float()))
else:
loss_pos = torch.tensor(0.)
loss_neg = F.cross_entropy(
predict[neg], target[neg], reduction='none')
n_neg = 100
if len(loss_neg) > n_neg:
loss_neg, _ = torch.topk(loss_neg, n_neg)
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
def fourier2poly(self, real_maps, imag_maps):
"""Transform Fourier coefficient maps to polygon maps.
Args:
real_maps (tensor): A map composed of the real parts of the
Fourier coefficients, whose shape is (-1, 2k+1)
imag_maps (tensor):A map composed of the imag parts of the
Fourier coefficients, whose shape is (-1, 2k+1)
Returns
x_maps (tensor): A map composed of the x value of the polygon
represented by n sample points (xn, yn), whose shape is (-1, n)
y_maps (tensor): A map composed of the y value of the polygon
represented by n sample points (xn, yn), whose shape is (-1, n)
"""
device = real_maps.device
k_vect = torch.arange(
-self.fourier_degree,
self.fourier_degree + 1,
dtype=torch.float,
device=device).view(-1, 1)
i_vect = torch.arange(
0, self.sample_num, dtype=torch.float, device=device).view(1, -1)
transform_matrix = 2 * np.pi / self.sample_num * torch.mm(
k_vect, i_vect)
x1 = torch.einsum('ak, kn-> an', real_maps,
torch.cos(transform_matrix))
x2 = torch.einsum('ak, kn-> an', imag_maps,
torch.sin(transform_matrix))
y1 = torch.einsum('ak, kn-> an', real_maps,
torch.sin(transform_matrix))
y2 = torch.einsum('ak, kn-> an', imag_maps,
torch.cos(transform_matrix))
x_maps = x1 - x2
y_maps = y1 + y2
return x_maps, y_maps

View File

@ -7,6 +7,7 @@ from shapely.geometry import Polygon
from skimage.morphology import skeletonize from skimage.morphology import skeletonize
from mmocr.core import points2boundary from mmocr.core import points2boundary
from mmocr.core.evaluation.utils import boundary_iou
def filter_instance(area, confidence, min_area, min_confidence): def filter_instance(area, confidence, min_area, min_confidence):
@ -24,6 +25,8 @@ def decode(
return db_decode(**kwargs) return db_decode(**kwargs)
if decoding_type == 'textsnake': if decoding_type == 'textsnake':
return textsnake_decode(**kwargs) return textsnake_decode(**kwargs)
if decoding_type == 'fcenet':
return fcenet_decode(**kwargs)
raise NotImplementedError raise NotImplementedError
@ -391,3 +394,177 @@ def textsnake_decode(preds,
boundaries.append(boundary + [score]) boundaries.append(boundary + [score])
return boundaries return boundaries
def fcenet_decode(
preds,
fourier_degree,
reconstr_points,
scale,
alpha=1.0,
beta=2.0,
text_repr_type='poly',
score_thresh=0.8,
nms_thresh=0.1,
):
"""Decoding predictions of FCENet to instances.
Args:
preds (list(Tensor)): The head output tensors.
fourier_degree (int): The maximum Fourier transform degree k.
reconstr_points (int): The points number of the polygon reconstructed
from predicted Fourier coefficients.
scale (int): The downsample scale of the prediction.
alpha (float) : The parameter to calculate final scores. Score_{final}
= (Score_{text region} ^ alpha)
* (Score_{text center region}^ beta)
beta (float) : The parameter to calculate final score.
text_repr_type (str): Boundary encoding type 'poly' or 'quad'.
score_thresh (float) : The threshold used to filter out the final
candidates.
nms_thresh (float) : The threshold of nms.
Returns:
boundaries (list[list[float]]): The instance boundary and confidence
list.
"""
assert isinstance(preds, list)
assert len(preds) == 2
assert text_repr_type == 'poly'
cls_pred = preds[0][0]
tr_pred = cls_pred[0:2].softmax(dim=0).data.cpu().numpy()
tcl_pred = cls_pred[2:].softmax(dim=0).data.cpu().numpy()
reg_pred = preds[1][0].permute(1, 2, 0).data.cpu().numpy()
x_pred = reg_pred[:, :, :2 * fourier_degree + 1]
y_pred = reg_pred[:, :, 2 * fourier_degree + 1:]
score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta)
tr_pred_mask = (score_pred) > score_thresh
tr_mask = fill_hole(tr_pred_mask)
tr_contours, _ = cv2.findContours(
tr_mask.astype(np.uint8), cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE) # opencv4
mask = np.zeros_like(tr_mask)
exp_matrix = generate_exp_matrix(reconstr_points, fourier_degree)
boundaries = []
for cont in tr_contours:
deal_map = mask.copy().astype(np.int8)
cv2.drawContours(deal_map, [cont], -1, 1, -1)
text_map = score_pred * deal_map
polygons = contour_transfor_inv(fourier_degree, x_pred, y_pred,
text_map, exp_matrix, scale)
polygons = poly_nms(polygons, nms_thresh)
boundaries = boundaries + polygons
boundaries = poly_nms(boundaries, nms_thresh)
return boundaries
def poly_nms(polygons, threshold):
assert isinstance(polygons, list)
polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
keep_poly = []
index = [i for i in range(polygons.shape[0])]
while len(index) > 0:
keep_poly.append(polygons[index[-1]].tolist())
A = polygons[index[-1]][:-1]
index = np.delete(index, -1)
iou_list = np.zeros((len(index), ))
for i in range(len(index)):
B = polygons[index[i]][:-1]
iou_list[i] = boundary_iou(A, B)
remove_index = np.where(iou_list > threshold)
index = np.delete(index, remove_index)
return keep_poly
def contour_transfor_inv(fourier_degree, x_pred, y_pred, score_map, exp_matrix,
scale):
"""Reconstruct polygon from predicts.
Args:
fourier_degree (int): The maximum Fourier degree K.
x_pred (ndarray): The real part of predicted Fourier coefficients.
y_pred (ndarray): The image part of predicted Fourier coefficients.
score_map (ndarray): The final score of candidates.
exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt, and shape
is (2k+1, n') where n' is reconstructed point number. See Eq.2
in paper.
scale (int): The down-sample scale.
Returns:
Polygons (list): The reconstructed polygons and scores.
"""
mask = score_map > 0
xy_text = np.argwhere(mask)
dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
x = x_pred[mask]
y = y_pred[mask]
c = x + y * 1j
c[:, fourier_degree] = c[:, fourier_degree] + dxy
c *= scale
polygons = fourier_inverse_matrix(c, exp_matrix=exp_matrix)
score = score_map[mask].reshape(-1, 1)
return np.hstack((polygons, score)).tolist()
def fourier_inverse_matrix(fourier_coeff, exp_matrix):
""" Inverse Fourier transform
Args:
fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1), with
n and k being candidates number and Fourier degree respectively.
exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt and shape
is (2k+1, n') where n' is reconstructed point number.
See Eq.2 in paper.
Returns:
Polygons (ndarray): The reconstructed polygons shaped (n, n')
"""
assert type(fourier_coeff) == np.ndarray
assert fourier_coeff.shape[1] == exp_matrix.shape[0]
n = exp_matrix.shape[1]
polygons = np.zeros((fourier_coeff.shape[0], n, 2))
points = np.matmul(fourier_coeff, exp_matrix)
p_x = np.real(points)
p_y = np.imag(points)
polygons[:, :, 0] = p_x
polygons[:, :, 1] = p_y
return polygons.astype('int32').reshape(polygons.shape[0], -1)
def generate_exp_matrix(point_num, fourier_degree):
""" Generate a matrix of e^x, where x = 2pi x ikt. See Eq.2 in paper.
Args:
point_num (int): Number of reconstruct points of polygon
fourier_degree (int): Maximum Fourier degree k
Returns:
exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt and
shape is (2k+1, n') where n' is reconstructed point number.
"""
e = complex(np.e)
exp_matrix = np.zeros([2 * fourier_degree + 1, point_num], dtype='complex')
temp = np.zeros([point_num], dtype='complex')
for i in range(point_num):
temp[i] = 2 * np.pi * 1j / point_num * i
for i in range(2 * fourier_degree + 1):
exp_matrix[i, :] = temp * (i - fourier_degree)
return np.power(e, exp_matrix)

View File

@ -218,3 +218,27 @@ def test_gen_textsnake_targets(mock_show_feature):
assert 'gt_sin_map' in output.keys() assert 'gt_sin_map' in output.keys()
assert 'gt_cos_map' in output.keys() assert 'gt_cos_map' in output.keys()
mock_show_feature.assert_called_once() mock_show_feature.assert_called_once()
def test_fcenet_generate_targets():
fourier_degree = 5
target_generator = textdet_targets.FCENetTargets(
fourier_degree=fourier_degree)
h, w, c = (64, 64, 3)
text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])],
[np.array([20, 0, 30, 0, 30, 10, 20, 10])]]
text_polys_ignore = [[np.array([0, 0, 15, 0, 15, 10, 0, 10])]]
results = {}
results['mask_fields'] = []
results['img_shape'] = (h, w, c)
results['gt_masks_ignore'] = PolygonMasks(text_polys_ignore, h, w)
results['gt_masks'] = PolygonMasks(text_polys, h, w)
results['gt_bboxes'] = np.array([[0, 0, 10, 10], [20, 0, 30, 10]])
results['gt_labels'] = np.array([0, 1])
target_generator.generate_targets(results)
assert 'p3_maps' in results.keys()
assert 'p4_maps' in results.keys()
assert 'p5_maps' in results.keys()

View File

@ -166,6 +166,72 @@ def test_affine_jitter():
assert np.allclose(np.array(output1), output2['img']) assert np.allclose(np.array(output1), output2['img'])
def test_random_scale():
h, w, c = 100, 100, 3
img = np.ones((h, w, c), dtype=np.uint8)
results = {'img': img, 'img_shape': (h, w, c)}
polygon = np.array([0., 0., 0., 10., 10., 10., 10., 0.])
results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2]))
results['mask_fields'] = ['gt_masks']
size = 100
scale = (2., 2.)
random_scaler = transforms.RandomScaling(size=size, scale=scale)
results = random_scaler(results)
out_img = results['img']
out_poly = results['gt_masks'].masks[0][0]
gt_poly = polygon * 2
assert np.allclose(out_img.shape, (2 * h, 2 * w, c))
assert np.allclose(out_poly, gt_poly)
@mock.patch('%s.transforms.np.random.randint' % __name__)
def test_random_crop_flip(mock_randint):
img = np.ones((10, 10, 3), dtype=np.uint8)
img[0, 0, :] = 0
results = {'img': img, 'img_shape': img.shape}
polygon = np.array([0., 0., 0., 10., 10., 10., 10., 0.])
results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2]))
results['gt_masks_ignore'] = PolygonMasks([], *(img.shape[:2]))
results['mask_fields'] = ['gt_masks', 'gt_masks_ignore']
crop_ratio = 1.1
iter_num = 3
random_crop_fliper = transforms.RandomCropFlip(
crop_ratio=crop_ratio, iter_num=iter_num)
# test crop_target
scale = 10
all_polys = results['gt_masks'].masks
h_axis, w_axis = random_crop_fliper.crop_target(img, all_polys, scale)
assert np.allclose(h_axis, (0, 11))
assert np.allclose(w_axis, (0, 11))
# test __call__
polygon = np.array([1., 1., 1., 9., 9., 9., 9., 1.])
results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2]))
results['gt_masks_ignore'] = PolygonMasks([[polygon]], *(img.shape[:2]))
mock_randint.side_effect = [0, 1, 2]
results = random_crop_fliper(results)
out_img = results['img']
out_poly = results['gt_masks'].masks[0][0]
gt_img = img
gt_poly = polygon
assert np.allclose(out_img, gt_img)
assert np.allclose(out_poly, gt_poly)
@mock.patch('%s.transforms.np.random.random_sample' % __name__) @mock.patch('%s.transforms.np.random.random_sample' % __name__)
@mock.patch('%s.transforms.np.random.randint' % __name__) @mock.patch('%s.transforms.np.random.randint' % __name__)
def test_random_crop_poly_instances(mock_randint, mock_sample): def test_random_crop_poly_instances(mock_randint, mock_sample):

View File

@ -372,3 +372,60 @@ def test_textsnake(cfg_file):
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
img = np.random.rand(5, 5) img = np.random.rand(5, 5)
detector.show_result(img, results) detector.show_result(img, results)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.parametrize(
'cfg_file', ['textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py'])
def test_fcenet(cfg_file):
model = _get_detector_cfg(cfg_file)
model['pretrained'] = None
model['backbone']['norm_cfg']['type'] = 'BN'
from mmocr.models import build_detector
detector = build_detector(model)
detector = detector.cuda()
fourier_degree = 5
input_shape = (1, 3, 256, 256)
(n, c, h, w) = input_shape
imgs = torch.randn(n, c, h, w).float().cuda()
img_metas = [{
'img_shape': (h, w, c),
'ori_shape': (h, w, c),
'pad_shape': (h, w, c),
'filename': '<demo>.png',
'scale_factor': np.array([1, 1, 1, 1]),
'flip': False,
} for _ in range(n)]
p3_maps = []
p4_maps = []
p5_maps = []
for _ in range(n):
p3_maps.append(
np.random.random((5 + 4 * fourier_degree, h // 8, w // 8)))
p4_maps.append(
np.random.random((5 + 4 * fourier_degree, h // 16, w // 16)))
p5_maps.append(
np.random.random((5 + 4 * fourier_degree, h // 32, w // 32)))
# Test forward train
losses = detector.forward(
imgs, img_metas, p3_maps=p3_maps, p4_maps=p4_maps, p5_maps=p5_maps)
assert isinstance(losses, dict)
# Test forward test
with torch.no_grad():
img_list = [g[None, :] for g in imgs]
batch_results = []
for one_img, one_meta in zip(img_list, img_metas):
result = detector.forward([one_img], [[one_meta]],
return_loss=False)
batch_results.append(result)
# Test show result
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
img = np.random.rand(5, 5)
detector.show_result(img, results)

View File

@ -31,3 +31,42 @@ def test_textsnakeloss():
bce_loss = textsnakeloss.balanced_bce_loss(pred, target, mask).item() bce_loss = textsnakeloss.balanced_bce_loss(pred, target, mask).item()
assert np.allclose(bce_loss, 0) assert np.allclose(bce_loss, 0)
def test_fcenetloss():
k = 5
fcenetloss = losses.FCELoss(fourier_degree=k, sample_num=10)
input_shape = (1, 3, 64, 64)
(n, c, h, w) = input_shape
# test ohem
pred = torch.ones((200, 2), dtype=torch.float)
target = torch.ones((200, ), dtype=torch.long)
target[20:] = 0
mask = torch.ones((200, ), dtype=torch.long)
ohem_loss1 = fcenetloss.ohem(pred, target, mask)
ohem_loss2 = fcenetloss.ohem(pred, target, 1 - mask)
assert isinstance(ohem_loss1, torch.Tensor)
assert isinstance(ohem_loss2, torch.Tensor)
# test forward
preds = []
for i in range(n):
scale = 8 * 2**i
pred = []
pred.append(torch.rand(n, 4, h // scale, w // scale))
pred.append(torch.rand(n, 4 * k + 2, h // scale, w // scale))
preds.append(pred)
p3_maps = []
p4_maps = []
p5_maps = []
for _ in range(n):
p3_maps.append(np.random.random((5 + 4 * k, h // 8, w // 8)))
p4_maps.append(np.random.random((5 + 4 * k, h // 16, w // 16)))
p5_maps.append(np.random.random((5 + 4 * k, h // 32, w // 32)))
loss = fcenetloss(preds, 0, p3_maps, p4_maps, p5_maps)
assert isinstance(loss, dict)

View File

@ -12,3 +12,17 @@ def test_db_boxes_from_bitmaps():
boxes = db_decode(preds, text_repr_type='quad', min_text_width=0) boxes = db_decode(preds, text_repr_type='quad', min_text_width=0)
assert len(boxes) == 1 assert len(boxes) == 1
def test_fcenet_decode():
from mmocr.models.textdet.postprocess.wrapper import fcenet_decode
k = 5
preds = []
preds.append(torch.randn(1, 4, 40, 40))
preds.append(torch.randn(1, 4 * k + 2, 40, 40))
boundaries = fcenet_decode(
preds=preds, fourier_degree=k, reconstr_points=50, scale=1)
assert isinstance(boundaries, list)