add fcenet (#133)

* add fcenet

* fix linting and code style

* fcenet finetune

* Update transforms.py

* Update fcenet_r50dcnv2_fpn_1500e_ctw1500.py

* Update fcenet_targets.py

* Update fce_loss.py

* fix

* add readme

* fix config

* Update fcenet_r50dcnv2_fpn_1500e_ctw1500.py

* fix

* fix readme

* fix readme

* Update test_loss.py

Co-authored-by: Hongbin Sun <hongbin306@gmail.com>
This commit is contained in:
Zyq-scut 2021-05-14 21:37:04 +08:00 committed by GitHub
parent 1240169f18
commit cbdd98a1e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1510 additions and 10 deletions

View File

@ -0,0 +1,22 @@
# Fourier Contour Embedding for Arbitrary-Shaped Text Detection
## Introduction
[ALGORITHM]
```bibtex
@InProceedings{zhu2021fourier,
title={Fourier Contour Embedding for Arbitrary-Shaped Text Detection},
author={Yiqin Zhu and Jianyong Chen and Lingyu Liang and Zhanghui Kuang and Lianwen Jin and Wayne Zhang},
year={2021},
booktitle = {CVPR}
}
```
## Results and models
### CTW1500
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
| :--------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [FCENet](/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1500 |(736, 1080)| 0.828 | 0.875 | 0.851 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/20210511_181328.log.json) |

View File

@ -0,0 +1,134 @@
fourier_degree = 5
model = dict(
type='FCENet',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
dcn=dict(type='DCNv2', deform_groups=2, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
in_channels=[512, 1024, 2048],
out_channels=256,
add_extra_convs=True,
extra_convs_on_inputs=False, # use P5
num_outs=3,
relu_before_extra_convs=True,
act_cfg=None),
bbox_head=dict(
type='FCEHead',
in_channels=256,
scales=(8, 16, 32),
loss=dict(type='FCELoss'),
fourier_degree=fourier_degree,
))
train_cfg = None
test_cfg = None
dataset_type = 'IcdarDataset'
data_root = 'data/ctw1500/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='LoadTextAnnotations',
with_bbox=True,
with_mask=True,
poly2mask=False),
dict(
type='ColorJitter',
brightness=32.0 / 255,
saturation=0.5,
contrast=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='RandomScaling', size=800, scale=(3. / 4, 5. / 2)),
dict(
type='RandomCropFlip', crop_ratio=0.5, iter_num=1, min_area_ratio=0.2),
dict(
type='RandomCropPolyInstances',
instance_key='gt_masks',
crop_ratio=0.8,
min_side_ratio=0.3),
dict(
type='RandomRotatePolyInstances',
rotate_ratio=0.5,
max_angle=30,
pad_with_fixed_color=False),
dict(type='SquareResizePad', target_size=800, pad_ratio=0.6),
dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
dict(type='Pad', size_divisor=32),
dict(
type='FCENetTargets',
fourier_degree=fourier_degree,
level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0))),
dict(
type='CustomFormatBundle',
keys=['p3_maps', 'p4_maps', 'p5_maps'],
visualize=dict(flag=False, boundary_key=None)),
dict(type='Collect', keys=['img', 'p3_maps', 'p4_maps', 'p5_maps'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1080, 736),
flip=False,
transforms=[
dict(type='Resize', img_scale=(1280, 800), keep_ratio=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=6,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
img_prefix=data_root + '/imgs',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline))
evaluation = dict(interval=5, metric='hmean-iou')
# optimizer
optimizer = dict(type='SGD', lr=1e-3, momentum=0.90, weight_decay=5e-4)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True)
total_epochs = 1500
checkpoint_config = dict(interval=5)
# yapf:disable
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook')
])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]

View File

@ -5,7 +5,7 @@ from .icdar_dataset import IcdarDataset
from .kie_dataset import KIEDataset
from .ocr_dataset import OCRDataset
from .ocr_seg_dataset import OCRSegDataset
from .pipelines import CustomFormatBundle, DBNetTargets
from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets
from .text_det_dataset import TextDetDataset
from .utils import * # NOQA
@ -13,7 +13,7 @@ from .utils import * # NOQA
__all__ = [
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
'DBNetTargets', 'OCRSegDataset', 'KIEDataset'
'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets'
]
__all__ += utils.__all__

View File

@ -8,10 +8,11 @@ from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
OpencvToPil, PilToOpencv, RandomPaddingOCR,
RandomRotateImageBox, ResizeOCR, ToTensorOCR)
from .test_time_aug import MultiRotateAugOCR
from .textdet_targets import DBNetTargets, PANetTargets, TextSnakeTargets
from .transforms import (ColorJitter, RandomCropInstances,
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
TextSnakeTargets)
from .transforms import (ColorJitter, RandomCropFlip, RandomCropInstances,
RandomCropPolyInstances, RandomRotatePolyInstances,
RandomRotateTextDet, ScaleAspectJitter,
RandomRotateTextDet, RandomScaling, ScaleAspectJitter,
SquareResizePad)
__all__ = [
@ -22,5 +23,6 @@ __all__ = [
'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR',
'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil',
'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8'
'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets',
'RandomScaling', 'RandomCropFlip'
]

View File

@ -1,10 +1,11 @@
from .base_textdet_targets import BaseTextDetTargets
from .dbnet_targets import DBNetTargets
from .fcenet_targets import FCENetTargets
from .panet_targets import PANetTargets
from .psenet_targets import PSENetTargets
from .textsnake_targets import TextSnakeTargets
__all__ = [
'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets',
'TextSnakeTargets'
'FCENetTargets', 'TextSnakeTargets'
]

View File

@ -0,0 +1,370 @@
import cv2
import numpy as np
from numpy.linalg import norm
import mmocr.utils.check_argument as check_argument
from mmdet.datasets.builder import PIPELINES
from .textsnake_targets import TextSnakeTargets
@PIPELINES.register_module()
class FCENetTargets(TextSnakeTargets):
"""Generate the ground truth targets of FCENet: Fourier Contour Embedding
for Arbitrary-Shaped Text Detection.
[https://arxiv.org/abs/2104.10442]
Args:
fourier_degree (int): The maximum Fourier transform degree k.
resample_step (float): The step size for resampling the text center
line (TCL). It's better not to exceed half of the minimum width.
center_region_shrink_ratio (float): The shrink ratio of text center
region.
level_size_divisors (tuple(int)): The downsample ratio on each level.
level_proportion_range (tuple(tuple(int))): The range of text sizes
assigned to each level.
"""
def __init__(self,
fourier_degree=5,
resample_step=4.0,
center_region_shrink_ratio=0.3,
level_size_divisors=(8, 16, 32),
level_proportion_range=((0, 0.4), (0.3, 0.7), (0.6, 1.0))):
super().__init__()
assert isinstance(level_size_divisors, tuple)
assert isinstance(level_proportion_range, tuple)
assert len(level_size_divisors) == len(level_proportion_range)
self.fourier_degree = fourier_degree
self.resample_step = resample_step
self.center_region_shrink_ratio = center_region_shrink_ratio
self.level_size_divisors = level_size_divisors
self.level_proportion_range = level_proportion_range
def generate_center_region_mask(self, img_size, text_polys):
"""Generate text center region mask.
Args:
img_size (tuple): The image size of (height, width).
text_polys (list[list[ndarray]]): The list of text polygons.
Returns:
center_region_mask (ndarray): The text center region mask.
"""
assert isinstance(img_size, tuple)
assert check_argument.is_2dlist(text_polys)
h, w = img_size
center_region_mask = np.zeros((h, w), np.uint8)
center_region_boxes = []
for poly in text_polys:
assert len(poly) == 1
polygon_points = poly[0].reshape(-1, 2)
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
resampled_top_line, resampled_bot_line = self.resample_sidelines(
top_line, bot_line, self.resample_step)
resampled_bot_line = resampled_bot_line[::-1]
center_line = (resampled_top_line + resampled_bot_line) / 2
line_head_shrink_len = norm(resampled_top_line[0] -
resampled_bot_line[0]) / 4.0
line_tail_shrink_len = norm(resampled_top_line[-1] -
resampled_bot_line[-1]) / 4.0
head_shrink_num = int(line_head_shrink_len // self.resample_step)
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
center_line = center_line[head_shrink_num:len(center_line) -
tail_shrink_num]
resampled_top_line = resampled_top_line[
head_shrink_num:len(resampled_top_line) - tail_shrink_num]
resampled_bot_line = resampled_bot_line[
head_shrink_num:len(resampled_bot_line) - tail_shrink_num]
for i in range(0, len(center_line) - 1):
tl = center_line[i] + (resampled_top_line[i] - center_line[i]
) * self.center_region_shrink_ratio
tr = center_line[i + 1] + (
resampled_top_line[i + 1] -
center_line[i + 1]) * self.center_region_shrink_ratio
br = center_line[i + 1] + (
resampled_bot_line[i + 1] -
center_line[i + 1]) * self.center_region_shrink_ratio
bl = center_line[i] + (resampled_bot_line[i] - center_line[i]
) * self.center_region_shrink_ratio
current_center_box = np.vstack([tl, tr, br,
bl]).astype(np.int32)
center_region_boxes.append(current_center_box)
cv2.fillPoly(center_region_mask, center_region_boxes, 1)
return center_region_mask
def resample_polygon(self, polygon, n=400):
"""Resample one polygon with n points on its boundary.
Args:
polygon (list[float]): The input polygon.
n (int): The number of resampled points.
Returns:
resampled_polygon (list[float]): The resampled polygon.
"""
length = []
for i in range(len(polygon)):
p1 = polygon[i]
if i == len(polygon) - 1:
p2 = polygon[0]
else:
p2 = polygon[i + 1]
length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5)
total_length = sum(length)
n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
n_on_each_line = n_on_each_line.astype(np.int32)
new_polygon = []
for i in range(len(polygon)):
num = n_on_each_line[i]
p1 = polygon[i]
if i == len(polygon) - 1:
p2 = polygon[0]
else:
p2 = polygon[i + 1]
if num == 0:
continue
dxdy = (p2 - p1) / num
for j in range(num):
point = p1 + dxdy * j
new_polygon.append(point)
return np.array(new_polygon)
def normalize_polygon(self, polygon):
"""Normalize one polygon so that its start point is at right most.
Args:
polygon (list[float]): The origin polygon.
Returns:
new_polygon (lost[float]): The polygon with start point at right.
"""
temp_polygon = polygon - polygon.mean(axis=0)
x = np.abs(temp_polygon[:, 0])
y = temp_polygon[:, 1]
index_x = np.argsort(x)
index_y = np.argmin(y[index_x[:8]])
index = index_x[index_y]
new_polygon = np.concatenate([polygon[index:], polygon[:index]])
return new_polygon
def fourier_transform(self, polygon, fourier_degree):
"""Perform Fourier transformation to generate Fourier coefficients ck
from polygon.
Args:
polygon (ndarray): An input polygon.
fourier_degree (int): The maximum Fourier degree K.
Returns:
c (ndarray(complex)): Fourier coefficients.
"""
points = polygon[:, 0] + polygon[:, 1] * 1j
n = len(points)
t = np.multiply([i / n for i in range(n)], -2 * np.pi * 1j)
e = complex(np.e)
c = np.zeros((2 * fourier_degree + 1, ), dtype='complex')
for i in range(-fourier_degree, fourier_degree + 1):
c[i + fourier_degree] = np.sum(points * np.power(e, i * t)) / n
return c
def clockwise(self, c, fourier_degree):
"""Make sure the polygon reconstructed from Fourier coefficients c in
the clockwise direction.
Args:
polygon (list[float]): The origin polygon.
Returns:
new_polygon (lost[float]): The polygon in clockwise point order.
"""
if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
return c
elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
return c[::-1]
else:
if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
return c
else:
return c[::-1]
def cal_fourier_signature(self, polygon, fourier_degree):
"""Calculate Fourier signature from input polygon.
Args:
polygon (ndarray): The input polygon.
fourier_degree (int): The maximum Fourier degree K.
Returns:
fourier_signature (ndarray): An array shaped (2k+1, 2) containing
real part and image part of 2k+1 Fourier coefficients.
"""
resampled_polygon = self.resample_polygon(polygon)
resampled_polygon = self.normalize_polygon(resampled_polygon)
fourier_coeff = self.fourier_transform(resampled_polygon,
fourier_degree)
fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
real_part = np.real(fourier_coeff).reshape((-1, 1))
image_part = np.imag(fourier_coeff).reshape((-1, 1))
fourier_signature = np.hstack([real_part, image_part])
return fourier_signature
def generate_fourier_maps(self, img_size, text_polys):
"""Generate Fourier coefficient maps.
Args:
img_size (tuple): The image size of (height, width).
text_polys (list[list[ndarray]]): The list of text polygons.
Returns:
fourier_real_map (ndarray): The Fourier coefficient real part maps.
fourier_image_map (ndarray): The Fourier coefficient image part
maps.
"""
assert isinstance(img_size, tuple)
assert check_argument.is_2dlist(text_polys)
h, w = img_size
k = self.fourier_degree
real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
for poly in text_polys:
assert len(poly) == 1
text_instance = [[poly[0][i], poly[0][i + 1]]
for i in range(0, len(poly[0]), 2)]
mask = np.zeros((h, w), dtype=np.uint8)
polygon = np.array(text_instance).reshape((1, -1, 2))
cv2.fillPoly(mask, polygon.astype(np.int32), 1)
fourier_coeff = self.cal_fourier_signature(polygon[0], k)
for i in range(-k, k + 1):
if i != 0:
real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + (
1 - mask) * real_map[i + k, :, :]
imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + (
1 - mask) * imag_map[i + k, :, :]
else:
yx = np.argwhere(mask > 0.5)
k_ind = np.ones((len(yx)), dtype=np.int64) * k
y, x = yx[:, 0], yx[:, 1]
real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
return real_map, imag_map
def generate_level_targets(self, img_size, text_polys, ignore_polys):
"""Generate ground truth target on each level.
Args:
img_size (list[int]): Shape of input image.
text_polys (list[list[ndarray]]): A list of ground truth polygons.
ignore_polys (list[list[ndarray]]): A list of ignored polygons.
Returns:
level_maps (list(ndarray)): A list of ground target on each level.
"""
h, w = img_size
lv_size_divs = self.level_size_divisors
lv_proportion_range = self.level_proportion_range
lv_text_polys = [[] for i in range(len(lv_size_divs))]
lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
level_maps = []
for poly in text_polys:
assert len(poly) == 1
text_instance = [[poly[0][i], poly[0][i + 1]]
for i in range(0, len(poly[0]), 2)]
polygon = np.array(text_instance, dtype=np.int).reshape((1, -1, 2))
_, _, box_w, box_h = cv2.boundingRect(polygon)
proportion = max(box_h, box_w) / (h + 1e-8)
for ind, proportion_range in enumerate(lv_proportion_range):
if proportion_range[0] < proportion < proportion_range[1]:
lv_text_polys[ind].append([poly[0] / lv_size_divs[ind]])
for ignore_poly in ignore_polys:
assert len(ignore_poly) == 1
text_instance = [[ignore_poly[0][i], ignore_poly[0][i + 1]]
for i in range(0, len(ignore_poly[0]), 2)]
polygon = np.array(text_instance, dtype=np.int).reshape((1, -1, 2))
_, _, box_w, box_h = cv2.boundingRect(polygon)
proportion = max(box_h, box_w) / (h + 1e-8)
for ind, proportion_range in enumerate(lv_proportion_range):
if proportion_range[0] < proportion < proportion_range[1]:
lv_text_polys[ind].append(
[ignore_poly[0] / lv_size_divs[ind]])
for ind, size_divisor in enumerate(lv_size_divs):
current_level_maps = []
level_img_size = (h // size_divisor, w // size_divisor)
text_region = self.generate_text_region_mask(
level_img_size, lv_text_polys[ind])
text_region = np.expand_dims(text_region, axis=0)
current_level_maps.append(text_region)
center_region = self.generate_center_region_mask(
level_img_size, lv_text_polys[ind])
center_region = np.expand_dims(center_region, axis=0)
current_level_maps.append(center_region)
effective_mask = self.generate_effective_mask(
level_img_size, lv_ignore_polys[ind])
effective_mask = np.expand_dims(effective_mask, axis=0)
current_level_maps.append(effective_mask)
fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
level_img_size, lv_text_polys[ind])
current_level_maps.append(fourier_real_map)
current_level_maps.append(fourier_image_maps)
level_maps.append(np.concatenate(current_level_maps))
return level_maps
def generate_targets(self, results):
"""Generate the ground truth targets for FCENet.
Args:
results (dict): The input result dictionary.
Returns:
results (dict): The output result dictionary.
"""
assert isinstance(results, dict)
polygon_masks = results['gt_masks'].masks
polygon_masks_ignore = results['gt_masks_ignore'].masks
h, w, _ = results['img_shape']
level_maps = self.generate_level_targets((h, w), polygon_masks,
polygon_masks_ignore)
results['mask_fields'].clear() # rm gt_masks encoded by polygons
mapping = {
'p3_maps': level_maps[0],
'p4_maps': level_maps[1],
'p5_maps': level_maps[2]
}
for key, value in mapping.items():
results[key] = value
return results

View File

@ -2,6 +2,7 @@ import math
import cv2
import numpy as np
import Polygon as plg
import torchvision.transforms as transforms
from PIL import Image
@ -731,3 +732,231 @@ class SquareResizePad:
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
@PIPELINES.register_module()
class RandomScaling:
def __init__(self, size=800, scale=(3. / 4, 5. / 2)):
"""Random scale the image while keeping aspect.
Args:
size (int) : Base size before scaling.
scale (tuple(float)) : The range of scaling.
"""
assert isinstance(size, int)
assert isinstance(scale, float) or isinstance(scale, tuple)
self.size = size
self.scale = scale if isinstance(scale, tuple) \
else (1 - scale, 1 + scale)
def __call__(self, results):
image = results['img']
h, w, _ = results['img_shape']
aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
scales = self.size * 1.0 / max(h, w) * aspect_ratio
scales = np.array([scales, scales])
out_size = (int(h * scales[1]), int(w * scales[0]))
image = cv2.resize(image, out_size[::-1])
results['img'] = image
results['img_shape'] = image.shape
for key in results.get('mask_fields', []):
if len(results[key].masks) == 0:
continue
results[key] = results[key].resize(out_size)
return results
@PIPELINES.register_module()
class RandomCropFlip:
def __init__(self, crop_ratio=0.5, iter_num=1, min_area_ratio=0.2):
"""Random crop and flip a patch of the image.
Args:
crop_ratio (float): The ratio of cropping.
iter_num (int): Number of operations.
min_area_ratio (float): Minimal area ratio between cropped patch
and original image.
"""
assert isinstance(crop_ratio, float)
assert isinstance(iter_num, int)
assert isinstance(min_area_ratio, float)
self.scale = 10
self.epsilon = 1e-2
self.crop_ratio = crop_ratio
self.iter_num = iter_num
self.min_area_ratio = min_area_ratio
def __call__(self, results):
for i in range(self.iter_num):
results = self.random_crop_flip(results)
return results
def random_crop_flip(self, results):
image = results['img']
polygons = results['gt_masks'].masks
ignore_polygons = results['gt_masks_ignore'].masks
all_polygons = polygons + ignore_polygons
if len(polygons) == 0:
return results
if np.random.random() >= self.crop_ratio:
return results
h_axis, w_axis = self.crop_target(image, all_polygons, self.scale)
if len(h_axis) == 0 or len(w_axis) == 0:
return results
attempt = 0
h, w, _ = results['img_shape']
area = h * w
pad_h = h // self.scale
pad_w = w // self.scale
while attempt < 10:
attempt += 1
polys_keep = []
polys_new = []
ign_polys_keep = []
ign_polys_new = []
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w - 1)
xmax = np.clip(xmax, 0, w - 1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
# area too small
continue
pts = np.stack([[xmin, xmax, xmax, xmin],
[ymin, ymin, ymax, ymax]]).T.astype(np.int32)
pp = plg.Polygon(pts)
fail_flag = False
for polygon in polygons:
ppi = plg.Polygon(polygon[0].reshape(-1, 2))
ppiou, _ = eval_utils.poly_intersection(ppi, pp)
if np.abs(ppiou - float(ppi.area())) > self.epsilon and \
np.abs(ppiou) > self.epsilon:
fail_flag = True
break
elif np.abs(ppiou - float(ppi.area())) < self.epsilon:
polys_new.append(polygon)
else:
polys_keep.append(polygon)
for polygon in ignore_polygons:
ppi = plg.Polygon(polygon[0].reshape(-1, 2))
ppiou, _ = eval_utils.poly_intersection(ppi, pp)
if np.abs(ppiou - float(ppi.area())) > self.epsilon and \
np.abs(ppiou) > self.epsilon:
fail_flag = True
break
elif np.abs(ppiou - float(ppi.area())) < self.epsilon:
ign_polys_new.append(polygon)
else:
ign_polys_keep.append(polygon)
if fail_flag:
continue
else:
break
cropped = image[ymin:ymax, xmin:xmax, :]
select_type = np.random.randint(3)
if select_type == 0:
img = np.ascontiguousarray(cropped[:, ::-1])
elif select_type == 1:
img = np.ascontiguousarray(cropped[::-1, :])
else:
img = np.ascontiguousarray(cropped[::-1, ::-1])
image[ymin:ymax, xmin:xmax, :] = img
results['img'] = image
if len(polys_new) + len(ign_polys_new) != 0:
height, width, _ = cropped.shape
if select_type == 0:
for idx, polygon in enumerate(polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 0] = width - poly[:, 0] + 2 * xmin
polys_new[idx] = [poly.reshape(-1, )]
for idx, polygon in enumerate(ign_polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 0] = width - poly[:, 0] + 2 * xmin
ign_polys_new[idx] = [poly.reshape(-1, )]
elif select_type == 1:
for idx, polygon in enumerate(polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 1] = height - poly[:, 1] + 2 * ymin
polys_new[idx] = [poly.reshape(-1, )]
for idx, polygon in enumerate(ign_polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 1] = height - poly[:, 1] + 2 * ymin
ign_polys_new[idx] = [poly.reshape(-1, )]
else:
for idx, polygon in enumerate(polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 0] = width - poly[:, 0] + 2 * xmin
poly[:, 1] = height - poly[:, 1] + 2 * ymin
polys_new[idx] = [poly.reshape(-1, )]
for idx, polygon in enumerate(ign_polys_new):
poly = polygon[0].reshape(-1, 2)
poly[:, 0] = width - poly[:, 0] + 2 * xmin
poly[:, 1] = height - poly[:, 1] + 2 * ymin
ign_polys_new[idx] = [poly.reshape(-1, )]
polygons = polys_keep + polys_new
ignore_polygons = ign_polys_keep + ign_polys_new
results['gt_masks'] = PolygonMasks(polygons, *(image.shape[:2]))
results['gt_masks_ignore'] = PolygonMasks(ignore_polygons,
*(image.shape[:2]))
return results
def crop_target(self, image, all_polys, scale):
"""Generate crop target and make sure not to crop the polygon
instances.
Args:
image (ndarray): The image waited to be crop.
all_polys (list[list[ndarray]]): All polygons including ground
truth polygons and ground truth ignored polygons.
scale (int): A scale factor to control crop range.
Returns:
h_axis (ndarray): Vertical cropping range.
w_axis (ndarray): Horizontal cropping range.
"""
h, w, _ = image.shape
pad_h = h // scale
pad_w = w // scale
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
text_polys = []
for polygon in all_polys:
rect = cv2.minAreaRect(polygon[0].astype(np.int32).reshape(-1, 2))
box = cv2.boxPoints(rect)
box = np.int0(box)
text_polys.append([box[0], box[1], box[2], box[3]])
polys = np.array(text_polys, dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32) # 四舍五入
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx + pad_w:maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny + pad_h:maxy + pad_h] = 1
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
return h_axis, w_axis

View File

@ -1,7 +1,10 @@
from .db_head import DBHead
from .fce_head import FCEHead
from .head_mixin import HeadMixin
from .pan_head import PANHead
from .pse_head import PSEHead
from .textsnake_head import TextSnakeHead
__all__ = ['PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'TextSnakeHead']
__all__ = [
'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'HeadMixin', 'TextSnakeHead'
]

View File

@ -0,0 +1,134 @@
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import multi_apply
from mmdet.models.builder import HEADS, build_loss
from mmocr.models.textdet.postprocess import decode
from ..postprocess.wrapper import poly_nms
from .head_mixin import HeadMixin
@HEADS.register_module()
class FCEHead(HeadMixin, nn.Module):
"""The class for implementing FCENet head.
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
Detection.
[https://arxiv.org/abs/2104.10442]
Args:
in_channels (int): The number of input channels.
scales (list[int]) : The scale of each layer.
fourier_degree (int) : The maximum Fourier transform degree k.
sample_num (int) : The sampling points number of regression
loss. If it is too small, FCEnet tends to be overfitting.
score_thresh (float) : The threshold to filter out the final
candidates.
nms_thresh (float) : The threshold of nms.
alpha (float) : The parameter to calculate final scores. Score_{final}
= (Score_{text region} ^ alpha)
* (Score{text center region} ^ beta)
beta (float) :The parameter to calculate final scores.
"""
def __init__(self,
in_channels,
scales,
fourier_degree=5,
sample_num=50,
reconstr_points=50,
decoding_type='fcenet',
loss=dict(type='FCELoss'),
score_thresh=0.3,
nms_thresh=0.1,
alpha=1.0,
beta=1.0,
train_cfg=None,
test_cfg=None):
super().__init__()
assert isinstance(in_channels, int)
self.downsample_ratio = 1.0
self.in_channels = in_channels
self.scales = scales
self.fourier_degree = fourier_degree
self.sample_num = sample_num
self.reconstr_points = reconstr_points
loss['fourier_degree'] = fourier_degree
loss['sample_num'] = sample_num
self.decoding_type = decoding_type
self.loss_module = build_loss(loss)
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.alpha = alpha
self.beta = beta
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.out_channels_cls = 4
self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
self.out_conv_cls = nn.Conv2d(
self.in_channels,
self.out_channels_cls,
kernel_size=3,
stride=1,
padding=1)
self.out_conv_reg = nn.Conv2d(
self.in_channels,
self.out_channels_reg,
kernel_size=3,
stride=1,
padding=1)
self.init_weights()
def init_weights(self):
normal_init(self.out_conv_cls, mean=0, std=0.01)
normal_init(self.out_conv_reg, mean=0, std=0.01)
def forward(self, feats):
cls_res, reg_res = multi_apply(self.forward_single, feats)
level_num = len(cls_res)
preds = [[cls_res[i], reg_res[i]] for i in range(level_num)]
return preds
def forward_single(self, x):
cls_predict = self.out_conv_cls(x)
reg_predict = self.out_conv_reg(x)
return cls_predict, reg_predict
def get_boundary(self, score_maps, img_metas, rescale):
assert len(score_maps) == len(self.scales)
boundaries = []
for idx, score_map in enumerate(score_maps):
scale = self.scales[idx]
boundaries = boundaries + self._get_boundary_single(
score_map, scale)
# nms
boundaries = poly_nms(boundaries, self.nms_thresh)
if rescale:
boundaries = self.resize_boundary(
boundaries, 1.0 / img_metas[0]['scale_factor'])
results = dict(boundary_result=boundaries)
return results
def _get_boundary_single(self, score_map, scale):
assert len(score_map) == 2
assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
return decode(
decoding_type=self.decoding_type,
preds=score_map,
fourier_degree=self.fourier_degree,
reconstr_points=self.reconstr_points,
scale=scale,
alpha=self.alpha,
beta=self.beta,
text_repr_type='poly',
score_thresh=self.score_thresh,
nms_thresh=self.nms_thresh)

View File

@ -1,4 +1,5 @@
from .dbnet import DBNet
from .fcenet import FCENet
from .ocr_mask_rcnn import OCRMaskRCNN
from .panet import PANet
from .psenet import PSENet
@ -8,5 +9,5 @@ from .textsnake import TextSnake
__all__ = [
'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',
'PANet', 'PSENet', 'TextSnake'
'PANet', 'PSENet', 'TextSnake', 'FCENet'
]

View File

@ -0,0 +1,32 @@
from mmdet.models.builder import DETECTORS
from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module()
class FCENet(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing FCENet text detector
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
Detection
[https://arxiv.org/abs/2104.10442]
"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None,
show_score=False):
SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
train_cfg, test_cfg, pretrained)
TextDetectorMixin.__init__(self, show_score)
def simple_test(self, img, img_metas, rescale=False):
x = self.extract_feat(img)
outs = self.bbox_head(x)
boundaries = self.bbox_head.get_boundary(outs, img_metas, rescale)
return [boundaries]

View File

@ -1,6 +1,7 @@
from .db_loss import DBLoss
from .fce_loss import FCELoss
from .pan_loss import PANLoss
from .pse_loss import PSELoss
from .textsnake_loss import TextSnakeLoss
__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss']
__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss', 'FCELoss']

View File

@ -0,0 +1,194 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from mmdet.core import multi_apply
from mmdet.models.builder import LOSSES
@LOSSES.register_module()
class FCELoss(nn.Module):
"""The class for implementing FCENet loss
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped
Text Detection
[https://arxiv.org/abs/2104.10442]
Args:
fourier_degree (int) : The maximum Fourier transform degree k.
sample_num (int) : The sampling points number of regression
loss. If it is too small, fcenet tends to be overfitting.
ohem_ratio (float): the negative/positive ratio in OHEM.
"""
def __init__(self, fourier_degree, sample_num, ohem_ratio=3.):
super().__init__()
self.fourier_degree = fourier_degree
self.sample_num = sample_num
self.ohem_ratio = ohem_ratio
def forward(self, preds, _, p3_maps, p4_maps, p5_maps):
assert isinstance(preds, list)
assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
'fourier degree not equal in FCEhead and FCEtarget'
device = preds[0][0].device
# to tensor
gts = [p3_maps, p4_maps, p5_maps]
for idx, maps in enumerate(gts):
gts[idx] = torch.from_numpy(np.stack(maps)).float().to(device)
losses = multi_apply(self.forward_single, preds, gts)
loss_tr = torch.tensor(0., device=device).float()
loss_tcl = torch.tensor(0., device=device).float()
loss_reg_x = torch.tensor(0., device=device).float()
loss_reg_y = torch.tensor(0., device=device).float()
for idx, loss in enumerate(losses):
if idx == 0:
loss_tr += sum(loss)
elif idx == 1:
loss_tcl += sum(loss)
elif idx == 2:
loss_reg_x += sum(loss)
else:
loss_reg_y += sum(loss)
results = dict(
loss_text=loss_tr,
loss_center=loss_tcl,
loss_reg_x=loss_reg_x,
loss_reg_y=loss_reg_y,
)
return results
def forward_single(self, pred, gt):
cls_pred, reg_pred = pred[0], pred[1]
tr_pred = cls_pred[:, :2, :, :].permute(0, 2, 3, 1)\
.contiguous().view(-1, 2)
tcl_pred = cls_pred[:, 2:, :, :].permute(0, 2, 3, 1)\
.contiguous().view(-1, 2)
x_pred = reg_pred[:, 0:2 * self.fourier_degree + 1, :, :]\
.permute(0, 2, 3, 1).contiguous().view(
-1, 2 * self.fourier_degree + 1)
y_pred = reg_pred[:,
2 * self.fourier_degree + 1:4 * self.fourier_degree +
2, :, :].permute(0, 2, 3, 1).contiguous().view(
-1, 2 * self.fourier_degree + 1)
tr_mask = gt[:, :1, :, :].permute(0, 2, 3, 1).contiguous().view(-1)
tcl_mask = gt[:, 1:2, :, :].permute(0, 2, 3, 1).contiguous().view(-1)
train_mask = gt[:, 2:3, :, :].permute(0, 2, 3, 1).contiguous().view(-1)
x_map = gt[:, 3:4 + 2 * self.fourier_degree, :, :].permute(
0, 2, 3, 1).contiguous().view(-1, 2 * self.fourier_degree + 1)
y_map = gt[:, 4 + 2 * self.fourier_degree:, :, :].permute(
0, 2, 3, 1).contiguous().view(-1, 2 * self.fourier_degree + 1)
tr_train_mask = train_mask * tr_mask
device = x_map.device
# tr loss
loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long())
# tcl loss
loss_tcl = torch.tensor(0.).float().to(device)
tr_neg_mask = 1 - tr_train_mask
if tr_train_mask.sum().item() > 0:
loss_tcl_pos = F.cross_entropy(
tcl_pred[tr_train_mask.bool()],
tcl_mask[tr_train_mask.bool()].long())
loss_tcl_neg = F.cross_entropy(tcl_pred[tr_neg_mask.bool()],
tcl_mask[tr_neg_mask.bool()].long())
loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
# regression loss
loss_reg_x = torch.tensor(0.).float().to(device)
loss_reg_y = torch.tensor(0.).float().to(device)
if tr_train_mask.sum().item() > 0:
weight = (tr_mask[tr_train_mask.bool()].float() +
tcl_mask[tr_train_mask.bool()].float()) / 2
weight = weight.contiguous().view(-1, 1)
ft_x, ft_y = self.fourier2poly(x_map, y_map)
ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
loss_reg_x = torch.mean(weight * F.smooth_l1_loss(
ft_x_pre[tr_train_mask.bool()],
ft_x[tr_train_mask.bool()],
reduction='none'))
loss_reg_y = torch.mean(weight * F.smooth_l1_loss(
ft_y_pre[tr_train_mask.bool()],
ft_y[tr_train_mask.bool()],
reduction='none'))
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
def ohem(self, predict, target, train_mask):
pos = (target * train_mask).bool()
neg = ((1 - target) * train_mask).bool()
n_pos = pos.float().sum()
if n_pos.item() > 0:
loss_pos = F.cross_entropy(
predict[pos], target[pos], reduction='sum')
loss_neg = F.cross_entropy(
predict[neg], target[neg], reduction='none')
n_neg = min(
int(neg.float().sum().item()),
int(self.ohem_ratio * n_pos.float()))
else:
loss_pos = torch.tensor(0.)
loss_neg = F.cross_entropy(
predict[neg], target[neg], reduction='none')
n_neg = 100
if len(loss_neg) > n_neg:
loss_neg, _ = torch.topk(loss_neg, n_neg)
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
def fourier2poly(self, real_maps, imag_maps):
"""Transform Fourier coefficient maps to polygon maps.
Args:
real_maps (tensor): A map composed of the real parts of the
Fourier coefficients, whose shape is (-1, 2k+1)
imag_maps (tensor):A map composed of the imag parts of the
Fourier coefficients, whose shape is (-1, 2k+1)
Returns
x_maps (tensor): A map composed of the x value of the polygon
represented by n sample points (xn, yn), whose shape is (-1, n)
y_maps (tensor): A map composed of the y value of the polygon
represented by n sample points (xn, yn), whose shape is (-1, n)
"""
device = real_maps.device
k_vect = torch.arange(
-self.fourier_degree,
self.fourier_degree + 1,
dtype=torch.float,
device=device).view(-1, 1)
i_vect = torch.arange(
0, self.sample_num, dtype=torch.float, device=device).view(1, -1)
transform_matrix = 2 * np.pi / self.sample_num * torch.mm(
k_vect, i_vect)
x1 = torch.einsum('ak, kn-> an', real_maps,
torch.cos(transform_matrix))
x2 = torch.einsum('ak, kn-> an', imag_maps,
torch.sin(transform_matrix))
y1 = torch.einsum('ak, kn-> an', real_maps,
torch.sin(transform_matrix))
y2 = torch.einsum('ak, kn-> an', imag_maps,
torch.cos(transform_matrix))
x_maps = x1 - x2
y_maps = y1 + y2
return x_maps, y_maps

View File

@ -7,6 +7,7 @@ from shapely.geometry import Polygon
from skimage.morphology import skeletonize
from mmocr.core import points2boundary
from mmocr.core.evaluation.utils import boundary_iou
def filter_instance(area, confidence, min_area, min_confidence):
@ -24,6 +25,8 @@ def decode(
return db_decode(**kwargs)
if decoding_type == 'textsnake':
return textsnake_decode(**kwargs)
if decoding_type == 'fcenet':
return fcenet_decode(**kwargs)
raise NotImplementedError
@ -391,3 +394,177 @@ def textsnake_decode(preds,
boundaries.append(boundary + [score])
return boundaries
def fcenet_decode(
preds,
fourier_degree,
reconstr_points,
scale,
alpha=1.0,
beta=2.0,
text_repr_type='poly',
score_thresh=0.8,
nms_thresh=0.1,
):
"""Decoding predictions of FCENet to instances.
Args:
preds (list(Tensor)): The head output tensors.
fourier_degree (int): The maximum Fourier transform degree k.
reconstr_points (int): The points number of the polygon reconstructed
from predicted Fourier coefficients.
scale (int): The downsample scale of the prediction.
alpha (float) : The parameter to calculate final scores. Score_{final}
= (Score_{text region} ^ alpha)
* (Score_{text center region}^ beta)
beta (float) : The parameter to calculate final score.
text_repr_type (str): Boundary encoding type 'poly' or 'quad'.
score_thresh (float) : The threshold used to filter out the final
candidates.
nms_thresh (float) : The threshold of nms.
Returns:
boundaries (list[list[float]]): The instance boundary and confidence
list.
"""
assert isinstance(preds, list)
assert len(preds) == 2
assert text_repr_type == 'poly'
cls_pred = preds[0][0]
tr_pred = cls_pred[0:2].softmax(dim=0).data.cpu().numpy()
tcl_pred = cls_pred[2:].softmax(dim=0).data.cpu().numpy()
reg_pred = preds[1][0].permute(1, 2, 0).data.cpu().numpy()
x_pred = reg_pred[:, :, :2 * fourier_degree + 1]
y_pred = reg_pred[:, :, 2 * fourier_degree + 1:]
score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta)
tr_pred_mask = (score_pred) > score_thresh
tr_mask = fill_hole(tr_pred_mask)
tr_contours, _ = cv2.findContours(
tr_mask.astype(np.uint8), cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE) # opencv4
mask = np.zeros_like(tr_mask)
exp_matrix = generate_exp_matrix(reconstr_points, fourier_degree)
boundaries = []
for cont in tr_contours:
deal_map = mask.copy().astype(np.int8)
cv2.drawContours(deal_map, [cont], -1, 1, -1)
text_map = score_pred * deal_map
polygons = contour_transfor_inv(fourier_degree, x_pred, y_pred,
text_map, exp_matrix, scale)
polygons = poly_nms(polygons, nms_thresh)
boundaries = boundaries + polygons
boundaries = poly_nms(boundaries, nms_thresh)
return boundaries
def poly_nms(polygons, threshold):
assert isinstance(polygons, list)
polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
keep_poly = []
index = [i for i in range(polygons.shape[0])]
while len(index) > 0:
keep_poly.append(polygons[index[-1]].tolist())
A = polygons[index[-1]][:-1]
index = np.delete(index, -1)
iou_list = np.zeros((len(index), ))
for i in range(len(index)):
B = polygons[index[i]][:-1]
iou_list[i] = boundary_iou(A, B)
remove_index = np.where(iou_list > threshold)
index = np.delete(index, remove_index)
return keep_poly
def contour_transfor_inv(fourier_degree, x_pred, y_pred, score_map, exp_matrix,
scale):
"""Reconstruct polygon from predicts.
Args:
fourier_degree (int): The maximum Fourier degree K.
x_pred (ndarray): The real part of predicted Fourier coefficients.
y_pred (ndarray): The image part of predicted Fourier coefficients.
score_map (ndarray): The final score of candidates.
exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt, and shape
is (2k+1, n') where n' is reconstructed point number. See Eq.2
in paper.
scale (int): The down-sample scale.
Returns:
Polygons (list): The reconstructed polygons and scores.
"""
mask = score_map > 0
xy_text = np.argwhere(mask)
dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
x = x_pred[mask]
y = y_pred[mask]
c = x + y * 1j
c[:, fourier_degree] = c[:, fourier_degree] + dxy
c *= scale
polygons = fourier_inverse_matrix(c, exp_matrix=exp_matrix)
score = score_map[mask].reshape(-1, 1)
return np.hstack((polygons, score)).tolist()
def fourier_inverse_matrix(fourier_coeff, exp_matrix):
""" Inverse Fourier transform
Args:
fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1), with
n and k being candidates number and Fourier degree respectively.
exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt and shape
is (2k+1, n') where n' is reconstructed point number.
See Eq.2 in paper.
Returns:
Polygons (ndarray): The reconstructed polygons shaped (n, n')
"""
assert type(fourier_coeff) == np.ndarray
assert fourier_coeff.shape[1] == exp_matrix.shape[0]
n = exp_matrix.shape[1]
polygons = np.zeros((fourier_coeff.shape[0], n, 2))
points = np.matmul(fourier_coeff, exp_matrix)
p_x = np.real(points)
p_y = np.imag(points)
polygons[:, :, 0] = p_x
polygons[:, :, 1] = p_y
return polygons.astype('int32').reshape(polygons.shape[0], -1)
def generate_exp_matrix(point_num, fourier_degree):
""" Generate a matrix of e^x, where x = 2pi x ikt. See Eq.2 in paper.
Args:
point_num (int): Number of reconstruct points of polygon
fourier_degree (int): Maximum Fourier degree k
Returns:
exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt and
shape is (2k+1, n') where n' is reconstructed point number.
"""
e = complex(np.e)
exp_matrix = np.zeros([2 * fourier_degree + 1, point_num], dtype='complex')
temp = np.zeros([point_num], dtype='complex')
for i in range(point_num):
temp[i] = 2 * np.pi * 1j / point_num * i
for i in range(2 * fourier_degree + 1):
exp_matrix[i, :] = temp * (i - fourier_degree)
return np.power(e, exp_matrix)

View File

@ -218,3 +218,27 @@ def test_gen_textsnake_targets(mock_show_feature):
assert 'gt_sin_map' in output.keys()
assert 'gt_cos_map' in output.keys()
mock_show_feature.assert_called_once()
def test_fcenet_generate_targets():
fourier_degree = 5
target_generator = textdet_targets.FCENetTargets(
fourier_degree=fourier_degree)
h, w, c = (64, 64, 3)
text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])],
[np.array([20, 0, 30, 0, 30, 10, 20, 10])]]
text_polys_ignore = [[np.array([0, 0, 15, 0, 15, 10, 0, 10])]]
results = {}
results['mask_fields'] = []
results['img_shape'] = (h, w, c)
results['gt_masks_ignore'] = PolygonMasks(text_polys_ignore, h, w)
results['gt_masks'] = PolygonMasks(text_polys, h, w)
results['gt_bboxes'] = np.array([[0, 0, 10, 10], [20, 0, 30, 10]])
results['gt_labels'] = np.array([0, 1])
target_generator.generate_targets(results)
assert 'p3_maps' in results.keys()
assert 'p4_maps' in results.keys()
assert 'p5_maps' in results.keys()

View File

@ -166,6 +166,72 @@ def test_affine_jitter():
assert np.allclose(np.array(output1), output2['img'])
def test_random_scale():
h, w, c = 100, 100, 3
img = np.ones((h, w, c), dtype=np.uint8)
results = {'img': img, 'img_shape': (h, w, c)}
polygon = np.array([0., 0., 0., 10., 10., 10., 10., 0.])
results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2]))
results['mask_fields'] = ['gt_masks']
size = 100
scale = (2., 2.)
random_scaler = transforms.RandomScaling(size=size, scale=scale)
results = random_scaler(results)
out_img = results['img']
out_poly = results['gt_masks'].masks[0][0]
gt_poly = polygon * 2
assert np.allclose(out_img.shape, (2 * h, 2 * w, c))
assert np.allclose(out_poly, gt_poly)
@mock.patch('%s.transforms.np.random.randint' % __name__)
def test_random_crop_flip(mock_randint):
img = np.ones((10, 10, 3), dtype=np.uint8)
img[0, 0, :] = 0
results = {'img': img, 'img_shape': img.shape}
polygon = np.array([0., 0., 0., 10., 10., 10., 10., 0.])
results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2]))
results['gt_masks_ignore'] = PolygonMasks([], *(img.shape[:2]))
results['mask_fields'] = ['gt_masks', 'gt_masks_ignore']
crop_ratio = 1.1
iter_num = 3
random_crop_fliper = transforms.RandomCropFlip(
crop_ratio=crop_ratio, iter_num=iter_num)
# test crop_target
scale = 10
all_polys = results['gt_masks'].masks
h_axis, w_axis = random_crop_fliper.crop_target(img, all_polys, scale)
assert np.allclose(h_axis, (0, 11))
assert np.allclose(w_axis, (0, 11))
# test __call__
polygon = np.array([1., 1., 1., 9., 9., 9., 9., 1.])
results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2]))
results['gt_masks_ignore'] = PolygonMasks([[polygon]], *(img.shape[:2]))
mock_randint.side_effect = [0, 1, 2]
results = random_crop_fliper(results)
out_img = results['img']
out_poly = results['gt_masks'].masks[0][0]
gt_img = img
gt_poly = polygon
assert np.allclose(out_img, gt_img)
assert np.allclose(out_poly, gt_poly)
@mock.patch('%s.transforms.np.random.random_sample' % __name__)
@mock.patch('%s.transforms.np.random.randint' % __name__)
def test_random_crop_poly_instances(mock_randint, mock_sample):

View File

@ -372,3 +372,60 @@ def test_textsnake(cfg_file):
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
img = np.random.rand(5, 5)
detector.show_result(img, results)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.parametrize(
'cfg_file', ['textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py'])
def test_fcenet(cfg_file):
model = _get_detector_cfg(cfg_file)
model['pretrained'] = None
model['backbone']['norm_cfg']['type'] = 'BN'
from mmocr.models import build_detector
detector = build_detector(model)
detector = detector.cuda()
fourier_degree = 5
input_shape = (1, 3, 256, 256)
(n, c, h, w) = input_shape
imgs = torch.randn(n, c, h, w).float().cuda()
img_metas = [{
'img_shape': (h, w, c),
'ori_shape': (h, w, c),
'pad_shape': (h, w, c),
'filename': '<demo>.png',
'scale_factor': np.array([1, 1, 1, 1]),
'flip': False,
} for _ in range(n)]
p3_maps = []
p4_maps = []
p5_maps = []
for _ in range(n):
p3_maps.append(
np.random.random((5 + 4 * fourier_degree, h // 8, w // 8)))
p4_maps.append(
np.random.random((5 + 4 * fourier_degree, h // 16, w // 16)))
p5_maps.append(
np.random.random((5 + 4 * fourier_degree, h // 32, w // 32)))
# Test forward train
losses = detector.forward(
imgs, img_metas, p3_maps=p3_maps, p4_maps=p4_maps, p5_maps=p5_maps)
assert isinstance(losses, dict)
# Test forward test
with torch.no_grad():
img_list = [g[None, :] for g in imgs]
batch_results = []
for one_img, one_meta in zip(img_list, img_metas):
result = detector.forward([one_img], [[one_meta]],
return_loss=False)
batch_results.append(result)
# Test show result
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
img = np.random.rand(5, 5)
detector.show_result(img, results)

View File

@ -31,3 +31,42 @@ def test_textsnakeloss():
bce_loss = textsnakeloss.balanced_bce_loss(pred, target, mask).item()
assert np.allclose(bce_loss, 0)
def test_fcenetloss():
k = 5
fcenetloss = losses.FCELoss(fourier_degree=k, sample_num=10)
input_shape = (1, 3, 64, 64)
(n, c, h, w) = input_shape
# test ohem
pred = torch.ones((200, 2), dtype=torch.float)
target = torch.ones((200, ), dtype=torch.long)
target[20:] = 0
mask = torch.ones((200, ), dtype=torch.long)
ohem_loss1 = fcenetloss.ohem(pred, target, mask)
ohem_loss2 = fcenetloss.ohem(pred, target, 1 - mask)
assert isinstance(ohem_loss1, torch.Tensor)
assert isinstance(ohem_loss2, torch.Tensor)
# test forward
preds = []
for i in range(n):
scale = 8 * 2**i
pred = []
pred.append(torch.rand(n, 4, h // scale, w // scale))
pred.append(torch.rand(n, 4 * k + 2, h // scale, w // scale))
preds.append(pred)
p3_maps = []
p4_maps = []
p5_maps = []
for _ in range(n):
p3_maps.append(np.random.random((5 + 4 * k, h // 8, w // 8)))
p4_maps.append(np.random.random((5 + 4 * k, h // 16, w // 16)))
p5_maps.append(np.random.random((5 + 4 * k, h // 32, w // 32)))
loss = fcenetloss(preds, 0, p3_maps, p4_maps, p5_maps)
assert isinstance(loss, dict)

View File

@ -12,3 +12,17 @@ def test_db_boxes_from_bitmaps():
boxes = db_decode(preds, text_repr_type='quad', min_text_width=0)
assert len(boxes) == 1
def test_fcenet_decode():
from mmocr.models.textdet.postprocess.wrapper import fcenet_decode
k = 5
preds = []
preds.append(torch.randn(1, 4, 40, 40))
preds.append(torch.randn(1, 4 * k + 2, 40, 40))
boundaries = fcenet_decode(
preds=preds, fourier_degree=k, reconstr_points=50, scale=1)
assert isinstance(boundaries, list)