mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
add fcenet (#133)
* add fcenet * fix linting and code style * fcenet finetune * Update transforms.py * Update fcenet_r50dcnv2_fpn_1500e_ctw1500.py * Update fcenet_targets.py * Update fce_loss.py * fix * add readme * fix config * Update fcenet_r50dcnv2_fpn_1500e_ctw1500.py * fix * fix readme * fix readme * Update test_loss.py Co-authored-by: Hongbin Sun <hongbin306@gmail.com>
This commit is contained in:
parent
1240169f18
commit
cbdd98a1e1
22
configs/textdet/fcenet/README.MD
Normal file
22
configs/textdet/fcenet/README.MD
Normal file
@ -0,0 +1,22 @@
|
||||
# Fourier Contour Embedding for Arbitrary-Shaped Text Detection
|
||||
|
||||
## Introduction
|
||||
|
||||
[ALGORITHM]
|
||||
|
||||
```bibtex
|
||||
@InProceedings{zhu2021fourier,
|
||||
title={Fourier Contour Embedding for Arbitrary-Shaped Text Detection},
|
||||
author={Yiqin Zhu and Jianyong Chen and Lingyu Liang and Zhanghui Kuang and Lianwen Jin and Wayne Zhang},
|
||||
year={2021},
|
||||
booktitle = {CVPR}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### CTW1500
|
||||
|
||||
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
|
||||
| :--------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| [FCENet](/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1500 |(736, 1080)| 0.828 | 0.875 | 0.851 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/20210511_181328.log.json) |
|
134
configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py
Normal file
134
configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py
Normal file
@ -0,0 +1,134 @@
|
||||
fourier_degree = 5
|
||||
model = dict(
|
||||
type='FCENet',
|
||||
pretrained='torchvision://resnet50',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=True,
|
||||
style='pytorch',
|
||||
dcn=dict(type='DCNv2', deform_groups=2, fallback_on_stride=False),
|
||||
stage_with_dcn=(False, True, True, True)),
|
||||
neck=dict(
|
||||
type='FPN',
|
||||
in_channels=[512, 1024, 2048],
|
||||
out_channels=256,
|
||||
add_extra_convs=True,
|
||||
extra_convs_on_inputs=False, # use P5
|
||||
num_outs=3,
|
||||
relu_before_extra_convs=True,
|
||||
act_cfg=None),
|
||||
bbox_head=dict(
|
||||
type='FCEHead',
|
||||
in_channels=256,
|
||||
scales=(8, 16, 32),
|
||||
loss=dict(type='FCELoss'),
|
||||
fourier_degree=fourier_degree,
|
||||
))
|
||||
|
||||
train_cfg = None
|
||||
test_cfg = None
|
||||
|
||||
dataset_type = 'IcdarDataset'
|
||||
data_root = 'data/ctw1500/'
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='LoadTextAnnotations',
|
||||
with_bbox=True,
|
||||
with_mask=True,
|
||||
poly2mask=False),
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=32.0 / 255,
|
||||
saturation=0.5,
|
||||
contrast=0.5),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='RandomScaling', size=800, scale=(3. / 4, 5. / 2)),
|
||||
dict(
|
||||
type='RandomCropFlip', crop_ratio=0.5, iter_num=1, min_area_ratio=0.2),
|
||||
dict(
|
||||
type='RandomCropPolyInstances',
|
||||
instance_key='gt_masks',
|
||||
crop_ratio=0.8,
|
||||
min_side_ratio=0.3),
|
||||
dict(
|
||||
type='RandomRotatePolyInstances',
|
||||
rotate_ratio=0.5,
|
||||
max_angle=30,
|
||||
pad_with_fixed_color=False),
|
||||
dict(type='SquareResizePad', target_size=800, pad_ratio=0.6),
|
||||
dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(
|
||||
type='FCENetTargets',
|
||||
fourier_degree=fourier_degree,
|
||||
level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0))),
|
||||
dict(
|
||||
type='CustomFormatBundle',
|
||||
keys=['p3_maps', 'p4_maps', 'p5_maps'],
|
||||
visualize=dict(flag=False, boundary_key=None)),
|
||||
dict(type='Collect', keys=['img', 'p3_maps', 'p4_maps', 'p5_maps'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(1080, 736),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', img_scale=(1280, 800), keep_ratio=True),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=6,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
img_prefix=data_root + '/imgs',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_test.json',
|
||||
img_prefix=data_root + '/imgs',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_test.json',
|
||||
img_prefix=data_root + '/imgs',
|
||||
pipeline=test_pipeline))
|
||||
evaluation = dict(interval=5, metric='hmean-iou')
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=1e-3, momentum=0.90, weight_decay=5e-4)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True)
|
||||
total_epochs = 1500
|
||||
|
||||
checkpoint_config = dict(interval=5)
|
||||
# yapf:disable
|
||||
log_config = dict(
|
||||
interval=20,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook')
|
||||
|
||||
])
|
||||
# yapf:enable
|
||||
dist_params = dict(backend='nccl')
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume_from = None
|
||||
workflow = [('train', 1)]
|
@ -5,7 +5,7 @@ from .icdar_dataset import IcdarDataset
|
||||
from .kie_dataset import KIEDataset
|
||||
from .ocr_dataset import OCRDataset
|
||||
from .ocr_seg_dataset import OCRSegDataset
|
||||
from .pipelines import CustomFormatBundle, DBNetTargets
|
||||
from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets
|
||||
from .text_det_dataset import TextDetDataset
|
||||
|
||||
from .utils import * # NOQA
|
||||
@ -13,7 +13,7 @@ from .utils import * # NOQA
|
||||
__all__ = [
|
||||
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
|
||||
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
|
||||
'DBNetTargets', 'OCRSegDataset', 'KIEDataset'
|
||||
'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets'
|
||||
]
|
||||
|
||||
__all__ += utils.__all__
|
||||
|
@ -8,10 +8,11 @@ from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
|
||||
OpencvToPil, PilToOpencv, RandomPaddingOCR,
|
||||
RandomRotateImageBox, ResizeOCR, ToTensorOCR)
|
||||
from .test_time_aug import MultiRotateAugOCR
|
||||
from .textdet_targets import DBNetTargets, PANetTargets, TextSnakeTargets
|
||||
from .transforms import (ColorJitter, RandomCropInstances,
|
||||
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
|
||||
TextSnakeTargets)
|
||||
from .transforms import (ColorJitter, RandomCropFlip, RandomCropInstances,
|
||||
RandomCropPolyInstances, RandomRotatePolyInstances,
|
||||
RandomRotateTextDet, ScaleAspectJitter,
|
||||
RandomRotateTextDet, RandomScaling, ScaleAspectJitter,
|
||||
SquareResizePad)
|
||||
|
||||
__all__ = [
|
||||
@ -22,5 +23,6 @@ __all__ = [
|
||||
'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR',
|
||||
'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil',
|
||||
'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
|
||||
'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8'
|
||||
'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets',
|
||||
'RandomScaling', 'RandomCropFlip'
|
||||
]
|
||||
|
@ -1,10 +1,11 @@
|
||||
from .base_textdet_targets import BaseTextDetTargets
|
||||
from .dbnet_targets import DBNetTargets
|
||||
from .fcenet_targets import FCENetTargets
|
||||
from .panet_targets import PANetTargets
|
||||
from .psenet_targets import PSENetTargets
|
||||
from .textsnake_targets import TextSnakeTargets
|
||||
|
||||
__all__ = [
|
||||
'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets',
|
||||
'TextSnakeTargets'
|
||||
'FCENetTargets', 'TextSnakeTargets'
|
||||
]
|
||||
|
370
mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py
Normal file
370
mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py
Normal file
@ -0,0 +1,370 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.linalg import norm
|
||||
|
||||
import mmocr.utils.check_argument as check_argument
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from .textsnake_targets import TextSnakeTargets
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class FCENetTargets(TextSnakeTargets):
|
||||
"""Generate the ground truth targets of FCENet: Fourier Contour Embedding
|
||||
for Arbitrary-Shaped Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2104.10442]
|
||||
|
||||
Args:
|
||||
fourier_degree (int): The maximum Fourier transform degree k.
|
||||
resample_step (float): The step size for resampling the text center
|
||||
line (TCL). It's better not to exceed half of the minimum width.
|
||||
center_region_shrink_ratio (float): The shrink ratio of text center
|
||||
region.
|
||||
level_size_divisors (tuple(int)): The downsample ratio on each level.
|
||||
level_proportion_range (tuple(tuple(int))): The range of text sizes
|
||||
assigned to each level.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
fourier_degree=5,
|
||||
resample_step=4.0,
|
||||
center_region_shrink_ratio=0.3,
|
||||
level_size_divisors=(8, 16, 32),
|
||||
level_proportion_range=((0, 0.4), (0.3, 0.7), (0.6, 1.0))):
|
||||
|
||||
super().__init__()
|
||||
assert isinstance(level_size_divisors, tuple)
|
||||
assert isinstance(level_proportion_range, tuple)
|
||||
assert len(level_size_divisors) == len(level_proportion_range)
|
||||
self.fourier_degree = fourier_degree
|
||||
self.resample_step = resample_step
|
||||
self.center_region_shrink_ratio = center_region_shrink_ratio
|
||||
self.level_size_divisors = level_size_divisors
|
||||
self.level_proportion_range = level_proportion_range
|
||||
|
||||
def generate_center_region_mask(self, img_size, text_polys):
|
||||
"""Generate text center region mask.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size of (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
|
||||
center_region_mask = np.zeros((h, w), np.uint8)
|
||||
|
||||
center_region_boxes = []
|
||||
for poly in text_polys:
|
||||
assert len(poly) == 1
|
||||
polygon_points = poly[0].reshape(-1, 2)
|
||||
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
|
||||
resampled_top_line, resampled_bot_line = self.resample_sidelines(
|
||||
top_line, bot_line, self.resample_step)
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
center_line = (resampled_top_line + resampled_bot_line) / 2
|
||||
|
||||
line_head_shrink_len = norm(resampled_top_line[0] -
|
||||
resampled_bot_line[0]) / 4.0
|
||||
line_tail_shrink_len = norm(resampled_top_line[-1] -
|
||||
resampled_bot_line[-1]) / 4.0
|
||||
head_shrink_num = int(line_head_shrink_len // self.resample_step)
|
||||
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
|
||||
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
|
||||
center_line = center_line[head_shrink_num:len(center_line) -
|
||||
tail_shrink_num]
|
||||
resampled_top_line = resampled_top_line[
|
||||
head_shrink_num:len(resampled_top_line) - tail_shrink_num]
|
||||
resampled_bot_line = resampled_bot_line[
|
||||
head_shrink_num:len(resampled_bot_line) - tail_shrink_num]
|
||||
|
||||
for i in range(0, len(center_line) - 1):
|
||||
tl = center_line[i] + (resampled_top_line[i] - center_line[i]
|
||||
) * self.center_region_shrink_ratio
|
||||
tr = center_line[i + 1] + (
|
||||
resampled_top_line[i + 1] -
|
||||
center_line[i + 1]) * self.center_region_shrink_ratio
|
||||
br = center_line[i + 1] + (
|
||||
resampled_bot_line[i + 1] -
|
||||
center_line[i + 1]) * self.center_region_shrink_ratio
|
||||
bl = center_line[i] + (resampled_bot_line[i] - center_line[i]
|
||||
) * self.center_region_shrink_ratio
|
||||
current_center_box = np.vstack([tl, tr, br,
|
||||
bl]).astype(np.int32)
|
||||
center_region_boxes.append(current_center_box)
|
||||
|
||||
cv2.fillPoly(center_region_mask, center_region_boxes, 1)
|
||||
return center_region_mask
|
||||
|
||||
def resample_polygon(self, polygon, n=400):
|
||||
"""Resample one polygon with n points on its boundary.
|
||||
|
||||
Args:
|
||||
polygon (list[float]): The input polygon.
|
||||
n (int): The number of resampled points.
|
||||
Returns:
|
||||
resampled_polygon (list[float]): The resampled polygon.
|
||||
"""
|
||||
length = []
|
||||
|
||||
for i in range(len(polygon)):
|
||||
p1 = polygon[i]
|
||||
if i == len(polygon) - 1:
|
||||
p2 = polygon[0]
|
||||
else:
|
||||
p2 = polygon[i + 1]
|
||||
length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5)
|
||||
|
||||
total_length = sum(length)
|
||||
n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
|
||||
n_on_each_line = n_on_each_line.astype(np.int32)
|
||||
new_polygon = []
|
||||
|
||||
for i in range(len(polygon)):
|
||||
num = n_on_each_line[i]
|
||||
p1 = polygon[i]
|
||||
if i == len(polygon) - 1:
|
||||
p2 = polygon[0]
|
||||
else:
|
||||
p2 = polygon[i + 1]
|
||||
|
||||
if num == 0:
|
||||
continue
|
||||
|
||||
dxdy = (p2 - p1) / num
|
||||
for j in range(num):
|
||||
point = p1 + dxdy * j
|
||||
new_polygon.append(point)
|
||||
|
||||
return np.array(new_polygon)
|
||||
|
||||
def normalize_polygon(self, polygon):
|
||||
"""Normalize one polygon so that its start point is at right most.
|
||||
|
||||
Args:
|
||||
polygon (list[float]): The origin polygon.
|
||||
Returns:
|
||||
new_polygon (lost[float]): The polygon with start point at right.
|
||||
"""
|
||||
temp_polygon = polygon - polygon.mean(axis=0)
|
||||
x = np.abs(temp_polygon[:, 0])
|
||||
y = temp_polygon[:, 1]
|
||||
index_x = np.argsort(x)
|
||||
index_y = np.argmin(y[index_x[:8]])
|
||||
index = index_x[index_y]
|
||||
new_polygon = np.concatenate([polygon[index:], polygon[:index]])
|
||||
return new_polygon
|
||||
|
||||
def fourier_transform(self, polygon, fourier_degree):
|
||||
"""Perform Fourier transformation to generate Fourier coefficients ck
|
||||
from polygon.
|
||||
|
||||
Args:
|
||||
polygon (ndarray): An input polygon.
|
||||
fourier_degree (int): The maximum Fourier degree K.
|
||||
Returns:
|
||||
c (ndarray(complex)): Fourier coefficients.
|
||||
"""
|
||||
points = polygon[:, 0] + polygon[:, 1] * 1j
|
||||
n = len(points)
|
||||
t = np.multiply([i / n for i in range(n)], -2 * np.pi * 1j)
|
||||
|
||||
e = complex(np.e)
|
||||
c = np.zeros((2 * fourier_degree + 1, ), dtype='complex')
|
||||
|
||||
for i in range(-fourier_degree, fourier_degree + 1):
|
||||
c[i + fourier_degree] = np.sum(points * np.power(e, i * t)) / n
|
||||
|
||||
return c
|
||||
|
||||
def clockwise(self, c, fourier_degree):
|
||||
"""Make sure the polygon reconstructed from Fourier coefficients c in
|
||||
the clockwise direction.
|
||||
|
||||
Args:
|
||||
polygon (list[float]): The origin polygon.
|
||||
Returns:
|
||||
new_polygon (lost[float]): The polygon in clockwise point order.
|
||||
"""
|
||||
if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
|
||||
return c
|
||||
elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
|
||||
return c[::-1]
|
||||
else:
|
||||
if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
|
||||
return c
|
||||
else:
|
||||
return c[::-1]
|
||||
|
||||
def cal_fourier_signature(self, polygon, fourier_degree):
|
||||
"""Calculate Fourier signature from input polygon.
|
||||
|
||||
Args:
|
||||
polygon (ndarray): The input polygon.
|
||||
fourier_degree (int): The maximum Fourier degree K.
|
||||
Returns:
|
||||
fourier_signature (ndarray): An array shaped (2k+1, 2) containing
|
||||
real part and image part of 2k+1 Fourier coefficients.
|
||||
"""
|
||||
resampled_polygon = self.resample_polygon(polygon)
|
||||
resampled_polygon = self.normalize_polygon(resampled_polygon)
|
||||
|
||||
fourier_coeff = self.fourier_transform(resampled_polygon,
|
||||
fourier_degree)
|
||||
fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
|
||||
|
||||
real_part = np.real(fourier_coeff).reshape((-1, 1))
|
||||
image_part = np.imag(fourier_coeff).reshape((-1, 1))
|
||||
fourier_signature = np.hstack([real_part, image_part])
|
||||
|
||||
return fourier_signature
|
||||
|
||||
def generate_fourier_maps(self, img_size, text_polys):
|
||||
"""Generate Fourier coefficient maps.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size of (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
fourier_real_map (ndarray): The Fourier coefficient real part maps.
|
||||
fourier_image_map (ndarray): The Fourier coefficient image part
|
||||
maps.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
k = self.fourier_degree
|
||||
real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
|
||||
imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
|
||||
|
||||
for poly in text_polys:
|
||||
assert len(poly) == 1
|
||||
text_instance = [[poly[0][i], poly[0][i + 1]]
|
||||
for i in range(0, len(poly[0]), 2)]
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
polygon = np.array(text_instance).reshape((1, -1, 2))
|
||||
cv2.fillPoly(mask, polygon.astype(np.int32), 1)
|
||||
fourier_coeff = self.cal_fourier_signature(polygon[0], k)
|
||||
for i in range(-k, k + 1):
|
||||
if i != 0:
|
||||
real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + (
|
||||
1 - mask) * real_map[i + k, :, :]
|
||||
imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + (
|
||||
1 - mask) * imag_map[i + k, :, :]
|
||||
else:
|
||||
yx = np.argwhere(mask > 0.5)
|
||||
k_ind = np.ones((len(yx)), dtype=np.int64) * k
|
||||
y, x = yx[:, 0], yx[:, 1]
|
||||
real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
|
||||
imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
|
||||
|
||||
return real_map, imag_map
|
||||
|
||||
def generate_level_targets(self, img_size, text_polys, ignore_polys):
|
||||
"""Generate ground truth target on each level.
|
||||
|
||||
Args:
|
||||
img_size (list[int]): Shape of input image.
|
||||
text_polys (list[list[ndarray]]): A list of ground truth polygons.
|
||||
ignore_polys (list[list[ndarray]]): A list of ignored polygons.
|
||||
Returns:
|
||||
level_maps (list(ndarray)): A list of ground target on each level.
|
||||
"""
|
||||
h, w = img_size
|
||||
lv_size_divs = self.level_size_divisors
|
||||
lv_proportion_range = self.level_proportion_range
|
||||
lv_text_polys = [[] for i in range(len(lv_size_divs))]
|
||||
lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
|
||||
level_maps = []
|
||||
for poly in text_polys:
|
||||
assert len(poly) == 1
|
||||
text_instance = [[poly[0][i], poly[0][i + 1]]
|
||||
for i in range(0, len(poly[0]), 2)]
|
||||
polygon = np.array(text_instance, dtype=np.int).reshape((1, -1, 2))
|
||||
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
||||
proportion = max(box_h, box_w) / (h + 1e-8)
|
||||
|
||||
for ind, proportion_range in enumerate(lv_proportion_range):
|
||||
if proportion_range[0] < proportion < proportion_range[1]:
|
||||
lv_text_polys[ind].append([poly[0] / lv_size_divs[ind]])
|
||||
|
||||
for ignore_poly in ignore_polys:
|
||||
assert len(ignore_poly) == 1
|
||||
text_instance = [[ignore_poly[0][i], ignore_poly[0][i + 1]]
|
||||
for i in range(0, len(ignore_poly[0]), 2)]
|
||||
polygon = np.array(text_instance, dtype=np.int).reshape((1, -1, 2))
|
||||
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
||||
proportion = max(box_h, box_w) / (h + 1e-8)
|
||||
|
||||
for ind, proportion_range in enumerate(lv_proportion_range):
|
||||
if proportion_range[0] < proportion < proportion_range[1]:
|
||||
lv_text_polys[ind].append(
|
||||
[ignore_poly[0] / lv_size_divs[ind]])
|
||||
|
||||
for ind, size_divisor in enumerate(lv_size_divs):
|
||||
current_level_maps = []
|
||||
level_img_size = (h // size_divisor, w // size_divisor)
|
||||
|
||||
text_region = self.generate_text_region_mask(
|
||||
level_img_size, lv_text_polys[ind])
|
||||
text_region = np.expand_dims(text_region, axis=0)
|
||||
current_level_maps.append(text_region)
|
||||
|
||||
center_region = self.generate_center_region_mask(
|
||||
level_img_size, lv_text_polys[ind])
|
||||
center_region = np.expand_dims(center_region, axis=0)
|
||||
current_level_maps.append(center_region)
|
||||
|
||||
effective_mask = self.generate_effective_mask(
|
||||
level_img_size, lv_ignore_polys[ind])
|
||||
effective_mask = np.expand_dims(effective_mask, axis=0)
|
||||
current_level_maps.append(effective_mask)
|
||||
|
||||
fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
|
||||
level_img_size, lv_text_polys[ind])
|
||||
current_level_maps.append(fourier_real_map)
|
||||
current_level_maps.append(fourier_image_maps)
|
||||
|
||||
level_maps.append(np.concatenate(current_level_maps))
|
||||
|
||||
return level_maps
|
||||
|
||||
def generate_targets(self, results):
|
||||
"""Generate the ground truth targets for FCENet.
|
||||
|
||||
Args:
|
||||
results (dict): The input result dictionary.
|
||||
|
||||
Returns:
|
||||
results (dict): The output result dictionary.
|
||||
"""
|
||||
|
||||
assert isinstance(results, dict)
|
||||
|
||||
polygon_masks = results['gt_masks'].masks
|
||||
polygon_masks_ignore = results['gt_masks_ignore'].masks
|
||||
|
||||
h, w, _ = results['img_shape']
|
||||
|
||||
level_maps = self.generate_level_targets((h, w), polygon_masks,
|
||||
polygon_masks_ignore)
|
||||
|
||||
results['mask_fields'].clear() # rm gt_masks encoded by polygons
|
||||
mapping = {
|
||||
'p3_maps': level_maps[0],
|
||||
'p4_maps': level_maps[1],
|
||||
'p5_maps': level_maps[2]
|
||||
}
|
||||
for key, value in mapping.items():
|
||||
results[key] = value
|
||||
|
||||
return results
|
@ -2,6 +2,7 @@ import math
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import Polygon as plg
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
|
||||
@ -731,3 +732,231 @@ class SquareResizePad:
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomScaling:
|
||||
|
||||
def __init__(self, size=800, scale=(3. / 4, 5. / 2)):
|
||||
"""Random scale the image while keeping aspect.
|
||||
|
||||
Args:
|
||||
size (int) : Base size before scaling.
|
||||
scale (tuple(float)) : The range of scaling.
|
||||
"""
|
||||
assert isinstance(size, int)
|
||||
assert isinstance(scale, float) or isinstance(scale, tuple)
|
||||
self.size = size
|
||||
self.scale = scale if isinstance(scale, tuple) \
|
||||
else (1 - scale, 1 + scale)
|
||||
|
||||
def __call__(self, results):
|
||||
image = results['img']
|
||||
h, w, _ = results['img_shape']
|
||||
|
||||
aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
|
||||
scales = self.size * 1.0 / max(h, w) * aspect_ratio
|
||||
scales = np.array([scales, scales])
|
||||
out_size = (int(h * scales[1]), int(w * scales[0]))
|
||||
image = cv2.resize(image, out_size[::-1])
|
||||
|
||||
results['img'] = image
|
||||
results['img_shape'] = image.shape
|
||||
|
||||
for key in results.get('mask_fields', []):
|
||||
if len(results[key].masks) == 0:
|
||||
continue
|
||||
results[key] = results[key].resize(out_size)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomCropFlip:
|
||||
|
||||
def __init__(self, crop_ratio=0.5, iter_num=1, min_area_ratio=0.2):
|
||||
"""Random crop and flip a patch of the image.
|
||||
|
||||
Args:
|
||||
crop_ratio (float): The ratio of cropping.
|
||||
iter_num (int): Number of operations.
|
||||
min_area_ratio (float): Minimal area ratio between cropped patch
|
||||
and original image.
|
||||
"""
|
||||
assert isinstance(crop_ratio, float)
|
||||
assert isinstance(iter_num, int)
|
||||
assert isinstance(min_area_ratio, float)
|
||||
|
||||
self.scale = 10
|
||||
self.epsilon = 1e-2
|
||||
self.crop_ratio = crop_ratio
|
||||
self.iter_num = iter_num
|
||||
self.min_area_ratio = min_area_ratio
|
||||
|
||||
def __call__(self, results):
|
||||
for i in range(self.iter_num):
|
||||
results = self.random_crop_flip(results)
|
||||
return results
|
||||
|
||||
def random_crop_flip(self, results):
|
||||
image = results['img']
|
||||
polygons = results['gt_masks'].masks
|
||||
ignore_polygons = results['gt_masks_ignore'].masks
|
||||
all_polygons = polygons + ignore_polygons
|
||||
if len(polygons) == 0:
|
||||
return results
|
||||
|
||||
if np.random.random() >= self.crop_ratio:
|
||||
return results
|
||||
|
||||
h_axis, w_axis = self.crop_target(image, all_polygons, self.scale)
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return results
|
||||
|
||||
attempt = 0
|
||||
h, w, _ = results['img_shape']
|
||||
area = h * w
|
||||
pad_h = h // self.scale
|
||||
pad_w = w // self.scale
|
||||
while attempt < 10:
|
||||
attempt += 1
|
||||
polys_keep = []
|
||||
polys_new = []
|
||||
ign_polys_keep = []
|
||||
ign_polys_new = []
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
|
||||
# area too small
|
||||
continue
|
||||
|
||||
pts = np.stack([[xmin, xmax, xmax, xmin],
|
||||
[ymin, ymin, ymax, ymax]]).T.astype(np.int32)
|
||||
pp = plg.Polygon(pts)
|
||||
fail_flag = False
|
||||
for polygon in polygons:
|
||||
ppi = plg.Polygon(polygon[0].reshape(-1, 2))
|
||||
ppiou, _ = eval_utils.poly_intersection(ppi, pp)
|
||||
if np.abs(ppiou - float(ppi.area())) > self.epsilon and \
|
||||
np.abs(ppiou) > self.epsilon:
|
||||
fail_flag = True
|
||||
break
|
||||
elif np.abs(ppiou - float(ppi.area())) < self.epsilon:
|
||||
polys_new.append(polygon)
|
||||
else:
|
||||
polys_keep.append(polygon)
|
||||
|
||||
for polygon in ignore_polygons:
|
||||
ppi = plg.Polygon(polygon[0].reshape(-1, 2))
|
||||
ppiou, _ = eval_utils.poly_intersection(ppi, pp)
|
||||
if np.abs(ppiou - float(ppi.area())) > self.epsilon and \
|
||||
np.abs(ppiou) > self.epsilon:
|
||||
fail_flag = True
|
||||
break
|
||||
elif np.abs(ppiou - float(ppi.area())) < self.epsilon:
|
||||
ign_polys_new.append(polygon)
|
||||
else:
|
||||
ign_polys_keep.append(polygon)
|
||||
|
||||
if fail_flag:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
cropped = image[ymin:ymax, xmin:xmax, :]
|
||||
select_type = np.random.randint(3)
|
||||
if select_type == 0:
|
||||
img = np.ascontiguousarray(cropped[:, ::-1])
|
||||
elif select_type == 1:
|
||||
img = np.ascontiguousarray(cropped[::-1, :])
|
||||
else:
|
||||
img = np.ascontiguousarray(cropped[::-1, ::-1])
|
||||
image[ymin:ymax, xmin:xmax, :] = img
|
||||
results['img'] = image
|
||||
|
||||
if len(polys_new) + len(ign_polys_new) != 0:
|
||||
height, width, _ = cropped.shape
|
||||
if select_type == 0:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
polys_new[idx] = [poly.reshape(-1, )]
|
||||
for idx, polygon in enumerate(ign_polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
ign_polys_new[idx] = [poly.reshape(-1, )]
|
||||
elif select_type == 1:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
polys_new[idx] = [poly.reshape(-1, )]
|
||||
for idx, polygon in enumerate(ign_polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
ign_polys_new[idx] = [poly.reshape(-1, )]
|
||||
else:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
polys_new[idx] = [poly.reshape(-1, )]
|
||||
for idx, polygon in enumerate(ign_polys_new):
|
||||
poly = polygon[0].reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
ign_polys_new[idx] = [poly.reshape(-1, )]
|
||||
polygons = polys_keep + polys_new
|
||||
ignore_polygons = ign_polys_keep + ign_polys_new
|
||||
results['gt_masks'] = PolygonMasks(polygons, *(image.shape[:2]))
|
||||
results['gt_masks_ignore'] = PolygonMasks(ignore_polygons,
|
||||
*(image.shape[:2]))
|
||||
|
||||
return results
|
||||
|
||||
def crop_target(self, image, all_polys, scale):
|
||||
"""Generate crop target and make sure not to crop the polygon
|
||||
instances.
|
||||
|
||||
Args:
|
||||
image (ndarray): The image waited to be crop.
|
||||
all_polys (list[list[ndarray]]): All polygons including ground
|
||||
truth polygons and ground truth ignored polygons.
|
||||
scale (int): A scale factor to control crop range.
|
||||
Returns:
|
||||
h_axis (ndarray): Vertical cropping range.
|
||||
w_axis (ndarray): Horizontal cropping range.
|
||||
"""
|
||||
h, w, _ = image.shape
|
||||
pad_h = h // scale
|
||||
pad_w = w // scale
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
|
||||
text_polys = []
|
||||
for polygon in all_polys:
|
||||
rect = cv2.minAreaRect(polygon[0].astype(np.int32).reshape(-1, 2))
|
||||
box = cv2.boxPoints(rect)
|
||||
box = np.int0(box)
|
||||
text_polys.append([box[0], box[1], box[2], box[3]])
|
||||
|
||||
polys = np.array(text_polys, dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32) # 四舍五入
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
return h_axis, w_axis
|
||||
|
@ -1,7 +1,10 @@
|
||||
from .db_head import DBHead
|
||||
from .fce_head import FCEHead
|
||||
from .head_mixin import HeadMixin
|
||||
from .pan_head import PANHead
|
||||
from .pse_head import PSEHead
|
||||
from .textsnake_head import TextSnakeHead
|
||||
|
||||
__all__ = ['PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'TextSnakeHead']
|
||||
__all__ = [
|
||||
'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'HeadMixin', 'TextSnakeHead'
|
||||
]
|
||||
|
134
mmocr/models/textdet/dense_heads/fce_head.py
Normal file
134
mmocr/models/textdet/dense_heads/fce_head.py
Normal file
@ -0,0 +1,134 @@
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import normal_init
|
||||
|
||||
from mmdet.core import multi_apply
|
||||
from mmdet.models.builder import HEADS, build_loss
|
||||
from mmocr.models.textdet.postprocess import decode
|
||||
from ..postprocess.wrapper import poly_nms
|
||||
from .head_mixin import HeadMixin
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class FCEHead(HeadMixin, nn.Module):
|
||||
"""The class for implementing FCENet head.
|
||||
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
|
||||
Detection.
|
||||
|
||||
[https://arxiv.org/abs/2104.10442]
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
scales (list[int]) : The scale of each layer.
|
||||
fourier_degree (int) : The maximum Fourier transform degree k.
|
||||
sample_num (int) : The sampling points number of regression
|
||||
loss. If it is too small, FCEnet tends to be overfitting.
|
||||
score_thresh (float) : The threshold to filter out the final
|
||||
candidates.
|
||||
nms_thresh (float) : The threshold of nms.
|
||||
alpha (float) : The parameter to calculate final scores. Score_{final}
|
||||
= (Score_{text region} ^ alpha)
|
||||
* (Score{text center region} ^ beta)
|
||||
beta (float) :The parameter to calculate final scores.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
scales,
|
||||
fourier_degree=5,
|
||||
sample_num=50,
|
||||
reconstr_points=50,
|
||||
decoding_type='fcenet',
|
||||
loss=dict(type='FCELoss'),
|
||||
score_thresh=0.3,
|
||||
nms_thresh=0.1,
|
||||
alpha=1.0,
|
||||
beta=1.0,
|
||||
train_cfg=None,
|
||||
test_cfg=None):
|
||||
|
||||
super().__init__()
|
||||
assert isinstance(in_channels, int)
|
||||
|
||||
self.downsample_ratio = 1.0
|
||||
self.in_channels = in_channels
|
||||
self.scales = scales
|
||||
self.fourier_degree = fourier_degree
|
||||
self.sample_num = sample_num
|
||||
self.reconstr_points = reconstr_points
|
||||
loss['fourier_degree'] = fourier_degree
|
||||
loss['sample_num'] = sample_num
|
||||
self.decoding_type = decoding_type
|
||||
self.loss_module = build_loss(loss)
|
||||
self.score_thresh = score_thresh
|
||||
self.nms_thresh = nms_thresh
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
self.out_channels_cls = 4
|
||||
self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
|
||||
|
||||
self.out_conv_cls = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.out_channels_cls,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.out_conv_reg = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.out_channels_reg,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
normal_init(self.out_conv_cls, mean=0, std=0.01)
|
||||
normal_init(self.out_conv_reg, mean=0, std=0.01)
|
||||
|
||||
def forward(self, feats):
|
||||
cls_res, reg_res = multi_apply(self.forward_single, feats)
|
||||
level_num = len(cls_res)
|
||||
preds = [[cls_res[i], reg_res[i]] for i in range(level_num)]
|
||||
return preds
|
||||
|
||||
def forward_single(self, x):
|
||||
cls_predict = self.out_conv_cls(x)
|
||||
reg_predict = self.out_conv_reg(x)
|
||||
return cls_predict, reg_predict
|
||||
|
||||
def get_boundary(self, score_maps, img_metas, rescale):
|
||||
assert len(score_maps) == len(self.scales)
|
||||
|
||||
boundaries = []
|
||||
for idx, score_map in enumerate(score_maps):
|
||||
scale = self.scales[idx]
|
||||
boundaries = boundaries + self._get_boundary_single(
|
||||
score_map, scale)
|
||||
|
||||
# nms
|
||||
boundaries = poly_nms(boundaries, self.nms_thresh)
|
||||
|
||||
if rescale:
|
||||
boundaries = self.resize_boundary(
|
||||
boundaries, 1.0 / img_metas[0]['scale_factor'])
|
||||
|
||||
results = dict(boundary_result=boundaries)
|
||||
return results
|
||||
|
||||
def _get_boundary_single(self, score_map, scale):
|
||||
assert len(score_map) == 2
|
||||
assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
|
||||
|
||||
return decode(
|
||||
decoding_type=self.decoding_type,
|
||||
preds=score_map,
|
||||
fourier_degree=self.fourier_degree,
|
||||
reconstr_points=self.reconstr_points,
|
||||
scale=scale,
|
||||
alpha=self.alpha,
|
||||
beta=self.beta,
|
||||
text_repr_type='poly',
|
||||
score_thresh=self.score_thresh,
|
||||
nms_thresh=self.nms_thresh)
|
@ -1,4 +1,5 @@
|
||||
from .dbnet import DBNet
|
||||
from .fcenet import FCENet
|
||||
from .ocr_mask_rcnn import OCRMaskRCNN
|
||||
from .panet import PANet
|
||||
from .psenet import PSENet
|
||||
@ -8,5 +9,5 @@ from .textsnake import TextSnake
|
||||
|
||||
__all__ = [
|
||||
'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',
|
||||
'PANet', 'PSENet', 'TextSnake'
|
||||
'PANet', 'PSENet', 'TextSnake', 'FCENet'
|
||||
]
|
||||
|
32
mmocr/models/textdet/detectors/fcenet.py
Normal file
32
mmocr/models/textdet/detectors/fcenet.py
Normal file
@ -0,0 +1,32 @@
|
||||
from mmdet.models.builder import DETECTORS
|
||||
from .single_stage_text_detector import SingleStageTextDetector
|
||||
from .text_detector_mixin import TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class FCENet(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for implementing FCENet text detector
|
||||
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
|
||||
Detection
|
||||
|
||||
[https://arxiv.org/abs/2104.10442]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
neck,
|
||||
bbox_head,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None,
|
||||
show_score=False):
|
||||
SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
|
||||
train_cfg, test_cfg, pretrained)
|
||||
TextDetectorMixin.__init__(self, show_score)
|
||||
|
||||
def simple_test(self, img, img_metas, rescale=False):
|
||||
x = self.extract_feat(img)
|
||||
outs = self.bbox_head(x)
|
||||
boundaries = self.bbox_head.get_boundary(outs, img_metas, rescale)
|
||||
|
||||
return [boundaries]
|
@ -1,6 +1,7 @@
|
||||
from .db_loss import DBLoss
|
||||
from .fce_loss import FCELoss
|
||||
from .pan_loss import PANLoss
|
||||
from .pse_loss import PSELoss
|
||||
from .textsnake_loss import TextSnakeLoss
|
||||
|
||||
__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss']
|
||||
__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss', 'FCELoss']
|
||||
|
194
mmocr/models/textdet/losses/fce_loss.py
Normal file
194
mmocr/models/textdet/losses/fce_loss.py
Normal file
@ -0,0 +1,194 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from mmdet.core import multi_apply
|
||||
from mmdet.models.builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class FCELoss(nn.Module):
|
||||
"""The class for implementing FCENet loss
|
||||
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped
|
||||
Text Detection
|
||||
|
||||
[https://arxiv.org/abs/2104.10442]
|
||||
|
||||
Args:
|
||||
fourier_degree (int) : The maximum Fourier transform degree k.
|
||||
sample_num (int) : The sampling points number of regression
|
||||
loss. If it is too small, fcenet tends to be overfitting.
|
||||
ohem_ratio (float): the negative/positive ratio in OHEM.
|
||||
"""
|
||||
|
||||
def __init__(self, fourier_degree, sample_num, ohem_ratio=3.):
|
||||
super().__init__()
|
||||
self.fourier_degree = fourier_degree
|
||||
self.sample_num = sample_num
|
||||
self.ohem_ratio = ohem_ratio
|
||||
|
||||
def forward(self, preds, _, p3_maps, p4_maps, p5_maps):
|
||||
assert isinstance(preds, list)
|
||||
assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
|
||||
'fourier degree not equal in FCEhead and FCEtarget'
|
||||
|
||||
device = preds[0][0].device
|
||||
# to tensor
|
||||
gts = [p3_maps, p4_maps, p5_maps]
|
||||
for idx, maps in enumerate(gts):
|
||||
gts[idx] = torch.from_numpy(np.stack(maps)).float().to(device)
|
||||
|
||||
losses = multi_apply(self.forward_single, preds, gts)
|
||||
|
||||
loss_tr = torch.tensor(0., device=device).float()
|
||||
loss_tcl = torch.tensor(0., device=device).float()
|
||||
loss_reg_x = torch.tensor(0., device=device).float()
|
||||
loss_reg_y = torch.tensor(0., device=device).float()
|
||||
|
||||
for idx, loss in enumerate(losses):
|
||||
if idx == 0:
|
||||
loss_tr += sum(loss)
|
||||
elif idx == 1:
|
||||
loss_tcl += sum(loss)
|
||||
elif idx == 2:
|
||||
loss_reg_x += sum(loss)
|
||||
else:
|
||||
loss_reg_y += sum(loss)
|
||||
|
||||
results = dict(
|
||||
loss_text=loss_tr,
|
||||
loss_center=loss_tcl,
|
||||
loss_reg_x=loss_reg_x,
|
||||
loss_reg_y=loss_reg_y,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def forward_single(self, pred, gt):
|
||||
cls_pred, reg_pred = pred[0], pred[1]
|
||||
|
||||
tr_pred = cls_pred[:, :2, :, :].permute(0, 2, 3, 1)\
|
||||
.contiguous().view(-1, 2)
|
||||
tcl_pred = cls_pred[:, 2:, :, :].permute(0, 2, 3, 1)\
|
||||
.contiguous().view(-1, 2)
|
||||
x_pred = reg_pred[:, 0:2 * self.fourier_degree + 1, :, :]\
|
||||
.permute(0, 2, 3, 1).contiguous().view(
|
||||
-1, 2 * self.fourier_degree + 1)
|
||||
y_pred = reg_pred[:,
|
||||
2 * self.fourier_degree + 1:4 * self.fourier_degree +
|
||||
2, :, :].permute(0, 2, 3, 1).contiguous().view(
|
||||
-1, 2 * self.fourier_degree + 1)
|
||||
|
||||
tr_mask = gt[:, :1, :, :].permute(0, 2, 3, 1).contiguous().view(-1)
|
||||
tcl_mask = gt[:, 1:2, :, :].permute(0, 2, 3, 1).contiguous().view(-1)
|
||||
train_mask = gt[:, 2:3, :, :].permute(0, 2, 3, 1).contiguous().view(-1)
|
||||
x_map = gt[:, 3:4 + 2 * self.fourier_degree, :, :].permute(
|
||||
0, 2, 3, 1).contiguous().view(-1, 2 * self.fourier_degree + 1)
|
||||
y_map = gt[:, 4 + 2 * self.fourier_degree:, :, :].permute(
|
||||
0, 2, 3, 1).contiguous().view(-1, 2 * self.fourier_degree + 1)
|
||||
|
||||
tr_train_mask = train_mask * tr_mask
|
||||
device = x_map.device
|
||||
# tr loss
|
||||
loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long())
|
||||
|
||||
# tcl loss
|
||||
loss_tcl = torch.tensor(0.).float().to(device)
|
||||
tr_neg_mask = 1 - tr_train_mask
|
||||
if tr_train_mask.sum().item() > 0:
|
||||
loss_tcl_pos = F.cross_entropy(
|
||||
tcl_pred[tr_train_mask.bool()],
|
||||
tcl_mask[tr_train_mask.bool()].long())
|
||||
loss_tcl_neg = F.cross_entropy(tcl_pred[tr_neg_mask.bool()],
|
||||
tcl_mask[tr_neg_mask.bool()].long())
|
||||
loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
|
||||
|
||||
# regression loss
|
||||
loss_reg_x = torch.tensor(0.).float().to(device)
|
||||
loss_reg_y = torch.tensor(0.).float().to(device)
|
||||
if tr_train_mask.sum().item() > 0:
|
||||
weight = (tr_mask[tr_train_mask.bool()].float() +
|
||||
tcl_mask[tr_train_mask.bool()].float()) / 2
|
||||
weight = weight.contiguous().view(-1, 1)
|
||||
|
||||
ft_x, ft_y = self.fourier2poly(x_map, y_map)
|
||||
ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
|
||||
|
||||
loss_reg_x = torch.mean(weight * F.smooth_l1_loss(
|
||||
ft_x_pre[tr_train_mask.bool()],
|
||||
ft_x[tr_train_mask.bool()],
|
||||
reduction='none'))
|
||||
loss_reg_y = torch.mean(weight * F.smooth_l1_loss(
|
||||
ft_y_pre[tr_train_mask.bool()],
|
||||
ft_y[tr_train_mask.bool()],
|
||||
reduction='none'))
|
||||
|
||||
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
|
||||
|
||||
def ohem(self, predict, target, train_mask):
|
||||
pos = (target * train_mask).bool()
|
||||
neg = ((1 - target) * train_mask).bool()
|
||||
|
||||
n_pos = pos.float().sum()
|
||||
|
||||
if n_pos.item() > 0:
|
||||
loss_pos = F.cross_entropy(
|
||||
predict[pos], target[pos], reduction='sum')
|
||||
loss_neg = F.cross_entropy(
|
||||
predict[neg], target[neg], reduction='none')
|
||||
n_neg = min(
|
||||
int(neg.float().sum().item()),
|
||||
int(self.ohem_ratio * n_pos.float()))
|
||||
else:
|
||||
loss_pos = torch.tensor(0.)
|
||||
loss_neg = F.cross_entropy(
|
||||
predict[neg], target[neg], reduction='none')
|
||||
n_neg = 100
|
||||
if len(loss_neg) > n_neg:
|
||||
loss_neg, _ = torch.topk(loss_neg, n_neg)
|
||||
|
||||
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
|
||||
|
||||
def fourier2poly(self, real_maps, imag_maps):
|
||||
"""Transform Fourier coefficient maps to polygon maps.
|
||||
|
||||
Args:
|
||||
real_maps (tensor): A map composed of the real parts of the
|
||||
Fourier coefficients, whose shape is (-1, 2k+1)
|
||||
imag_maps (tensor):A map composed of the imag parts of the
|
||||
Fourier coefficients, whose shape is (-1, 2k+1)
|
||||
|
||||
Returns
|
||||
x_maps (tensor): A map composed of the x value of the polygon
|
||||
represented by n sample points (xn, yn), whose shape is (-1, n)
|
||||
y_maps (tensor): A map composed of the y value of the polygon
|
||||
represented by n sample points (xn, yn), whose shape is (-1, n)
|
||||
"""
|
||||
|
||||
device = real_maps.device
|
||||
|
||||
k_vect = torch.arange(
|
||||
-self.fourier_degree,
|
||||
self.fourier_degree + 1,
|
||||
dtype=torch.float,
|
||||
device=device).view(-1, 1)
|
||||
i_vect = torch.arange(
|
||||
0, self.sample_num, dtype=torch.float, device=device).view(1, -1)
|
||||
|
||||
transform_matrix = 2 * np.pi / self.sample_num * torch.mm(
|
||||
k_vect, i_vect)
|
||||
|
||||
x1 = torch.einsum('ak, kn-> an', real_maps,
|
||||
torch.cos(transform_matrix))
|
||||
x2 = torch.einsum('ak, kn-> an', imag_maps,
|
||||
torch.sin(transform_matrix))
|
||||
y1 = torch.einsum('ak, kn-> an', real_maps,
|
||||
torch.sin(transform_matrix))
|
||||
y2 = torch.einsum('ak, kn-> an', imag_maps,
|
||||
torch.cos(transform_matrix))
|
||||
|
||||
x_maps = x1 - x2
|
||||
y_maps = y1 + y2
|
||||
|
||||
return x_maps, y_maps
|
@ -7,6 +7,7 @@ from shapely.geometry import Polygon
|
||||
from skimage.morphology import skeletonize
|
||||
|
||||
from mmocr.core import points2boundary
|
||||
from mmocr.core.evaluation.utils import boundary_iou
|
||||
|
||||
|
||||
def filter_instance(area, confidence, min_area, min_confidence):
|
||||
@ -24,6 +25,8 @@ def decode(
|
||||
return db_decode(**kwargs)
|
||||
if decoding_type == 'textsnake':
|
||||
return textsnake_decode(**kwargs)
|
||||
if decoding_type == 'fcenet':
|
||||
return fcenet_decode(**kwargs)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@ -391,3 +394,177 @@ def textsnake_decode(preds,
|
||||
boundaries.append(boundary + [score])
|
||||
|
||||
return boundaries
|
||||
|
||||
|
||||
def fcenet_decode(
|
||||
preds,
|
||||
fourier_degree,
|
||||
reconstr_points,
|
||||
scale,
|
||||
alpha=1.0,
|
||||
beta=2.0,
|
||||
text_repr_type='poly',
|
||||
score_thresh=0.8,
|
||||
nms_thresh=0.1,
|
||||
):
|
||||
"""Decoding predictions of FCENet to instances.
|
||||
|
||||
Args:
|
||||
preds (list(Tensor)): The head output tensors.
|
||||
fourier_degree (int): The maximum Fourier transform degree k.
|
||||
reconstr_points (int): The points number of the polygon reconstructed
|
||||
from predicted Fourier coefficients.
|
||||
scale (int): The downsample scale of the prediction.
|
||||
alpha (float) : The parameter to calculate final scores. Score_{final}
|
||||
= (Score_{text region} ^ alpha)
|
||||
* (Score_{text center region}^ beta)
|
||||
beta (float) : The parameter to calculate final score.
|
||||
text_repr_type (str): Boundary encoding type 'poly' or 'quad'.
|
||||
score_thresh (float) : The threshold used to filter out the final
|
||||
candidates.
|
||||
nms_thresh (float) : The threshold of nms.
|
||||
|
||||
Returns:
|
||||
boundaries (list[list[float]]): The instance boundary and confidence
|
||||
list.
|
||||
"""
|
||||
assert isinstance(preds, list)
|
||||
assert len(preds) == 2
|
||||
assert text_repr_type == 'poly'
|
||||
|
||||
cls_pred = preds[0][0]
|
||||
tr_pred = cls_pred[0:2].softmax(dim=0).data.cpu().numpy()
|
||||
tcl_pred = cls_pred[2:].softmax(dim=0).data.cpu().numpy()
|
||||
|
||||
reg_pred = preds[1][0].permute(1, 2, 0).data.cpu().numpy()
|
||||
x_pred = reg_pred[:, :, :2 * fourier_degree + 1]
|
||||
y_pred = reg_pred[:, :, 2 * fourier_degree + 1:]
|
||||
|
||||
score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta)
|
||||
tr_pred_mask = (score_pred) > score_thresh
|
||||
tr_mask = fill_hole(tr_pred_mask)
|
||||
|
||||
tr_contours, _ = cv2.findContours(
|
||||
tr_mask.astype(np.uint8), cv2.RETR_TREE,
|
||||
cv2.CHAIN_APPROX_SIMPLE) # opencv4
|
||||
|
||||
mask = np.zeros_like(tr_mask)
|
||||
exp_matrix = generate_exp_matrix(reconstr_points, fourier_degree)
|
||||
boundaries = []
|
||||
for cont in tr_contours:
|
||||
deal_map = mask.copy().astype(np.int8)
|
||||
cv2.drawContours(deal_map, [cont], -1, 1, -1)
|
||||
|
||||
text_map = score_pred * deal_map
|
||||
polygons = contour_transfor_inv(fourier_degree, x_pred, y_pred,
|
||||
text_map, exp_matrix, scale)
|
||||
polygons = poly_nms(polygons, nms_thresh)
|
||||
boundaries = boundaries + polygons
|
||||
|
||||
boundaries = poly_nms(boundaries, nms_thresh)
|
||||
return boundaries
|
||||
|
||||
|
||||
def poly_nms(polygons, threshold):
|
||||
assert isinstance(polygons, list)
|
||||
|
||||
polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
|
||||
|
||||
keep_poly = []
|
||||
index = [i for i in range(polygons.shape[0])]
|
||||
|
||||
while len(index) > 0:
|
||||
keep_poly.append(polygons[index[-1]].tolist())
|
||||
A = polygons[index[-1]][:-1]
|
||||
index = np.delete(index, -1)
|
||||
|
||||
iou_list = np.zeros((len(index), ))
|
||||
for i in range(len(index)):
|
||||
B = polygons[index[i]][:-1]
|
||||
|
||||
iou_list[i] = boundary_iou(A, B)
|
||||
remove_index = np.where(iou_list > threshold)
|
||||
index = np.delete(index, remove_index)
|
||||
|
||||
return keep_poly
|
||||
|
||||
|
||||
def contour_transfor_inv(fourier_degree, x_pred, y_pred, score_map, exp_matrix,
|
||||
scale):
|
||||
"""Reconstruct polygon from predicts.
|
||||
|
||||
Args:
|
||||
fourier_degree (int): The maximum Fourier degree K.
|
||||
x_pred (ndarray): The real part of predicted Fourier coefficients.
|
||||
y_pred (ndarray): The image part of predicted Fourier coefficients.
|
||||
score_map (ndarray): The final score of candidates.
|
||||
exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt, and shape
|
||||
is (2k+1, n') where n' is reconstructed point number. See Eq.2
|
||||
in paper.
|
||||
scale (int): The down-sample scale.
|
||||
Returns:
|
||||
Polygons (list): The reconstructed polygons and scores.
|
||||
"""
|
||||
mask = score_map > 0
|
||||
|
||||
xy_text = np.argwhere(mask)
|
||||
dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
|
||||
|
||||
x = x_pred[mask]
|
||||
y = y_pred[mask]
|
||||
|
||||
c = x + y * 1j
|
||||
c[:, fourier_degree] = c[:, fourier_degree] + dxy
|
||||
c *= scale
|
||||
|
||||
polygons = fourier_inverse_matrix(c, exp_matrix=exp_matrix)
|
||||
score = score_map[mask].reshape(-1, 1)
|
||||
return np.hstack((polygons, score)).tolist()
|
||||
|
||||
|
||||
def fourier_inverse_matrix(fourier_coeff, exp_matrix):
|
||||
""" Inverse Fourier transform
|
||||
Args:
|
||||
fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1), with
|
||||
n and k being candidates number and Fourier degree respectively.
|
||||
exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt and shape
|
||||
is (2k+1, n') where n' is reconstructed point number.
|
||||
See Eq.2 in paper.
|
||||
Returns:
|
||||
Polygons (ndarray): The reconstructed polygons shaped (n, n')
|
||||
"""
|
||||
|
||||
assert type(fourier_coeff) == np.ndarray
|
||||
assert fourier_coeff.shape[1] == exp_matrix.shape[0]
|
||||
|
||||
n = exp_matrix.shape[1]
|
||||
polygons = np.zeros((fourier_coeff.shape[0], n, 2))
|
||||
|
||||
points = np.matmul(fourier_coeff, exp_matrix)
|
||||
p_x = np.real(points)
|
||||
p_y = np.imag(points)
|
||||
polygons[:, :, 0] = p_x
|
||||
polygons[:, :, 1] = p_y
|
||||
return polygons.astype('int32').reshape(polygons.shape[0], -1)
|
||||
|
||||
|
||||
def generate_exp_matrix(point_num, fourier_degree):
|
||||
""" Generate a matrix of e^x, where x = 2pi x ikt. See Eq.2 in paper.
|
||||
Args:
|
||||
point_num (int): Number of reconstruct points of polygon
|
||||
fourier_degree (int): Maximum Fourier degree k
|
||||
Returns:
|
||||
exp_matrix (ndarray): A matrix of e^x, where x = 2pi x ikt and
|
||||
shape is (2k+1, n') where n' is reconstructed point number.
|
||||
"""
|
||||
e = complex(np.e)
|
||||
exp_matrix = np.zeros([2 * fourier_degree + 1, point_num], dtype='complex')
|
||||
|
||||
temp = np.zeros([point_num], dtype='complex')
|
||||
for i in range(point_num):
|
||||
temp[i] = 2 * np.pi * 1j / point_num * i
|
||||
|
||||
for i in range(2 * fourier_degree + 1):
|
||||
exp_matrix[i, :] = temp * (i - fourier_degree)
|
||||
|
||||
return np.power(e, exp_matrix)
|
||||
|
@ -218,3 +218,27 @@ def test_gen_textsnake_targets(mock_show_feature):
|
||||
assert 'gt_sin_map' in output.keys()
|
||||
assert 'gt_cos_map' in output.keys()
|
||||
mock_show_feature.assert_called_once()
|
||||
|
||||
|
||||
def test_fcenet_generate_targets():
|
||||
fourier_degree = 5
|
||||
target_generator = textdet_targets.FCENetTargets(
|
||||
fourier_degree=fourier_degree)
|
||||
|
||||
h, w, c = (64, 64, 3)
|
||||
text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])],
|
||||
[np.array([20, 0, 30, 0, 30, 10, 20, 10])]]
|
||||
text_polys_ignore = [[np.array([0, 0, 15, 0, 15, 10, 0, 10])]]
|
||||
|
||||
results = {}
|
||||
results['mask_fields'] = []
|
||||
results['img_shape'] = (h, w, c)
|
||||
results['gt_masks_ignore'] = PolygonMasks(text_polys_ignore, h, w)
|
||||
results['gt_masks'] = PolygonMasks(text_polys, h, w)
|
||||
results['gt_bboxes'] = np.array([[0, 0, 10, 10], [20, 0, 30, 10]])
|
||||
results['gt_labels'] = np.array([0, 1])
|
||||
|
||||
target_generator.generate_targets(results)
|
||||
assert 'p3_maps' in results.keys()
|
||||
assert 'p4_maps' in results.keys()
|
||||
assert 'p5_maps' in results.keys()
|
||||
|
@ -166,6 +166,72 @@ def test_affine_jitter():
|
||||
assert np.allclose(np.array(output1), output2['img'])
|
||||
|
||||
|
||||
def test_random_scale():
|
||||
h, w, c = 100, 100, 3
|
||||
img = np.ones((h, w, c), dtype=np.uint8)
|
||||
results = {'img': img, 'img_shape': (h, w, c)}
|
||||
|
||||
polygon = np.array([0., 0., 0., 10., 10., 10., 10., 0.])
|
||||
|
||||
results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2]))
|
||||
results['mask_fields'] = ['gt_masks']
|
||||
|
||||
size = 100
|
||||
scale = (2., 2.)
|
||||
random_scaler = transforms.RandomScaling(size=size, scale=scale)
|
||||
|
||||
results = random_scaler(results)
|
||||
|
||||
out_img = results['img']
|
||||
out_poly = results['gt_masks'].masks[0][0]
|
||||
gt_poly = polygon * 2
|
||||
|
||||
assert np.allclose(out_img.shape, (2 * h, 2 * w, c))
|
||||
assert np.allclose(out_poly, gt_poly)
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.randint' % __name__)
|
||||
def test_random_crop_flip(mock_randint):
|
||||
img = np.ones((10, 10, 3), dtype=np.uint8)
|
||||
img[0, 0, :] = 0
|
||||
results = {'img': img, 'img_shape': img.shape}
|
||||
|
||||
polygon = np.array([0., 0., 0., 10., 10., 10., 10., 0.])
|
||||
|
||||
results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2]))
|
||||
results['gt_masks_ignore'] = PolygonMasks([], *(img.shape[:2]))
|
||||
results['mask_fields'] = ['gt_masks', 'gt_masks_ignore']
|
||||
|
||||
crop_ratio = 1.1
|
||||
iter_num = 3
|
||||
random_crop_fliper = transforms.RandomCropFlip(
|
||||
crop_ratio=crop_ratio, iter_num=iter_num)
|
||||
|
||||
# test crop_target
|
||||
scale = 10
|
||||
all_polys = results['gt_masks'].masks
|
||||
h_axis, w_axis = random_crop_fliper.crop_target(img, all_polys, scale)
|
||||
|
||||
assert np.allclose(h_axis, (0, 11))
|
||||
assert np.allclose(w_axis, (0, 11))
|
||||
|
||||
# test __call__
|
||||
polygon = np.array([1., 1., 1., 9., 9., 9., 9., 1.])
|
||||
results['gt_masks'] = PolygonMasks([[polygon]], *(img.shape[:2]))
|
||||
results['gt_masks_ignore'] = PolygonMasks([[polygon]], *(img.shape[:2]))
|
||||
|
||||
mock_randint.side_effect = [0, 1, 2]
|
||||
results = random_crop_fliper(results)
|
||||
|
||||
out_img = results['img']
|
||||
out_poly = results['gt_masks'].masks[0][0]
|
||||
gt_img = img
|
||||
gt_poly = polygon
|
||||
|
||||
assert np.allclose(out_img, gt_img)
|
||||
assert np.allclose(out_poly, gt_poly)
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.random_sample' % __name__)
|
||||
@mock.patch('%s.transforms.np.random.randint' % __name__)
|
||||
def test_random_crop_poly_instances(mock_randint, mock_sample):
|
||||
|
@ -372,3 +372,60 @@ def test_textsnake(cfg_file):
|
||||
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
|
||||
img = np.random.rand(5, 5)
|
||||
detector.show_result(img, results)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||
@pytest.mark.parametrize(
|
||||
'cfg_file', ['textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py'])
|
||||
def test_fcenet(cfg_file):
|
||||
model = _get_detector_cfg(cfg_file)
|
||||
model['pretrained'] = None
|
||||
model['backbone']['norm_cfg']['type'] = 'BN'
|
||||
|
||||
from mmocr.models import build_detector
|
||||
detector = build_detector(model)
|
||||
detector = detector.cuda()
|
||||
|
||||
fourier_degree = 5
|
||||
input_shape = (1, 3, 256, 256)
|
||||
(n, c, h, w) = input_shape
|
||||
|
||||
imgs = torch.randn(n, c, h, w).float().cuda()
|
||||
img_metas = [{
|
||||
'img_shape': (h, w, c),
|
||||
'ori_shape': (h, w, c),
|
||||
'pad_shape': (h, w, c),
|
||||
'filename': '<demo>.png',
|
||||
'scale_factor': np.array([1, 1, 1, 1]),
|
||||
'flip': False,
|
||||
} for _ in range(n)]
|
||||
|
||||
p3_maps = []
|
||||
p4_maps = []
|
||||
p5_maps = []
|
||||
for _ in range(n):
|
||||
p3_maps.append(
|
||||
np.random.random((5 + 4 * fourier_degree, h // 8, w // 8)))
|
||||
p4_maps.append(
|
||||
np.random.random((5 + 4 * fourier_degree, h // 16, w // 16)))
|
||||
p5_maps.append(
|
||||
np.random.random((5 + 4 * fourier_degree, h // 32, w // 32)))
|
||||
|
||||
# Test forward train
|
||||
losses = detector.forward(
|
||||
imgs, img_metas, p3_maps=p3_maps, p4_maps=p4_maps, p5_maps=p5_maps)
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
# Test forward test
|
||||
with torch.no_grad():
|
||||
img_list = [g[None, :] for g in imgs]
|
||||
batch_results = []
|
||||
for one_img, one_meta in zip(img_list, img_metas):
|
||||
result = detector.forward([one_img], [[one_meta]],
|
||||
return_loss=False)
|
||||
batch_results.append(result)
|
||||
|
||||
# Test show result
|
||||
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
|
||||
img = np.random.rand(5, 5)
|
||||
detector.show_result(img, results)
|
||||
|
@ -31,3 +31,42 @@ def test_textsnakeloss():
|
||||
bce_loss = textsnakeloss.balanced_bce_loss(pred, target, mask).item()
|
||||
|
||||
assert np.allclose(bce_loss, 0)
|
||||
|
||||
|
||||
def test_fcenetloss():
|
||||
k = 5
|
||||
fcenetloss = losses.FCELoss(fourier_degree=k, sample_num=10)
|
||||
|
||||
input_shape = (1, 3, 64, 64)
|
||||
(n, c, h, w) = input_shape
|
||||
|
||||
# test ohem
|
||||
pred = torch.ones((200, 2), dtype=torch.float)
|
||||
target = torch.ones((200, ), dtype=torch.long)
|
||||
target[20:] = 0
|
||||
mask = torch.ones((200, ), dtype=torch.long)
|
||||
|
||||
ohem_loss1 = fcenetloss.ohem(pred, target, mask)
|
||||
ohem_loss2 = fcenetloss.ohem(pred, target, 1 - mask)
|
||||
assert isinstance(ohem_loss1, torch.Tensor)
|
||||
assert isinstance(ohem_loss2, torch.Tensor)
|
||||
|
||||
# test forward
|
||||
preds = []
|
||||
for i in range(n):
|
||||
scale = 8 * 2**i
|
||||
pred = []
|
||||
pred.append(torch.rand(n, 4, h // scale, w // scale))
|
||||
pred.append(torch.rand(n, 4 * k + 2, h // scale, w // scale))
|
||||
preds.append(pred)
|
||||
|
||||
p3_maps = []
|
||||
p4_maps = []
|
||||
p5_maps = []
|
||||
for _ in range(n):
|
||||
p3_maps.append(np.random.random((5 + 4 * k, h // 8, w // 8)))
|
||||
p4_maps.append(np.random.random((5 + 4 * k, h // 16, w // 16)))
|
||||
p5_maps.append(np.random.random((5 + 4 * k, h // 32, w // 32)))
|
||||
|
||||
loss = fcenetloss(preds, 0, p3_maps, p4_maps, p5_maps)
|
||||
assert isinstance(loss, dict)
|
||||
|
@ -12,3 +12,17 @@ def test_db_boxes_from_bitmaps():
|
||||
|
||||
boxes = db_decode(preds, text_repr_type='quad', min_text_width=0)
|
||||
assert len(boxes) == 1
|
||||
|
||||
|
||||
def test_fcenet_decode():
|
||||
from mmocr.models.textdet.postprocess.wrapper import fcenet_decode
|
||||
|
||||
k = 5
|
||||
preds = []
|
||||
preds.append(torch.randn(1, 4, 40, 40))
|
||||
preds.append(torch.randn(1, 4 * k + 2, 40, 40))
|
||||
|
||||
boundaries = fcenet_decode(
|
||||
preds=preds, fourier_degree=k, reconstr_points=50, scale=1)
|
||||
|
||||
assert isinstance(boundaries, list)
|
||||
|
Loading…
x
Reference in New Issue
Block a user