mirror of https://github.com/open-mmlab/mmocr.git
Add drrg (#189)
* merge drrg * directory structure&fix redundant import * docstrings * fix isort * drrg readme * merge drrg * directory structure&fix redundant import * docstrings * fix isort * drrg readme * add unittests&fix docstrings * revert test_loss * add unittest * add unittests * fix docstrings * fix docstrings * fix yapf * fix yapf * Update test_textdet_head.py * Update test_textdet_head.py * add unittests * add unittests * add unittests * fix docstrings * fix docstrings * fix docstring * fix unittests * fix pytest * fix pytest * fix pytest * fix variable names Co-authored-by: Hongbin Sun <hongbin306@gmail.com>pull/204/head
parent
ed6b3b890a
commit
2414c65577
|
@ -0,0 +1,23 @@
|
|||
# DRRG
|
||||
|
||||
## Introduction
|
||||
|
||||
[ALGORITHM]
|
||||
|
||||
```bibtex
|
||||
@article{zhang2020drrg,
|
||||
title={Deep relational reasoning graph network for arbitrary shape text detection},
|
||||
author={Zhang, Shi-Xue and Zhu, Xiaobin and Hou, Jie-Bo and Liu, Chang and Yang, Chun and Wang, Hongfa and Yin, Xu-Cheng},
|
||||
booktitle={CVPR},
|
||||
pages={9699-9708},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### CTW1500
|
||||
|
||||
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
|
||||
| :--------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| [DRRG](/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1200 | 640 | 0.822 | 0.858 | 0.840 | [model](https://download.openmmlab.com/mmocr/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth) \ [log](https://download.openmmlab.com/mmocr/textdet/drrg/20210511_234719.log) |
|
|
@ -0,0 +1,110 @@
|
|||
_base_ = [
|
||||
'../../_base_/schedules/schedule_1200e.py',
|
||||
'../../_base_/default_runtime.py'
|
||||
]
|
||||
model = dict(
|
||||
type='DRRG',
|
||||
pretrained='torchvision://resnet50',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=True,
|
||||
style='caffe'),
|
||||
neck=dict(
|
||||
type='FPN_UNet', in_channels=[256, 512, 1024, 2048], out_channels=32),
|
||||
bbox_head=dict(
|
||||
type='DRRGHead',
|
||||
in_channels=32,
|
||||
text_region_thr=0.3,
|
||||
center_region_thr=0.4,
|
||||
link_thr=0.80,
|
||||
loss=dict(type='DRRGLoss')))
|
||||
train_cfg = None
|
||||
test_cfg = None
|
||||
|
||||
dataset_type = 'IcdarDataset'
|
||||
data_root = 'data/ctw1500/'
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='LoadTextAnnotations',
|
||||
with_bbox=True,
|
||||
with_mask=True,
|
||||
poly2mask=False),
|
||||
dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='RandomScaling', size=800, scale=(0.75, 2.5)),
|
||||
dict(
|
||||
type='RandomCropFlip', crop_ratio=0.5, iter_num=1, min_area_ratio=0.2),
|
||||
dict(
|
||||
type='RandomCropPolyInstances',
|
||||
instance_key='gt_masks',
|
||||
crop_ratio=0.8,
|
||||
min_side_ratio=0.3),
|
||||
dict(
|
||||
type='RandomRotatePolyInstances',
|
||||
rotate_ratio=0.5,
|
||||
max_angle=60,
|
||||
pad_with_fixed_color=False),
|
||||
dict(type='SquareResizePad', target_size=800, pad_ratio=0.6),
|
||||
dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
|
||||
dict(type='DRRGTargets'),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(
|
||||
type='CustomFormatBundle',
|
||||
keys=[
|
||||
'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
|
||||
'gt_top_height_map', 'gt_bot_height_map', 'gt_sin_map',
|
||||
'gt_cos_map', 'gt_comp_attribs'
|
||||
],
|
||||
visualize=dict(flag=False, boundary_key='gt_text_mask')),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=[
|
||||
'img', 'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
|
||||
'gt_top_height_map', 'gt_bot_height_map', 'gt_sin_map',
|
||||
'gt_cos_map', 'gt_comp_attribs'
|
||||
])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(1024, 640),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', img_scale=(1024, 640), keep_ratio=True),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=f'{data_root}/instances_training.json',
|
||||
img_prefix=f'{data_root}/imgs',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
ann_file=f'{data_root}/instances_test.json',
|
||||
img_prefix=f'{data_root}/imgs',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file=f'{data_root}/instances_test.json',
|
||||
img_prefix=f'{data_root}/imgs',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=20, metric='hmean-iou')
|
|
@ -15,7 +15,7 @@ model = dict(
|
|||
norm_eval=True,
|
||||
style='caffe'),
|
||||
neck=dict(
|
||||
type='FPN_UNET', in_channels=[256, 512, 1024, 2048], out_channels=32),
|
||||
type='FPN_UNet', in_channels=[256, 512, 1024, 2048], out_channels=32),
|
||||
bbox_head=dict(
|
||||
type='TextSnakeHead',
|
||||
in_channels=32,
|
||||
|
@ -96,18 +96,18 @@ data = dict(
|
|||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
img_prefix=data_root + '/imgs',
|
||||
ann_file=f'{data_root}/instances_training.json',
|
||||
img_prefix=f'{data_root}/imgs',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_test.json',
|
||||
img_prefix=data_root + '/imgs',
|
||||
ann_file=f'{data_root}/instances_test.json',
|
||||
img_prefix=f'{data_root}/imgs',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_test.json',
|
||||
img_prefix=data_root + '/imgs',
|
||||
ann_file=f'{data_root}/instances_test.json',
|
||||
img_prefix=f'{data_root}/imgs',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=10, metric='hmean-iou')
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .base_textdet_targets import BaseTextDetTargets
|
||||
from .dbnet_targets import DBNetTargets
|
||||
from .drrg_targets import DRRGTargets
|
||||
from .fcenet_targets import FCENetTargets
|
||||
from .panet_targets import PANetTargets
|
||||
from .psenet_targets import PSENetTargets
|
||||
|
@ -7,5 +8,5 @@ from .textsnake_targets import TextSnakeTargets
|
|||
|
||||
__all__ = [
|
||||
'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets',
|
||||
'FCENetTargets', 'TextSnakeTargets'
|
||||
'FCENetTargets', 'TextSnakeTargets', 'DRRGTargets'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,533 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from lanms import merge_quadrangle_n9 as la_nms
|
||||
from numpy.linalg import norm
|
||||
|
||||
import mmocr.utils.check_argument as check_argument
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from .textsnake_targets import TextSnakeTargets
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class DRRGTargets(TextSnakeTargets):
|
||||
"""Generate the ground truth targets of DRRG: Deep Relational Reasoning
|
||||
Graph Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]. This code was partially adapted from
|
||||
https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
orientation_thr (float): The threshold for distinguishing between
|
||||
head edge and tail edge among the horizontal and vertical edges
|
||||
of a quadrangle.
|
||||
resample_step (float): The step size for resampling the text center
|
||||
line.
|
||||
num_min_comps (int): The minimum number of text components, which
|
||||
should be larger than k_hop1 mentioned in paper.
|
||||
num_max_comps (int): The maximum number of text components.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
center_region_shrink_ratio (float): The shrink ratio of text center
|
||||
regions.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
comp_w_h_ratio (float): The width to height ratio of text components.
|
||||
min_rand_half_height(float): The minimum half-height of random text
|
||||
components.
|
||||
max_rand_half_height (float): The maximum half-height of random
|
||||
text components.
|
||||
jitter_level (float): The jitter level of text component geometric
|
||||
features.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
orientation_thr=2.0,
|
||||
resample_step=8.0,
|
||||
num_min_comps=9,
|
||||
num_max_comps=600,
|
||||
min_width=8.0,
|
||||
max_width=24.0,
|
||||
center_region_shrink_ratio=0.3,
|
||||
comp_shrink_ratio=1.0,
|
||||
comp_w_h_ratio=0.3,
|
||||
text_comp_nms_thr=0.25,
|
||||
min_rand_half_height=8.0,
|
||||
max_rand_half_height=24.0,
|
||||
jitter_level=0.2):
|
||||
|
||||
super().__init__()
|
||||
self.orientation_thr = orientation_thr
|
||||
self.resample_step = resample_step
|
||||
self.num_max_comps = num_max_comps
|
||||
self.num_min_comps = num_min_comps
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.center_region_shrink_ratio = center_region_shrink_ratio
|
||||
self.comp_shrink_ratio = comp_shrink_ratio
|
||||
self.comp_w_h_ratio = comp_w_h_ratio
|
||||
self.text_comp_nms_thr = text_comp_nms_thr
|
||||
self.min_rand_half_height = min_rand_half_height
|
||||
self.max_rand_half_height = max_rand_half_height
|
||||
self.jitter_level = jitter_level
|
||||
|
||||
def dist_point2line(self, point, line):
|
||||
|
||||
assert isinstance(line, tuple)
|
||||
point1, point2 = line
|
||||
d = abs(np.cross(point2 - point1, point - point1)) / (
|
||||
norm(point2 - point1) + 1e-8)
|
||||
return d
|
||||
|
||||
def draw_center_region_maps(self, top_line, bot_line, center_line,
|
||||
center_region_mask, top_height_map,
|
||||
bot_height_map, sin_map, cos_map,
|
||||
region_shrink_ratio):
|
||||
"""Draw attributes of text components on text center regions.
|
||||
|
||||
Args:
|
||||
top_line (ndarray): The points composing the top side lines of text
|
||||
polygons.
|
||||
bot_line (ndarray): The points composing bottom side lines of text
|
||||
polygons.
|
||||
center_line (ndarray): The points composing the center lines of
|
||||
text instances.
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
top_height_map (ndarray): The map on which the distance from points
|
||||
to top side lines will be drawn for each pixel in text center
|
||||
regions.
|
||||
bot_height_map (ndarray): The map on which the distance from points
|
||||
to bottom side lines will be drawn for each pixel in text
|
||||
center regions.
|
||||
sin_map (ndarray): The map of vector_sin(top_point - bot_point)
|
||||
that will be drawn on text center regions.
|
||||
cos_map (ndarray): The map of vector_cos(top_point - bot_point)
|
||||
will be drawn on text center regions.
|
||||
region_shrink_ratio (float): The shrink ratio of text center
|
||||
regions.
|
||||
"""
|
||||
|
||||
assert top_line.shape == bot_line.shape == center_line.shape
|
||||
assert (center_region_mask.shape == top_height_map.shape ==
|
||||
bot_height_map.shape == sin_map.shape == cos_map.shape)
|
||||
assert isinstance(region_shrink_ratio, float)
|
||||
|
||||
h, w = center_region_mask.shape
|
||||
for i in range(0, len(center_line) - 1):
|
||||
|
||||
top_mid_point = (top_line[i] + top_line[i + 1]) / 2
|
||||
bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
|
||||
|
||||
sin_theta = self.vector_sin(top_mid_point - bot_mid_point)
|
||||
cos_theta = self.vector_cos(top_mid_point - bot_mid_point)
|
||||
|
||||
tl = center_line[i] + (top_line[i] -
|
||||
center_line[i]) * region_shrink_ratio
|
||||
tr = center_line[i + 1] + (
|
||||
top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
|
||||
br = center_line[i + 1] + (
|
||||
bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
|
||||
bl = center_line[i] + (bot_line[i] -
|
||||
center_line[i]) * region_shrink_ratio
|
||||
current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
|
||||
|
||||
cv2.fillPoly(center_region_mask, [current_center_box], color=1)
|
||||
cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
|
||||
cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
|
||||
|
||||
current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0,
|
||||
w - 1)
|
||||
current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0,
|
||||
h - 1)
|
||||
min_coord = np.min(current_center_box, axis=0).astype(np.int32)
|
||||
max_coord = np.max(current_center_box, axis=0).astype(np.int32)
|
||||
current_center_box = current_center_box - min_coord
|
||||
box_sz = (max_coord - min_coord + 1)
|
||||
|
||||
center_box_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
|
||||
cv2.fillPoly(center_box_mask, [current_center_box], color=1)
|
||||
|
||||
inds = np.argwhere(center_box_mask > 0)
|
||||
inds = inds + (min_coord[1], min_coord[0])
|
||||
inds_xy = np.fliplr(inds)
|
||||
top_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
|
||||
inds_xy, (top_line[i], top_line[i + 1]))
|
||||
bot_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
|
||||
inds_xy, (bot_line[i], bot_line[i + 1]))
|
||||
|
||||
def generate_center_mask_attrib_maps(self, img_size, text_polys):
|
||||
"""Generate text center region masks and geometric attribute maps.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
center_lines (list): The list of text center lines.
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
top_height_map (ndarray): The map on which the distance from points
|
||||
to top side lines will be drawn for each pixel in text center
|
||||
regions.
|
||||
bot_height_map (ndarray): The map on which the distance from points
|
||||
to bottom side lines will be drawn for each pixel in text
|
||||
center regions.
|
||||
sin_map (ndarray): The sin(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
cos_map (ndarray): The cos(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
|
||||
center_lines = []
|
||||
center_region_mask = np.zeros((h, w), np.uint8)
|
||||
top_height_map = np.zeros((h, w), dtype=np.float32)
|
||||
bot_height_map = np.zeros((h, w), dtype=np.float32)
|
||||
sin_map = np.zeros((h, w), dtype=np.float32)
|
||||
cos_map = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
for poly in text_polys:
|
||||
assert len(poly) == 1
|
||||
polygon_points = poly[0].reshape(-1, 2)
|
||||
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
|
||||
resampled_top_line, resampled_bot_line = self.resample_sidelines(
|
||||
top_line, bot_line, self.resample_step)
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
center_line = (resampled_top_line + resampled_bot_line) / 2
|
||||
|
||||
if self.vector_slope(center_line[-1] - center_line[0]) > 2:
|
||||
if (center_line[-1] - center_line[0])[1] < 0:
|
||||
center_line = center_line[::-1]
|
||||
resampled_top_line = resampled_top_line[::-1]
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
else:
|
||||
if (center_line[-1] - center_line[0])[0] < 0:
|
||||
center_line = center_line[::-1]
|
||||
resampled_top_line = resampled_top_line[::-1]
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
|
||||
line_head_shrink_len = np.clip(
|
||||
(norm(top_line[0] - bot_line[0]) * self.comp_w_h_ratio),
|
||||
self.min_width, self.max_width) / 2
|
||||
line_tail_shrink_len = np.clip(
|
||||
(norm(top_line[-1] - bot_line[-1]) * self.comp_w_h_ratio),
|
||||
self.min_width, self.max_width) / 2
|
||||
num_head_shrink = int(line_head_shrink_len // self.resample_step)
|
||||
num_tail_shrink = int(line_tail_shrink_len // self.resample_step)
|
||||
if len(center_line) > num_head_shrink + num_tail_shrink + 2:
|
||||
center_line = center_line[num_head_shrink:len(center_line) -
|
||||
num_tail_shrink]
|
||||
resampled_top_line = resampled_top_line[
|
||||
num_head_shrink:len(resampled_top_line) - num_tail_shrink]
|
||||
resampled_bot_line = resampled_bot_line[
|
||||
num_head_shrink:len(resampled_bot_line) - num_tail_shrink]
|
||||
center_lines.append(center_line.astype(np.int32))
|
||||
|
||||
self.draw_center_region_maps(resampled_top_line,
|
||||
resampled_bot_line, center_line,
|
||||
center_region_mask, top_height_map,
|
||||
bot_height_map, sin_map, cos_map,
|
||||
self.center_region_shrink_ratio)
|
||||
|
||||
return (center_lines, center_region_mask, top_height_map,
|
||||
bot_height_map, sin_map, cos_map)
|
||||
|
||||
def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
|
||||
"""Generate random text components and their attributes to ensure the
|
||||
the number of text components in an image is larger than k_hop1, which
|
||||
is the number of one hop neighbors in KNN graph.
|
||||
|
||||
Args:
|
||||
num_rand_comps (int): The number of random text components.
|
||||
center_sample_mask (ndarray): The region mask for sampling text
|
||||
component centers .
|
||||
|
||||
Returns:
|
||||
rand_comp_attribs (ndarray): The random text component attributes
|
||||
(x, y, h, w, cos, sin, comp_label=0).
|
||||
"""
|
||||
|
||||
assert isinstance(num_rand_comps, int)
|
||||
assert num_rand_comps > 0
|
||||
assert center_sample_mask.ndim == 2
|
||||
|
||||
h, w = center_sample_mask.shape
|
||||
|
||||
max_rand_half_height = self.max_rand_half_height
|
||||
min_rand_half_height = self.min_rand_half_height
|
||||
max_rand_height = max_rand_half_height * 2
|
||||
max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
|
||||
self.min_width, self.max_width)
|
||||
margin = int(
|
||||
np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
|
||||
|
||||
if 2 * margin + 1 > min(h, w):
|
||||
|
||||
assert min(h, w) > (np.sqrt(2) * (self.min_width + 1))
|
||||
max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1)
|
||||
min_rand_half_height = max(max_rand_half_height / 4,
|
||||
self.min_width / 2)
|
||||
|
||||
max_rand_height = max_rand_half_height * 2
|
||||
max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
|
||||
self.min_width, self.max_width)
|
||||
margin = int(
|
||||
np.sqrt((max_rand_height / 2)**2 +
|
||||
(max_rand_width / 2)**2)) + 1
|
||||
|
||||
inner_center_sample_mask = np.zeros_like(center_sample_mask)
|
||||
inner_center_sample_mask[margin:h - margin, margin:w - margin] = \
|
||||
center_sample_mask[margin:h - margin, margin:w - margin]
|
||||
kernel_size = int(np.clip(max_rand_half_height, 7, 21))
|
||||
inner_center_sample_mask = cv2.erode(
|
||||
inner_center_sample_mask,
|
||||
np.ones((kernel_size, kernel_size), np.uint8))
|
||||
|
||||
center_candidates = np.argwhere(inner_center_sample_mask > 0)
|
||||
num_center_candidates = len(center_candidates)
|
||||
sample_inds = np.random.choice(num_center_candidates, num_rand_comps)
|
||||
rand_centers = center_candidates[sample_inds]
|
||||
|
||||
rand_top_height = np.random.randint(
|
||||
min_rand_half_height,
|
||||
max_rand_half_height,
|
||||
size=(len(rand_centers), 1))
|
||||
rand_bot_height = np.random.randint(
|
||||
min_rand_half_height,
|
||||
max_rand_half_height,
|
||||
size=(len(rand_centers), 1))
|
||||
|
||||
rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
|
||||
rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
|
||||
scale = np.sqrt(1.0 / (rand_cos**2 + rand_sin**2 + 1e-8))
|
||||
rand_cos = rand_cos * scale
|
||||
rand_sin = rand_sin * scale
|
||||
|
||||
height = (rand_top_height + rand_bot_height)
|
||||
width = np.clip(height * self.comp_w_h_ratio, self.min_width,
|
||||
self.max_width)
|
||||
|
||||
rand_comp_attribs = np.hstack([
|
||||
rand_centers[:, ::-1], height, width, rand_cos, rand_sin,
|
||||
np.zeros_like(rand_sin)
|
||||
]).astype(np.float32)
|
||||
|
||||
return rand_comp_attribs
|
||||
|
||||
def jitter_comp_attribs(self, comp_attribs, jitter_level):
|
||||
"""Jitter text components attributes.
|
||||
|
||||
Args:
|
||||
comp_attribs (ndarray): The text component attributes.
|
||||
jitter_level (float): The jitter level of text components
|
||||
attributes.
|
||||
|
||||
Returns:
|
||||
jittered_comp_attribs (ndarray): The jittered text component
|
||||
attributes (x, y, h, w, cos, sin, comp_label).
|
||||
"""
|
||||
|
||||
assert comp_attribs.shape[1] == 7
|
||||
assert comp_attribs.shape[0] > 0
|
||||
assert isinstance(jitter_level, float)
|
||||
|
||||
x = comp_attribs[:, 0].reshape((-1, 1))
|
||||
y = comp_attribs[:, 1].reshape((-1, 1))
|
||||
h = comp_attribs[:, 2].reshape((-1, 1))
|
||||
w = comp_attribs[:, 3].reshape((-1, 1))
|
||||
cos = comp_attribs[:, 4].reshape((-1, 1))
|
||||
sin = comp_attribs[:, 5].reshape((-1, 1))
|
||||
comp_labels = comp_attribs[:, 6].reshape((-1, 1))
|
||||
|
||||
x += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * (h * np.abs(cos) + w * np.abs(sin)) * jitter_level
|
||||
y += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * (h * np.abs(sin) + w * np.abs(cos)) * jitter_level
|
||||
|
||||
h += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * h * jitter_level
|
||||
w += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * w * jitter_level
|
||||
|
||||
cos += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * 2 * jitter_level
|
||||
sin += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * 2 * jitter_level
|
||||
|
||||
scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8))
|
||||
cos = cos * scale
|
||||
sin = sin * scale
|
||||
|
||||
jittered_comp_attribs = np.hstack([x, y, h, w, cos, sin, comp_labels])
|
||||
|
||||
return jittered_comp_attribs
|
||||
|
||||
def generate_comp_attribs(self, center_lines, text_mask,
|
||||
center_region_mask, top_height_map,
|
||||
bot_height_map, sin_map, cos_map):
|
||||
"""Generate text component attributes.
|
||||
|
||||
Args:
|
||||
center_lines (list[ndarray]): The list of text center lines .
|
||||
text_mask (ndarray): The text region mask.
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
top_height_map (ndarray): The map on which the distance from points
|
||||
to top side lines will be drawn for each pixel in text center
|
||||
regions.
|
||||
bot_height_map (ndarray): The map on which the distance from points
|
||||
to bottom side lines will be drawn for each pixel in text
|
||||
center regions.
|
||||
sin_map (ndarray): The sin(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
cos_map (ndarray): The cos(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
|
||||
Returns:
|
||||
pad_comp_attribs (ndarray): The padded text component attributes
|
||||
of a fixed size.
|
||||
"""
|
||||
|
||||
assert isinstance(center_lines, list)
|
||||
assert (text_mask.shape == center_region_mask.shape ==
|
||||
top_height_map.shape == bot_height_map.shape == sin_map.shape
|
||||
== cos_map.shape)
|
||||
|
||||
center_lines_mask = np.zeros_like(center_region_mask)
|
||||
cv2.polylines(center_lines_mask, center_lines, 0, 1, 1)
|
||||
center_lines_mask = center_lines_mask * center_region_mask
|
||||
comp_centers = np.argwhere(center_lines_mask > 0)
|
||||
|
||||
y = comp_centers[:, 0]
|
||||
x = comp_centers[:, 1]
|
||||
|
||||
top_height = top_height_map[y, x].reshape(
|
||||
(-1, 1)) * self.comp_shrink_ratio
|
||||
bot_height = bot_height_map[y, x].reshape(
|
||||
(-1, 1)) * self.comp_shrink_ratio
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
|
||||
top_mid_points = comp_centers + np.hstack(
|
||||
[top_height * sin, top_height * cos])
|
||||
bot_mid_points = comp_centers - np.hstack(
|
||||
[bot_height * sin, bot_height * cos])
|
||||
|
||||
width = (top_height + bot_height) * self.comp_w_h_ratio
|
||||
width = np.clip(width, self.min_width, self.max_width)
|
||||
r = width / 2
|
||||
|
||||
tl = top_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
|
||||
tr = top_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
|
||||
br = bot_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
|
||||
bl = bot_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
|
||||
text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
|
||||
|
||||
score = np.ones((text_comps.shape[0], 1), dtype=np.float32)
|
||||
text_comps = np.hstack([text_comps, score])
|
||||
text_comps = la_nms(text_comps, self.text_comp_nms_thr)
|
||||
|
||||
if text_comps.shape[0] >= 1:
|
||||
img_h, img_w = center_region_mask.shape
|
||||
text_comps[:, 0:8:2] = np.clip(text_comps[:, 0:8:2], 0, img_w - 1)
|
||||
text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1)
|
||||
|
||||
comp_centers = np.mean(
|
||||
text_comps[:, 0:8].reshape((-1, 4, 2)),
|
||||
axis=1).astype(np.int32)
|
||||
x = comp_centers[:, 0]
|
||||
y = comp_centers[:, 1]
|
||||
|
||||
height = (top_height_map[y, x] + bot_height_map[y, x]).reshape(
|
||||
(-1, 1))
|
||||
width = np.clip(height * self.comp_w_h_ratio, self.min_width,
|
||||
self.max_width)
|
||||
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
|
||||
_, comp_label_mask = cv2.connectedComponents(
|
||||
center_region_mask, connectivity=8)
|
||||
comp_labels = comp_label_mask[y, x].reshape(
|
||||
(-1, 1)).astype(np.float32)
|
||||
|
||||
x = x.reshape((-1, 1)).astype(np.float32)
|
||||
y = y.reshape((-1, 1)).astype(np.float32)
|
||||
comp_attribs = np.hstack(
|
||||
[x, y, height, width, cos, sin, comp_labels])
|
||||
comp_attribs = self.jitter_comp_attribs(comp_attribs,
|
||||
self.jitter_level)
|
||||
|
||||
if comp_attribs.shape[0] < self.num_min_comps:
|
||||
num_rand_comps = self.num_min_comps - comp_attribs.shape[0]
|
||||
rand_comp_attribs = self.generate_rand_comp_attribs(
|
||||
num_rand_comps, 1 - text_mask)
|
||||
comp_attribs = np.vstack([comp_attribs, rand_comp_attribs])
|
||||
else:
|
||||
comp_attribs = self.generate_rand_comp_attribs(
|
||||
self.num_min_comps, 1 - text_mask)
|
||||
|
||||
num_comps = (
|
||||
np.ones((comp_attribs.shape[0], 1), dtype=np.float32) *
|
||||
comp_attribs.shape[0])
|
||||
comp_attribs = np.hstack([num_comps, comp_attribs])
|
||||
|
||||
if comp_attribs.shape[0] > self.num_max_comps:
|
||||
comp_attribs = comp_attribs[:self.num_max_comps, :]
|
||||
comp_attribs[:, 0] = self.num_max_comps
|
||||
|
||||
pad_comp_attribs = np.zeros(
|
||||
(self.num_max_comps, comp_attribs.shape[1]), dtype=np.float32)
|
||||
pad_comp_attribs[:comp_attribs.shape[0], :] = comp_attribs
|
||||
|
||||
return pad_comp_attribs
|
||||
|
||||
def generate_targets(self, results):
|
||||
"""Generate the gt targets for DRRG.
|
||||
|
||||
Args:
|
||||
results (dict): The input result dictionary.
|
||||
|
||||
Returns:
|
||||
results (dict): The output result dictionary.
|
||||
"""
|
||||
|
||||
assert isinstance(results, dict)
|
||||
|
||||
polygon_masks = results['gt_masks'].masks
|
||||
polygon_masks_ignore = results['gt_masks_ignore'].masks
|
||||
|
||||
h, w, _ = results['img_shape']
|
||||
|
||||
gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
|
||||
gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
|
||||
(center_lines, gt_center_region_mask, gt_top_height_map,
|
||||
gt_bot_height_map, gt_sin_map,
|
||||
gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
|
||||
polygon_masks)
|
||||
|
||||
gt_comp_attribs = self.generate_comp_attribs(center_lines,
|
||||
gt_text_mask,
|
||||
gt_center_region_mask,
|
||||
gt_top_height_map,
|
||||
gt_bot_height_map,
|
||||
gt_sin_map, gt_cos_map)
|
||||
|
||||
results['mask_fields'].clear() # rm gt_masks encoded by polygons
|
||||
mapping = {
|
||||
'gt_text_mask': gt_text_mask,
|
||||
'gt_center_region_mask': gt_center_region_mask,
|
||||
'gt_mask': gt_mask,
|
||||
'gt_top_height_map': gt_top_height_map,
|
||||
'gt_bot_height_map': gt_bot_height_map,
|
||||
'gt_sin_map': gt_sin_map,
|
||||
'gt_cos_map': gt_cos_map
|
||||
}
|
||||
for key, value in mapping.items():
|
||||
value = value if isinstance(value, list) else [value]
|
||||
results[key] = BitmapMasks(value, h, w)
|
||||
results['mask_fields'].append(key)
|
||||
|
||||
results['gt_comp_attribs'] = gt_comp_attribs
|
||||
return results
|
|
@ -1,4 +1,5 @@
|
|||
from .db_head import DBHead
|
||||
from .drrg_head import DRRGHead
|
||||
from .fce_head import FCEHead
|
||||
from .head_mixin import HeadMixin
|
||||
from .pan_head import PANHead
|
||||
|
@ -6,5 +7,6 @@ from .pse_head import PSEHead
|
|||
from .textsnake_head import TextSnakeHead
|
||||
|
||||
__all__ = [
|
||||
'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'HeadMixin', 'TextSnakeHead'
|
||||
'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'HeadMixin', 'TextSnakeHead',
|
||||
'DRRGHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,219 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import normal_init
|
||||
|
||||
from mmdet.models.builder import HEADS, build_loss
|
||||
from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs
|
||||
from mmocr.models.textdet.postprocess import decode
|
||||
from mmocr.utils import check_argument
|
||||
from .head_mixin import HeadMixin
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class DRRGHead(HeadMixin, nn.Module):
|
||||
"""The class for DRRG head: Deep Relational Reasoning Graph Network for
|
||||
Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]
|
||||
|
||||
Args:
|
||||
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2.
|
||||
num_adjacent_linkages (int): The number of linkages when constructing
|
||||
adjacent matrix.
|
||||
node_geo_feat_len (int): The length of embedded geometric feature
|
||||
vector of a component.
|
||||
pooling_scale (float): The spatial scale of rotated RoI-Align.
|
||||
pooling_output_size (tuple(int)): The output size of RRoI-Aligning.
|
||||
nms_thr (float): The locality-aware NMS threshold of text components.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
comp_ratio (float): The reciprocal of aspect ratio of text components.
|
||||
comp_score_thr (float): The score threshold of text components.
|
||||
text_region_thr (float): The threshold for text region probability map.
|
||||
center_region_thr (float): The threshold for text center region
|
||||
probability map.
|
||||
center_region_area_thr (int): The threshold for filtering small-sized
|
||||
text center region.
|
||||
local_graph_thr (float): The threshold to filter identical local
|
||||
graphs.
|
||||
link_thr(float): The threshold for connected components search.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
k_at_hops=(8, 4),
|
||||
num_adjacent_linkages=3,
|
||||
node_geo_feat_len=120,
|
||||
pooling_scale=1.0,
|
||||
pooling_output_size=(4, 3),
|
||||
nms_thr=0.3,
|
||||
min_width=8.0,
|
||||
max_width=24.0,
|
||||
comp_shrink_ratio=1.03,
|
||||
comp_ratio=0.4,
|
||||
comp_score_thr=0.3,
|
||||
text_region_thr=0.2,
|
||||
center_region_thr=0.2,
|
||||
center_region_area_thr=50,
|
||||
local_graph_thr=0.7,
|
||||
link_thr=0.85,
|
||||
loss=dict(type='DRRGLoss'),
|
||||
train_cfg=None,
|
||||
test_cfg=None):
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(k_at_hops, tuple)
|
||||
assert isinstance(num_adjacent_linkages, int)
|
||||
assert isinstance(node_geo_feat_len, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert isinstance(pooling_output_size, tuple)
|
||||
assert isinstance(comp_shrink_ratio, float)
|
||||
assert isinstance(nms_thr, float)
|
||||
assert isinstance(min_width, float)
|
||||
assert isinstance(max_width, float)
|
||||
assert isinstance(comp_ratio, float)
|
||||
assert isinstance(comp_score_thr, float)
|
||||
assert isinstance(text_region_thr, float)
|
||||
assert isinstance(center_region_thr, float)
|
||||
assert isinstance(center_region_area_thr, int)
|
||||
assert isinstance(local_graph_thr, float)
|
||||
assert isinstance(link_thr, float)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = 6
|
||||
self.downsample_ratio = 1.0
|
||||
self.k_at_hops = k_at_hops
|
||||
self.num_adjacent_linkages = num_adjacent_linkages
|
||||
self.node_geo_feat_len = node_geo_feat_len
|
||||
self.pooling_scale = pooling_scale
|
||||
self.pooling_output_size = pooling_output_size
|
||||
self.comp_shrink_ratio = comp_shrink_ratio
|
||||
self.nms_thr = nms_thr
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.comp_ratio = comp_ratio
|
||||
self.comp_score_thr = comp_score_thr
|
||||
self.text_region_thr = text_region_thr
|
||||
self.center_region_thr = center_region_thr
|
||||
self.center_region_area_thr = center_region_area_thr
|
||||
self.local_graph_thr = local_graph_thr
|
||||
self.link_thr = link_thr
|
||||
self.loss_module = build_loss(loss)
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
self.out_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.init_weights()
|
||||
|
||||
self.graph_train = LocalGraphs(self.k_at_hops,
|
||||
self.num_adjacent_linkages,
|
||||
self.node_geo_feat_len,
|
||||
self.pooling_scale,
|
||||
self.pooling_output_size,
|
||||
self.local_graph_thr)
|
||||
|
||||
self.graph_test = ProposalLocalGraphs(
|
||||
self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
|
||||
self.pooling_scale, self.pooling_output_size, self.nms_thr,
|
||||
self.min_width, self.max_width, self.comp_shrink_ratio,
|
||||
self.comp_ratio, self.comp_score_thr, self.text_region_thr,
|
||||
self.center_region_thr, self.center_region_area_thr)
|
||||
|
||||
pool_w, pool_h = self.pooling_output_size
|
||||
node_feat_len = (pool_w * pool_h) * (
|
||||
self.in_channels + self.out_channels) + self.node_geo_feat_len
|
||||
self.gcn = GCN(node_feat_len)
|
||||
|
||||
def init_weights(self):
|
||||
normal_init(self.out_conv, mean=0, std=0.01)
|
||||
|
||||
def forward(self, inputs, gt_comp_attribs):
|
||||
|
||||
pred_maps = self.out_conv(inputs)
|
||||
feat_maps = torch.cat([inputs, pred_maps], dim=1)
|
||||
node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train(
|
||||
feat_maps, np.stack(gt_comp_attribs))
|
||||
|
||||
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds)
|
||||
|
||||
return pred_maps, (gcn_pred, gt_labels)
|
||||
|
||||
def single_test(self, feat_maps):
|
||||
|
||||
pred_maps = self.out_conv(feat_maps)
|
||||
feat_maps = torch.cat([feat_maps, pred_maps], dim=1)
|
||||
|
||||
none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
|
||||
|
||||
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
|
||||
pivot_local_graphs, text_comps) = graph_data
|
||||
|
||||
if none_flag:
|
||||
return None, None, None
|
||||
|
||||
gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices,
|
||||
pivots_knn_inds)
|
||||
pred_labels = F.softmax(gcn_pred, dim=1)
|
||||
|
||||
edges = []
|
||||
scores = []
|
||||
pivot_local_graphs = pivot_local_graphs.long().squeeze().cpu().numpy()
|
||||
|
||||
for pivot_ind, pivot_local_graph in enumerate(pivot_local_graphs):
|
||||
pivot = pivot_local_graph[0]
|
||||
for k_ind, neighbor_ind in enumerate(pivots_knn_inds[pivot_ind]):
|
||||
neighbor = pivot_local_graph[neighbor_ind.item()]
|
||||
edges.append([pivot, neighbor])
|
||||
scores.append(
|
||||
pred_labels[pivot_ind * pivots_knn_inds.shape[1] + k_ind,
|
||||
1].item())
|
||||
|
||||
edges = np.asarray(edges)
|
||||
scores = np.asarray(scores)
|
||||
|
||||
return edges, scores, text_comps
|
||||
|
||||
def get_boundary(self, edges, scores, text_comps, img_metas, rescale):
|
||||
"""Compute text boundaries via post processing.
|
||||
|
||||
Args:
|
||||
edges (ndarray): The edge array of shape N * 2, each row is a pair
|
||||
of text component indices that makes up an edge in graph.
|
||||
scores (ndarray): The edge score array.
|
||||
text_comps (ndarray): The text components.
|
||||
img_metas (list[dict]): The image meta infos.
|
||||
rescale (bool): Rescale boundaries to the original image
|
||||
resolution.
|
||||
|
||||
Returns:
|
||||
results (dict): The result dict.
|
||||
"""
|
||||
|
||||
assert check_argument.is_type_list(img_metas, dict)
|
||||
assert isinstance(rescale, bool)
|
||||
|
||||
boundaries = []
|
||||
if edges is not None:
|
||||
boundaries = decode(
|
||||
decoding_type='drrg',
|
||||
edges=edges,
|
||||
scores=scores,
|
||||
text_comps=text_comps,
|
||||
link_thr=self.link_thr)
|
||||
if rescale:
|
||||
boundaries = self.resize_boundary(
|
||||
boundaries,
|
||||
1.0 / self.downsample_ratio / img_metas[0]['scale_factor'])
|
||||
|
||||
results = dict(boundary_result=boundaries)
|
||||
|
||||
return results
|
|
@ -1,4 +1,5 @@
|
|||
from .dbnet import DBNet
|
||||
from .drrg import DRRG
|
||||
from .fcenet import FCENet
|
||||
from .ocr_mask_rcnn import OCRMaskRCNN
|
||||
from .panet import PANet
|
||||
|
@ -9,5 +10,5 @@ from .textsnake import TextSnake
|
|||
|
||||
__all__ = [
|
||||
'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',
|
||||
'PANet', 'PSENet', 'TextSnake', 'FCENet'
|
||||
'PANet', 'PSENet', 'TextSnake', 'FCENet', 'DRRG'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
from mmdet.models.builder import DETECTORS
|
||||
from mmocr.models.textdet.detectors.single_stage_text_detector import \
|
||||
SingleStageTextDetector
|
||||
from mmocr.models.textdet.detectors.text_detector_mixin import \
|
||||
TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class DRRG(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for implementing DRRG text detector. Deep Relational Reasoning
|
||||
Graph Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
neck,
|
||||
bbox_head,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None,
|
||||
show_score=False):
|
||||
SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
|
||||
train_cfg, test_cfg, pretrained)
|
||||
TextDetectorMixin.__init__(self, show_score)
|
||||
|
||||
def forward_train(self, img, img_metas, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
img_metas (list[dict]): A List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details of the values of these keys see
|
||||
:class:`mmdet.datasets.pipelines.Collect`.
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
x = self.extract_feat(img)
|
||||
gt_comp_attribs = kwargs.pop('gt_comp_attribs')
|
||||
preds = self.bbox_head(x, gt_comp_attribs)
|
||||
losses = self.bbox_head.loss(preds, **kwargs)
|
||||
return losses
|
||||
|
||||
def simple_test(self, img, img_metas, rescale=False):
|
||||
|
||||
x = self.extract_feat(img)
|
||||
outs = self.bbox_head.single_test(x)
|
||||
boundaries = self.bbox_head.get_boundary(*outs, img_metas, rescale)
|
||||
|
||||
return [boundaries]
|
|
@ -1,7 +1,10 @@
|
|||
from .db_loss import DBLoss
|
||||
from .drrg_loss import DRRGLoss
|
||||
from .fce_loss import FCELoss
|
||||
from .pan_loss import PANLoss
|
||||
from .pse_loss import PSELoss
|
||||
from .textsnake_loss import TextSnakeLoss
|
||||
|
||||
__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss', 'FCELoss']
|
||||
__all__ = [
|
||||
'PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss', 'FCELoss', 'DRRGLoss'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,214 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.models.builder import LOSSES
|
||||
from mmocr.utils import check_argument
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class DRRGLoss(nn.Module):
|
||||
"""The class for implementing DRRG loss: Deep Relational Reasoning Graph
|
||||
Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/1908.05900] This is partially adapted from
|
||||
https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
"""
|
||||
|
||||
def __init__(self, ohem_ratio=3.0):
|
||||
"""Initialization.
|
||||
|
||||
Args:
|
||||
ohem_ratio (float): The negative/positive ratio in OHEM.
|
||||
"""
|
||||
super().__init__()
|
||||
self.ohem_ratio = ohem_ratio
|
||||
|
||||
def balance_bce_loss(self, pred, gt, mask):
|
||||
|
||||
assert pred.shape == gt.shape == mask.shape
|
||||
assert torch.all(pred >= 0) and torch.all(pred <= 1)
|
||||
assert torch.all(gt >= 0) and torch.all(gt <= 1)
|
||||
positive = gt * mask
|
||||
negative = (1 - gt) * mask
|
||||
positive_count = int(positive.float().sum())
|
||||
gt = gt.float()
|
||||
if positive_count > 0:
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction='none')
|
||||
positive_loss = torch.sum(loss * positive.float())
|
||||
negative_loss = loss * negative.float()
|
||||
negative_count = min(
|
||||
int(negative.float().sum()),
|
||||
int(positive_count * self.ohem_ratio))
|
||||
else:
|
||||
positive_loss = torch.tensor(0.0, device=pred.device)
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction='none')
|
||||
negative_loss = loss * negative.float()
|
||||
negative_count = 100
|
||||
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
|
||||
|
||||
balance_loss = (positive_loss + torch.sum(negative_loss)) / (
|
||||
float(positive_count + negative_count) + 1e-5)
|
||||
|
||||
return balance_loss
|
||||
|
||||
def gcn_loss(self, gcn_data):
|
||||
|
||||
gcn_pred, gt_labels = gcn_data
|
||||
gt_labels = gt_labels.view(-1).to(gcn_pred.device)
|
||||
loss = F.cross_entropy(gcn_pred, gt_labels)
|
||||
|
||||
return loss
|
||||
|
||||
def bitmasks2tensor(self, bitmasks, target_sz):
|
||||
"""Convert Bitmasks to tensor.
|
||||
|
||||
Args:
|
||||
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
|
||||
for one img.
|
||||
target_sz (tuple(int, int)): The target tensor size HxW.
|
||||
|
||||
Returns
|
||||
results (list[tensor]): The list of kernel tensors. Each
|
||||
element is for one kernel level.
|
||||
"""
|
||||
assert check_argument.is_type_list(bitmasks, BitmapMasks)
|
||||
assert isinstance(target_sz, tuple)
|
||||
|
||||
batch_size = len(bitmasks)
|
||||
num_masks = len(bitmasks[0])
|
||||
|
||||
results = []
|
||||
|
||||
for level_inx in range(num_masks):
|
||||
kernel = []
|
||||
for batch_inx in range(batch_size):
|
||||
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
|
||||
# hxw
|
||||
mask_sz = mask.shape
|
||||
# left, right, top, bottom
|
||||
pad = [
|
||||
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
|
||||
]
|
||||
mask = F.pad(mask, pad, mode='constant', value=0)
|
||||
kernel.append(mask)
|
||||
kernel = torch.stack(kernel)
|
||||
results.append(kernel)
|
||||
|
||||
return results
|
||||
|
||||
def forward(self, preds, downsample_ratio, gt_text_mask,
|
||||
gt_center_region_mask, gt_mask, gt_top_height_map,
|
||||
gt_bot_height_map, gt_sin_map, gt_cos_map):
|
||||
|
||||
assert isinstance(preds, tuple)
|
||||
assert isinstance(downsample_ratio, float)
|
||||
assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_top_height_map, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_bot_height_map, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_cos_map, BitmapMasks)
|
||||
|
||||
pred_maps, gcn_data = preds
|
||||
pred_text_region = pred_maps[:, 0, :, :]
|
||||
pred_center_region = pred_maps[:, 1, :, :]
|
||||
pred_sin_map = pred_maps[:, 2, :, :]
|
||||
pred_cos_map = pred_maps[:, 3, :, :]
|
||||
pred_top_height_map = pred_maps[:, 4, :, :]
|
||||
pred_bot_height_map = pred_maps[:, 5, :, :]
|
||||
feature_sz = pred_maps.size()
|
||||
device = pred_maps.device
|
||||
|
||||
# bitmask 2 tensor
|
||||
mapping = {
|
||||
'gt_text_mask': gt_text_mask,
|
||||
'gt_center_region_mask': gt_center_region_mask,
|
||||
'gt_mask': gt_mask,
|
||||
'gt_top_height_map': gt_top_height_map,
|
||||
'gt_bot_height_map': gt_bot_height_map,
|
||||
'gt_sin_map': gt_sin_map,
|
||||
'gt_cos_map': gt_cos_map
|
||||
}
|
||||
gt = {}
|
||||
for key, value in mapping.items():
|
||||
gt[key] = value
|
||||
if abs(downsample_ratio - 1.0) < 1e-2:
|
||||
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
|
||||
else:
|
||||
gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
|
||||
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
|
||||
if key in ['gt_top_height_map', 'gt_bot_height_map']:
|
||||
gt[key] = [item * downsample_ratio for item in gt[key]]
|
||||
gt[key] = [item.to(device) for item in gt[key]]
|
||||
|
||||
scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
|
||||
pred_sin_map = pred_sin_map * scale
|
||||
pred_cos_map = pred_cos_map * scale
|
||||
|
||||
loss_text = self.balance_bce_loss(
|
||||
torch.sigmoid(pred_text_region), gt['gt_text_mask'][0],
|
||||
gt['gt_mask'][0])
|
||||
|
||||
text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
|
||||
negative_text_mask = ((1 - gt['gt_text_mask'][0]) *
|
||||
gt['gt_mask'][0]).float()
|
||||
loss_center_map = F.binary_cross_entropy(
|
||||
torch.sigmoid(pred_center_region),
|
||||
gt['gt_center_region_mask'][0].float(),
|
||||
reduction='none')
|
||||
if int(text_mask.sum()) > 0:
|
||||
loss_center_positive = torch.sum(
|
||||
loss_center_map * text_mask) / torch.sum(text_mask)
|
||||
else:
|
||||
loss_center_positive = torch.tensor(0.0, device=device)
|
||||
loss_center_negative = torch.sum(
|
||||
loss_center_map *
|
||||
negative_text_mask) / torch.sum(negative_text_mask)
|
||||
loss_center = loss_center_positive + 0.5 * loss_center_negative
|
||||
|
||||
center_mask = (gt['gt_center_region_mask'][0] *
|
||||
gt['gt_mask'][0]).float()
|
||||
if int(center_mask.sum()) > 0:
|
||||
map_sz = pred_top_height_map.size()
|
||||
ones = torch.ones(map_sz, dtype=torch.float, device=device)
|
||||
loss_top = F.smooth_l1_loss(
|
||||
pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2),
|
||||
ones,
|
||||
reduction='none')
|
||||
loss_bot = F.smooth_l1_loss(
|
||||
pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2),
|
||||
ones,
|
||||
reduction='none')
|
||||
gt_height = (
|
||||
gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0])
|
||||
loss_height = torch.sum(
|
||||
(torch.log(gt_height + 1) *
|
||||
(loss_top + loss_bot)) * center_mask) / torch.sum(center_mask)
|
||||
|
||||
loss_sin = torch.sum(
|
||||
F.smooth_l1_loss(
|
||||
pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
|
||||
center_mask) / torch.sum(center_mask)
|
||||
loss_cos = torch.sum(
|
||||
F.smooth_l1_loss(
|
||||
pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
|
||||
center_mask) / torch.sum(center_mask)
|
||||
else:
|
||||
loss_height = torch.tensor(0.0, device=device)
|
||||
loss_sin = torch.tensor(0.0, device=device)
|
||||
loss_cos = torch.tensor(0.0, device=device)
|
||||
|
||||
loss_gcn = self.gcn_loss(gcn_data)
|
||||
|
||||
results = dict(
|
||||
loss_text=loss_text,
|
||||
loss_center=loss_center,
|
||||
loss_height=loss_height,
|
||||
loss_sin=loss_sin,
|
||||
loss_cos=loss_cos,
|
||||
loss_gcn=loss_gcn)
|
||||
|
||||
return results
|
|
@ -0,0 +1,5 @@
|
|||
from .gcn import GCN
|
||||
from .local_graph import LocalGraphs
|
||||
from .proposal_local_graph import ProposalLocalGraphs
|
||||
|
||||
__all__ = ['LocalGraphs', 'ProposalLocalGraphs', 'GCN']
|
|
@ -0,0 +1,75 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
|
||||
|
||||
class MeanAggregator(nn.Module):
|
||||
|
||||
def forward(self, features, A):
|
||||
x = torch.bmm(A, features)
|
||||
return x
|
||||
|
||||
|
||||
class GraphConv(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
|
||||
self.bias = nn.Parameter(torch.FloatTensor(out_dim))
|
||||
init.xavier_uniform_(self.weight)
|
||||
init.constant_(self.bias, 0)
|
||||
self.aggregator = MeanAggregator()
|
||||
|
||||
def forward(self, features, A):
|
||||
b, n, d = features.shape
|
||||
assert d == self.in_dim
|
||||
agg_feats = self.aggregator(features, A)
|
||||
cat_feats = torch.cat([features, agg_feats], dim=2)
|
||||
out = torch.einsum('bnd,df->bnf', cat_feats, self.weight)
|
||||
out = F.relu(out + self.bias)
|
||||
return out
|
||||
|
||||
|
||||
class GCN(nn.Module):
|
||||
"""Graph convolutional network for clustering. This was from repo
|
||||
https://github.com/Zhongdao/gcn_clustering licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
feat_len(int): The input node feature length.
|
||||
"""
|
||||
|
||||
def __init__(self, feat_len):
|
||||
super(GCN, self).__init__()
|
||||
self.bn0 = nn.BatchNorm1d(feat_len, affine=False).float()
|
||||
self.conv1 = GraphConv(feat_len, 512)
|
||||
self.conv2 = GraphConv(512, 256)
|
||||
self.conv3 = GraphConv(256, 128)
|
||||
self.conv4 = GraphConv(128, 64)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(64, 32), nn.PReLU(32), nn.Linear(32, 2))
|
||||
|
||||
def forward(self, x, A, knn_inds):
|
||||
|
||||
num_local_graphs, num_max_nodes, feat_len = x.shape
|
||||
|
||||
x = x.view(-1, feat_len)
|
||||
x = self.bn0(x)
|
||||
x = x.view(num_local_graphs, num_max_nodes, feat_len)
|
||||
|
||||
x = self.conv1(x, A)
|
||||
x = self.conv2(x, A)
|
||||
x = self.conv3(x, A)
|
||||
x = self.conv4(x, A)
|
||||
k = knn_inds.size(-1)
|
||||
mid_feat_len = x.size(-1)
|
||||
edge_feat = torch.zeros((num_local_graphs, k, mid_feat_len),
|
||||
device=x.device)
|
||||
for graph_ind in range(num_local_graphs):
|
||||
edge_feat[graph_ind, :, :] = x[graph_ind, knn_inds[graph_ind]]
|
||||
edge_feat = edge_feat.view(-1, mid_feat_len)
|
||||
pred = self.classifier(edge_feat)
|
||||
|
||||
return pred
|
|
@ -0,0 +1,296 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from mmcv.ops import RoIAlignRotated
|
||||
|
||||
from .utils import (euclidean_distance_matrix, feature_embedding,
|
||||
normalize_adjacent_matrix)
|
||||
|
||||
|
||||
class LocalGraphs(object):
|
||||
"""Generate local graphs for GCN to classify the neighbors of a pivot for
|
||||
DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape Text
|
||||
Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]. This code was partially adapted from
|
||||
https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2.
|
||||
num_adjacent_linkages (int): The number of linkages when constructing
|
||||
adjacent matrix.
|
||||
node_geo_feat_len (int): The length of embedded geometric feature
|
||||
vector of a text component.
|
||||
pooling_scale (float): The spatial scale of rotated RoI-Align.
|
||||
pooling_output_size (tuple(int)): The output size of rotated RoI-Align.
|
||||
local_graph_thr(float): The threshold for filtering out identical local
|
||||
graphs.
|
||||
"""
|
||||
|
||||
def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
|
||||
pooling_scale, pooling_output_size, local_graph_thr):
|
||||
|
||||
assert len(k_at_hops) == 2
|
||||
assert all(isinstance(n, int) for n in k_at_hops)
|
||||
assert isinstance(num_adjacent_linkages, int)
|
||||
assert isinstance(node_geo_feat_len, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert all(isinstance(n, int) for n in pooling_output_size)
|
||||
assert isinstance(local_graph_thr, float)
|
||||
|
||||
self.k_at_hops = k_at_hops
|
||||
self.num_adjacent_linkages = num_adjacent_linkages
|
||||
self.node_geo_feat_dim = node_geo_feat_len
|
||||
self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
|
||||
self.local_graph_thr = local_graph_thr
|
||||
|
||||
def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels):
|
||||
"""Generate local graphs for GCN to predict which instance a text
|
||||
component belongs to.
|
||||
|
||||
Args:
|
||||
sorted_dist_inds (ndarray): The complete graph node indices, which
|
||||
is sorted according to the Euclidean distance.
|
||||
gt_comp_labels(ndarray): The ground truth labels define the
|
||||
instance to which the text components (nodes in graphs) belong.
|
||||
|
||||
Returns:
|
||||
pivot_local_graphs(list[list[int]]): The list of local graph
|
||||
neighbor indices of pivots.
|
||||
pivot_knns(list[list[int]]): The list of k-nearest neighbor indices
|
||||
of pivots.
|
||||
"""
|
||||
|
||||
assert sorted_dist_inds.ndim == 2
|
||||
assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
|
||||
gt_comp_labels.shape[0])
|
||||
|
||||
knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
|
||||
pivot_local_graphs = []
|
||||
pivot_knns = []
|
||||
for pivot_ind, knn in enumerate(knn_graph):
|
||||
|
||||
local_graph_neighbors = set(knn)
|
||||
|
||||
for neighbor_ind in knn:
|
||||
local_graph_neighbors.update(
|
||||
set(sorted_dist_inds[neighbor_ind,
|
||||
1:self.k_at_hops[1] + 1]))
|
||||
|
||||
local_graph_neighbors.discard(pivot_ind)
|
||||
pivot_local_graph = list(local_graph_neighbors)
|
||||
pivot_local_graph.insert(0, pivot_ind)
|
||||
pivot_knn = [pivot_ind] + list(knn)
|
||||
|
||||
if pivot_ind < 1:
|
||||
pivot_local_graphs.append(pivot_local_graph)
|
||||
pivot_knns.append(pivot_knn)
|
||||
else:
|
||||
add_flag = True
|
||||
for graph_ind, added_knn in enumerate(pivot_knns):
|
||||
added_pivot_ind = added_knn[0]
|
||||
added_local_graph = pivot_local_graphs[graph_ind]
|
||||
|
||||
union = len(
|
||||
set(pivot_local_graph[1:]).union(
|
||||
set(added_local_graph[1:])))
|
||||
intersect = len(
|
||||
set(pivot_local_graph[1:]).intersection(
|
||||
set(added_local_graph[1:])))
|
||||
local_graph_iou = intersect / (union + 1e-8)
|
||||
|
||||
if (local_graph_iou > self.local_graph_thr
|
||||
and pivot_ind in added_knn
|
||||
and gt_comp_labels[added_pivot_ind]
|
||||
== gt_comp_labels[pivot_ind]
|
||||
and gt_comp_labels[pivot_ind] != 0):
|
||||
add_flag = False
|
||||
break
|
||||
if add_flag:
|
||||
pivot_local_graphs.append(pivot_local_graph)
|
||||
pivot_knns.append(pivot_knn)
|
||||
|
||||
return pivot_local_graphs, pivot_knns
|
||||
|
||||
def generate_gcn_input(self, node_feat_batch, node_label_batch,
|
||||
local_graph_batch, knn_batch,
|
||||
sorted_dist_ind_batch):
|
||||
"""Generate graph convolution network input data.
|
||||
|
||||
Args:
|
||||
node_feat_batch (List[Tensor]): The batched graph node features.
|
||||
node_label_batch (List[ndarray]): The batched text component
|
||||
labels.
|
||||
local_graph_batch (List[List[list[int]]]): The local graph node
|
||||
indices of image batch.
|
||||
knn_batch (List[List[list[int]]]): The knn graph node indices of
|
||||
image batch.
|
||||
sorted_dist_ind_batch (list[ndarray]): The node indices sorted
|
||||
according to the Euclidean distance.
|
||||
|
||||
Returns:
|
||||
local_graphs_node_feat (Tensor): The node features of graph.
|
||||
adjacent_matrices (Tensor): The adjacent matrices of local graphs.
|
||||
pivots_knn_inds (Tensor): The k-nearest neighbor indices in
|
||||
local graph.
|
||||
gt_linkage (Tensor): The surpervision signal of GCN for linkage
|
||||
prediction.
|
||||
"""
|
||||
assert isinstance(node_feat_batch, list)
|
||||
assert isinstance(node_label_batch, list)
|
||||
assert isinstance(local_graph_batch, list)
|
||||
assert isinstance(knn_batch, list)
|
||||
assert isinstance(sorted_dist_ind_batch, list)
|
||||
|
||||
num_max_nodes = max([
|
||||
len(pivot_local_graph) for pivot_local_graphs in local_graph_batch
|
||||
for pivot_local_graph in pivot_local_graphs
|
||||
])
|
||||
|
||||
local_graphs_node_feat = []
|
||||
adjacent_matrices = []
|
||||
pivots_knn_inds = []
|
||||
pivots_gt_linkage = []
|
||||
|
||||
for batch_ind, sorted_dist_inds in enumerate(sorted_dist_ind_batch):
|
||||
node_feats = node_feat_batch[batch_ind]
|
||||
pivot_local_graphs = local_graph_batch[batch_ind]
|
||||
pivot_knns = knn_batch[batch_ind]
|
||||
node_labels = node_label_batch[batch_ind]
|
||||
device = node_feats.device
|
||||
|
||||
for graph_ind, pivot_knn in enumerate(pivot_knns):
|
||||
pivot_local_graph = pivot_local_graphs[graph_ind]
|
||||
num_nodes = len(pivot_local_graph)
|
||||
pivot_ind = pivot_local_graph[0]
|
||||
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
|
||||
|
||||
knn_inds = torch.tensor(
|
||||
[node2ind_map[i] for i in pivot_knn[1:]])
|
||||
pivot_feats = node_feats[pivot_ind]
|
||||
normalized_feats = node_feats[pivot_local_graph] - pivot_feats
|
||||
|
||||
adjacent_matrix = np.zeros((num_nodes, num_nodes),
|
||||
dtype=np.float32)
|
||||
for node in pivot_local_graph:
|
||||
neighbors = sorted_dist_inds[node,
|
||||
1:self.num_adjacent_linkages +
|
||||
1]
|
||||
for neighbor in neighbors:
|
||||
if neighbor in pivot_local_graph:
|
||||
|
||||
adjacent_matrix[node2ind_map[node],
|
||||
node2ind_map[neighbor]] = 1
|
||||
adjacent_matrix[node2ind_map[neighbor],
|
||||
node2ind_map[node]] = 1
|
||||
|
||||
adjacent_matrix = normalize_adjacent_matrix(
|
||||
adjacent_matrix, mode='DAD')
|
||||
pad_adjacent_matrix = torch.zeros(
|
||||
(num_max_nodes, num_max_nodes),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
pad_adjacent_matrix[:num_nodes, :num_nodes] = adjacent_matrix
|
||||
|
||||
pad_normalized_feats = torch.cat([
|
||||
normalized_feats,
|
||||
torch.zeros(
|
||||
(num_max_nodes - num_nodes, normalized_feats.shape[1]),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
local_graph_labels = node_labels[pivot_local_graph]
|
||||
knn_labels = local_graph_labels[knn_inds]
|
||||
link_labels = ((node_labels[pivot_ind] == knn_labels) &
|
||||
(node_labels[pivot_ind] > 0)).astype(np.int64)
|
||||
link_labels = torch.from_numpy(link_labels)
|
||||
|
||||
local_graphs_node_feat.append(pad_normalized_feats)
|
||||
adjacent_matrices.append(pad_adjacent_matrix)
|
||||
pivots_knn_inds.append(knn_inds)
|
||||
pivots_gt_linkage.append(link_labels)
|
||||
|
||||
local_graphs_node_feat = torch.stack(local_graphs_node_feat, 0)
|
||||
adjacent_matrices = torch.stack(adjacent_matrices, 0)
|
||||
pivots_knn_inds = torch.stack(pivots_knn_inds, 0)
|
||||
pivots_gt_linkage = torch.stack(pivots_gt_linkage, 0)
|
||||
|
||||
return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
|
||||
pivots_gt_linkage)
|
||||
|
||||
def __call__(self, feat_maps, comp_attribs):
|
||||
"""Generate local graphs as GCN input.
|
||||
|
||||
Args:
|
||||
feat_maps (Tensor): The feature maps to extract the content
|
||||
features of text components.
|
||||
comp_attribs (ndarray): The text component attributes.
|
||||
|
||||
Returns:
|
||||
local_graphs_node_feat (Tensor): The node features of graph.
|
||||
adjacent_matrices (Tensor): The adjacent matrices of local graphs.
|
||||
pivots_knn_inds (Tensor): The k-nearest neighbor indices in local
|
||||
graph.
|
||||
gt_linkage (Tensor): The surpervision signal of GCN for linkage
|
||||
prediction.
|
||||
"""
|
||||
|
||||
assert isinstance(feat_maps, torch.Tensor)
|
||||
assert comp_attribs.ndim == 3
|
||||
assert comp_attribs.shape[2] == 8
|
||||
|
||||
sorted_dist_inds_batch = []
|
||||
local_graph_batch = []
|
||||
knn_batch = []
|
||||
node_feat_batch = []
|
||||
node_label_batch = []
|
||||
device = feat_maps.device
|
||||
|
||||
for batch_ind in range(comp_attribs.shape[0]):
|
||||
num_comps = int(comp_attribs[batch_ind, 0, 0])
|
||||
comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7]
|
||||
node_labels = comp_attribs[batch_ind, :num_comps,
|
||||
7].astype(np.int32)
|
||||
|
||||
comp_centers = comp_geo_attribs[:, 0:2]
|
||||
distance_matrix = euclidean_distance_matrix(
|
||||
comp_centers, comp_centers)
|
||||
|
||||
batch_id = np.zeros(
|
||||
(comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind
|
||||
comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1)
|
||||
angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign(
|
||||
comp_geo_attribs[:, -1])
|
||||
angle = angle.reshape((-1, 1))
|
||||
rotated_rois = np.hstack(
|
||||
[batch_id, comp_geo_attribs[:, :-2], angle])
|
||||
rois = torch.from_numpy(rotated_rois).to(device)
|
||||
content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0),
|
||||
rois)
|
||||
|
||||
content_feats = content_feats.view(content_feats.shape[0],
|
||||
-1).to(feat_maps.device)
|
||||
geo_feats = feature_embedding(comp_geo_attribs,
|
||||
self.node_geo_feat_dim)
|
||||
geo_feats = torch.from_numpy(geo_feats).to(device)
|
||||
node_feats = torch.cat([content_feats, geo_feats], dim=-1)
|
||||
|
||||
sorted_dist_inds = np.argsort(distance_matrix, axis=1)
|
||||
pivot_local_graphs, pivot_knns = self.generate_local_graphs(
|
||||
sorted_dist_inds, node_labels)
|
||||
|
||||
node_feat_batch.append(node_feats)
|
||||
node_label_batch.append(node_labels)
|
||||
local_graph_batch.append(pivot_local_graphs)
|
||||
knn_batch.append(pivot_knns)
|
||||
sorted_dist_inds_batch.append(sorted_dist_inds)
|
||||
|
||||
(node_feats, adjacent_matrices, knn_inds, gt_linkage) = \
|
||||
self.generate_gcn_input(node_feat_batch,
|
||||
node_label_batch,
|
||||
local_graph_batch,
|
||||
knn_batch,
|
||||
sorted_dist_inds_batch)
|
||||
|
||||
return node_feats, adjacent_matrices, knn_inds, gt_linkage
|
|
@ -0,0 +1,413 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from lanms import merge_quadrangle_n9 as la_nms
|
||||
from mmcv.ops import RoIAlignRotated
|
||||
|
||||
from mmocr.models.textdet.postprocess.wrapper import fill_hole
|
||||
from .utils import (euclidean_distance_matrix, feature_embedding,
|
||||
normalize_adjacent_matrix)
|
||||
|
||||
|
||||
class ProposalLocalGraphs(object):
|
||||
"""Propose text components and generate local graphs for GCN to classify
|
||||
the k-nearest neighbors of a pivot in DRRG: Deep Relational Reasoning Graph
|
||||
Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]. This code was partially adapted from
|
||||
https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2.
|
||||
num_adjacent_linkages (int): The number of linkages when constructing
|
||||
adjacent matrix.
|
||||
node_geo_feat_len (int): The length of embedded geometric feature
|
||||
vector of a text component.
|
||||
pooling_scale (float): The spatial scale of rotated RoI-Align.
|
||||
pooling_output_size (tuple(int)): The output size of rotated RoI-Align.
|
||||
nms_thr (float): The locality-aware NMS threshold for text components.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
comp_w_h_ratio (float): The width to height ratio of text components.
|
||||
comp_score_thr (float): The score threshold of text component.
|
||||
text_region_thr (float): The threshold for text region probability map.
|
||||
center_region_thr (float): The threshold for text center region
|
||||
probability map.
|
||||
center_region_area_thr (int): The threshold for filtering small-sized
|
||||
text center region.
|
||||
"""
|
||||
|
||||
def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
|
||||
pooling_scale, pooling_output_size, nms_thr, min_width,
|
||||
max_width, comp_shrink_ratio, comp_w_h_ratio, comp_score_thr,
|
||||
text_region_thr, center_region_thr, center_region_area_thr):
|
||||
|
||||
assert len(k_at_hops) == 2
|
||||
assert isinstance(k_at_hops, tuple)
|
||||
assert isinstance(num_adjacent_linkages, int)
|
||||
assert isinstance(node_geo_feat_len, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert isinstance(pooling_output_size, tuple)
|
||||
assert isinstance(nms_thr, float)
|
||||
assert isinstance(min_width, float)
|
||||
assert isinstance(max_width, float)
|
||||
assert isinstance(comp_shrink_ratio, float)
|
||||
assert isinstance(comp_w_h_ratio, float)
|
||||
assert isinstance(comp_score_thr, float)
|
||||
assert isinstance(text_region_thr, float)
|
||||
assert isinstance(center_region_thr, float)
|
||||
assert isinstance(center_region_area_thr, int)
|
||||
|
||||
self.k_at_hops = k_at_hops
|
||||
self.active_connection = num_adjacent_linkages
|
||||
self.local_graph_depth = len(self.k_at_hops)
|
||||
self.node_geo_feat_dim = node_geo_feat_len
|
||||
self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
|
||||
self.nms_thr = nms_thr
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.comp_shrink_ratio = comp_shrink_ratio
|
||||
self.comp_w_h_ratio = comp_w_h_ratio
|
||||
self.comp_score_thr = comp_score_thr
|
||||
self.text_region_thr = text_region_thr
|
||||
self.center_region_thr = center_region_thr
|
||||
self.center_region_area_thr = center_region_area_thr
|
||||
|
||||
def propose_comps(self, score_map, top_height_map, bot_height_map, sin_map,
|
||||
cos_map, comp_score_thr, min_width, max_width,
|
||||
comp_shrink_ratio, comp_w_h_ratio):
|
||||
"""Propose text components.
|
||||
|
||||
Args:
|
||||
score_map (ndarray): The score map for NMS.
|
||||
top_height_map (ndarray): The predicted text height map from each
|
||||
pixel in text center region to top sideline.
|
||||
bot_height_map (ndarray): The predicted text height map from each
|
||||
pixel in text center region to bottom sideline.
|
||||
sin_map (ndarray): The predicted sin(theta) map.
|
||||
cos_map (ndarray): The predicted cos(theta) map.
|
||||
comp_score_thr (float): The score threshold of text component.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
comp_w_h_ratio (float): The width to height ratio of text
|
||||
components.
|
||||
|
||||
Returns:
|
||||
text_comps (ndarray): The text components.
|
||||
"""
|
||||
|
||||
comp_centers = np.argwhere(score_map > comp_score_thr)
|
||||
comp_centers = comp_centers[np.argsort(comp_centers[:, 0])]
|
||||
y = comp_centers[:, 0]
|
||||
x = comp_centers[:, 1]
|
||||
|
||||
top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
|
||||
bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
|
||||
top_mid_pts = comp_centers + np.hstack(
|
||||
[top_height * sin, top_height * cos])
|
||||
bot_mid_pts = comp_centers - np.hstack(
|
||||
[bot_height * sin, bot_height * cos])
|
||||
|
||||
width = (top_height + bot_height) * comp_w_h_ratio
|
||||
width = np.clip(width, min_width, max_width)
|
||||
r = width / 2
|
||||
|
||||
tl = top_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
|
||||
tr = top_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
|
||||
br = bot_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
|
||||
bl = bot_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
|
||||
text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
|
||||
|
||||
score = score_map[y, x].reshape((-1, 1))
|
||||
text_comps = np.hstack([text_comps, score])
|
||||
|
||||
return text_comps
|
||||
|
||||
def propose_comps_and_attribs(self, text_region_map, center_region_map,
|
||||
top_height_map, bot_height_map, sin_map,
|
||||
cos_map):
|
||||
"""Generate text components and attributes.
|
||||
|
||||
Args:
|
||||
text_region_map (ndarray): The predicted text region probability
|
||||
map.
|
||||
center_region_map (ndarray): The predicted text center region
|
||||
probability map.
|
||||
top_height_map (ndarray): The predicted text height map from each
|
||||
pixel in text center region to top sideline.
|
||||
bot_height_map (ndarray): The predicted text height map from each
|
||||
pixel in text center region to bottom sideline.
|
||||
sin_map (ndarray): The predicted sin(theta) map.
|
||||
cos_map (ndarray): The predicted cos(theta) map.
|
||||
|
||||
Returns:
|
||||
comp_attribs (ndarray): The text component attributes.
|
||||
text_comps (ndarray): The text components.
|
||||
"""
|
||||
|
||||
assert (text_region_map.shape == center_region_map.shape ==
|
||||
top_height_map.shape == bot_height_map.shape == sin_map.shape
|
||||
== cos_map.shape)
|
||||
text_mask = text_region_map > self.text_region_thr
|
||||
center_region_mask = (center_region_map >
|
||||
self.center_region_thr) * text_mask
|
||||
|
||||
scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2))
|
||||
sin_map, cos_map = sin_map * scale, cos_map * scale
|
||||
|
||||
center_region_mask = fill_hole(center_region_mask)
|
||||
center_region_contours, _ = cv2.findContours(
|
||||
center_region_mask.astype(np.uint8), cv2.RETR_TREE,
|
||||
cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
mask_sz = center_region_map.shape
|
||||
comp_list = []
|
||||
for contour in center_region_contours:
|
||||
current_center_mask = np.zeros(mask_sz)
|
||||
cv2.drawContours(current_center_mask, [contour], -1, 1, -1)
|
||||
if current_center_mask.sum() <= self.center_region_area_thr:
|
||||
continue
|
||||
score_map = text_region_map * current_center_mask
|
||||
|
||||
text_comps = self.propose_comps(score_map, top_height_map,
|
||||
bot_height_map, sin_map, cos_map,
|
||||
self.comp_score_thr,
|
||||
self.min_width, self.max_width,
|
||||
self.comp_shrink_ratio,
|
||||
self.comp_w_h_ratio)
|
||||
|
||||
text_comps = la_nms(text_comps, self.nms_thr)
|
||||
text_comp_mask = np.zeros(mask_sz)
|
||||
text_comp_boxs = text_comps[:, :8].reshape(
|
||||
(-1, 4, 2)).astype(np.int32)
|
||||
|
||||
cv2.drawContours(text_comp_mask, text_comp_boxs, -1, 1, -1)
|
||||
if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
|
||||
continue
|
||||
if text_comps.shape[-1] > 0:
|
||||
comp_list.append(text_comps)
|
||||
|
||||
if len(comp_list) <= 0:
|
||||
return None, None
|
||||
|
||||
text_comps = np.vstack(comp_list)
|
||||
text_comp_boxs = text_comps[:, :8].reshape((-1, 4, 2))
|
||||
centers = np.mean(text_comp_boxs, axis=1).astype(np.int32)
|
||||
x = centers[:, 0]
|
||||
y = centers[:, 1]
|
||||
|
||||
scores = []
|
||||
for text_comp_box in text_comp_boxs:
|
||||
text_comp_box[:, 0] = np.clip(text_comp_box[:, 0], 0,
|
||||
mask_sz[1] - 1)
|
||||
text_comp_box[:, 1] = np.clip(text_comp_box[:, 1], 0,
|
||||
mask_sz[0] - 1)
|
||||
min_coord = np.min(text_comp_box, axis=0).astype(np.int32)
|
||||
max_coord = np.max(text_comp_box, axis=0).astype(np.int32)
|
||||
text_comp_box = text_comp_box - min_coord
|
||||
box_sz = (max_coord - min_coord + 1)
|
||||
temp_comp_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
|
||||
cv2.fillPoly(temp_comp_mask, [text_comp_box.astype(np.int32)], 1)
|
||||
temp_region_patch = text_region_map[min_coord[1]:(max_coord[1] +
|
||||
1),
|
||||
min_coord[0]:(max_coord[0] +
|
||||
1)]
|
||||
score = cv2.mean(temp_region_patch, temp_comp_mask)[0]
|
||||
scores.append(score)
|
||||
scores = np.array(scores).reshape((-1, 1))
|
||||
text_comps = np.hstack([text_comps[:, :-1], scores])
|
||||
|
||||
h = top_height_map[y, x].reshape(
|
||||
(-1, 1)) + bot_height_map[y, x].reshape((-1, 1))
|
||||
w = np.clip(h * self.comp_w_h_ratio, self.min_width, self.max_width)
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
|
||||
x = x.reshape((-1, 1))
|
||||
y = y.reshape((-1, 1))
|
||||
comp_attribs = np.hstack([x, y, h, w, cos, sin])
|
||||
|
||||
return comp_attribs, text_comps
|
||||
|
||||
def generate_local_graphs(self, sorted_dist_inds, node_feats):
|
||||
"""Generate local graphs and graph convolution network input data.
|
||||
|
||||
Args:
|
||||
sorted_dist_inds (ndarray): The node indices sorted according to
|
||||
the Euclidean distance.
|
||||
node_feats (tensor): The features of nodes in graph.
|
||||
|
||||
Returns:
|
||||
local_graphs_node_feats (tensor): The features of nodes in local
|
||||
graphs.
|
||||
adjacent_matrices (tensor): The adjacent matrices.
|
||||
pivots_knn_inds (tensor): The k-nearest neighbor indices in
|
||||
local graphs.
|
||||
pivots_local_graphs (tensor): The indices of nodes in local
|
||||
graphs.
|
||||
"""
|
||||
|
||||
assert sorted_dist_inds.ndim == 2
|
||||
assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
|
||||
node_feats.shape[0])
|
||||
|
||||
knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
|
||||
pivot_local_graphs = []
|
||||
pivot_knns = []
|
||||
device = node_feats.device
|
||||
|
||||
for pivot_ind, knn in enumerate(knn_graph):
|
||||
|
||||
local_graph_neighbors = set(knn)
|
||||
|
||||
for neighbor_ind in knn:
|
||||
local_graph_neighbors.update(
|
||||
set(sorted_dist_inds[neighbor_ind,
|
||||
1:self.k_at_hops[1] + 1]))
|
||||
|
||||
local_graph_neighbors.discard(pivot_ind)
|
||||
pivot_local_graph = list(local_graph_neighbors)
|
||||
pivot_local_graph.insert(0, pivot_ind)
|
||||
pivot_knn = [pivot_ind] + list(knn)
|
||||
|
||||
pivot_local_graphs.append(pivot_local_graph)
|
||||
pivot_knns.append(pivot_knn)
|
||||
|
||||
num_max_nodes = max([
|
||||
len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs
|
||||
])
|
||||
|
||||
local_graphs_node_feat = []
|
||||
adjacent_matrices = []
|
||||
pivots_knn_inds = []
|
||||
pivots_local_graphs = []
|
||||
|
||||
for graph_ind, pivot_knn in enumerate(pivot_knns):
|
||||
pivot_local_graph = pivot_local_graphs[graph_ind]
|
||||
num_nodes = len(pivot_local_graph)
|
||||
pivot_ind = pivot_local_graph[0]
|
||||
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
|
||||
|
||||
knn_inds = torch.tensor([node2ind_map[i]
|
||||
for i in pivot_knn[1:]]).long().to(device)
|
||||
pivot_feats = node_feats[pivot_ind]
|
||||
normalized_feats = node_feats[pivot_local_graph] - pivot_feats
|
||||
|
||||
adjacent_matrix = np.zeros((num_nodes, num_nodes))
|
||||
for node in pivot_local_graph:
|
||||
neighbors = sorted_dist_inds[node,
|
||||
1:self.active_connection + 1]
|
||||
for neighbor in neighbors:
|
||||
if neighbor in pivot_local_graph:
|
||||
adjacent_matrix[node2ind_map[node],
|
||||
node2ind_map[neighbor]] = 1
|
||||
adjacent_matrix[node2ind_map[neighbor],
|
||||
node2ind_map[node]] = 1
|
||||
|
||||
adjacent_matrix = normalize_adjacent_matrix(
|
||||
adjacent_matrix, mode='DAD')
|
||||
pad_adjacent_matrix = torch.zeros((num_max_nodes, num_max_nodes),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
pad_adjacent_matrix[:num_nodes, :num_nodes] = adjacent_matrix
|
||||
|
||||
pad_normalized_feats = torch.cat([
|
||||
normalized_feats,
|
||||
torch.zeros(
|
||||
(num_max_nodes - num_nodes, normalized_feats.shape[1]),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
local_graph_nodes = torch.tensor(pivot_local_graph)
|
||||
local_graph_nodes = torch.cat([
|
||||
local_graph_nodes,
|
||||
torch.zeros(num_max_nodes - num_nodes, dtype=torch.long)
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
local_graphs_node_feat.append(pad_normalized_feats)
|
||||
adjacent_matrices.append(pad_adjacent_matrix)
|
||||
pivots_knn_inds.append(knn_inds)
|
||||
pivots_local_graphs.append(local_graph_nodes)
|
||||
|
||||
local_graphs_node_feat = torch.stack(local_graphs_node_feat, 0)
|
||||
adjacent_matrices = torch.stack(adjacent_matrices, 0)
|
||||
pivots_knn_inds = torch.stack(pivots_knn_inds, 0)
|
||||
pivots_local_graphs = torch.stack(pivots_local_graphs, 0)
|
||||
|
||||
return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
|
||||
pivots_local_graphs)
|
||||
|
||||
def __call__(self, preds, feat_maps):
|
||||
"""Generate local graphs and graph convolutional network input data.
|
||||
|
||||
Args:
|
||||
preds (tensor): The predicted maps.
|
||||
feat_maps (tensor): The feature maps to extract content feature of
|
||||
text components.
|
||||
|
||||
Returns:
|
||||
none_flag (bool): The flag showing whether the number of proposed
|
||||
text components is 0.
|
||||
local_graphs_node_feats (tensor): The features of nodes in local
|
||||
graphs.
|
||||
adjacent_matrices (tensor): The adjacent matrices.
|
||||
pivots_knn_inds (tensor): The k-nearest neighbor indices in
|
||||
local graphs.
|
||||
pivots_local_graphs (tensor): The indices of nodes in local
|
||||
graphs.
|
||||
text_comps (ndarray): The predicted text components.
|
||||
"""
|
||||
|
||||
if preds.ndim == 4:
|
||||
assert preds.shape[0] == 1
|
||||
preds = torch.squeeze(preds)
|
||||
pred_text_region = torch.sigmoid(preds[0]).data.cpu().numpy()
|
||||
pred_center_region = torch.sigmoid(preds[1]).data.cpu().numpy()
|
||||
pred_sin_map = preds[2].data.cpu().numpy()
|
||||
pred_cos_map = preds[3].data.cpu().numpy()
|
||||
pred_top_height_map = preds[4].data.cpu().numpy()
|
||||
pred_bot_height_map = preds[5].data.cpu().numpy()
|
||||
device = preds.device
|
||||
|
||||
comp_attribs, text_comps = self.propose_comps_and_attribs(
|
||||
pred_text_region, pred_center_region, pred_top_height_map,
|
||||
pred_bot_height_map, pred_sin_map, pred_cos_map)
|
||||
|
||||
if comp_attribs is None or len(comp_attribs) < 2:
|
||||
none_flag = True
|
||||
return none_flag, (0, 0, 0, 0, 0)
|
||||
|
||||
comp_centers = comp_attribs[:, 0:2]
|
||||
distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
|
||||
|
||||
geo_feats = feature_embedding(comp_attribs, self.node_geo_feat_dim)
|
||||
geo_feats = torch.from_numpy(geo_feats).to(preds.device)
|
||||
|
||||
batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32)
|
||||
comp_attribs = comp_attribs.astype(np.float32)
|
||||
angle = np.arccos(comp_attribs[:, -2]) * np.sign(comp_attribs[:, -1])
|
||||
angle = angle.reshape((-1, 1))
|
||||
rotated_rois = np.hstack([batch_id, comp_attribs[:, :-2], angle])
|
||||
rois = torch.from_numpy(rotated_rois).to(device)
|
||||
|
||||
content_feats = self.pooling(feat_maps, rois)
|
||||
content_feats = content_feats.view(content_feats.shape[0],
|
||||
-1).to(device)
|
||||
node_feats = torch.cat([content_feats, geo_feats], dim=-1)
|
||||
|
||||
sorted_dist_inds = np.argsort(distance_matrix, axis=1)
|
||||
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
|
||||
pivots_local_graphs) = self.generate_local_graphs(
|
||||
sorted_dist_inds, node_feats)
|
||||
|
||||
none_flag = False
|
||||
return none_flag, (local_graphs_node_feat, adjacent_matrices,
|
||||
pivots_knn_inds, pivots_local_graphs, text_comps)
|
|
@ -0,0 +1,116 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def normalize_adjacent_matrix(A, mode='AD'):
|
||||
"""Normalize adjacent matrix for GCN. This code was partially adapted from
|
||||
https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
A (ndarray): The adjacent matrix.
|
||||
mode (string): The normalize mode.
|
||||
|
||||
returns:
|
||||
G (ndarray): The normalized adjacent matrix.
|
||||
"""
|
||||
assert A.ndim == 2
|
||||
assert A.shape[0] == A.shape[1]
|
||||
|
||||
if mode == 'DAD':
|
||||
A = A + np.eye(A.shape[0])
|
||||
d = np.sum(A, axis=0)
|
||||
d_inv = np.power(d, -0.5).flatten()
|
||||
d_inv[np.isinf(d_inv)] = 0.0
|
||||
d_inv = np.diag(d_inv)
|
||||
G = A.dot(d_inv).transpose().dot(d_inv)
|
||||
G = torch.from_numpy(G)
|
||||
elif mode == 'AD':
|
||||
A = A + np.eye(A.shape[0])
|
||||
A = torch.from_numpy(A)
|
||||
D = A.sum(1, keepdim=True)
|
||||
G = A.div(D)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return G
|
||||
|
||||
|
||||
def euclidean_distance_matrix(A, B):
|
||||
"""Calculate the Euclidean distance matrix.
|
||||
|
||||
Args:
|
||||
A (ndarray): The point sequence.
|
||||
B (ndarray): The point sequence with the same dimensions as A.
|
||||
|
||||
returns:
|
||||
D (ndarray): The Euclidean distance matrix.
|
||||
"""
|
||||
assert A.ndim == 2
|
||||
assert B.ndim == 2
|
||||
assert A.shape[1] == B.shape[1]
|
||||
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
|
||||
A_dots = (A * A).sum(axis=1).reshape((m, 1)) * np.ones(shape=(1, n))
|
||||
B_dots = (B * B).sum(axis=1) * np.ones(shape=(m, 1))
|
||||
D_squared = A_dots + B_dots - 2 * A.dot(B.T)
|
||||
|
||||
zero_mask = np.less(D_squared, 0.0)
|
||||
D_squared[zero_mask] = 0.0
|
||||
D = np.sqrt(D_squared)
|
||||
return D
|
||||
|
||||
|
||||
def feature_embedding(input_feats, out_feat_len):
|
||||
"""Embed features. This code was partially adapted from
|
||||
https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
input_feats (ndarray): The input features of shape (N, d), where N is
|
||||
the number of nodes in graph, d is the input feature vector length.
|
||||
out_feat_len (int): The length of output feature vector.
|
||||
|
||||
Returns:
|
||||
embedded_feats (ndarray): The embedded features.
|
||||
"""
|
||||
assert input_feats.ndim == 2
|
||||
assert isinstance(out_feat_len, int)
|
||||
assert out_feat_len >= input_feats.shape[1]
|
||||
|
||||
num_nodes = input_feats.shape[0]
|
||||
feat_dim = input_feats.shape[1]
|
||||
feat_repeat_times = out_feat_len // feat_dim
|
||||
residue_dim = out_feat_len % feat_dim
|
||||
|
||||
if residue_dim > 0:
|
||||
embed_wave = np.array([
|
||||
np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
|
||||
for j in range(feat_repeat_times + 1)
|
||||
]).reshape((feat_repeat_times + 1, 1, 1))
|
||||
repeat_feats = np.repeat(
|
||||
np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0)
|
||||
residue_feats = np.hstack([
|
||||
input_feats[:, 0:residue_dim],
|
||||
np.zeros((num_nodes, feat_dim - residue_dim))
|
||||
])
|
||||
residue_feats = np.expand_dims(residue_feats, axis=0)
|
||||
repeat_feats = np.concatenate([repeat_feats, residue_feats], axis=0)
|
||||
embedded_feats = repeat_feats / embed_wave
|
||||
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
|
||||
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
|
||||
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
|
||||
(num_nodes, -1))[:, 0:out_feat_len]
|
||||
else:
|
||||
embed_wave = np.array([
|
||||
np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
|
||||
for j in range(feat_repeat_times)
|
||||
]).reshape((feat_repeat_times, 1, 1))
|
||||
repeat_feats = np.repeat(
|
||||
np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0)
|
||||
embedded_feats = repeat_feats / embed_wave
|
||||
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
|
||||
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
|
||||
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
|
||||
(num_nodes, -1)).astype(np.float32)
|
||||
|
||||
return embedded_feats
|
|
@ -1,6 +1,6 @@
|
|||
from .fpem_ffm import FPEM_FFM
|
||||
from .fpn_cat import FPNC
|
||||
from .fpn_unet import FPN_UNET
|
||||
from .fpn_unet import FPN_UNet
|
||||
from .fpnf import FPNF
|
||||
|
||||
__all__ = ['FPEM_FFM', 'FPNF', 'FPNC', 'FPN_UNET']
|
||||
__all__ = ['FPEM_FFM', 'FPNF', 'FPNC', 'FPN_UNet']
|
||||
|
|
|
@ -30,7 +30,7 @@ class UpBlock(nn.Module):
|
|||
|
||||
|
||||
@NECKS.register_module()
|
||||
class FPN_UNET(nn.Module):
|
||||
class FPN_UNet(nn.Module):
|
||||
"""The class for implementing DRRG and TextSnake U-Net-like FPN.
|
||||
|
||||
DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import functools
|
||||
import operator
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pyclipper
|
||||
|
@ -28,6 +31,8 @@ def decode(
|
|||
return textsnake_decode(**kwargs)
|
||||
if decoding_type == 'fcenet':
|
||||
return fcenet_decode(**kwargs)
|
||||
if decoding_type == 'drrg':
|
||||
return drrg_decode(**kwargs)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -553,3 +558,350 @@ def generate_exp_matrix(point_num, fourier_degree):
|
|||
exp_matrix[i, :] = temp * (i - fourier_degree)
|
||||
|
||||
return np.power(e, exp_matrix)
|
||||
|
||||
|
||||
class Node(object):
|
||||
|
||||
def __init__(self, ind):
|
||||
self.__ind = ind
|
||||
self.__links = set()
|
||||
|
||||
@property
|
||||
def ind(self):
|
||||
return self.__ind
|
||||
|
||||
@property
|
||||
def links(self):
|
||||
return set(self.__links)
|
||||
|
||||
def add_link(self, link_node):
|
||||
self.__links.add(link_node)
|
||||
link_node.__links.add(self)
|
||||
|
||||
|
||||
def graph_propagation(edges, scores, text_comps, edge_len_thr=50.):
|
||||
"""Propagate edge score information and construct graph. This code was
|
||||
partially adapted from https://github.com/GXYM/DRRG licensed under the MIT
|
||||
license.
|
||||
|
||||
Args:
|
||||
edges (ndarray): The edge array of shape N * 2, each row is a node
|
||||
index pair that makes up an edge in graph.
|
||||
scores (ndarray): The edge score array.
|
||||
text_comps (ndarray): The text components.
|
||||
edge_len_thr (float): The edge length threshold.
|
||||
|
||||
Returns:
|
||||
vertices (list[Node]): The Nodes in graph.
|
||||
score_dict (dict): The edge score dict.
|
||||
"""
|
||||
assert edges.ndim == 2
|
||||
assert edges.shape[1] == 2
|
||||
assert edges.shape[0] == scores.shape[0]
|
||||
assert text_comps.ndim == 2
|
||||
assert isinstance(edge_len_thr, float)
|
||||
|
||||
edges = np.sort(edges, axis=1)
|
||||
score_dict = {}
|
||||
for i, edge in enumerate(edges):
|
||||
if text_comps is not None:
|
||||
box1 = text_comps[edge[0], :8].reshape(4, 2)
|
||||
box2 = text_comps[edge[1], :8].reshape(4, 2)
|
||||
center1 = np.mean(box1, axis=0)
|
||||
center2 = np.mean(box2, axis=0)
|
||||
distance = norm(center1 - center2)
|
||||
if distance > edge_len_thr:
|
||||
scores[i] = 0
|
||||
if (edge[0], edge[1]) in score_dict:
|
||||
score_dict[edge[0], edge[1]] = 0.5 * (
|
||||
score_dict[edge[0], edge[1]] + scores[i])
|
||||
else:
|
||||
score_dict[edge[0], edge[1]] = scores[i]
|
||||
|
||||
nodes = np.sort(np.unique(edges.flatten()))
|
||||
mapping = -1 * np.ones((np.max(nodes) + 1), dtype=np.int)
|
||||
mapping[nodes] = np.arange(nodes.shape[0])
|
||||
order_inds = mapping[edges]
|
||||
vertices = [Node(node) for node in nodes]
|
||||
for ind in order_inds:
|
||||
vertices[ind[0]].add_link(vertices[ind[1]])
|
||||
|
||||
return vertices, score_dict
|
||||
|
||||
|
||||
def connected_components(nodes, score_dict, link_thr):
|
||||
"""Conventional connected components searching. This code was partially
|
||||
adapted from https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
nodes (list[Node]): The list of Node objects.
|
||||
score_dict (dict): The edge score dict.
|
||||
link_thr (float): The link threshold.
|
||||
|
||||
Returns:
|
||||
clusters (List[list[Node]]): The clustered Node objects.
|
||||
"""
|
||||
assert isinstance(nodes, list)
|
||||
assert all([isinstance(node, Node) for node in nodes])
|
||||
assert isinstance(score_dict, dict)
|
||||
assert isinstance(link_thr, float)
|
||||
|
||||
clusters = []
|
||||
nodes = set(nodes)
|
||||
while nodes:
|
||||
node = nodes.pop()
|
||||
cluster = {node}
|
||||
node_queue = [node]
|
||||
while node_queue:
|
||||
node = node_queue.pop(0)
|
||||
neighbors = set([
|
||||
neighbor for neighbor in node.links if
|
||||
score_dict[tuple(sorted([node.ind, neighbor.ind]))] >= link_thr
|
||||
])
|
||||
neighbors.difference_update(cluster)
|
||||
nodes.difference_update(neighbors)
|
||||
cluster.update(neighbors)
|
||||
node_queue.extend(neighbors)
|
||||
clusters.append(list(cluster))
|
||||
return clusters
|
||||
|
||||
|
||||
def clusters2labels(clusters, num_nodes):
|
||||
"""Convert clusters of Node to text component labels. This code was
|
||||
partially adapted from https://github.com/GXYM/DRRG licensed under the MIT
|
||||
license.
|
||||
|
||||
Args:
|
||||
clusters (List[list[Node]]): The clusters of Node objects.
|
||||
num_nodes (int): The total node number of graphs in an image.
|
||||
|
||||
Returns:
|
||||
node_labels (ndarray): The node label array.
|
||||
"""
|
||||
assert isinstance(clusters, list)
|
||||
assert all([isinstance(cluster, list) for cluster in clusters])
|
||||
assert all(
|
||||
[isinstance(node, Node) for cluster in clusters for node in cluster])
|
||||
assert isinstance(num_nodes, int)
|
||||
|
||||
node_labels = np.zeros(num_nodes)
|
||||
for cluster_ind, cluster in enumerate(clusters):
|
||||
for node in cluster:
|
||||
node_labels[node.ind] = cluster_ind
|
||||
return node_labels
|
||||
|
||||
|
||||
def remove_single(text_comps, comp_pred_labels):
|
||||
"""Remove isolated text components. This code was partially adapted from
|
||||
https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
text_comps (ndarray): The text components.
|
||||
comp_pred_labels (ndarray): The clustering labels of text components.
|
||||
|
||||
Returns:
|
||||
filtered_text_comps (ndarray): The text components with isolated ones
|
||||
removed.
|
||||
comp_pred_labels (ndarray): The clustering labels with labels of
|
||||
isolated text components removed.
|
||||
"""
|
||||
assert text_comps.ndim == 2
|
||||
assert text_comps.shape[0] == comp_pred_labels.shape[0]
|
||||
|
||||
single_flags = np.zeros_like(comp_pred_labels)
|
||||
pred_labels = np.unique(comp_pred_labels)
|
||||
for label in pred_labels:
|
||||
current_label_flag = (comp_pred_labels == label)
|
||||
if np.sum(current_label_flag) == 1:
|
||||
single_flags[np.where(current_label_flag)[0][0]] = 1
|
||||
keep_ind = [i for i in range(len(comp_pred_labels)) if not single_flags[i]]
|
||||
filtered_text_comps = text_comps[keep_ind, :]
|
||||
filtered_labels = comp_pred_labels[keep_ind]
|
||||
|
||||
return filtered_text_comps, filtered_labels
|
||||
|
||||
|
||||
def norm2(point1, point2):
|
||||
return ((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)**0.5
|
||||
|
||||
|
||||
def min_connect_path(points):
|
||||
"""Find the shortest path to traverse all points. This code was partially
|
||||
adapted from https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
points(List[list[int]]): The point sequence [[x0, y0], [x1, y1], ...].
|
||||
|
||||
Returns:
|
||||
shortest_path(List[list[int]]): The shortest index path.
|
||||
"""
|
||||
assert isinstance(points, list)
|
||||
assert all([isinstance(point, list) for point in points])
|
||||
assert all([isinstance(coord, int) for point in points for coord in point])
|
||||
|
||||
points_queue = points.copy()
|
||||
shortest_path = []
|
||||
current_edge = [[], []]
|
||||
|
||||
edge_dict0 = {}
|
||||
edge_dict1 = {}
|
||||
current_edge[0] = points_queue[0]
|
||||
current_edge[1] = points_queue[0]
|
||||
points_queue.remove(points_queue[0])
|
||||
while points_queue:
|
||||
for point in points_queue:
|
||||
length0 = norm2(point, current_edge[0])
|
||||
edge_dict0[length0] = [point, current_edge[0]]
|
||||
length1 = norm2(current_edge[1], point)
|
||||
edge_dict1[length1] = [current_edge[1], point]
|
||||
key0 = min(edge_dict0.keys())
|
||||
key1 = min(edge_dict1.keys())
|
||||
|
||||
if key0 <= key1:
|
||||
start = edge_dict0[key0][0]
|
||||
end = edge_dict0[key0][1]
|
||||
shortest_path.insert(0, [points.index(start), points.index(end)])
|
||||
points_queue.remove(start)
|
||||
current_edge[0] = start
|
||||
else:
|
||||
start = edge_dict1[key1][0]
|
||||
end = edge_dict1[key1][1]
|
||||
shortest_path.append([points.index(start), points.index(end)])
|
||||
points_queue.remove(end)
|
||||
current_edge[1] = end
|
||||
|
||||
edge_dict0 = {}
|
||||
edge_dict1 = {}
|
||||
|
||||
shortest_path = functools.reduce(operator.concat, shortest_path)
|
||||
shortest_path = sorted(set(shortest_path), key=shortest_path.index)
|
||||
|
||||
return shortest_path
|
||||
|
||||
|
||||
def in_contour(cont, point):
|
||||
x, y = point
|
||||
is_inner = cv2.pointPolygonTest(cont, (int(x), int(y)), False) > 0.5
|
||||
return is_inner
|
||||
|
||||
|
||||
def fix_corner(top_line, bot_line, start_box, end_box):
|
||||
"""Add corner points to predicted side lines. This code was partially
|
||||
adapted from https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
top_line (List[list[int]]): The predicted top sidelines of text
|
||||
instance.
|
||||
bot_line (List[list[int]]): The predicted bottom sidelines of text
|
||||
instance.
|
||||
start_box (ndarray): The first text component box.
|
||||
end_box (ndarray): The last text component box.
|
||||
|
||||
Returns:
|
||||
top_line (List[list[int]]): The top sidelines with corner point added.
|
||||
bot_line (List[list[int]]): The bottom sidelines with corner point
|
||||
added.
|
||||
"""
|
||||
assert isinstance(top_line, list)
|
||||
assert all(isinstance(point, list) for point in top_line)
|
||||
assert isinstance(bot_line, list)
|
||||
assert all(isinstance(point, list) for point in bot_line)
|
||||
assert start_box.shape == end_box.shape == (4, 2)
|
||||
|
||||
contour = np.array(top_line + bot_line[::-1])
|
||||
start_left_mid = (start_box[0] + start_box[3]) / 2
|
||||
start_right_mid = (start_box[1] + start_box[2]) / 2
|
||||
end_left_mid = (end_box[0] + end_box[3]) / 2
|
||||
end_right_mid = (end_box[1] + end_box[2]) / 2
|
||||
if not in_contour(contour, start_left_mid):
|
||||
top_line.insert(0, start_box[0].tolist())
|
||||
bot_line.insert(0, start_box[3].tolist())
|
||||
elif not in_contour(contour, start_right_mid):
|
||||
top_line.insert(0, start_box[1].tolist())
|
||||
bot_line.insert(0, start_box[2].tolist())
|
||||
if not in_contour(contour, end_left_mid):
|
||||
top_line.append(end_box[0].tolist())
|
||||
bot_line.append(end_box[3].tolist())
|
||||
elif not in_contour(contour, end_right_mid):
|
||||
top_line.append(end_box[1].tolist())
|
||||
bot_line.append(end_box[2].tolist())
|
||||
return top_line, bot_line
|
||||
|
||||
|
||||
def comps2boundaries(text_comps, comp_pred_labels):
|
||||
"""Construct text instance boundaries from clustered text components. This
|
||||
code was partially adapted from https://github.com/GXYM/DRRG licensed under
|
||||
the MIT license.
|
||||
|
||||
Args:
|
||||
text_comps (ndarray): The text components.
|
||||
comp_pred_labels (ndarray): The clustering labels of text components.
|
||||
|
||||
Returns:
|
||||
boundaries (List[list[float]]): The predicted boundaries of text
|
||||
instances.
|
||||
"""
|
||||
assert text_comps.ndim == 2
|
||||
assert len(text_comps) == len(comp_pred_labels)
|
||||
boundaries = []
|
||||
if len(text_comps) < 1:
|
||||
return boundaries
|
||||
for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1):
|
||||
cluster_comp_inds = np.where(comp_pred_labels == cluster_ind)
|
||||
text_comp_boxes = text_comps[cluster_comp_inds, :8].reshape(
|
||||
(-1, 4, 2)).astype(np.int32)
|
||||
score = np.mean(text_comps[cluster_comp_inds, -1])
|
||||
|
||||
if text_comp_boxes.shape[0] < 1:
|
||||
continue
|
||||
|
||||
elif text_comp_boxes.shape[0] > 1:
|
||||
centers = np.mean(
|
||||
text_comp_boxes, axis=1).astype(np.int32).tolist()
|
||||
shortest_path = min_connect_path(centers)
|
||||
text_comp_boxes = text_comp_boxes[shortest_path]
|
||||
top_line = np.mean(
|
||||
text_comp_boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
|
||||
bot_line = np.mean(
|
||||
text_comp_boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
|
||||
top_line, bot_line = fix_corner(top_line, bot_line,
|
||||
text_comp_boxes[0],
|
||||
text_comp_boxes[-1])
|
||||
boundary_points = top_line + bot_line[::-1]
|
||||
|
||||
else:
|
||||
top_line = text_comp_boxes[0, 0:2, :].astype(np.int32).tolist()
|
||||
bot_line = text_comp_boxes[0, 2:4:-1, :].astype(np.int32).tolist()
|
||||
boundary_points = top_line + bot_line
|
||||
|
||||
boundary = [p for coord in boundary_points for p in coord] + [score]
|
||||
boundaries.append(boundary)
|
||||
|
||||
return boundaries
|
||||
|
||||
|
||||
def drrg_decode(edges, scores, text_comps, link_thr):
|
||||
"""Merge text components and construct boundaries of text instances.
|
||||
|
||||
Args:
|
||||
edges (ndarray): The edge array of shape N * 2, each row is a node
|
||||
index pair that makes up an edge in graph.
|
||||
scores (ndarray): The edge score array.
|
||||
text_comps (ndarray): The text components.
|
||||
link_thr (float): The edge score threshold.
|
||||
|
||||
Returns:
|
||||
boundaries (List[list[float]]): The predicted boundaries of text
|
||||
instances.
|
||||
"""
|
||||
assert len(edges) == len(scores)
|
||||
assert text_comps.ndim == 2
|
||||
assert text_comps.shape[1] == 9
|
||||
assert isinstance(link_thr, float)
|
||||
vertices, score_dict = graph_propagation(edges, scores, text_comps)
|
||||
clusters = connected_components(vertices, score_dict, link_thr)
|
||||
pred_labels = clusters2labels(clusters, text_comps.shape[0])
|
||||
text_comps, pred_labels = remove_single(text_comps, pred_labels)
|
||||
boundaries = comps2boundaries(text_comps, pred_labels)
|
||||
|
||||
return boundaries
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
imgaug
|
||||
lanms-proper
|
||||
lmdb
|
||||
matplotlib
|
||||
numba>=0.45.1
|
||||
|
|
|
@ -20,7 +20,7 @@ line_length = 79
|
|||
multi_line_output = 0
|
||||
known_standard_library = setuptools
|
||||
known_first_party = mmdet,mmocr
|
||||
known_third_party = PIL,Polygon,cv2,imgaug,lmdb,matplotlib,mmcv,numpy,pyclipper,pycocotools,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision
|
||||
known_third_party = PIL,Polygon,cv2,imgaug,lanms,lmdb,matplotlib,mmcv,numpy,pyclipper,pycocotools,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
|
|
|
@ -242,3 +242,94 @@ def test_fcenet_generate_targets():
|
|||
assert 'p3_maps' in results.keys()
|
||||
assert 'p4_maps' in results.keys()
|
||||
assert 'p5_maps' in results.keys()
|
||||
|
||||
|
||||
def test_gen_drrg_targets():
|
||||
target_generator = textdet_targets.DRRGTargets()
|
||||
assert np.allclose(target_generator.orientation_thr, 2.0)
|
||||
assert np.allclose(target_generator.resample_step, 8.0)
|
||||
assert target_generator.num_min_comps == 9
|
||||
assert target_generator.num_max_comps == 600
|
||||
assert np.allclose(target_generator.min_width, 8.0)
|
||||
assert np.allclose(target_generator.max_width, 24.0)
|
||||
assert np.allclose(target_generator.center_region_shrink_ratio, 0.3)
|
||||
assert np.allclose(target_generator.comp_shrink_ratio, 1.0)
|
||||
assert np.allclose(target_generator.comp_w_h_ratio, 0.3)
|
||||
assert np.allclose(target_generator.text_comp_nms_thr, 0.25)
|
||||
assert np.allclose(target_generator.min_rand_half_height, 8.0)
|
||||
assert np.allclose(target_generator.max_rand_half_height, 24.0)
|
||||
assert np.allclose(target_generator.jitter_level, 0.2)
|
||||
|
||||
# test generate_targets
|
||||
target_generator = textdet_targets.DRRGTargets(
|
||||
min_width=2.,
|
||||
max_width=4.,
|
||||
min_rand_half_height=3.,
|
||||
max_rand_half_height=5.)
|
||||
|
||||
results = {}
|
||||
results['img'] = np.zeros((64, 64, 3), np.uint8)
|
||||
text_polys = [[np.array([4, 2, 30, 2, 30, 10, 4, 10])],
|
||||
[np.array([36, 12, 8, 12, 8, 22, 36, 22])],
|
||||
[np.array([48, 20, 52, 20, 52, 50, 48, 50])],
|
||||
[np.array([44, 50, 38, 50, 38, 20, 44, 20])]]
|
||||
results['gt_masks'] = PolygonMasks(text_polys, 20, 30)
|
||||
results['gt_masks_ignore'] = PolygonMasks([], 64, 64)
|
||||
results['img_shape'] = (64, 64, 3)
|
||||
results['mask_fields'] = []
|
||||
output = target_generator(results)
|
||||
assert len(output['gt_text_mask']) == 1
|
||||
assert len(output['gt_center_region_mask']) == 1
|
||||
assert len(output['gt_mask']) == 1
|
||||
assert len(output['gt_top_height_map']) == 1
|
||||
assert len(output['gt_bot_height_map']) == 1
|
||||
assert len(output['gt_sin_map']) == 1
|
||||
assert len(output['gt_cos_map']) == 1
|
||||
assert output['gt_comp_attribs'].shape[-1] == 8
|
||||
|
||||
# test generate_targets with the number of proposed text components exceeds
|
||||
# num_max_comps
|
||||
target_generator = textdet_targets.DRRGTargets(
|
||||
min_width=2.,
|
||||
max_width=4.,
|
||||
min_rand_half_height=3.,
|
||||
max_rand_half_height=5.,
|
||||
num_max_comps=6)
|
||||
output = target_generator(results)
|
||||
assert output['gt_comp_attribs'].ndim == 2
|
||||
assert output['gt_comp_attribs'].shape[0] == 6
|
||||
|
||||
# test generate_targets with blank polygon masks
|
||||
target_generator = textdet_targets.DRRGTargets(
|
||||
min_width=2.,
|
||||
max_width=4.,
|
||||
min_rand_half_height=3.,
|
||||
max_rand_half_height=5.)
|
||||
results = {}
|
||||
results['img'] = np.zeros((20, 30, 3), np.uint8)
|
||||
results['gt_masks'] = PolygonMasks([], 20, 30)
|
||||
results['gt_masks_ignore'] = PolygonMasks([], 20, 30)
|
||||
results['img_shape'] = (20, 30, 3)
|
||||
results['mask_fields'] = []
|
||||
output = target_generator(results)
|
||||
assert output['gt_comp_attribs'][0, 0] > 8
|
||||
|
||||
# test generate_targets with one proposed text component
|
||||
text_polys = [[np.array([13, 6, 17, 6, 17, 14, 13, 14])]]
|
||||
target_generator = textdet_targets.DRRGTargets(
|
||||
min_width=4.,
|
||||
max_width=8.,
|
||||
min_rand_half_height=3.,
|
||||
max_rand_half_height=5.)
|
||||
results['gt_masks'] = PolygonMasks(text_polys, 20, 30)
|
||||
output = target_generator(results)
|
||||
assert output['gt_comp_attribs'][0, 0] > 8
|
||||
|
||||
# test generate_targets with shrunk margin in generate_rand_comp_attribs
|
||||
target_generator = textdet_targets.DRRGTargets(
|
||||
min_width=2.,
|
||||
max_width=30.,
|
||||
min_rand_half_height=3.,
|
||||
max_rand_half_height=30.)
|
||||
output = target_generator(results)
|
||||
assert output['gt_comp_attribs'][0, 0] > 8
|
||||
|
|
|
@ -429,3 +429,83 @@ def test_fcenet(cfg_file):
|
|||
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
|
||||
img = np.random.rand(5, 5)
|
||||
detector.show_result(img, results)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'cfg_file', ['textdet/drrg/'
|
||||
'drrg_r50_fpn_unet_1200e_ctw1500.py'])
|
||||
def test_drrg(cfg_file):
|
||||
model = _get_detector_cfg(cfg_file)
|
||||
model['pretrained'] = None
|
||||
model['backbone']['norm_cfg']['type'] = 'BN'
|
||||
|
||||
from mmocr.models import build_detector
|
||||
detector = build_detector(model)
|
||||
|
||||
input_shape = (1, 3, 224, 224)
|
||||
num_kernels = 1
|
||||
mm_inputs = _demo_mm_inputs(num_kernels, input_shape)
|
||||
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
gt_text_mask = mm_inputs.pop('gt_text_mask')
|
||||
gt_center_region_mask = mm_inputs.pop('gt_center_region_mask')
|
||||
gt_mask = mm_inputs.pop('gt_mask')
|
||||
gt_top_height_map = mm_inputs.pop('gt_radius_map')
|
||||
gt_bot_height_map = gt_top_height_map.copy()
|
||||
gt_sin_map = mm_inputs.pop('gt_sin_map')
|
||||
gt_cos_map = mm_inputs.pop('gt_cos_map')
|
||||
num_rois = 32
|
||||
x = np.random.randint(4, 224, (num_rois, 1))
|
||||
y = np.random.randint(4, 224, (num_rois, 1))
|
||||
h = 4 * np.ones((num_rois, 1))
|
||||
w = 4 * np.ones((num_rois, 1))
|
||||
angle = (np.random.random_sample((num_rois, 1)) * 2 - 1) * np.pi / 2
|
||||
cos, sin = np.cos(angle), np.sin(angle)
|
||||
comp_labels = np.random.randint(1, 3, (num_rois, 1))
|
||||
num_rois = num_rois * np.ones((num_rois, 1))
|
||||
comp_attribs = np.hstack([num_rois, x, y, h, w, cos, sin, comp_labels])
|
||||
gt_comp_attribs = np.expand_dims(comp_attribs.astype(np.float32), axis=0)
|
||||
|
||||
# Test forward train
|
||||
losses = detector.forward(
|
||||
imgs,
|
||||
img_metas,
|
||||
gt_text_mask=gt_text_mask,
|
||||
gt_center_region_mask=gt_center_region_mask,
|
||||
gt_mask=gt_mask,
|
||||
gt_top_height_map=gt_top_height_map,
|
||||
gt_bot_height_map=gt_bot_height_map,
|
||||
gt_sin_map=gt_sin_map,
|
||||
gt_cos_map=gt_cos_map,
|
||||
gt_comp_attribs=gt_comp_attribs)
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
# Test forward test
|
||||
model['bbox_head']['in_channels'] = 6
|
||||
model['bbox_head']['text_region_thr'] = 0.8
|
||||
model['bbox_head']['center_region_thr'] = 0.8
|
||||
detector = build_detector(model)
|
||||
maps = torch.zeros((1, 6, 224, 224), dtype=torch.float)
|
||||
maps[:, 0:2, :, :] = -10.
|
||||
maps[:, 0, 60:100, 50:170] = 10.
|
||||
maps[:, 1, 75:85, 60:160] = 10.
|
||||
maps[:, 2, 75:85, 60:160] = 0.
|
||||
maps[:, 3, 75:85, 60:160] = 1.
|
||||
maps[:, 4, 75:85, 60:160] = 10.
|
||||
maps[:, 5, 75:85, 60:160] = 10.
|
||||
|
||||
with torch.no_grad():
|
||||
full_pass_weight = torch.zeros((6, 6, 1, 1))
|
||||
for i in range(6):
|
||||
full_pass_weight[i, i, 0, 0] = 1
|
||||
detector.bbox_head.out_conv.weight.data = full_pass_weight
|
||||
detector.bbox_head.out_conv.bias.data.fill_(0.)
|
||||
outs = detector.bbox_head.single_test(maps)
|
||||
boundaries = detector.bbox_head.get_boundary(*outs, img_metas, True)
|
||||
assert len(boundaries)
|
||||
|
||||
# Test show result
|
||||
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
|
||||
img = np.random.rand(5, 5)
|
||||
detector.show_result(img, results)
|
||||
|
|
|
@ -42,9 +42,9 @@ def test_fcenetloss():
|
|||
|
||||
# test ohem
|
||||
pred = torch.ones((200, 2), dtype=torch.float)
|
||||
target = torch.ones((200, ), dtype=torch.long)
|
||||
target = torch.ones(200, dtype=torch.long)
|
||||
target[20:] = 0
|
||||
mask = torch.ones((200, ), dtype=torch.long)
|
||||
mask = torch.ones(200, dtype=torch.long)
|
||||
|
||||
ohem_loss1 = fcenetloss.ohem(pred, target, mask)
|
||||
ohem_loss2 = fcenetloss.ohem(pred, target, 1 - mask)
|
||||
|
@ -70,3 +70,76 @@ def test_fcenetloss():
|
|||
|
||||
loss = fcenetloss(preds, 0, p3_maps, p4_maps, p5_maps)
|
||||
assert isinstance(loss, dict)
|
||||
|
||||
|
||||
def test_drrgloss():
|
||||
drrgloss = losses.DRRGLoss()
|
||||
assert np.allclose(drrgloss.ohem_ratio, 3.0)
|
||||
|
||||
# test balance_bce_loss
|
||||
pred = torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=torch.float)
|
||||
target = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long)
|
||||
mask = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long)
|
||||
bce_loss = drrgloss.balance_bce_loss(pred, target, mask).item()
|
||||
assert np.allclose(bce_loss, 0)
|
||||
|
||||
# test balance_bce_loss with positive_count equal to zero
|
||||
pred = torch.ones((16, 16), dtype=torch.float)
|
||||
target = torch.ones((16, 16), dtype=torch.long)
|
||||
mask = torch.zeros((16, 16), dtype=torch.long)
|
||||
bce_loss = drrgloss.balance_bce_loss(pred, target, mask).item()
|
||||
assert np.allclose(bce_loss, 0)
|
||||
|
||||
# test gcn_loss
|
||||
gcn_preds = torch.tensor([[0., 1.], [1., 0.]])
|
||||
labels = torch.tensor([1, 0], dtype=torch.long)
|
||||
gcn_loss = drrgloss.gcn_loss((gcn_preds, labels))
|
||||
assert gcn_loss.item()
|
||||
|
||||
# test bitmasks2tensor
|
||||
mask = [[1, 0, 1], [1, 1, 1], [0, 0, 1]]
|
||||
target = [[1, 0, 1, 0, 0], [1, 1, 1, 0, 0], [0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]
|
||||
masks = [np.array(mask)]
|
||||
bitmasks = BitmapMasks(masks, 3, 3)
|
||||
target_sz = (6, 5)
|
||||
results = drrgloss.bitmasks2tensor([bitmasks], target_sz)
|
||||
assert len(results) == 1
|
||||
assert torch.sum(torch.abs(results[0].float() -
|
||||
torch.Tensor(target))).item() == 0
|
||||
|
||||
# test forward
|
||||
target_maps = [BitmapMasks([np.random.randn(20, 20)], 20, 20)]
|
||||
target_masks = [BitmapMasks([np.ones((20, 20))], 20, 20)]
|
||||
gt_masks = [BitmapMasks([np.ones((20, 20))], 20, 20)]
|
||||
preds = (torch.randn((1, 6, 20, 20)), (gcn_preds, labels))
|
||||
loss_dict = drrgloss(preds, 1., target_masks, target_masks, gt_masks,
|
||||
target_maps, target_maps, target_maps, target_maps)
|
||||
|
||||
assert isinstance(loss_dict, dict)
|
||||
assert 'loss_text' in loss_dict.keys()
|
||||
assert 'loss_center' in loss_dict.keys()
|
||||
assert 'loss_height' in loss_dict.keys()
|
||||
assert 'loss_sin' in loss_dict.keys()
|
||||
assert 'loss_cos' in loss_dict.keys()
|
||||
assert 'loss_gcn' in loss_dict.keys()
|
||||
|
||||
# test forward with downsample_ratio less than 1.
|
||||
target_maps = [BitmapMasks([np.random.randn(40, 40)], 40, 40)]
|
||||
target_masks = [BitmapMasks([np.ones((40, 40))], 40, 40)]
|
||||
gt_masks = [BitmapMasks([np.ones((40, 40))], 40, 40)]
|
||||
preds = (torch.randn((1, 6, 20, 20)), (gcn_preds, labels))
|
||||
loss_dict = drrgloss(preds, 0.5, target_masks, target_masks, gt_masks,
|
||||
target_maps, target_maps, target_maps, target_maps)
|
||||
|
||||
assert isinstance(loss_dict, dict)
|
||||
|
||||
# test forward with blank gt_mask.
|
||||
target_maps = [BitmapMasks([np.random.randn(20, 20)], 20, 20)]
|
||||
target_masks = [BitmapMasks([np.ones((20, 20))], 20, 20)]
|
||||
gt_masks = [BitmapMasks([np.zeros((20, 20))], 20, 20)]
|
||||
preds = (torch.randn((1, 6, 20, 20)), (gcn_preds, labels))
|
||||
loss_dict = drrgloss(preds, 1., target_masks, target_masks, gt_masks,
|
||||
target_maps, target_maps, target_maps, target_maps)
|
||||
|
||||
assert isinstance(loss_dict, dict)
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs
|
||||
from mmocr.models.textdet.modules.utils import (feature_embedding,
|
||||
normalize_adjacent_matrix)
|
||||
|
||||
|
||||
def test_local_graph_forward_train():
|
||||
geo_feat_len = 24
|
||||
pooling_h, pooling_w = pooling_out_size = (2, 2)
|
||||
num_rois = 32
|
||||
|
||||
local_graph_generator = LocalGraphs((4, 4), 3, geo_feat_len, 1.0,
|
||||
pooling_out_size, 0.5)
|
||||
|
||||
feature_maps = torch.randn((2, 3, 128, 128), dtype=torch.float)
|
||||
x = np.random.randint(4, 124, (num_rois, 1))
|
||||
y = np.random.randint(4, 124, (num_rois, 1))
|
||||
h = 4 * np.ones((num_rois, 1))
|
||||
w = 4 * np.ones((num_rois, 1))
|
||||
angle = (np.random.random_sample((num_rois, 1)) * 2 - 1) * np.pi / 2
|
||||
cos, sin = np.cos(angle), np.sin(angle)
|
||||
comp_labels = np.random.randint(1, 3, (num_rois, 1))
|
||||
num_rois = num_rois * np.ones((num_rois, 1))
|
||||
comp_attribs = np.hstack([num_rois, x, y, h, w, cos, sin, comp_labels])
|
||||
comp_attribs = comp_attribs.astype(np.float32)
|
||||
comp_attribs_ = comp_attribs.copy()
|
||||
comp_attribs = np.stack([comp_attribs, comp_attribs_])
|
||||
|
||||
(node_feats, adjacent_matrix, knn_inds,
|
||||
linkage_labels) = local_graph_generator(feature_maps, comp_attribs)
|
||||
feat_len = geo_feat_len + feature_maps.size()[1] * pooling_h * pooling_w
|
||||
|
||||
assert node_feats.dim() == adjacent_matrix.dim() == 3
|
||||
assert node_feats.size()[-1] == feat_len
|
||||
assert knn_inds.size()[-1] == 4
|
||||
assert linkage_labels.size()[-1] == 4
|
||||
assert (node_feats.size()[0] == adjacent_matrix.size()[0] ==
|
||||
knn_inds.size()[0] == linkage_labels.size()[0])
|
||||
assert (node_feats.size()[1] == adjacent_matrix.size()[1] ==
|
||||
adjacent_matrix.size()[2])
|
||||
|
||||
|
||||
def test_local_graph_forward_test():
|
||||
geo_feat_len = 24
|
||||
pooling_h, pooling_w = pooling_out_size = (2, 2)
|
||||
|
||||
local_graph_generator = ProposalLocalGraphs(
|
||||
(4, 4), 2, geo_feat_len, 1., pooling_out_size, 0.1, 3., 6., 1., 0.5,
|
||||
0.3, 0.5, 0.5, 2)
|
||||
|
||||
maps = torch.zeros((1, 6, 224, 224), dtype=torch.float)
|
||||
maps[:, 0:2, :, :] = -10.
|
||||
maps[:, 0, 60:100, 50:170] = 10.
|
||||
maps[:, 1, 75:85, 60:160] = 10.
|
||||
maps[:, 2, 75:85, 60:160] = 0.
|
||||
maps[:, 3, 75:85, 60:160] = 1.
|
||||
maps[:, 4, 75:85, 60:160] = 10.
|
||||
maps[:, 5, 75:85, 60:160] = 10.
|
||||
feature_maps = torch.randn((2, 6, 224, 224), dtype=torch.float)
|
||||
feat_len = geo_feat_len + feature_maps.size()[1] * pooling_h * pooling_w
|
||||
|
||||
none_flag, graph_data = local_graph_generator(maps, feature_maps)
|
||||
(node_feats, adjacent_matrices, knn_inds, local_graphs,
|
||||
text_comps) = graph_data
|
||||
|
||||
assert none_flag is False
|
||||
assert text_comps.ndim == 2
|
||||
assert text_comps.shape[0] > 0
|
||||
assert text_comps.shape[1] == 9
|
||||
assert (node_feats.size()[0] == adjacent_matrices.size()[0] ==
|
||||
knn_inds.size()[0] == local_graphs.size()[0] ==
|
||||
text_comps.shape[0])
|
||||
assert (node_feats.size()[1] == adjacent_matrices.size()[1] ==
|
||||
adjacent_matrices.size()[2] == local_graphs.size()[1])
|
||||
assert node_feats.size()[-1] == feat_len
|
||||
|
||||
# test proposal local graphs with area of center region less than threshold
|
||||
maps[:, 1, 75:85, 60:160] = -10.
|
||||
maps[:, 1, 80, 80] = 10.
|
||||
none_flag, _ = local_graph_generator(maps, feature_maps)
|
||||
assert none_flag
|
||||
|
||||
# test proposal local graphs with one text component
|
||||
local_graph_generator = ProposalLocalGraphs(
|
||||
(4, 4), 2, geo_feat_len, 1., pooling_out_size, 0.1, 8., 20., 1., 0.5,
|
||||
0.3, 0.5, 0.5, 2)
|
||||
maps[:, 1, 78:82, 78:82] = 10.
|
||||
none_flag, _ = local_graph_generator(maps, feature_maps)
|
||||
assert none_flag
|
||||
|
||||
# test proposal local graphs with text components out of text region
|
||||
maps[:, 0, 60:100, 50:170] = -10.
|
||||
maps[:, 0, 78:82, 78:82] = 10.
|
||||
none_flag, _ = local_graph_generator(maps, feature_maps)
|
||||
assert none_flag
|
||||
|
||||
|
||||
def test_gcn():
|
||||
num_local_graphs = 32
|
||||
num_max_graph_nodes = 16
|
||||
input_feat_len = 512
|
||||
k = 8
|
||||
gcn = GCN(input_feat_len)
|
||||
node_feat = torch.randn(
|
||||
(num_local_graphs, num_max_graph_nodes, input_feat_len))
|
||||
adjacent_matrix = torch.rand(
|
||||
(num_local_graphs, num_max_graph_nodes, num_max_graph_nodes))
|
||||
knn_inds = torch.randint(1, num_max_graph_nodes, (num_local_graphs, k))
|
||||
output = gcn(node_feat, adjacent_matrix, knn_inds)
|
||||
assert output.size() == (num_local_graphs * k, 2)
|
||||
|
||||
|
||||
def test_normalize_adjacent_matrix():
|
||||
adjacent_matrix = np.random.randn(32, 32)
|
||||
normalized_matrix = normalize_adjacent_matrix(adjacent_matrix, mode='AD')
|
||||
assert normalized_matrix.shape == adjacent_matrix.shape
|
||||
|
||||
normalized_matrix = normalize_adjacent_matrix(adjacent_matrix, mode='DAD')
|
||||
assert normalized_matrix.shape == adjacent_matrix.shape
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
normalized_matrix = normalize_adjacent_matrix(
|
||||
adjacent_matrix, mode='DA')
|
||||
|
||||
|
||||
def test_feature_embedding():
|
||||
out_feat_len = 48
|
||||
|
||||
# test without residue dimensions
|
||||
feats = np.random.randn(10, 8)
|
||||
embed_feats = feature_embedding(feats, out_feat_len)
|
||||
assert embed_feats.shape == (10, out_feat_len)
|
||||
|
||||
# test with residue dimensions
|
||||
feats = np.random.randn(10, 9)
|
||||
embed_feats = feature_embedding(feats, out_feat_len)
|
||||
assert embed_feats.shape == (10, out_feat_len)
|
|
@ -0,0 +1,82 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmocr.models.textdet.dense_heads import DRRGHead
|
||||
|
||||
|
||||
def test_drrg_head():
|
||||
in_channels = 10
|
||||
drrg_head = DRRGHead(in_channels)
|
||||
assert drrg_head.in_channels == in_channels
|
||||
assert drrg_head.k_at_hops == (8, 4)
|
||||
assert drrg_head.num_adjacent_linkages == 3
|
||||
assert drrg_head.node_geo_feat_len == 120
|
||||
assert np.allclose(drrg_head.pooling_scale, 1.0)
|
||||
assert drrg_head.pooling_output_size == (4, 3)
|
||||
assert np.allclose(drrg_head.nms_thr, 0.3)
|
||||
assert np.allclose(drrg_head.min_width, 8.0)
|
||||
assert np.allclose(drrg_head.max_width, 24.0)
|
||||
assert np.allclose(drrg_head.comp_shrink_ratio, 1.03)
|
||||
assert np.allclose(drrg_head.comp_ratio, 0.4)
|
||||
assert np.allclose(drrg_head.comp_score_thr, 0.3)
|
||||
assert np.allclose(drrg_head.text_region_thr, 0.2)
|
||||
assert np.allclose(drrg_head.center_region_thr, 0.2)
|
||||
assert drrg_head.center_region_area_thr == 50
|
||||
assert np.allclose(drrg_head.local_graph_thr, 0.7)
|
||||
assert np.allclose(drrg_head.link_thr, 0.85)
|
||||
|
||||
# test forward train
|
||||
num_rois = 16
|
||||
feature_maps = torch.randn((2, 10, 128, 128), dtype=torch.float)
|
||||
x = np.random.randint(4, 124, (num_rois, 1))
|
||||
y = np.random.randint(4, 124, (num_rois, 1))
|
||||
h = 4 * np.ones((num_rois, 1))
|
||||
w = 4 * np.ones((num_rois, 1))
|
||||
angle = (np.random.random_sample((num_rois, 1)) * 2 - 1) * np.pi / 2
|
||||
cos, sin = np.cos(angle), np.sin(angle)
|
||||
comp_labels = np.random.randint(1, 3, (num_rois, 1))
|
||||
num_rois = num_rois * np.ones((num_rois, 1))
|
||||
comp_attribs = np.hstack([num_rois, x, y, h, w, cos, sin, comp_labels])
|
||||
comp_attribs = comp_attribs.astype(np.float32)
|
||||
comp_attribs_ = comp_attribs.copy()
|
||||
comp_attribs = np.stack([comp_attribs, comp_attribs_])
|
||||
pred_maps, gcn_data = drrg_head(feature_maps, comp_attribs)
|
||||
pred_labels, gt_labels = gcn_data
|
||||
assert pred_maps.size() == (2, 6, 128, 128)
|
||||
assert pred_labels.ndim == gt_labels.ndim == 2
|
||||
assert gt_labels.size()[0] * gt_labels.size()[1] == pred_labels.size()[0]
|
||||
assert pred_labels.size()[1] == 2
|
||||
|
||||
# test forward test
|
||||
with torch.no_grad():
|
||||
feat_maps = torch.zeros((1, 10, 128, 128))
|
||||
drrg_head.out_conv.bias.data.fill_(-10)
|
||||
preds = drrg_head.single_test(feat_maps)
|
||||
assert all([pred is None for pred in preds])
|
||||
|
||||
# test get_boundary
|
||||
edges = np.stack([np.arange(0, 10), np.arange(1, 11)]).transpose()
|
||||
edges = np.vstack([edges, np.array([1, 0])])
|
||||
scores = np.ones(11, dtype=np.float32) * 0.9
|
||||
x1 = np.arange(2, 22, 2)
|
||||
x2 = x1 + 2
|
||||
y1 = np.ones(10) * 2
|
||||
y2 = y1 + 2
|
||||
comp_scores = np.ones(10, dtype=np.float32) * 0.9
|
||||
text_comps = np.stack([x1, y1, x2, y1, x2, y2, x1, y2,
|
||||
comp_scores]).transpose()
|
||||
outlier = np.array([50, 50, 52, 50, 52, 52, 50, 52, 0.9])
|
||||
text_comps = np.vstack([text_comps, outlier])
|
||||
|
||||
(C, H, W) = (10, 128, 128)
|
||||
img_metas = [{
|
||||
'img_shape': (H, W, C),
|
||||
'ori_shape': (H, W, C),
|
||||
'pad_shape': (H, W, C),
|
||||
'filename': '<demo>.png',
|
||||
'scale_factor': np.array([1, 1, 1, 1]),
|
||||
'flip': False,
|
||||
}]
|
||||
results = drrg_head.get_boundary(
|
||||
edges, scores, text_comps, img_metas, rescale=True)
|
||||
assert 'boundary_result' in results.keys()
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textdet.necks import FPN_UNET, FPNC
|
||||
from mmocr.models.textdet.necks import FPNC, FPN_UNet
|
||||
|
||||
|
||||
def test_fpnc():
|
||||
|
@ -32,18 +32,18 @@ def test_fpn_unet_neck():
|
|||
|
||||
# len(in_channcels) is not equal to 4
|
||||
with pytest.raises(AssertionError):
|
||||
FPN_UNET(in_channels + [128], out_channels)
|
||||
FPN_UNet(in_channels + [128], out_channels)
|
||||
|
||||
# `out_channels` is not int type
|
||||
with pytest.raises(AssertionError):
|
||||
FPN_UNET(in_channels, [2, 4])
|
||||
FPN_UNet(in_channels, [2, 4])
|
||||
|
||||
feats = [
|
||||
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
|
||||
for i in range(len(in_channels))
|
||||
]
|
||||
|
||||
fpn_unet_neck = FPN_UNET(in_channels, out_channels)
|
||||
fpn_unet_neck = FPN_UNet(in_channels, out_channels)
|
||||
fpn_unet_neck.init_weights()
|
||||
|
||||
out_neck = fpn_unet_neck(feats)
|
||||
|
|
|
@ -26,3 +26,24 @@ def test_fcenet_decode():
|
|||
preds=preds, fourier_degree=k, reconstr_points=50, scale=1)
|
||||
|
||||
assert isinstance(boundaries, list)
|
||||
|
||||
|
||||
def test_comps2boundaries():
|
||||
from mmocr.models.textdet.postprocess.wrapper import comps2boundaries
|
||||
|
||||
# test comps2boundaries
|
||||
x1 = np.arange(2, 18, 2)
|
||||
x2 = x1 + 2
|
||||
y1 = np.ones(8) * 2
|
||||
y2 = y1 + 2
|
||||
comp_scores = np.ones(8, dtype=np.float32) * 0.9
|
||||
text_comps = np.stack([x1, y1, x2, y1, x2, y2, x1, y2,
|
||||
comp_scores]).transpose()
|
||||
comp_labels = np.array([1, 1, 1, 1, 1, 3, 5, 5])
|
||||
shuffle = [3, 2, 5, 7, 6, 0, 4, 1]
|
||||
boundaries = comps2boundaries(text_comps[shuffle], comp_labels[shuffle])
|
||||
assert len(boundaries) == 3
|
||||
|
||||
# test comps2boundaries with blank inputs
|
||||
boundaries = comps2boundaries(text_comps[[]], comp_labels[[]])
|
||||
assert len(boundaries) == 0
|
||||
|
|
Loading…
Reference in New Issue