* merge drrg

* directory structure&fix redundant import

* docstrings

* fix isort

* drrg readme

* merge drrg

* directory structure&fix redundant import

* docstrings

* fix isort

* drrg readme

* add unittests&fix docstrings

* revert test_loss

* add unittest

* add unittests

* fix docstrings

* fix docstrings

* fix yapf

* fix yapf

* Update test_textdet_head.py

* Update test_textdet_head.py

* add unittests

* add unittests

* add unittests

* fix docstrings

* fix docstrings

* fix docstring

* fix unittests

* fix pytest

* fix pytest

* fix pytest

* fix variable names

Co-authored-by: Hongbin Sun <hongbin306@gmail.com>
pull/204/head
Jianyong Chen 2021-05-17 22:15:47 -05:00 committed by GitHub
parent ed6b3b890a
commit 2414c65577
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 2925 additions and 21 deletions

View File

@ -0,0 +1,23 @@
# DRRG
## Introduction
[ALGORITHM]
```bibtex
@article{zhang2020drrg,
title={Deep relational reasoning graph network for arbitrary shape text detection},
author={Zhang, Shi-Xue and Zhu, Xiaobin and Hou, Jie-Bo and Liu, Chang and Yang, Chun and Wang, Hongfa and Yin, Xu-Cheng},
booktitle={CVPR},
pages={9699-9708},
year={2020}
}
```
## Results and models
### CTW1500
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
| :--------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [DRRG](/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1200 | 640 | 0.822 | 0.858 | 0.840 | [model](https://download.openmmlab.com/mmocr/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth) \ [log](https://download.openmmlab.com/mmocr/textdet/drrg/20210511_234719.log) |

View File

@ -0,0 +1,110 @@
_base_ = [
'../../_base_/schedules/schedule_1200e.py',
'../../_base_/default_runtime.py'
]
model = dict(
type='DRRG',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='caffe'),
neck=dict(
type='FPN_UNet', in_channels=[256, 512, 1024, 2048], out_channels=32),
bbox_head=dict(
type='DRRGHead',
in_channels=32,
text_region_thr=0.3,
center_region_thr=0.4,
link_thr=0.80,
loss=dict(type='DRRGLoss')))
train_cfg = None
test_cfg = None
dataset_type = 'IcdarDataset'
data_root = 'data/ctw1500/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='LoadTextAnnotations',
with_bbox=True,
with_mask=True,
poly2mask=False),
dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='RandomScaling', size=800, scale=(0.75, 2.5)),
dict(
type='RandomCropFlip', crop_ratio=0.5, iter_num=1, min_area_ratio=0.2),
dict(
type='RandomCropPolyInstances',
instance_key='gt_masks',
crop_ratio=0.8,
min_side_ratio=0.3),
dict(
type='RandomRotatePolyInstances',
rotate_ratio=0.5,
max_angle=60,
pad_with_fixed_color=False),
dict(type='SquareResizePad', target_size=800, pad_ratio=0.6),
dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
dict(type='DRRGTargets'),
dict(type='Pad', size_divisor=32),
dict(
type='CustomFormatBundle',
keys=[
'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
'gt_top_height_map', 'gt_bot_height_map', 'gt_sin_map',
'gt_cos_map', 'gt_comp_attribs'
],
visualize=dict(flag=False, boundary_key='gt_text_mask')),
dict(
type='Collect',
keys=[
'img', 'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
'gt_top_height_map', 'gt_bot_height_map', 'gt_sin_map',
'gt_cos_map', 'gt_comp_attribs'
])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1024, 640),
flip=False,
transforms=[
dict(type='Resize', img_scale=(1024, 640), keep_ratio=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=f'{data_root}/instances_training.json',
img_prefix=f'{data_root}/imgs',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=f'{data_root}/instances_test.json',
img_prefix=f'{data_root}/imgs',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=f'{data_root}/instances_test.json',
img_prefix=f'{data_root}/imgs',
pipeline=test_pipeline))
evaluation = dict(interval=20, metric='hmean-iou')

View File

@ -15,7 +15,7 @@ model = dict(
norm_eval=True,
style='caffe'),
neck=dict(
type='FPN_UNET', in_channels=[256, 512, 1024, 2048], out_channels=32),
type='FPN_UNet', in_channels=[256, 512, 1024, 2048], out_channels=32),
bbox_head=dict(
type='TextSnakeHead',
in_channels=32,
@ -96,18 +96,18 @@ data = dict(
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
img_prefix=data_root + '/imgs',
ann_file=f'{data_root}/instances_training.json',
img_prefix=f'{data_root}/imgs',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
ann_file=f'{data_root}/instances_test.json',
img_prefix=f'{data_root}/imgs',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
ann_file=f'{data_root}/instances_test.json',
img_prefix=f'{data_root}/imgs',
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='hmean-iou')

View File

@ -1,5 +1,6 @@
from .base_textdet_targets import BaseTextDetTargets
from .dbnet_targets import DBNetTargets
from .drrg_targets import DRRGTargets
from .fcenet_targets import FCENetTargets
from .panet_targets import PANetTargets
from .psenet_targets import PSENetTargets
@ -7,5 +8,5 @@ from .textsnake_targets import TextSnakeTargets
__all__ = [
'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets',
'FCENetTargets', 'TextSnakeTargets'
'FCENetTargets', 'TextSnakeTargets', 'DRRGTargets'
]

View File

@ -0,0 +1,533 @@
import cv2
import numpy as np
from lanms import merge_quadrangle_n9 as la_nms
from numpy.linalg import norm
import mmocr.utils.check_argument as check_argument
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
from .textsnake_targets import TextSnakeTargets
@PIPELINES.register_module()
class DRRGTargets(TextSnakeTargets):
"""Generate the ground truth targets of DRRG: Deep Relational Reasoning
Graph Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]. This code was partially adapted from
https://github.com/GXYM/DRRG licensed under the MIT license.
Args:
orientation_thr (float): The threshold for distinguishing between
head edge and tail edge among the horizontal and vertical edges
of a quadrangle.
resample_step (float): The step size for resampling the text center
line.
num_min_comps (int): The minimum number of text components, which
should be larger than k_hop1 mentioned in paper.
num_max_comps (int): The maximum number of text components.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
center_region_shrink_ratio (float): The shrink ratio of text center
regions.
comp_shrink_ratio (float): The shrink ratio of text components.
comp_w_h_ratio (float): The width to height ratio of text components.
min_rand_half_height(float): The minimum half-height of random text
components.
max_rand_half_height (float): The maximum half-height of random
text components.
jitter_level (float): The jitter level of text component geometric
features.
"""
def __init__(self,
orientation_thr=2.0,
resample_step=8.0,
num_min_comps=9,
num_max_comps=600,
min_width=8.0,
max_width=24.0,
center_region_shrink_ratio=0.3,
comp_shrink_ratio=1.0,
comp_w_h_ratio=0.3,
text_comp_nms_thr=0.25,
min_rand_half_height=8.0,
max_rand_half_height=24.0,
jitter_level=0.2):
super().__init__()
self.orientation_thr = orientation_thr
self.resample_step = resample_step
self.num_max_comps = num_max_comps
self.num_min_comps = num_min_comps
self.min_width = min_width
self.max_width = max_width
self.center_region_shrink_ratio = center_region_shrink_ratio
self.comp_shrink_ratio = comp_shrink_ratio
self.comp_w_h_ratio = comp_w_h_ratio
self.text_comp_nms_thr = text_comp_nms_thr
self.min_rand_half_height = min_rand_half_height
self.max_rand_half_height = max_rand_half_height
self.jitter_level = jitter_level
def dist_point2line(self, point, line):
assert isinstance(line, tuple)
point1, point2 = line
d = abs(np.cross(point2 - point1, point - point1)) / (
norm(point2 - point1) + 1e-8)
return d
def draw_center_region_maps(self, top_line, bot_line, center_line,
center_region_mask, top_height_map,
bot_height_map, sin_map, cos_map,
region_shrink_ratio):
"""Draw attributes of text components on text center regions.
Args:
top_line (ndarray): The points composing the top side lines of text
polygons.
bot_line (ndarray): The points composing bottom side lines of text
polygons.
center_line (ndarray): The points composing the center lines of
text instances.
center_region_mask (ndarray): The text center region mask.
top_height_map (ndarray): The map on which the distance from points
to top side lines will be drawn for each pixel in text center
regions.
bot_height_map (ndarray): The map on which the distance from points
to bottom side lines will be drawn for each pixel in text
center regions.
sin_map (ndarray): The map of vector_sin(top_point - bot_point)
that will be drawn on text center regions.
cos_map (ndarray): The map of vector_cos(top_point - bot_point)
will be drawn on text center regions.
region_shrink_ratio (float): The shrink ratio of text center
regions.
"""
assert top_line.shape == bot_line.shape == center_line.shape
assert (center_region_mask.shape == top_height_map.shape ==
bot_height_map.shape == sin_map.shape == cos_map.shape)
assert isinstance(region_shrink_ratio, float)
h, w = center_region_mask.shape
for i in range(0, len(center_line) - 1):
top_mid_point = (top_line[i] + top_line[i + 1]) / 2
bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
sin_theta = self.vector_sin(top_mid_point - bot_mid_point)
cos_theta = self.vector_cos(top_mid_point - bot_mid_point)
tl = center_line[i] + (top_line[i] -
center_line[i]) * region_shrink_ratio
tr = center_line[i + 1] + (
top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
br = center_line[i + 1] + (
bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
bl = center_line[i] + (bot_line[i] -
center_line[i]) * region_shrink_ratio
current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
cv2.fillPoly(center_region_mask, [current_center_box], color=1)
cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0,
w - 1)
current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0,
h - 1)
min_coord = np.min(current_center_box, axis=0).astype(np.int32)
max_coord = np.max(current_center_box, axis=0).astype(np.int32)
current_center_box = current_center_box - min_coord
box_sz = (max_coord - min_coord + 1)
center_box_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
cv2.fillPoly(center_box_mask, [current_center_box], color=1)
inds = np.argwhere(center_box_mask > 0)
inds = inds + (min_coord[1], min_coord[0])
inds_xy = np.fliplr(inds)
top_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
inds_xy, (top_line[i], top_line[i + 1]))
bot_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
inds_xy, (bot_line[i], bot_line[i + 1]))
def generate_center_mask_attrib_maps(self, img_size, text_polys):
"""Generate text center region masks and geometric attribute maps.
Args:
img_size (tuple): The image size (height, width).
text_polys (list[list[ndarray]]): The list of text polygons.
Returns:
center_lines (list): The list of text center lines.
center_region_mask (ndarray): The text center region mask.
top_height_map (ndarray): The map on which the distance from points
to top side lines will be drawn for each pixel in text center
regions.
bot_height_map (ndarray): The map on which the distance from points
to bottom side lines will be drawn for each pixel in text
center regions.
sin_map (ndarray): The sin(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
cos_map (ndarray): The cos(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
"""
assert isinstance(img_size, tuple)
assert check_argument.is_2dlist(text_polys)
h, w = img_size
center_lines = []
center_region_mask = np.zeros((h, w), np.uint8)
top_height_map = np.zeros((h, w), dtype=np.float32)
bot_height_map = np.zeros((h, w), dtype=np.float32)
sin_map = np.zeros((h, w), dtype=np.float32)
cos_map = np.zeros((h, w), dtype=np.float32)
for poly in text_polys:
assert len(poly) == 1
polygon_points = poly[0].reshape(-1, 2)
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
resampled_top_line, resampled_bot_line = self.resample_sidelines(
top_line, bot_line, self.resample_step)
resampled_bot_line = resampled_bot_line[::-1]
center_line = (resampled_top_line + resampled_bot_line) / 2
if self.vector_slope(center_line[-1] - center_line[0]) > 2:
if (center_line[-1] - center_line[0])[1] < 0:
center_line = center_line[::-1]
resampled_top_line = resampled_top_line[::-1]
resampled_bot_line = resampled_bot_line[::-1]
else:
if (center_line[-1] - center_line[0])[0] < 0:
center_line = center_line[::-1]
resampled_top_line = resampled_top_line[::-1]
resampled_bot_line = resampled_bot_line[::-1]
line_head_shrink_len = np.clip(
(norm(top_line[0] - bot_line[0]) * self.comp_w_h_ratio),
self.min_width, self.max_width) / 2
line_tail_shrink_len = np.clip(
(norm(top_line[-1] - bot_line[-1]) * self.comp_w_h_ratio),
self.min_width, self.max_width) / 2
num_head_shrink = int(line_head_shrink_len // self.resample_step)
num_tail_shrink = int(line_tail_shrink_len // self.resample_step)
if len(center_line) > num_head_shrink + num_tail_shrink + 2:
center_line = center_line[num_head_shrink:len(center_line) -
num_tail_shrink]
resampled_top_line = resampled_top_line[
num_head_shrink:len(resampled_top_line) - num_tail_shrink]
resampled_bot_line = resampled_bot_line[
num_head_shrink:len(resampled_bot_line) - num_tail_shrink]
center_lines.append(center_line.astype(np.int32))
self.draw_center_region_maps(resampled_top_line,
resampled_bot_line, center_line,
center_region_mask, top_height_map,
bot_height_map, sin_map, cos_map,
self.center_region_shrink_ratio)
return (center_lines, center_region_mask, top_height_map,
bot_height_map, sin_map, cos_map)
def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
"""Generate random text components and their attributes to ensure the
the number of text components in an image is larger than k_hop1, which
is the number of one hop neighbors in KNN graph.
Args:
num_rand_comps (int): The number of random text components.
center_sample_mask (ndarray): The region mask for sampling text
component centers .
Returns:
rand_comp_attribs (ndarray): The random text component attributes
(x, y, h, w, cos, sin, comp_label=0).
"""
assert isinstance(num_rand_comps, int)
assert num_rand_comps > 0
assert center_sample_mask.ndim == 2
h, w = center_sample_mask.shape
max_rand_half_height = self.max_rand_half_height
min_rand_half_height = self.min_rand_half_height
max_rand_height = max_rand_half_height * 2
max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
self.min_width, self.max_width)
margin = int(
np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
if 2 * margin + 1 > min(h, w):
assert min(h, w) > (np.sqrt(2) * (self.min_width + 1))
max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1)
min_rand_half_height = max(max_rand_half_height / 4,
self.min_width / 2)
max_rand_height = max_rand_half_height * 2
max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
self.min_width, self.max_width)
margin = int(
np.sqrt((max_rand_height / 2)**2 +
(max_rand_width / 2)**2)) + 1
inner_center_sample_mask = np.zeros_like(center_sample_mask)
inner_center_sample_mask[margin:h - margin, margin:w - margin] = \
center_sample_mask[margin:h - margin, margin:w - margin]
kernel_size = int(np.clip(max_rand_half_height, 7, 21))
inner_center_sample_mask = cv2.erode(
inner_center_sample_mask,
np.ones((kernel_size, kernel_size), np.uint8))
center_candidates = np.argwhere(inner_center_sample_mask > 0)
num_center_candidates = len(center_candidates)
sample_inds = np.random.choice(num_center_candidates, num_rand_comps)
rand_centers = center_candidates[sample_inds]
rand_top_height = np.random.randint(
min_rand_half_height,
max_rand_half_height,
size=(len(rand_centers), 1))
rand_bot_height = np.random.randint(
min_rand_half_height,
max_rand_half_height,
size=(len(rand_centers), 1))
rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
scale = np.sqrt(1.0 / (rand_cos**2 + rand_sin**2 + 1e-8))
rand_cos = rand_cos * scale
rand_sin = rand_sin * scale
height = (rand_top_height + rand_bot_height)
width = np.clip(height * self.comp_w_h_ratio, self.min_width,
self.max_width)
rand_comp_attribs = np.hstack([
rand_centers[:, ::-1], height, width, rand_cos, rand_sin,
np.zeros_like(rand_sin)
]).astype(np.float32)
return rand_comp_attribs
def jitter_comp_attribs(self, comp_attribs, jitter_level):
"""Jitter text components attributes.
Args:
comp_attribs (ndarray): The text component attributes.
jitter_level (float): The jitter level of text components
attributes.
Returns:
jittered_comp_attribs (ndarray): The jittered text component
attributes (x, y, h, w, cos, sin, comp_label).
"""
assert comp_attribs.shape[1] == 7
assert comp_attribs.shape[0] > 0
assert isinstance(jitter_level, float)
x = comp_attribs[:, 0].reshape((-1, 1))
y = comp_attribs[:, 1].reshape((-1, 1))
h = comp_attribs[:, 2].reshape((-1, 1))
w = comp_attribs[:, 3].reshape((-1, 1))
cos = comp_attribs[:, 4].reshape((-1, 1))
sin = comp_attribs[:, 5].reshape((-1, 1))
comp_labels = comp_attribs[:, 6].reshape((-1, 1))
x += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * (h * np.abs(cos) + w * np.abs(sin)) * jitter_level
y += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * (h * np.abs(sin) + w * np.abs(cos)) * jitter_level
h += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * h * jitter_level
w += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * w * jitter_level
cos += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * 2 * jitter_level
sin += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * 2 * jitter_level
scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8))
cos = cos * scale
sin = sin * scale
jittered_comp_attribs = np.hstack([x, y, h, w, cos, sin, comp_labels])
return jittered_comp_attribs
def generate_comp_attribs(self, center_lines, text_mask,
center_region_mask, top_height_map,
bot_height_map, sin_map, cos_map):
"""Generate text component attributes.
Args:
center_lines (list[ndarray]): The list of text center lines .
text_mask (ndarray): The text region mask.
center_region_mask (ndarray): The text center region mask.
top_height_map (ndarray): The map on which the distance from points
to top side lines will be drawn for each pixel in text center
regions.
bot_height_map (ndarray): The map on which the distance from points
to bottom side lines will be drawn for each pixel in text
center regions.
sin_map (ndarray): The sin(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
cos_map (ndarray): The cos(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
Returns:
pad_comp_attribs (ndarray): The padded text component attributes
of a fixed size.
"""
assert isinstance(center_lines, list)
assert (text_mask.shape == center_region_mask.shape ==
top_height_map.shape == bot_height_map.shape == sin_map.shape
== cos_map.shape)
center_lines_mask = np.zeros_like(center_region_mask)
cv2.polylines(center_lines_mask, center_lines, 0, 1, 1)
center_lines_mask = center_lines_mask * center_region_mask
comp_centers = np.argwhere(center_lines_mask > 0)
y = comp_centers[:, 0]
x = comp_centers[:, 1]
top_height = top_height_map[y, x].reshape(
(-1, 1)) * self.comp_shrink_ratio
bot_height = bot_height_map[y, x].reshape(
(-1, 1)) * self.comp_shrink_ratio
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
top_mid_points = comp_centers + np.hstack(
[top_height * sin, top_height * cos])
bot_mid_points = comp_centers - np.hstack(
[bot_height * sin, bot_height * cos])
width = (top_height + bot_height) * self.comp_w_h_ratio
width = np.clip(width, self.min_width, self.max_width)
r = width / 2
tl = top_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
tr = top_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
br = bot_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
bl = bot_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
score = np.ones((text_comps.shape[0], 1), dtype=np.float32)
text_comps = np.hstack([text_comps, score])
text_comps = la_nms(text_comps, self.text_comp_nms_thr)
if text_comps.shape[0] >= 1:
img_h, img_w = center_region_mask.shape
text_comps[:, 0:8:2] = np.clip(text_comps[:, 0:8:2], 0, img_w - 1)
text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1)
comp_centers = np.mean(
text_comps[:, 0:8].reshape((-1, 4, 2)),
axis=1).astype(np.int32)
x = comp_centers[:, 0]
y = comp_centers[:, 1]
height = (top_height_map[y, x] + bot_height_map[y, x]).reshape(
(-1, 1))
width = np.clip(height * self.comp_w_h_ratio, self.min_width,
self.max_width)
cos = cos_map[y, x].reshape((-1, 1))
sin = sin_map[y, x].reshape((-1, 1))
_, comp_label_mask = cv2.connectedComponents(
center_region_mask, connectivity=8)
comp_labels = comp_label_mask[y, x].reshape(
(-1, 1)).astype(np.float32)
x = x.reshape((-1, 1)).astype(np.float32)
y = y.reshape((-1, 1)).astype(np.float32)
comp_attribs = np.hstack(
[x, y, height, width, cos, sin, comp_labels])
comp_attribs = self.jitter_comp_attribs(comp_attribs,
self.jitter_level)
if comp_attribs.shape[0] < self.num_min_comps:
num_rand_comps = self.num_min_comps - comp_attribs.shape[0]
rand_comp_attribs = self.generate_rand_comp_attribs(
num_rand_comps, 1 - text_mask)
comp_attribs = np.vstack([comp_attribs, rand_comp_attribs])
else:
comp_attribs = self.generate_rand_comp_attribs(
self.num_min_comps, 1 - text_mask)
num_comps = (
np.ones((comp_attribs.shape[0], 1), dtype=np.float32) *
comp_attribs.shape[0])
comp_attribs = np.hstack([num_comps, comp_attribs])
if comp_attribs.shape[0] > self.num_max_comps:
comp_attribs = comp_attribs[:self.num_max_comps, :]
comp_attribs[:, 0] = self.num_max_comps
pad_comp_attribs = np.zeros(
(self.num_max_comps, comp_attribs.shape[1]), dtype=np.float32)
pad_comp_attribs[:comp_attribs.shape[0], :] = comp_attribs
return pad_comp_attribs
def generate_targets(self, results):
"""Generate the gt targets for DRRG.
Args:
results (dict): The input result dictionary.
Returns:
results (dict): The output result dictionary.
"""
assert isinstance(results, dict)
polygon_masks = results['gt_masks'].masks
polygon_masks_ignore = results['gt_masks_ignore'].masks
h, w, _ = results['img_shape']
gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
(center_lines, gt_center_region_mask, gt_top_height_map,
gt_bot_height_map, gt_sin_map,
gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
polygon_masks)
gt_comp_attribs = self.generate_comp_attribs(center_lines,
gt_text_mask,
gt_center_region_mask,
gt_top_height_map,
gt_bot_height_map,
gt_sin_map, gt_cos_map)
results['mask_fields'].clear() # rm gt_masks encoded by polygons
mapping = {
'gt_text_mask': gt_text_mask,
'gt_center_region_mask': gt_center_region_mask,
'gt_mask': gt_mask,
'gt_top_height_map': gt_top_height_map,
'gt_bot_height_map': gt_bot_height_map,
'gt_sin_map': gt_sin_map,
'gt_cos_map': gt_cos_map
}
for key, value in mapping.items():
value = value if isinstance(value, list) else [value]
results[key] = BitmapMasks(value, h, w)
results['mask_fields'].append(key)
results['gt_comp_attribs'] = gt_comp_attribs
return results

View File

@ -1,4 +1,5 @@
from .db_head import DBHead
from .drrg_head import DRRGHead
from .fce_head import FCEHead
from .head_mixin import HeadMixin
from .pan_head import PANHead
@ -6,5 +7,6 @@ from .pse_head import PSEHead
from .textsnake_head import TextSnakeHead
__all__ = [
'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'HeadMixin', 'TextSnakeHead'
'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'HeadMixin', 'TextSnakeHead',
'DRRGHead'
]

View File

@ -0,0 +1,219 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.models.builder import HEADS, build_loss
from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs
from mmocr.models.textdet.postprocess import decode
from mmocr.utils import check_argument
from .head_mixin import HeadMixin
@HEADS.register_module()
class DRRGHead(HeadMixin, nn.Module):
"""The class for DRRG head: Deep Relational Reasoning Graph Network for
Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]
Args:
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2.
num_adjacent_linkages (int): The number of linkages when constructing
adjacent matrix.
node_geo_feat_len (int): The length of embedded geometric feature
vector of a component.
pooling_scale (float): The spatial scale of rotated RoI-Align.
pooling_output_size (tuple(int)): The output size of RRoI-Aligning.
nms_thr (float): The locality-aware NMS threshold of text components.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_shrink_ratio (float): The shrink ratio of text components.
comp_ratio (float): The reciprocal of aspect ratio of text components.
comp_score_thr (float): The score threshold of text components.
text_region_thr (float): The threshold for text region probability map.
center_region_thr (float): The threshold for text center region
probability map.
center_region_area_thr (int): The threshold for filtering small-sized
text center region.
local_graph_thr (float): The threshold to filter identical local
graphs.
link_thr(float): The threshold for connected components search.
"""
def __init__(self,
in_channels,
k_at_hops=(8, 4),
num_adjacent_linkages=3,
node_geo_feat_len=120,
pooling_scale=1.0,
pooling_output_size=(4, 3),
nms_thr=0.3,
min_width=8.0,
max_width=24.0,
comp_shrink_ratio=1.03,
comp_ratio=0.4,
comp_score_thr=0.3,
text_region_thr=0.2,
center_region_thr=0.2,
center_region_area_thr=50,
local_graph_thr=0.7,
link_thr=0.85,
loss=dict(type='DRRGLoss'),
train_cfg=None,
test_cfg=None):
super().__init__()
assert isinstance(in_channels, int)
assert isinstance(k_at_hops, tuple)
assert isinstance(num_adjacent_linkages, int)
assert isinstance(node_geo_feat_len, int)
assert isinstance(pooling_scale, float)
assert isinstance(pooling_output_size, tuple)
assert isinstance(comp_shrink_ratio, float)
assert isinstance(nms_thr, float)
assert isinstance(min_width, float)
assert isinstance(max_width, float)
assert isinstance(comp_ratio, float)
assert isinstance(comp_score_thr, float)
assert isinstance(text_region_thr, float)
assert isinstance(center_region_thr, float)
assert isinstance(center_region_area_thr, int)
assert isinstance(local_graph_thr, float)
assert isinstance(link_thr, float)
self.in_channels = in_channels
self.out_channels = 6
self.downsample_ratio = 1.0
self.k_at_hops = k_at_hops
self.num_adjacent_linkages = num_adjacent_linkages
self.node_geo_feat_len = node_geo_feat_len
self.pooling_scale = pooling_scale
self.pooling_output_size = pooling_output_size
self.comp_shrink_ratio = comp_shrink_ratio
self.nms_thr = nms_thr
self.min_width = min_width
self.max_width = max_width
self.comp_ratio = comp_ratio
self.comp_score_thr = comp_score_thr
self.text_region_thr = text_region_thr
self.center_region_thr = center_region_thr
self.center_region_area_thr = center_region_area_thr
self.local_graph_thr = local_graph_thr
self.link_thr = link_thr
self.loss_module = build_loss(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.out_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0)
self.init_weights()
self.graph_train = LocalGraphs(self.k_at_hops,
self.num_adjacent_linkages,
self.node_geo_feat_len,
self.pooling_scale,
self.pooling_output_size,
self.local_graph_thr)
self.graph_test = ProposalLocalGraphs(
self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
self.pooling_scale, self.pooling_output_size, self.nms_thr,
self.min_width, self.max_width, self.comp_shrink_ratio,
self.comp_ratio, self.comp_score_thr, self.text_region_thr,
self.center_region_thr, self.center_region_area_thr)
pool_w, pool_h = self.pooling_output_size
node_feat_len = (pool_w * pool_h) * (
self.in_channels + self.out_channels) + self.node_geo_feat_len
self.gcn = GCN(node_feat_len)
def init_weights(self):
normal_init(self.out_conv, mean=0, std=0.01)
def forward(self, inputs, gt_comp_attribs):
pred_maps = self.out_conv(inputs)
feat_maps = torch.cat([inputs, pred_maps], dim=1)
node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train(
feat_maps, np.stack(gt_comp_attribs))
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds)
return pred_maps, (gcn_pred, gt_labels)
def single_test(self, feat_maps):
pred_maps = self.out_conv(feat_maps)
feat_maps = torch.cat([feat_maps, pred_maps], dim=1)
none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
pivot_local_graphs, text_comps) = graph_data
if none_flag:
return None, None, None
gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices,
pivots_knn_inds)
pred_labels = F.softmax(gcn_pred, dim=1)
edges = []
scores = []
pivot_local_graphs = pivot_local_graphs.long().squeeze().cpu().numpy()
for pivot_ind, pivot_local_graph in enumerate(pivot_local_graphs):
pivot = pivot_local_graph[0]
for k_ind, neighbor_ind in enumerate(pivots_knn_inds[pivot_ind]):
neighbor = pivot_local_graph[neighbor_ind.item()]
edges.append([pivot, neighbor])
scores.append(
pred_labels[pivot_ind * pivots_knn_inds.shape[1] + k_ind,
1].item())
edges = np.asarray(edges)
scores = np.asarray(scores)
return edges, scores, text_comps
def get_boundary(self, edges, scores, text_comps, img_metas, rescale):
"""Compute text boundaries via post processing.
Args:
edges (ndarray): The edge array of shape N * 2, each row is a pair
of text component indices that makes up an edge in graph.
scores (ndarray): The edge score array.
text_comps (ndarray): The text components.
img_metas (list[dict]): The image meta infos.
rescale (bool): Rescale boundaries to the original image
resolution.
Returns:
results (dict): The result dict.
"""
assert check_argument.is_type_list(img_metas, dict)
assert isinstance(rescale, bool)
boundaries = []
if edges is not None:
boundaries = decode(
decoding_type='drrg',
edges=edges,
scores=scores,
text_comps=text_comps,
link_thr=self.link_thr)
if rescale:
boundaries = self.resize_boundary(
boundaries,
1.0 / self.downsample_ratio / img_metas[0]['scale_factor'])
results = dict(boundary_result=boundaries)
return results

View File

@ -1,4 +1,5 @@
from .dbnet import DBNet
from .drrg import DRRG
from .fcenet import FCENet
from .ocr_mask_rcnn import OCRMaskRCNN
from .panet import PANet
@ -9,5 +10,5 @@ from .textsnake import TextSnake
__all__ = [
'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',
'PANet', 'PSENet', 'TextSnake', 'FCENet'
'PANet', 'PSENet', 'TextSnake', 'FCENet', 'DRRG'
]

View File

@ -0,0 +1,53 @@
from mmdet.models.builder import DETECTORS
from mmocr.models.textdet.detectors.single_stage_text_detector import \
SingleStageTextDetector
from mmocr.models.textdet.detectors.text_detector_mixin import \
TextDetectorMixin
@DETECTORS.register_module()
class DRRG(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing DRRG text detector. Deep Relational Reasoning
Graph Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]
"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None,
show_score=False):
SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
train_cfg, test_cfg, pretrained)
TextDetectorMixin.__init__(self, show_score)
def forward_train(self, img, img_metas, **kwargs):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details of the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img)
gt_comp_attribs = kwargs.pop('gt_comp_attribs')
preds = self.bbox_head(x, gt_comp_attribs)
losses = self.bbox_head.loss(preds, **kwargs)
return losses
def simple_test(self, img, img_metas, rescale=False):
x = self.extract_feat(img)
outs = self.bbox_head.single_test(x)
boundaries = self.bbox_head.get_boundary(*outs, img_metas, rescale)
return [boundaries]

View File

@ -1,7 +1,10 @@
from .db_loss import DBLoss
from .drrg_loss import DRRGLoss
from .fce_loss import FCELoss
from .pan_loss import PANLoss
from .pse_loss import PSELoss
from .textsnake_loss import TextSnakeLoss
__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss', 'FCELoss']
__all__ = [
'PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss', 'FCELoss', 'DRRGLoss'
]

View File

@ -0,0 +1,214 @@
import torch
import torch.nn.functional as F
from torch import nn
from mmdet.core import BitmapMasks
from mmdet.models.builder import LOSSES
from mmocr.utils import check_argument
@LOSSES.register_module()
class DRRGLoss(nn.Module):
"""The class for implementing DRRG loss: Deep Relational Reasoning Graph
Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/1908.05900] This is partially adapted from
https://github.com/GXYM/DRRG licensed under the MIT license.
"""
def __init__(self, ohem_ratio=3.0):
"""Initialization.
Args:
ohem_ratio (float): The negative/positive ratio in OHEM.
"""
super().__init__()
self.ohem_ratio = ohem_ratio
def balance_bce_loss(self, pred, gt, mask):
assert pred.shape == gt.shape == mask.shape
assert torch.all(pred >= 0) and torch.all(pred <= 1)
assert torch.all(gt >= 0) and torch.all(gt <= 1)
positive = gt * mask
negative = (1 - gt) * mask
positive_count = int(positive.float().sum())
gt = gt.float()
if positive_count > 0:
loss = F.binary_cross_entropy(pred, gt, reduction='none')
positive_loss = torch.sum(loss * positive.float())
negative_loss = loss * negative.float()
negative_count = min(
int(negative.float().sum()),
int(positive_count * self.ohem_ratio))
else:
positive_loss = torch.tensor(0.0, device=pred.device)
loss = F.binary_cross_entropy(pred, gt, reduction='none')
negative_loss = loss * negative.float()
negative_count = 100
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
balance_loss = (positive_loss + torch.sum(negative_loss)) / (
float(positive_count + negative_count) + 1e-5)
return balance_loss
def gcn_loss(self, gcn_data):
gcn_pred, gt_labels = gcn_data
gt_labels = gt_labels.view(-1).to(gcn_pred.device)
loss = F.cross_entropy(gcn_pred, gt_labels)
return loss
def bitmasks2tensor(self, bitmasks, target_sz):
"""Convert Bitmasks to tensor.
Args:
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
for one img.
target_sz (tuple(int, int)): The target tensor size HxW.
Returns
results (list[tensor]): The list of kernel tensors. Each
element is for one kernel level.
"""
assert check_argument.is_type_list(bitmasks, BitmapMasks)
assert isinstance(target_sz, tuple)
batch_size = len(bitmasks)
num_masks = len(bitmasks[0])
results = []
for level_inx in range(num_masks):
kernel = []
for batch_inx in range(batch_size):
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
# hxw
mask_sz = mask.shape
# left, right, top, bottom
pad = [
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
]
mask = F.pad(mask, pad, mode='constant', value=0)
kernel.append(mask)
kernel = torch.stack(kernel)
results.append(kernel)
return results
def forward(self, preds, downsample_ratio, gt_text_mask,
gt_center_region_mask, gt_mask, gt_top_height_map,
gt_bot_height_map, gt_sin_map, gt_cos_map):
assert isinstance(preds, tuple)
assert isinstance(downsample_ratio, float)
assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
assert check_argument.is_type_list(gt_mask, BitmapMasks)
assert check_argument.is_type_list(gt_top_height_map, BitmapMasks)
assert check_argument.is_type_list(gt_bot_height_map, BitmapMasks)
assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
assert check_argument.is_type_list(gt_cos_map, BitmapMasks)
pred_maps, gcn_data = preds
pred_text_region = pred_maps[:, 0, :, :]
pred_center_region = pred_maps[:, 1, :, :]
pred_sin_map = pred_maps[:, 2, :, :]
pred_cos_map = pred_maps[:, 3, :, :]
pred_top_height_map = pred_maps[:, 4, :, :]
pred_bot_height_map = pred_maps[:, 5, :, :]
feature_sz = pred_maps.size()
device = pred_maps.device
# bitmask 2 tensor
mapping = {
'gt_text_mask': gt_text_mask,
'gt_center_region_mask': gt_center_region_mask,
'gt_mask': gt_mask,
'gt_top_height_map': gt_top_height_map,
'gt_bot_height_map': gt_bot_height_map,
'gt_sin_map': gt_sin_map,
'gt_cos_map': gt_cos_map
}
gt = {}
for key, value in mapping.items():
gt[key] = value
if abs(downsample_ratio - 1.0) < 1e-2:
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
else:
gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
if key in ['gt_top_height_map', 'gt_bot_height_map']:
gt[key] = [item * downsample_ratio for item in gt[key]]
gt[key] = [item.to(device) for item in gt[key]]
scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
pred_sin_map = pred_sin_map * scale
pred_cos_map = pred_cos_map * scale
loss_text = self.balance_bce_loss(
torch.sigmoid(pred_text_region), gt['gt_text_mask'][0],
gt['gt_mask'][0])
text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
negative_text_mask = ((1 - gt['gt_text_mask'][0]) *
gt['gt_mask'][0]).float()
loss_center_map = F.binary_cross_entropy(
torch.sigmoid(pred_center_region),
gt['gt_center_region_mask'][0].float(),
reduction='none')
if int(text_mask.sum()) > 0:
loss_center_positive = torch.sum(
loss_center_map * text_mask) / torch.sum(text_mask)
else:
loss_center_positive = torch.tensor(0.0, device=device)
loss_center_negative = torch.sum(
loss_center_map *
negative_text_mask) / torch.sum(negative_text_mask)
loss_center = loss_center_positive + 0.5 * loss_center_negative
center_mask = (gt['gt_center_region_mask'][0] *
gt['gt_mask'][0]).float()
if int(center_mask.sum()) > 0:
map_sz = pred_top_height_map.size()
ones = torch.ones(map_sz, dtype=torch.float, device=device)
loss_top = F.smooth_l1_loss(
pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2),
ones,
reduction='none')
loss_bot = F.smooth_l1_loss(
pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2),
ones,
reduction='none')
gt_height = (
gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0])
loss_height = torch.sum(
(torch.log(gt_height + 1) *
(loss_top + loss_bot)) * center_mask) / torch.sum(center_mask)
loss_sin = torch.sum(
F.smooth_l1_loss(
pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
center_mask) / torch.sum(center_mask)
loss_cos = torch.sum(
F.smooth_l1_loss(
pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
center_mask) / torch.sum(center_mask)
else:
loss_height = torch.tensor(0.0, device=device)
loss_sin = torch.tensor(0.0, device=device)
loss_cos = torch.tensor(0.0, device=device)
loss_gcn = self.gcn_loss(gcn_data)
results = dict(
loss_text=loss_text,
loss_center=loss_center,
loss_height=loss_height,
loss_sin=loss_sin,
loss_cos=loss_cos,
loss_gcn=loss_gcn)
return results

View File

@ -0,0 +1,5 @@
from .gcn import GCN
from .local_graph import LocalGraphs
from .proposal_local_graph import ProposalLocalGraphs
__all__ = ['LocalGraphs', 'ProposalLocalGraphs', 'GCN']

View File

@ -0,0 +1,75 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
class MeanAggregator(nn.Module):
def forward(self, features, A):
x = torch.bmm(A, features)
return x
class GraphConv(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
self.bias = nn.Parameter(torch.FloatTensor(out_dim))
init.xavier_uniform_(self.weight)
init.constant_(self.bias, 0)
self.aggregator = MeanAggregator()
def forward(self, features, A):
b, n, d = features.shape
assert d == self.in_dim
agg_feats = self.aggregator(features, A)
cat_feats = torch.cat([features, agg_feats], dim=2)
out = torch.einsum('bnd,df->bnf', cat_feats, self.weight)
out = F.relu(out + self.bias)
return out
class GCN(nn.Module):
"""Graph convolutional network for clustering. This was from repo
https://github.com/Zhongdao/gcn_clustering licensed under the MIT license.
Args:
feat_len(int): The input node feature length.
"""
def __init__(self, feat_len):
super(GCN, self).__init__()
self.bn0 = nn.BatchNorm1d(feat_len, affine=False).float()
self.conv1 = GraphConv(feat_len, 512)
self.conv2 = GraphConv(512, 256)
self.conv3 = GraphConv(256, 128)
self.conv4 = GraphConv(128, 64)
self.classifier = nn.Sequential(
nn.Linear(64, 32), nn.PReLU(32), nn.Linear(32, 2))
def forward(self, x, A, knn_inds):
num_local_graphs, num_max_nodes, feat_len = x.shape
x = x.view(-1, feat_len)
x = self.bn0(x)
x = x.view(num_local_graphs, num_max_nodes, feat_len)
x = self.conv1(x, A)
x = self.conv2(x, A)
x = self.conv3(x, A)
x = self.conv4(x, A)
k = knn_inds.size(-1)
mid_feat_len = x.size(-1)
edge_feat = torch.zeros((num_local_graphs, k, mid_feat_len),
device=x.device)
for graph_ind in range(num_local_graphs):
edge_feat[graph_ind, :, :] = x[graph_ind, knn_inds[graph_ind]]
edge_feat = edge_feat.view(-1, mid_feat_len)
pred = self.classifier(edge_feat)
return pred

View File

@ -0,0 +1,296 @@
import numpy as np
import torch
from mmcv.ops import RoIAlignRotated
from .utils import (euclidean_distance_matrix, feature_embedding,
normalize_adjacent_matrix)
class LocalGraphs(object):
"""Generate local graphs for GCN to classify the neighbors of a pivot for
DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape Text
Detection.
[https://arxiv.org/abs/2003.07493]. This code was partially adapted from
https://github.com/GXYM/DRRG licensed under the MIT license.
Args:
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2.
num_adjacent_linkages (int): The number of linkages when constructing
adjacent matrix.
node_geo_feat_len (int): The length of embedded geometric feature
vector of a text component.
pooling_scale (float): The spatial scale of rotated RoI-Align.
pooling_output_size (tuple(int)): The output size of rotated RoI-Align.
local_graph_thr(float): The threshold for filtering out identical local
graphs.
"""
def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
pooling_scale, pooling_output_size, local_graph_thr):
assert len(k_at_hops) == 2
assert all(isinstance(n, int) for n in k_at_hops)
assert isinstance(num_adjacent_linkages, int)
assert isinstance(node_geo_feat_len, int)
assert isinstance(pooling_scale, float)
assert all(isinstance(n, int) for n in pooling_output_size)
assert isinstance(local_graph_thr, float)
self.k_at_hops = k_at_hops
self.num_adjacent_linkages = num_adjacent_linkages
self.node_geo_feat_dim = node_geo_feat_len
self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
self.local_graph_thr = local_graph_thr
def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels):
"""Generate local graphs for GCN to predict which instance a text
component belongs to.
Args:
sorted_dist_inds (ndarray): The complete graph node indices, which
is sorted according to the Euclidean distance.
gt_comp_labels(ndarray): The ground truth labels define the
instance to which the text components (nodes in graphs) belong.
Returns:
pivot_local_graphs(list[list[int]]): The list of local graph
neighbor indices of pivots.
pivot_knns(list[list[int]]): The list of k-nearest neighbor indices
of pivots.
"""
assert sorted_dist_inds.ndim == 2
assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
gt_comp_labels.shape[0])
knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
pivot_local_graphs = []
pivot_knns = []
for pivot_ind, knn in enumerate(knn_graph):
local_graph_neighbors = set(knn)
for neighbor_ind in knn:
local_graph_neighbors.update(
set(sorted_dist_inds[neighbor_ind,
1:self.k_at_hops[1] + 1]))
local_graph_neighbors.discard(pivot_ind)
pivot_local_graph = list(local_graph_neighbors)
pivot_local_graph.insert(0, pivot_ind)
pivot_knn = [pivot_ind] + list(knn)
if pivot_ind < 1:
pivot_local_graphs.append(pivot_local_graph)
pivot_knns.append(pivot_knn)
else:
add_flag = True
for graph_ind, added_knn in enumerate(pivot_knns):
added_pivot_ind = added_knn[0]
added_local_graph = pivot_local_graphs[graph_ind]
union = len(
set(pivot_local_graph[1:]).union(
set(added_local_graph[1:])))
intersect = len(
set(pivot_local_graph[1:]).intersection(
set(added_local_graph[1:])))
local_graph_iou = intersect / (union + 1e-8)
if (local_graph_iou > self.local_graph_thr
and pivot_ind in added_knn
and gt_comp_labels[added_pivot_ind]
== gt_comp_labels[pivot_ind]
and gt_comp_labels[pivot_ind] != 0):
add_flag = False
break
if add_flag:
pivot_local_graphs.append(pivot_local_graph)
pivot_knns.append(pivot_knn)
return pivot_local_graphs, pivot_knns
def generate_gcn_input(self, node_feat_batch, node_label_batch,
local_graph_batch, knn_batch,
sorted_dist_ind_batch):
"""Generate graph convolution network input data.
Args:
node_feat_batch (List[Tensor]): The batched graph node features.
node_label_batch (List[ndarray]): The batched text component
labels.
local_graph_batch (List[List[list[int]]]): The local graph node
indices of image batch.
knn_batch (List[List[list[int]]]): The knn graph node indices of
image batch.
sorted_dist_ind_batch (list[ndarray]): The node indices sorted
according to the Euclidean distance.
Returns:
local_graphs_node_feat (Tensor): The node features of graph.
adjacent_matrices (Tensor): The adjacent matrices of local graphs.
pivots_knn_inds (Tensor): The k-nearest neighbor indices in
local graph.
gt_linkage (Tensor): The surpervision signal of GCN for linkage
prediction.
"""
assert isinstance(node_feat_batch, list)
assert isinstance(node_label_batch, list)
assert isinstance(local_graph_batch, list)
assert isinstance(knn_batch, list)
assert isinstance(sorted_dist_ind_batch, list)
num_max_nodes = max([
len(pivot_local_graph) for pivot_local_graphs in local_graph_batch
for pivot_local_graph in pivot_local_graphs
])
local_graphs_node_feat = []
adjacent_matrices = []
pivots_knn_inds = []
pivots_gt_linkage = []
for batch_ind, sorted_dist_inds in enumerate(sorted_dist_ind_batch):
node_feats = node_feat_batch[batch_ind]
pivot_local_graphs = local_graph_batch[batch_ind]
pivot_knns = knn_batch[batch_ind]
node_labels = node_label_batch[batch_ind]
device = node_feats.device
for graph_ind, pivot_knn in enumerate(pivot_knns):
pivot_local_graph = pivot_local_graphs[graph_ind]
num_nodes = len(pivot_local_graph)
pivot_ind = pivot_local_graph[0]
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
knn_inds = torch.tensor(
[node2ind_map[i] for i in pivot_knn[1:]])
pivot_feats = node_feats[pivot_ind]
normalized_feats = node_feats[pivot_local_graph] - pivot_feats
adjacent_matrix = np.zeros((num_nodes, num_nodes),
dtype=np.float32)
for node in pivot_local_graph:
neighbors = sorted_dist_inds[node,
1:self.num_adjacent_linkages +
1]
for neighbor in neighbors:
if neighbor in pivot_local_graph:
adjacent_matrix[node2ind_map[node],
node2ind_map[neighbor]] = 1
adjacent_matrix[node2ind_map[neighbor],
node2ind_map[node]] = 1
adjacent_matrix = normalize_adjacent_matrix(
adjacent_matrix, mode='DAD')
pad_adjacent_matrix = torch.zeros(
(num_max_nodes, num_max_nodes),
dtype=torch.float,
device=device)
pad_adjacent_matrix[:num_nodes, :num_nodes] = adjacent_matrix
pad_normalized_feats = torch.cat([
normalized_feats,
torch.zeros(
(num_max_nodes - num_nodes, normalized_feats.shape[1]),
dtype=torch.float,
device=device)
],
dim=0)
local_graph_labels = node_labels[pivot_local_graph]
knn_labels = local_graph_labels[knn_inds]
link_labels = ((node_labels[pivot_ind] == knn_labels) &
(node_labels[pivot_ind] > 0)).astype(np.int64)
link_labels = torch.from_numpy(link_labels)
local_graphs_node_feat.append(pad_normalized_feats)
adjacent_matrices.append(pad_adjacent_matrix)
pivots_knn_inds.append(knn_inds)
pivots_gt_linkage.append(link_labels)
local_graphs_node_feat = torch.stack(local_graphs_node_feat, 0)
adjacent_matrices = torch.stack(adjacent_matrices, 0)
pivots_knn_inds = torch.stack(pivots_knn_inds, 0)
pivots_gt_linkage = torch.stack(pivots_gt_linkage, 0)
return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
pivots_gt_linkage)
def __call__(self, feat_maps, comp_attribs):
"""Generate local graphs as GCN input.
Args:
feat_maps (Tensor): The feature maps to extract the content
features of text components.
comp_attribs (ndarray): The text component attributes.
Returns:
local_graphs_node_feat (Tensor): The node features of graph.
adjacent_matrices (Tensor): The adjacent matrices of local graphs.
pivots_knn_inds (Tensor): The k-nearest neighbor indices in local
graph.
gt_linkage (Tensor): The surpervision signal of GCN for linkage
prediction.
"""
assert isinstance(feat_maps, torch.Tensor)
assert comp_attribs.ndim == 3
assert comp_attribs.shape[2] == 8
sorted_dist_inds_batch = []
local_graph_batch = []
knn_batch = []
node_feat_batch = []
node_label_batch = []
device = feat_maps.device
for batch_ind in range(comp_attribs.shape[0]):
num_comps = int(comp_attribs[batch_ind, 0, 0])
comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7]
node_labels = comp_attribs[batch_ind, :num_comps,
7].astype(np.int32)
comp_centers = comp_geo_attribs[:, 0:2]
distance_matrix = euclidean_distance_matrix(
comp_centers, comp_centers)
batch_id = np.zeros(
(comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind
comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1)
angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign(
comp_geo_attribs[:, -1])
angle = angle.reshape((-1, 1))
rotated_rois = np.hstack(
[batch_id, comp_geo_attribs[:, :-2], angle])
rois = torch.from_numpy(rotated_rois).to(device)
content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0),
rois)
content_feats = content_feats.view(content_feats.shape[0],
-1).to(feat_maps.device)
geo_feats = feature_embedding(comp_geo_attribs,
self.node_geo_feat_dim)
geo_feats = torch.from_numpy(geo_feats).to(device)
node_feats = torch.cat([content_feats, geo_feats], dim=-1)
sorted_dist_inds = np.argsort(distance_matrix, axis=1)
pivot_local_graphs, pivot_knns = self.generate_local_graphs(
sorted_dist_inds, node_labels)
node_feat_batch.append(node_feats)
node_label_batch.append(node_labels)
local_graph_batch.append(pivot_local_graphs)
knn_batch.append(pivot_knns)
sorted_dist_inds_batch.append(sorted_dist_inds)
(node_feats, adjacent_matrices, knn_inds, gt_linkage) = \
self.generate_gcn_input(node_feat_batch,
node_label_batch,
local_graph_batch,
knn_batch,
sorted_dist_inds_batch)
return node_feats, adjacent_matrices, knn_inds, gt_linkage

View File

@ -0,0 +1,413 @@
import cv2
import numpy as np
import torch
from lanms import merge_quadrangle_n9 as la_nms
from mmcv.ops import RoIAlignRotated
from mmocr.models.textdet.postprocess.wrapper import fill_hole
from .utils import (euclidean_distance_matrix, feature_embedding,
normalize_adjacent_matrix)
class ProposalLocalGraphs(object):
"""Propose text components and generate local graphs for GCN to classify
the k-nearest neighbors of a pivot in DRRG: Deep Relational Reasoning Graph
Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]. This code was partially adapted from
https://github.com/GXYM/DRRG licensed under the MIT license.
Args:
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2.
num_adjacent_linkages (int): The number of linkages when constructing
adjacent matrix.
node_geo_feat_len (int): The length of embedded geometric feature
vector of a text component.
pooling_scale (float): The spatial scale of rotated RoI-Align.
pooling_output_size (tuple(int)): The output size of rotated RoI-Align.
nms_thr (float): The locality-aware NMS threshold for text components.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_shrink_ratio (float): The shrink ratio of text components.
comp_w_h_ratio (float): The width to height ratio of text components.
comp_score_thr (float): The score threshold of text component.
text_region_thr (float): The threshold for text region probability map.
center_region_thr (float): The threshold for text center region
probability map.
center_region_area_thr (int): The threshold for filtering small-sized
text center region.
"""
def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
pooling_scale, pooling_output_size, nms_thr, min_width,
max_width, comp_shrink_ratio, comp_w_h_ratio, comp_score_thr,
text_region_thr, center_region_thr, center_region_area_thr):
assert len(k_at_hops) == 2
assert isinstance(k_at_hops, tuple)
assert isinstance(num_adjacent_linkages, int)
assert isinstance(node_geo_feat_len, int)
assert isinstance(pooling_scale, float)
assert isinstance(pooling_output_size, tuple)
assert isinstance(nms_thr, float)
assert isinstance(min_width, float)
assert isinstance(max_width, float)
assert isinstance(comp_shrink_ratio, float)
assert isinstance(comp_w_h_ratio, float)
assert isinstance(comp_score_thr, float)
assert isinstance(text_region_thr, float)
assert isinstance(center_region_thr, float)
assert isinstance(center_region_area_thr, int)
self.k_at_hops = k_at_hops
self.active_connection = num_adjacent_linkages
self.local_graph_depth = len(self.k_at_hops)
self.node_geo_feat_dim = node_geo_feat_len
self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
self.nms_thr = nms_thr
self.min_width = min_width
self.max_width = max_width
self.comp_shrink_ratio = comp_shrink_ratio
self.comp_w_h_ratio = comp_w_h_ratio
self.comp_score_thr = comp_score_thr
self.text_region_thr = text_region_thr
self.center_region_thr = center_region_thr
self.center_region_area_thr = center_region_area_thr
def propose_comps(self, score_map, top_height_map, bot_height_map, sin_map,
cos_map, comp_score_thr, min_width, max_width,
comp_shrink_ratio, comp_w_h_ratio):
"""Propose text components.
Args:
score_map (ndarray): The score map for NMS.
top_height_map (ndarray): The predicted text height map from each
pixel in text center region to top sideline.
bot_height_map (ndarray): The predicted text height map from each
pixel in text center region to bottom sideline.
sin_map (ndarray): The predicted sin(theta) map.
cos_map (ndarray): The predicted cos(theta) map.
comp_score_thr (float): The score threshold of text component.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_shrink_ratio (float): The shrink ratio of text components.
comp_w_h_ratio (float): The width to height ratio of text
components.
Returns:
text_comps (ndarray): The text components.
"""
comp_centers = np.argwhere(score_map > comp_score_thr)
comp_centers = comp_centers[np.argsort(comp_centers[:, 0])]
y = comp_centers[:, 0]
x = comp_centers[:, 1]
top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
top_mid_pts = comp_centers + np.hstack(
[top_height * sin, top_height * cos])
bot_mid_pts = comp_centers - np.hstack(
[bot_height * sin, bot_height * cos])
width = (top_height + bot_height) * comp_w_h_ratio
width = np.clip(width, min_width, max_width)
r = width / 2
tl = top_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
tr = top_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
br = bot_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
bl = bot_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
score = score_map[y, x].reshape((-1, 1))
text_comps = np.hstack([text_comps, score])
return text_comps
def propose_comps_and_attribs(self, text_region_map, center_region_map,
top_height_map, bot_height_map, sin_map,
cos_map):
"""Generate text components and attributes.
Args:
text_region_map (ndarray): The predicted text region probability
map.
center_region_map (ndarray): The predicted text center region
probability map.
top_height_map (ndarray): The predicted text height map from each
pixel in text center region to top sideline.
bot_height_map (ndarray): The predicted text height map from each
pixel in text center region to bottom sideline.
sin_map (ndarray): The predicted sin(theta) map.
cos_map (ndarray): The predicted cos(theta) map.
Returns:
comp_attribs (ndarray): The text component attributes.
text_comps (ndarray): The text components.
"""
assert (text_region_map.shape == center_region_map.shape ==
top_height_map.shape == bot_height_map.shape == sin_map.shape
== cos_map.shape)
text_mask = text_region_map > self.text_region_thr
center_region_mask = (center_region_map >
self.center_region_thr) * text_mask
scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2))
sin_map, cos_map = sin_map * scale, cos_map * scale
center_region_mask = fill_hole(center_region_mask)
center_region_contours, _ = cv2.findContours(
center_region_mask.astype(np.uint8), cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
mask_sz = center_region_map.shape
comp_list = []
for contour in center_region_contours:
current_center_mask = np.zeros(mask_sz)
cv2.drawContours(current_center_mask, [contour], -1, 1, -1)
if current_center_mask.sum() <= self.center_region_area_thr:
continue
score_map = text_region_map * current_center_mask
text_comps = self.propose_comps(score_map, top_height_map,
bot_height_map, sin_map, cos_map,
self.comp_score_thr,
self.min_width, self.max_width,
self.comp_shrink_ratio,
self.comp_w_h_ratio)
text_comps = la_nms(text_comps, self.nms_thr)
text_comp_mask = np.zeros(mask_sz)
text_comp_boxs = text_comps[:, :8].reshape(
(-1, 4, 2)).astype(np.int32)
cv2.drawContours(text_comp_mask, text_comp_boxs, -1, 1, -1)
if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
continue
if text_comps.shape[-1] > 0:
comp_list.append(text_comps)
if len(comp_list) <= 0:
return None, None
text_comps = np.vstack(comp_list)
text_comp_boxs = text_comps[:, :8].reshape((-1, 4, 2))
centers = np.mean(text_comp_boxs, axis=1).astype(np.int32)
x = centers[:, 0]
y = centers[:, 1]
scores = []
for text_comp_box in text_comp_boxs:
text_comp_box[:, 0] = np.clip(text_comp_box[:, 0], 0,
mask_sz[1] - 1)
text_comp_box[:, 1] = np.clip(text_comp_box[:, 1], 0,
mask_sz[0] - 1)
min_coord = np.min(text_comp_box, axis=0).astype(np.int32)
max_coord = np.max(text_comp_box, axis=0).astype(np.int32)
text_comp_box = text_comp_box - min_coord
box_sz = (max_coord - min_coord + 1)
temp_comp_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
cv2.fillPoly(temp_comp_mask, [text_comp_box.astype(np.int32)], 1)
temp_region_patch = text_region_map[min_coord[1]:(max_coord[1] +
1),
min_coord[0]:(max_coord[0] +
1)]
score = cv2.mean(temp_region_patch, temp_comp_mask)[0]
scores.append(score)
scores = np.array(scores).reshape((-1, 1))
text_comps = np.hstack([text_comps[:, :-1], scores])
h = top_height_map[y, x].reshape(
(-1, 1)) + bot_height_map[y, x].reshape((-1, 1))
w = np.clip(h * self.comp_w_h_ratio, self.min_width, self.max_width)
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
x = x.reshape((-1, 1))
y = y.reshape((-1, 1))
comp_attribs = np.hstack([x, y, h, w, cos, sin])
return comp_attribs, text_comps
def generate_local_graphs(self, sorted_dist_inds, node_feats):
"""Generate local graphs and graph convolution network input data.
Args:
sorted_dist_inds (ndarray): The node indices sorted according to
the Euclidean distance.
node_feats (tensor): The features of nodes in graph.
Returns:
local_graphs_node_feats (tensor): The features of nodes in local
graphs.
adjacent_matrices (tensor): The adjacent matrices.
pivots_knn_inds (tensor): The k-nearest neighbor indices in
local graphs.
pivots_local_graphs (tensor): The indices of nodes in local
graphs.
"""
assert sorted_dist_inds.ndim == 2
assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
node_feats.shape[0])
knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
pivot_local_graphs = []
pivot_knns = []
device = node_feats.device
for pivot_ind, knn in enumerate(knn_graph):
local_graph_neighbors = set(knn)
for neighbor_ind in knn:
local_graph_neighbors.update(
set(sorted_dist_inds[neighbor_ind,
1:self.k_at_hops[1] + 1]))
local_graph_neighbors.discard(pivot_ind)
pivot_local_graph = list(local_graph_neighbors)
pivot_local_graph.insert(0, pivot_ind)
pivot_knn = [pivot_ind] + list(knn)
pivot_local_graphs.append(pivot_local_graph)
pivot_knns.append(pivot_knn)
num_max_nodes = max([
len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs
])
local_graphs_node_feat = []
adjacent_matrices = []
pivots_knn_inds = []
pivots_local_graphs = []
for graph_ind, pivot_knn in enumerate(pivot_knns):
pivot_local_graph = pivot_local_graphs[graph_ind]
num_nodes = len(pivot_local_graph)
pivot_ind = pivot_local_graph[0]
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
knn_inds = torch.tensor([node2ind_map[i]
for i in pivot_knn[1:]]).long().to(device)
pivot_feats = node_feats[pivot_ind]
normalized_feats = node_feats[pivot_local_graph] - pivot_feats
adjacent_matrix = np.zeros((num_nodes, num_nodes))
for node in pivot_local_graph:
neighbors = sorted_dist_inds[node,
1:self.active_connection + 1]
for neighbor in neighbors:
if neighbor in pivot_local_graph:
adjacent_matrix[node2ind_map[node],
node2ind_map[neighbor]] = 1
adjacent_matrix[node2ind_map[neighbor],
node2ind_map[node]] = 1
adjacent_matrix = normalize_adjacent_matrix(
adjacent_matrix, mode='DAD')
pad_adjacent_matrix = torch.zeros((num_max_nodes, num_max_nodes),
dtype=torch.float,
device=device)
pad_adjacent_matrix[:num_nodes, :num_nodes] = adjacent_matrix
pad_normalized_feats = torch.cat([
normalized_feats,
torch.zeros(
(num_max_nodes - num_nodes, normalized_feats.shape[1]),
dtype=torch.float,
device=device)
],
dim=0)
local_graph_nodes = torch.tensor(pivot_local_graph)
local_graph_nodes = torch.cat([
local_graph_nodes,
torch.zeros(num_max_nodes - num_nodes, dtype=torch.long)
],
dim=-1)
local_graphs_node_feat.append(pad_normalized_feats)
adjacent_matrices.append(pad_adjacent_matrix)
pivots_knn_inds.append(knn_inds)
pivots_local_graphs.append(local_graph_nodes)
local_graphs_node_feat = torch.stack(local_graphs_node_feat, 0)
adjacent_matrices = torch.stack(adjacent_matrices, 0)
pivots_knn_inds = torch.stack(pivots_knn_inds, 0)
pivots_local_graphs = torch.stack(pivots_local_graphs, 0)
return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
pivots_local_graphs)
def __call__(self, preds, feat_maps):
"""Generate local graphs and graph convolutional network input data.
Args:
preds (tensor): The predicted maps.
feat_maps (tensor): The feature maps to extract content feature of
text components.
Returns:
none_flag (bool): The flag showing whether the number of proposed
text components is 0.
local_graphs_node_feats (tensor): The features of nodes in local
graphs.
adjacent_matrices (tensor): The adjacent matrices.
pivots_knn_inds (tensor): The k-nearest neighbor indices in
local graphs.
pivots_local_graphs (tensor): The indices of nodes in local
graphs.
text_comps (ndarray): The predicted text components.
"""
if preds.ndim == 4:
assert preds.shape[0] == 1
preds = torch.squeeze(preds)
pred_text_region = torch.sigmoid(preds[0]).data.cpu().numpy()
pred_center_region = torch.sigmoid(preds[1]).data.cpu().numpy()
pred_sin_map = preds[2].data.cpu().numpy()
pred_cos_map = preds[3].data.cpu().numpy()
pred_top_height_map = preds[4].data.cpu().numpy()
pred_bot_height_map = preds[5].data.cpu().numpy()
device = preds.device
comp_attribs, text_comps = self.propose_comps_and_attribs(
pred_text_region, pred_center_region, pred_top_height_map,
pred_bot_height_map, pred_sin_map, pred_cos_map)
if comp_attribs is None or len(comp_attribs) < 2:
none_flag = True
return none_flag, (0, 0, 0, 0, 0)
comp_centers = comp_attribs[:, 0:2]
distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
geo_feats = feature_embedding(comp_attribs, self.node_geo_feat_dim)
geo_feats = torch.from_numpy(geo_feats).to(preds.device)
batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32)
comp_attribs = comp_attribs.astype(np.float32)
angle = np.arccos(comp_attribs[:, -2]) * np.sign(comp_attribs[:, -1])
angle = angle.reshape((-1, 1))
rotated_rois = np.hstack([batch_id, comp_attribs[:, :-2], angle])
rois = torch.from_numpy(rotated_rois).to(device)
content_feats = self.pooling(feat_maps, rois)
content_feats = content_feats.view(content_feats.shape[0],
-1).to(device)
node_feats = torch.cat([content_feats, geo_feats], dim=-1)
sorted_dist_inds = np.argsort(distance_matrix, axis=1)
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
pivots_local_graphs) = self.generate_local_graphs(
sorted_dist_inds, node_feats)
none_flag = False
return none_flag, (local_graphs_node_feat, adjacent_matrices,
pivots_knn_inds, pivots_local_graphs, text_comps)

View File

@ -0,0 +1,116 @@
import numpy as np
import torch
def normalize_adjacent_matrix(A, mode='AD'):
"""Normalize adjacent matrix for GCN. This code was partially adapted from
https://github.com/GXYM/DRRG licensed under the MIT license.
Args:
A (ndarray): The adjacent matrix.
mode (string): The normalize mode.
returns:
G (ndarray): The normalized adjacent matrix.
"""
assert A.ndim == 2
assert A.shape[0] == A.shape[1]
if mode == 'DAD':
A = A + np.eye(A.shape[0])
d = np.sum(A, axis=0)
d_inv = np.power(d, -0.5).flatten()
d_inv[np.isinf(d_inv)] = 0.0
d_inv = np.diag(d_inv)
G = A.dot(d_inv).transpose().dot(d_inv)
G = torch.from_numpy(G)
elif mode == 'AD':
A = A + np.eye(A.shape[0])
A = torch.from_numpy(A)
D = A.sum(1, keepdim=True)
G = A.div(D)
else:
raise NotImplementedError
return G
def euclidean_distance_matrix(A, B):
"""Calculate the Euclidean distance matrix.
Args:
A (ndarray): The point sequence.
B (ndarray): The point sequence with the same dimensions as A.
returns:
D (ndarray): The Euclidean distance matrix.
"""
assert A.ndim == 2
assert B.ndim == 2
assert A.shape[1] == B.shape[1]
m = A.shape[0]
n = B.shape[0]
A_dots = (A * A).sum(axis=1).reshape((m, 1)) * np.ones(shape=(1, n))
B_dots = (B * B).sum(axis=1) * np.ones(shape=(m, 1))
D_squared = A_dots + B_dots - 2 * A.dot(B.T)
zero_mask = np.less(D_squared, 0.0)
D_squared[zero_mask] = 0.0
D = np.sqrt(D_squared)
return D
def feature_embedding(input_feats, out_feat_len):
"""Embed features. This code was partially adapted from
https://github.com/GXYM/DRRG licensed under the MIT license.
Args:
input_feats (ndarray): The input features of shape (N, d), where N is
the number of nodes in graph, d is the input feature vector length.
out_feat_len (int): The length of output feature vector.
Returns:
embedded_feats (ndarray): The embedded features.
"""
assert input_feats.ndim == 2
assert isinstance(out_feat_len, int)
assert out_feat_len >= input_feats.shape[1]
num_nodes = input_feats.shape[0]
feat_dim = input_feats.shape[1]
feat_repeat_times = out_feat_len // feat_dim
residue_dim = out_feat_len % feat_dim
if residue_dim > 0:
embed_wave = np.array([
np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
for j in range(feat_repeat_times + 1)
]).reshape((feat_repeat_times + 1, 1, 1))
repeat_feats = np.repeat(
np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0)
residue_feats = np.hstack([
input_feats[:, 0:residue_dim],
np.zeros((num_nodes, feat_dim - residue_dim))
])
residue_feats = np.expand_dims(residue_feats, axis=0)
repeat_feats = np.concatenate([repeat_feats, residue_feats], axis=0)
embedded_feats = repeat_feats / embed_wave
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
(num_nodes, -1))[:, 0:out_feat_len]
else:
embed_wave = np.array([
np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
for j in range(feat_repeat_times)
]).reshape((feat_repeat_times, 1, 1))
repeat_feats = np.repeat(
np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0)
embedded_feats = repeat_feats / embed_wave
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
(num_nodes, -1)).astype(np.float32)
return embedded_feats

View File

@ -1,6 +1,6 @@
from .fpem_ffm import FPEM_FFM
from .fpn_cat import FPNC
from .fpn_unet import FPN_UNET
from .fpn_unet import FPN_UNet
from .fpnf import FPNF
__all__ = ['FPEM_FFM', 'FPNF', 'FPNC', 'FPN_UNET']
__all__ = ['FPEM_FFM', 'FPNF', 'FPNC', 'FPN_UNet']

View File

@ -30,7 +30,7 @@ class UpBlock(nn.Module):
@NECKS.register_module()
class FPN_UNET(nn.Module):
class FPN_UNet(nn.Module):
"""The class for implementing DRRG and TextSnake U-Net-like FPN.
DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape

View File

@ -1,3 +1,6 @@
import functools
import operator
import cv2
import numpy as np
import pyclipper
@ -28,6 +31,8 @@ def decode(
return textsnake_decode(**kwargs)
if decoding_type == 'fcenet':
return fcenet_decode(**kwargs)
if decoding_type == 'drrg':
return drrg_decode(**kwargs)
raise NotImplementedError
@ -553,3 +558,350 @@ def generate_exp_matrix(point_num, fourier_degree):
exp_matrix[i, :] = temp * (i - fourier_degree)
return np.power(e, exp_matrix)
class Node(object):
def __init__(self, ind):
self.__ind = ind
self.__links = set()
@property
def ind(self):
return self.__ind
@property
def links(self):
return set(self.__links)
def add_link(self, link_node):
self.__links.add(link_node)
link_node.__links.add(self)
def graph_propagation(edges, scores, text_comps, edge_len_thr=50.):
"""Propagate edge score information and construct graph. This code was
partially adapted from https://github.com/GXYM/DRRG licensed under the MIT
license.
Args:
edges (ndarray): The edge array of shape N * 2, each row is a node
index pair that makes up an edge in graph.
scores (ndarray): The edge score array.
text_comps (ndarray): The text components.
edge_len_thr (float): The edge length threshold.
Returns:
vertices (list[Node]): The Nodes in graph.
score_dict (dict): The edge score dict.
"""
assert edges.ndim == 2
assert edges.shape[1] == 2
assert edges.shape[0] == scores.shape[0]
assert text_comps.ndim == 2
assert isinstance(edge_len_thr, float)
edges = np.sort(edges, axis=1)
score_dict = {}
for i, edge in enumerate(edges):
if text_comps is not None:
box1 = text_comps[edge[0], :8].reshape(4, 2)
box2 = text_comps[edge[1], :8].reshape(4, 2)
center1 = np.mean(box1, axis=0)
center2 = np.mean(box2, axis=0)
distance = norm(center1 - center2)
if distance > edge_len_thr:
scores[i] = 0
if (edge[0], edge[1]) in score_dict:
score_dict[edge[0], edge[1]] = 0.5 * (
score_dict[edge[0], edge[1]] + scores[i])
else:
score_dict[edge[0], edge[1]] = scores[i]
nodes = np.sort(np.unique(edges.flatten()))
mapping = -1 * np.ones((np.max(nodes) + 1), dtype=np.int)
mapping[nodes] = np.arange(nodes.shape[0])
order_inds = mapping[edges]
vertices = [Node(node) for node in nodes]
for ind in order_inds:
vertices[ind[0]].add_link(vertices[ind[1]])
return vertices, score_dict
def connected_components(nodes, score_dict, link_thr):
"""Conventional connected components searching. This code was partially
adapted from https://github.com/GXYM/DRRG licensed under the MIT license.
Args:
nodes (list[Node]): The list of Node objects.
score_dict (dict): The edge score dict.
link_thr (float): The link threshold.
Returns:
clusters (List[list[Node]]): The clustered Node objects.
"""
assert isinstance(nodes, list)
assert all([isinstance(node, Node) for node in nodes])
assert isinstance(score_dict, dict)
assert isinstance(link_thr, float)
clusters = []
nodes = set(nodes)
while nodes:
node = nodes.pop()
cluster = {node}
node_queue = [node]
while node_queue:
node = node_queue.pop(0)
neighbors = set([
neighbor for neighbor in node.links if
score_dict[tuple(sorted([node.ind, neighbor.ind]))] >= link_thr
])
neighbors.difference_update(cluster)
nodes.difference_update(neighbors)
cluster.update(neighbors)
node_queue.extend(neighbors)
clusters.append(list(cluster))
return clusters
def clusters2labels(clusters, num_nodes):
"""Convert clusters of Node to text component labels. This code was
partially adapted from https://github.com/GXYM/DRRG licensed under the MIT
license.
Args:
clusters (List[list[Node]]): The clusters of Node objects.
num_nodes (int): The total node number of graphs in an image.
Returns:
node_labels (ndarray): The node label array.
"""
assert isinstance(clusters, list)
assert all([isinstance(cluster, list) for cluster in clusters])
assert all(
[isinstance(node, Node) for cluster in clusters for node in cluster])
assert isinstance(num_nodes, int)
node_labels = np.zeros(num_nodes)
for cluster_ind, cluster in enumerate(clusters):
for node in cluster:
node_labels[node.ind] = cluster_ind
return node_labels
def remove_single(text_comps, comp_pred_labels):
"""Remove isolated text components. This code was partially adapted from
https://github.com/GXYM/DRRG licensed under the MIT license.
Args:
text_comps (ndarray): The text components.
comp_pred_labels (ndarray): The clustering labels of text components.
Returns:
filtered_text_comps (ndarray): The text components with isolated ones
removed.
comp_pred_labels (ndarray): The clustering labels with labels of
isolated text components removed.
"""
assert text_comps.ndim == 2
assert text_comps.shape[0] == comp_pred_labels.shape[0]
single_flags = np.zeros_like(comp_pred_labels)
pred_labels = np.unique(comp_pred_labels)
for label in pred_labels:
current_label_flag = (comp_pred_labels == label)
if np.sum(current_label_flag) == 1:
single_flags[np.where(current_label_flag)[0][0]] = 1
keep_ind = [i for i in range(len(comp_pred_labels)) if not single_flags[i]]
filtered_text_comps = text_comps[keep_ind, :]
filtered_labels = comp_pred_labels[keep_ind]
return filtered_text_comps, filtered_labels
def norm2(point1, point2):
return ((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)**0.5
def min_connect_path(points):
"""Find the shortest path to traverse all points. This code was partially
adapted from https://github.com/GXYM/DRRG licensed under the MIT license.
Args:
points(List[list[int]]): The point sequence [[x0, y0], [x1, y1], ...].
Returns:
shortest_path(List[list[int]]): The shortest index path.
"""
assert isinstance(points, list)
assert all([isinstance(point, list) for point in points])
assert all([isinstance(coord, int) for point in points for coord in point])
points_queue = points.copy()
shortest_path = []
current_edge = [[], []]
edge_dict0 = {}
edge_dict1 = {}
current_edge[0] = points_queue[0]
current_edge[1] = points_queue[0]
points_queue.remove(points_queue[0])
while points_queue:
for point in points_queue:
length0 = norm2(point, current_edge[0])
edge_dict0[length0] = [point, current_edge[0]]
length1 = norm2(current_edge[1], point)
edge_dict1[length1] = [current_edge[1], point]
key0 = min(edge_dict0.keys())
key1 = min(edge_dict1.keys())
if key0 <= key1:
start = edge_dict0[key0][0]
end = edge_dict0[key0][1]
shortest_path.insert(0, [points.index(start), points.index(end)])
points_queue.remove(start)
current_edge[0] = start
else:
start = edge_dict1[key1][0]
end = edge_dict1[key1][1]
shortest_path.append([points.index(start), points.index(end)])
points_queue.remove(end)
current_edge[1] = end
edge_dict0 = {}
edge_dict1 = {}
shortest_path = functools.reduce(operator.concat, shortest_path)
shortest_path = sorted(set(shortest_path), key=shortest_path.index)
return shortest_path
def in_contour(cont, point):
x, y = point
is_inner = cv2.pointPolygonTest(cont, (int(x), int(y)), False) > 0.5
return is_inner
def fix_corner(top_line, bot_line, start_box, end_box):
"""Add corner points to predicted side lines. This code was partially
adapted from https://github.com/GXYM/DRRG licensed under the MIT license.
Args:
top_line (List[list[int]]): The predicted top sidelines of text
instance.
bot_line (List[list[int]]): The predicted bottom sidelines of text
instance.
start_box (ndarray): The first text component box.
end_box (ndarray): The last text component box.
Returns:
top_line (List[list[int]]): The top sidelines with corner point added.
bot_line (List[list[int]]): The bottom sidelines with corner point
added.
"""
assert isinstance(top_line, list)
assert all(isinstance(point, list) for point in top_line)
assert isinstance(bot_line, list)
assert all(isinstance(point, list) for point in bot_line)
assert start_box.shape == end_box.shape == (4, 2)
contour = np.array(top_line + bot_line[::-1])
start_left_mid = (start_box[0] + start_box[3]) / 2
start_right_mid = (start_box[1] + start_box[2]) / 2
end_left_mid = (end_box[0] + end_box[3]) / 2
end_right_mid = (end_box[1] + end_box[2]) / 2
if not in_contour(contour, start_left_mid):
top_line.insert(0, start_box[0].tolist())
bot_line.insert(0, start_box[3].tolist())
elif not in_contour(contour, start_right_mid):
top_line.insert(0, start_box[1].tolist())
bot_line.insert(0, start_box[2].tolist())
if not in_contour(contour, end_left_mid):
top_line.append(end_box[0].tolist())
bot_line.append(end_box[3].tolist())
elif not in_contour(contour, end_right_mid):
top_line.append(end_box[1].tolist())
bot_line.append(end_box[2].tolist())
return top_line, bot_line
def comps2boundaries(text_comps, comp_pred_labels):
"""Construct text instance boundaries from clustered text components. This
code was partially adapted from https://github.com/GXYM/DRRG licensed under
the MIT license.
Args:
text_comps (ndarray): The text components.
comp_pred_labels (ndarray): The clustering labels of text components.
Returns:
boundaries (List[list[float]]): The predicted boundaries of text
instances.
"""
assert text_comps.ndim == 2
assert len(text_comps) == len(comp_pred_labels)
boundaries = []
if len(text_comps) < 1:
return boundaries
for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1):
cluster_comp_inds = np.where(comp_pred_labels == cluster_ind)
text_comp_boxes = text_comps[cluster_comp_inds, :8].reshape(
(-1, 4, 2)).astype(np.int32)
score = np.mean(text_comps[cluster_comp_inds, -1])
if text_comp_boxes.shape[0] < 1:
continue
elif text_comp_boxes.shape[0] > 1:
centers = np.mean(
text_comp_boxes, axis=1).astype(np.int32).tolist()
shortest_path = min_connect_path(centers)
text_comp_boxes = text_comp_boxes[shortest_path]
top_line = np.mean(
text_comp_boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
bot_line = np.mean(
text_comp_boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
top_line, bot_line = fix_corner(top_line, bot_line,
text_comp_boxes[0],
text_comp_boxes[-1])
boundary_points = top_line + bot_line[::-1]
else:
top_line = text_comp_boxes[0, 0:2, :].astype(np.int32).tolist()
bot_line = text_comp_boxes[0, 2:4:-1, :].astype(np.int32).tolist()
boundary_points = top_line + bot_line
boundary = [p for coord in boundary_points for p in coord] + [score]
boundaries.append(boundary)
return boundaries
def drrg_decode(edges, scores, text_comps, link_thr):
"""Merge text components and construct boundaries of text instances.
Args:
edges (ndarray): The edge array of shape N * 2, each row is a node
index pair that makes up an edge in graph.
scores (ndarray): The edge score array.
text_comps (ndarray): The text components.
link_thr (float): The edge score threshold.
Returns:
boundaries (List[list[float]]): The predicted boundaries of text
instances.
"""
assert len(edges) == len(scores)
assert text_comps.ndim == 2
assert text_comps.shape[1] == 9
assert isinstance(link_thr, float)
vertices, score_dict = graph_propagation(edges, scores, text_comps)
clusters = connected_components(vertices, score_dict, link_thr)
pred_labels = clusters2labels(clusters, text_comps.shape[0])
text_comps, pred_labels = remove_single(text_comps, pred_labels)
boundaries = comps2boundaries(text_comps, pred_labels)
return boundaries

View File

@ -1,4 +1,5 @@
imgaug
lanms-proper
lmdb
matplotlib
numba>=0.45.1

View File

@ -20,7 +20,7 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmdet,mmocr
known_third_party = PIL,Polygon,cv2,imgaug,lmdb,matplotlib,mmcv,numpy,pyclipper,pycocotools,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision
known_third_party = PIL,Polygon,cv2,imgaug,lanms,lmdb,matplotlib,mmcv,numpy,pyclipper,pycocotools,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

View File

@ -242,3 +242,94 @@ def test_fcenet_generate_targets():
assert 'p3_maps' in results.keys()
assert 'p4_maps' in results.keys()
assert 'p5_maps' in results.keys()
def test_gen_drrg_targets():
target_generator = textdet_targets.DRRGTargets()
assert np.allclose(target_generator.orientation_thr, 2.0)
assert np.allclose(target_generator.resample_step, 8.0)
assert target_generator.num_min_comps == 9
assert target_generator.num_max_comps == 600
assert np.allclose(target_generator.min_width, 8.0)
assert np.allclose(target_generator.max_width, 24.0)
assert np.allclose(target_generator.center_region_shrink_ratio, 0.3)
assert np.allclose(target_generator.comp_shrink_ratio, 1.0)
assert np.allclose(target_generator.comp_w_h_ratio, 0.3)
assert np.allclose(target_generator.text_comp_nms_thr, 0.25)
assert np.allclose(target_generator.min_rand_half_height, 8.0)
assert np.allclose(target_generator.max_rand_half_height, 24.0)
assert np.allclose(target_generator.jitter_level, 0.2)
# test generate_targets
target_generator = textdet_targets.DRRGTargets(
min_width=2.,
max_width=4.,
min_rand_half_height=3.,
max_rand_half_height=5.)
results = {}
results['img'] = np.zeros((64, 64, 3), np.uint8)
text_polys = [[np.array([4, 2, 30, 2, 30, 10, 4, 10])],
[np.array([36, 12, 8, 12, 8, 22, 36, 22])],
[np.array([48, 20, 52, 20, 52, 50, 48, 50])],
[np.array([44, 50, 38, 50, 38, 20, 44, 20])]]
results['gt_masks'] = PolygonMasks(text_polys, 20, 30)
results['gt_masks_ignore'] = PolygonMasks([], 64, 64)
results['img_shape'] = (64, 64, 3)
results['mask_fields'] = []
output = target_generator(results)
assert len(output['gt_text_mask']) == 1
assert len(output['gt_center_region_mask']) == 1
assert len(output['gt_mask']) == 1
assert len(output['gt_top_height_map']) == 1
assert len(output['gt_bot_height_map']) == 1
assert len(output['gt_sin_map']) == 1
assert len(output['gt_cos_map']) == 1
assert output['gt_comp_attribs'].shape[-1] == 8
# test generate_targets with the number of proposed text components exceeds
# num_max_comps
target_generator = textdet_targets.DRRGTargets(
min_width=2.,
max_width=4.,
min_rand_half_height=3.,
max_rand_half_height=5.,
num_max_comps=6)
output = target_generator(results)
assert output['gt_comp_attribs'].ndim == 2
assert output['gt_comp_attribs'].shape[0] == 6
# test generate_targets with blank polygon masks
target_generator = textdet_targets.DRRGTargets(
min_width=2.,
max_width=4.,
min_rand_half_height=3.,
max_rand_half_height=5.)
results = {}
results['img'] = np.zeros((20, 30, 3), np.uint8)
results['gt_masks'] = PolygonMasks([], 20, 30)
results['gt_masks_ignore'] = PolygonMasks([], 20, 30)
results['img_shape'] = (20, 30, 3)
results['mask_fields'] = []
output = target_generator(results)
assert output['gt_comp_attribs'][0, 0] > 8
# test generate_targets with one proposed text component
text_polys = [[np.array([13, 6, 17, 6, 17, 14, 13, 14])]]
target_generator = textdet_targets.DRRGTargets(
min_width=4.,
max_width=8.,
min_rand_half_height=3.,
max_rand_half_height=5.)
results['gt_masks'] = PolygonMasks(text_polys, 20, 30)
output = target_generator(results)
assert output['gt_comp_attribs'][0, 0] > 8
# test generate_targets with shrunk margin in generate_rand_comp_attribs
target_generator = textdet_targets.DRRGTargets(
min_width=2.,
max_width=30.,
min_rand_half_height=3.,
max_rand_half_height=30.)
output = target_generator(results)
assert output['gt_comp_attribs'][0, 0] > 8

View File

@ -429,3 +429,83 @@ def test_fcenet(cfg_file):
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
img = np.random.rand(5, 5)
detector.show_result(img, results)
@pytest.mark.parametrize(
'cfg_file', ['textdet/drrg/'
'drrg_r50_fpn_unet_1200e_ctw1500.py'])
def test_drrg(cfg_file):
model = _get_detector_cfg(cfg_file)
model['pretrained'] = None
model['backbone']['norm_cfg']['type'] = 'BN'
from mmocr.models import build_detector
detector = build_detector(model)
input_shape = (1, 3, 224, 224)
num_kernels = 1
mm_inputs = _demo_mm_inputs(num_kernels, input_shape)
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
gt_text_mask = mm_inputs.pop('gt_text_mask')
gt_center_region_mask = mm_inputs.pop('gt_center_region_mask')
gt_mask = mm_inputs.pop('gt_mask')
gt_top_height_map = mm_inputs.pop('gt_radius_map')
gt_bot_height_map = gt_top_height_map.copy()
gt_sin_map = mm_inputs.pop('gt_sin_map')
gt_cos_map = mm_inputs.pop('gt_cos_map')
num_rois = 32
x = np.random.randint(4, 224, (num_rois, 1))
y = np.random.randint(4, 224, (num_rois, 1))
h = 4 * np.ones((num_rois, 1))
w = 4 * np.ones((num_rois, 1))
angle = (np.random.random_sample((num_rois, 1)) * 2 - 1) * np.pi / 2
cos, sin = np.cos(angle), np.sin(angle)
comp_labels = np.random.randint(1, 3, (num_rois, 1))
num_rois = num_rois * np.ones((num_rois, 1))
comp_attribs = np.hstack([num_rois, x, y, h, w, cos, sin, comp_labels])
gt_comp_attribs = np.expand_dims(comp_attribs.astype(np.float32), axis=0)
# Test forward train
losses = detector.forward(
imgs,
img_metas,
gt_text_mask=gt_text_mask,
gt_center_region_mask=gt_center_region_mask,
gt_mask=gt_mask,
gt_top_height_map=gt_top_height_map,
gt_bot_height_map=gt_bot_height_map,
gt_sin_map=gt_sin_map,
gt_cos_map=gt_cos_map,
gt_comp_attribs=gt_comp_attribs)
assert isinstance(losses, dict)
# Test forward test
model['bbox_head']['in_channels'] = 6
model['bbox_head']['text_region_thr'] = 0.8
model['bbox_head']['center_region_thr'] = 0.8
detector = build_detector(model)
maps = torch.zeros((1, 6, 224, 224), dtype=torch.float)
maps[:, 0:2, :, :] = -10.
maps[:, 0, 60:100, 50:170] = 10.
maps[:, 1, 75:85, 60:160] = 10.
maps[:, 2, 75:85, 60:160] = 0.
maps[:, 3, 75:85, 60:160] = 1.
maps[:, 4, 75:85, 60:160] = 10.
maps[:, 5, 75:85, 60:160] = 10.
with torch.no_grad():
full_pass_weight = torch.zeros((6, 6, 1, 1))
for i in range(6):
full_pass_weight[i, i, 0, 0] = 1
detector.bbox_head.out_conv.weight.data = full_pass_weight
detector.bbox_head.out_conv.bias.data.fill_(0.)
outs = detector.bbox_head.single_test(maps)
boundaries = detector.bbox_head.get_boundary(*outs, img_metas, True)
assert len(boundaries)
# Test show result
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
img = np.random.rand(5, 5)
detector.show_result(img, results)

View File

@ -42,9 +42,9 @@ def test_fcenetloss():
# test ohem
pred = torch.ones((200, 2), dtype=torch.float)
target = torch.ones((200, ), dtype=torch.long)
target = torch.ones(200, dtype=torch.long)
target[20:] = 0
mask = torch.ones((200, ), dtype=torch.long)
mask = torch.ones(200, dtype=torch.long)
ohem_loss1 = fcenetloss.ohem(pred, target, mask)
ohem_loss2 = fcenetloss.ohem(pred, target, 1 - mask)
@ -70,3 +70,76 @@ def test_fcenetloss():
loss = fcenetloss(preds, 0, p3_maps, p4_maps, p5_maps)
assert isinstance(loss, dict)
def test_drrgloss():
drrgloss = losses.DRRGLoss()
assert np.allclose(drrgloss.ohem_ratio, 3.0)
# test balance_bce_loss
pred = torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=torch.float)
target = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long)
mask = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long)
bce_loss = drrgloss.balance_bce_loss(pred, target, mask).item()
assert np.allclose(bce_loss, 0)
# test balance_bce_loss with positive_count equal to zero
pred = torch.ones((16, 16), dtype=torch.float)
target = torch.ones((16, 16), dtype=torch.long)
mask = torch.zeros((16, 16), dtype=torch.long)
bce_loss = drrgloss.balance_bce_loss(pred, target, mask).item()
assert np.allclose(bce_loss, 0)
# test gcn_loss
gcn_preds = torch.tensor([[0., 1.], [1., 0.]])
labels = torch.tensor([1, 0], dtype=torch.long)
gcn_loss = drrgloss.gcn_loss((gcn_preds, labels))
assert gcn_loss.item()
# test bitmasks2tensor
mask = [[1, 0, 1], [1, 1, 1], [0, 0, 1]]
target = [[1, 0, 1, 0, 0], [1, 1, 1, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]
masks = [np.array(mask)]
bitmasks = BitmapMasks(masks, 3, 3)
target_sz = (6, 5)
results = drrgloss.bitmasks2tensor([bitmasks], target_sz)
assert len(results) == 1
assert torch.sum(torch.abs(results[0].float() -
torch.Tensor(target))).item() == 0
# test forward
target_maps = [BitmapMasks([np.random.randn(20, 20)], 20, 20)]
target_masks = [BitmapMasks([np.ones((20, 20))], 20, 20)]
gt_masks = [BitmapMasks([np.ones((20, 20))], 20, 20)]
preds = (torch.randn((1, 6, 20, 20)), (gcn_preds, labels))
loss_dict = drrgloss(preds, 1., target_masks, target_masks, gt_masks,
target_maps, target_maps, target_maps, target_maps)
assert isinstance(loss_dict, dict)
assert 'loss_text' in loss_dict.keys()
assert 'loss_center' in loss_dict.keys()
assert 'loss_height' in loss_dict.keys()
assert 'loss_sin' in loss_dict.keys()
assert 'loss_cos' in loss_dict.keys()
assert 'loss_gcn' in loss_dict.keys()
# test forward with downsample_ratio less than 1.
target_maps = [BitmapMasks([np.random.randn(40, 40)], 40, 40)]
target_masks = [BitmapMasks([np.ones((40, 40))], 40, 40)]
gt_masks = [BitmapMasks([np.ones((40, 40))], 40, 40)]
preds = (torch.randn((1, 6, 20, 20)), (gcn_preds, labels))
loss_dict = drrgloss(preds, 0.5, target_masks, target_masks, gt_masks,
target_maps, target_maps, target_maps, target_maps)
assert isinstance(loss_dict, dict)
# test forward with blank gt_mask.
target_maps = [BitmapMasks([np.random.randn(20, 20)], 20, 20)]
target_masks = [BitmapMasks([np.ones((20, 20))], 20, 20)]
gt_masks = [BitmapMasks([np.zeros((20, 20))], 20, 20)]
preds = (torch.randn((1, 6, 20, 20)), (gcn_preds, labels))
loss_dict = drrgloss(preds, 1., target_masks, target_masks, gt_masks,
target_maps, target_maps, target_maps, target_maps)
assert isinstance(loss_dict, dict)

View File

@ -0,0 +1,140 @@
import numpy as np
import pytest
import torch
from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs
from mmocr.models.textdet.modules.utils import (feature_embedding,
normalize_adjacent_matrix)
def test_local_graph_forward_train():
geo_feat_len = 24
pooling_h, pooling_w = pooling_out_size = (2, 2)
num_rois = 32
local_graph_generator = LocalGraphs((4, 4), 3, geo_feat_len, 1.0,
pooling_out_size, 0.5)
feature_maps = torch.randn((2, 3, 128, 128), dtype=torch.float)
x = np.random.randint(4, 124, (num_rois, 1))
y = np.random.randint(4, 124, (num_rois, 1))
h = 4 * np.ones((num_rois, 1))
w = 4 * np.ones((num_rois, 1))
angle = (np.random.random_sample((num_rois, 1)) * 2 - 1) * np.pi / 2
cos, sin = np.cos(angle), np.sin(angle)
comp_labels = np.random.randint(1, 3, (num_rois, 1))
num_rois = num_rois * np.ones((num_rois, 1))
comp_attribs = np.hstack([num_rois, x, y, h, w, cos, sin, comp_labels])
comp_attribs = comp_attribs.astype(np.float32)
comp_attribs_ = comp_attribs.copy()
comp_attribs = np.stack([comp_attribs, comp_attribs_])
(node_feats, adjacent_matrix, knn_inds,
linkage_labels) = local_graph_generator(feature_maps, comp_attribs)
feat_len = geo_feat_len + feature_maps.size()[1] * pooling_h * pooling_w
assert node_feats.dim() == adjacent_matrix.dim() == 3
assert node_feats.size()[-1] == feat_len
assert knn_inds.size()[-1] == 4
assert linkage_labels.size()[-1] == 4
assert (node_feats.size()[0] == adjacent_matrix.size()[0] ==
knn_inds.size()[0] == linkage_labels.size()[0])
assert (node_feats.size()[1] == adjacent_matrix.size()[1] ==
adjacent_matrix.size()[2])
def test_local_graph_forward_test():
geo_feat_len = 24
pooling_h, pooling_w = pooling_out_size = (2, 2)
local_graph_generator = ProposalLocalGraphs(
(4, 4), 2, geo_feat_len, 1., pooling_out_size, 0.1, 3., 6., 1., 0.5,
0.3, 0.5, 0.5, 2)
maps = torch.zeros((1, 6, 224, 224), dtype=torch.float)
maps[:, 0:2, :, :] = -10.
maps[:, 0, 60:100, 50:170] = 10.
maps[:, 1, 75:85, 60:160] = 10.
maps[:, 2, 75:85, 60:160] = 0.
maps[:, 3, 75:85, 60:160] = 1.
maps[:, 4, 75:85, 60:160] = 10.
maps[:, 5, 75:85, 60:160] = 10.
feature_maps = torch.randn((2, 6, 224, 224), dtype=torch.float)
feat_len = geo_feat_len + feature_maps.size()[1] * pooling_h * pooling_w
none_flag, graph_data = local_graph_generator(maps, feature_maps)
(node_feats, adjacent_matrices, knn_inds, local_graphs,
text_comps) = graph_data
assert none_flag is False
assert text_comps.ndim == 2
assert text_comps.shape[0] > 0
assert text_comps.shape[1] == 9
assert (node_feats.size()[0] == adjacent_matrices.size()[0] ==
knn_inds.size()[0] == local_graphs.size()[0] ==
text_comps.shape[0])
assert (node_feats.size()[1] == adjacent_matrices.size()[1] ==
adjacent_matrices.size()[2] == local_graphs.size()[1])
assert node_feats.size()[-1] == feat_len
# test proposal local graphs with area of center region less than threshold
maps[:, 1, 75:85, 60:160] = -10.
maps[:, 1, 80, 80] = 10.
none_flag, _ = local_graph_generator(maps, feature_maps)
assert none_flag
# test proposal local graphs with one text component
local_graph_generator = ProposalLocalGraphs(
(4, 4), 2, geo_feat_len, 1., pooling_out_size, 0.1, 8., 20., 1., 0.5,
0.3, 0.5, 0.5, 2)
maps[:, 1, 78:82, 78:82] = 10.
none_flag, _ = local_graph_generator(maps, feature_maps)
assert none_flag
# test proposal local graphs with text components out of text region
maps[:, 0, 60:100, 50:170] = -10.
maps[:, 0, 78:82, 78:82] = 10.
none_flag, _ = local_graph_generator(maps, feature_maps)
assert none_flag
def test_gcn():
num_local_graphs = 32
num_max_graph_nodes = 16
input_feat_len = 512
k = 8
gcn = GCN(input_feat_len)
node_feat = torch.randn(
(num_local_graphs, num_max_graph_nodes, input_feat_len))
adjacent_matrix = torch.rand(
(num_local_graphs, num_max_graph_nodes, num_max_graph_nodes))
knn_inds = torch.randint(1, num_max_graph_nodes, (num_local_graphs, k))
output = gcn(node_feat, adjacent_matrix, knn_inds)
assert output.size() == (num_local_graphs * k, 2)
def test_normalize_adjacent_matrix():
adjacent_matrix = np.random.randn(32, 32)
normalized_matrix = normalize_adjacent_matrix(adjacent_matrix, mode='AD')
assert normalized_matrix.shape == adjacent_matrix.shape
normalized_matrix = normalize_adjacent_matrix(adjacent_matrix, mode='DAD')
assert normalized_matrix.shape == adjacent_matrix.shape
with pytest.raises(NotImplementedError):
normalized_matrix = normalize_adjacent_matrix(
adjacent_matrix, mode='DA')
def test_feature_embedding():
out_feat_len = 48
# test without residue dimensions
feats = np.random.randn(10, 8)
embed_feats = feature_embedding(feats, out_feat_len)
assert embed_feats.shape == (10, out_feat_len)
# test with residue dimensions
feats = np.random.randn(10, 9)
embed_feats = feature_embedding(feats, out_feat_len)
assert embed_feats.shape == (10, out_feat_len)

View File

@ -0,0 +1,82 @@
import numpy as np
import torch
from mmocr.models.textdet.dense_heads import DRRGHead
def test_drrg_head():
in_channels = 10
drrg_head = DRRGHead(in_channels)
assert drrg_head.in_channels == in_channels
assert drrg_head.k_at_hops == (8, 4)
assert drrg_head.num_adjacent_linkages == 3
assert drrg_head.node_geo_feat_len == 120
assert np.allclose(drrg_head.pooling_scale, 1.0)
assert drrg_head.pooling_output_size == (4, 3)
assert np.allclose(drrg_head.nms_thr, 0.3)
assert np.allclose(drrg_head.min_width, 8.0)
assert np.allclose(drrg_head.max_width, 24.0)
assert np.allclose(drrg_head.comp_shrink_ratio, 1.03)
assert np.allclose(drrg_head.comp_ratio, 0.4)
assert np.allclose(drrg_head.comp_score_thr, 0.3)
assert np.allclose(drrg_head.text_region_thr, 0.2)
assert np.allclose(drrg_head.center_region_thr, 0.2)
assert drrg_head.center_region_area_thr == 50
assert np.allclose(drrg_head.local_graph_thr, 0.7)
assert np.allclose(drrg_head.link_thr, 0.85)
# test forward train
num_rois = 16
feature_maps = torch.randn((2, 10, 128, 128), dtype=torch.float)
x = np.random.randint(4, 124, (num_rois, 1))
y = np.random.randint(4, 124, (num_rois, 1))
h = 4 * np.ones((num_rois, 1))
w = 4 * np.ones((num_rois, 1))
angle = (np.random.random_sample((num_rois, 1)) * 2 - 1) * np.pi / 2
cos, sin = np.cos(angle), np.sin(angle)
comp_labels = np.random.randint(1, 3, (num_rois, 1))
num_rois = num_rois * np.ones((num_rois, 1))
comp_attribs = np.hstack([num_rois, x, y, h, w, cos, sin, comp_labels])
comp_attribs = comp_attribs.astype(np.float32)
comp_attribs_ = comp_attribs.copy()
comp_attribs = np.stack([comp_attribs, comp_attribs_])
pred_maps, gcn_data = drrg_head(feature_maps, comp_attribs)
pred_labels, gt_labels = gcn_data
assert pred_maps.size() == (2, 6, 128, 128)
assert pred_labels.ndim == gt_labels.ndim == 2
assert gt_labels.size()[0] * gt_labels.size()[1] == pred_labels.size()[0]
assert pred_labels.size()[1] == 2
# test forward test
with torch.no_grad():
feat_maps = torch.zeros((1, 10, 128, 128))
drrg_head.out_conv.bias.data.fill_(-10)
preds = drrg_head.single_test(feat_maps)
assert all([pred is None for pred in preds])
# test get_boundary
edges = np.stack([np.arange(0, 10), np.arange(1, 11)]).transpose()
edges = np.vstack([edges, np.array([1, 0])])
scores = np.ones(11, dtype=np.float32) * 0.9
x1 = np.arange(2, 22, 2)
x2 = x1 + 2
y1 = np.ones(10) * 2
y2 = y1 + 2
comp_scores = np.ones(10, dtype=np.float32) * 0.9
text_comps = np.stack([x1, y1, x2, y1, x2, y2, x1, y2,
comp_scores]).transpose()
outlier = np.array([50, 50, 52, 50, 52, 52, 50, 52, 0.9])
text_comps = np.vstack([text_comps, outlier])
(C, H, W) = (10, 128, 128)
img_metas = [{
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': np.array([1, 1, 1, 1]),
'flip': False,
}]
results = drrg_head.get_boundary(
edges, scores, text_comps, img_metas, rescale=True)
assert 'boundary_result' in results.keys()

View File

@ -1,7 +1,7 @@
import pytest
import torch
from mmocr.models.textdet.necks import FPN_UNET, FPNC
from mmocr.models.textdet.necks import FPNC, FPN_UNet
def test_fpnc():
@ -32,18 +32,18 @@ def test_fpn_unet_neck():
# len(in_channcels) is not equal to 4
with pytest.raises(AssertionError):
FPN_UNET(in_channels + [128], out_channels)
FPN_UNet(in_channels + [128], out_channels)
# `out_channels` is not int type
with pytest.raises(AssertionError):
FPN_UNET(in_channels, [2, 4])
FPN_UNet(in_channels, [2, 4])
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
fpn_unet_neck = FPN_UNET(in_channels, out_channels)
fpn_unet_neck = FPN_UNet(in_channels, out_channels)
fpn_unet_neck.init_weights()
out_neck = fpn_unet_neck(feats)

View File

@ -26,3 +26,24 @@ def test_fcenet_decode():
preds=preds, fourier_degree=k, reconstr_points=50, scale=1)
assert isinstance(boundaries, list)
def test_comps2boundaries():
from mmocr.models.textdet.postprocess.wrapper import comps2boundaries
# test comps2boundaries
x1 = np.arange(2, 18, 2)
x2 = x1 + 2
y1 = np.ones(8) * 2
y2 = y1 + 2
comp_scores = np.ones(8, dtype=np.float32) * 0.9
text_comps = np.stack([x1, y1, x2, y1, x2, y2, x1, y2,
comp_scores]).transpose()
comp_labels = np.array([1, 1, 1, 1, 1, 3, 5, 5])
shuffle = [3, 2, 5, 7, 6, 0, 4, 1]
boundaries = comps2boundaries(text_comps[shuffle], comp_labels[shuffle])
assert len(boundaries) == 3
# test comps2boundaries with blank inputs
boundaries = comps2boundaries(text_comps[[]], comp_labels[[]])
assert len(boundaries) == 0