Merge pull request #2 from HolyCrap96/feature/textsnake_drrg

[feature]: add textsnake_drrg
pull/2/head
Hongbin Sun 2021-04-03 00:25:29 +08:00 committed by GitHub
commit 3ed6aaa4e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 2961 additions and 0 deletions

View File

@ -0,0 +1,566 @@
import cv2
import numpy as np
from numpy.linalg import norm
import mmocr.utils.check_argument as check_argument
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
# from mmocr.models.textdet import la_nms
from .textsnake_targets import TextSnakeTargets
@PIPELINES.register_module()
class DRRGTargets(TextSnakeTargets):
"""Generate the ground truth targets of DRRG: Deep Relational Reasoning
Graph Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]. This was partially adapted from
https://github.com/GXYM/DRRG.
Args:
orientation_thr (float): The threshold for distinguishing between
head edge and tail edge among the horizontal and vertical edges
of a quadrangle.
resample_step (float): The step size for resampling the text center
line (TCL). Better not exceed half of the minimum width
of the text component.
min_comp_num (int): The minimum number of text components, which
should be k_hop1 + 1 on graph.
max_comp_num (int): The maximum number of text components.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
center_region_shrink_ratio (float): The shrink ratio of text center
region.
comp_shrink_ratio (float): The shrink ratio of text components.
text_comp_ratio (float): The reciprocal of aspect ratio of text
components.
min_rand_half_height(float): The minimum half-height of random text
components.
max_rand_half_height (float): The maximum half-height of random
text components.
jitter_level (float): The jitter level of text components geometric
features.
"""
def __init__(self,
orientation_thr=2.0,
resample_step=8.0,
min_comp_num=9,
max_comp_num=600,
min_width=8.0,
max_width=24.0,
center_region_shrink_ratio=0.3,
comp_shrink_ratio=1.0,
text_comp_ratio=0.65,
text_comp_nms_thr=0.25,
min_rand_half_height=6.0,
max_rand_half_height=24.0,
jitter_level=0.2):
super().__init__()
self.orientation_thr = orientation_thr
self.resample_step = resample_step
self.max_comp_num = max_comp_num
self.min_comp_num = min_comp_num
self.min_width = min_width
self.max_width = max_width
self.center_region_shrink_ratio = center_region_shrink_ratio
self.comp_shrink_ratio = comp_shrink_ratio
self.text_comp_ratio = text_comp_ratio
self.text_comp_nms_thr = text_comp_nms_thr
self.min_rand_half_height = min_rand_half_height
self.max_rand_half_height = max_rand_half_height
self.jitter_level = jitter_level
def dist_point2line(self, pnt, line):
assert isinstance(line, tuple)
pnt1, pnt2 = line
d = abs(np.cross(pnt2 - pnt1, pnt - pnt1)) / (norm(pnt2 - pnt1) + 1e-8)
return d
def draw_center_region_maps(self, top_line, bot_line, center_line,
center_region_mask, top_height_map,
bot_height_map, sin_map, cos_map,
region_shrink_ratio):
"""Draw attributes on text center region.
Args:
top_line (ndarray): The points composing top curved sideline of
text polygon.
bot_line (ndarray): The points composing bottom curved sideline
of text polygon.
center_line (ndarray): The points composing the center line of text
instance.
center_region_mask (ndarray): The text center region mask.
top_height_map (ndarray): The map on which the distance from point
to top sideline will be drawn for each pixel in text center
region.
bot_height_map (ndarray): The map on which the distance from point
to bottom sideline will be drawn for each pixel in text center
region.
sin_map (ndarray): The map of vector_sin(top_point -bot_point)
that will be drawn on text center region.
cos_map (ndarray): The map of vector_cos(top_point -bot_point)
will be drawn on text center region.
region_shrink_ratio (float): The shrink ratio of text center.
"""
assert top_line.shape == bot_line.shape == center_line.shape
assert (center_region_mask.shape == top_height_map.shape ==
bot_height_map.shape == sin_map.shape == cos_map.shape)
assert isinstance(region_shrink_ratio, float)
h, w = center_region_mask.shape
for i in range(0, len(center_line) - 1):
top_mid_point = (top_line[i] + top_line[i + 1]) / 2
bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
sin_theta = self.vector_sin(top_mid_point - bot_mid_point)
cos_theta = self.vector_cos(top_mid_point - bot_mid_point)
pnt_tl = center_line[i] + (top_line[i] -
center_line[i]) * region_shrink_ratio
pnt_tr = center_line[i + 1] + (
top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
pnt_br = center_line[i + 1] + (
bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
pnt_bl = center_line[i] + (bot_line[i] -
center_line[i]) * region_shrink_ratio
current_center_box = np.vstack([pnt_tl, pnt_tr, pnt_br,
pnt_bl]).astype(np.int32)
cv2.fillPoly(center_region_mask, [current_center_box], color=1)
cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
# x,y order
current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0,
w - 1)
current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0,
h - 1)
min_coord = np.min(current_center_box, axis=0).astype(np.int32)
max_coord = np.max(current_center_box, axis=0).astype(np.int32)
current_center_box = current_center_box - min_coord
sz = (max_coord - min_coord + 1)
center_box_mask = np.zeros((sz[1], sz[0]), dtype=np.uint8)
cv2.fillPoly(center_box_mask, [current_center_box], color=1)
inx = np.argwhere(center_box_mask > 0)
inx = inx + (min_coord[1], min_coord[0]) # y, x order
inx_xy = np.fliplr(inx)
top_height_map[(inx[:, 0], inx[:, 1])] = self.dist_point2line(
inx_xy, (top_line[i], top_line[i + 1]))
bot_height_map[(inx[:, 0], inx[:, 1])] = self.dist_point2line(
inx_xy, (bot_line[i], bot_line[i + 1]))
def generate_center_mask_attrib_maps(self, img_size, text_polys):
"""Generate text center region mask and geometry attribute maps.
Args:
img_size (tuple): The image size of (height, width).
text_polys (list[list[ndarray]]): The list of text polygons.
Returns:
center_lines (list): The list of text center lines.
center_region_mask (ndarray): The text center region mask.
top_height_map (ndarray): The distance map from each pixel in text
center region to top sideline.
bot_height_map (ndarray): The distance map from each pixel in text
center region to bottom sideline.
sin_map (ndarray): The sin(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
cos_map (ndarray): The cos(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
"""
assert isinstance(img_size, tuple)
assert check_argument.is_2dlist(text_polys)
h, w = img_size
center_lines = []
center_region_mask = np.zeros((h, w), np.uint8)
top_height_map = np.zeros((h, w), dtype=np.float32)
bot_height_map = np.zeros((h, w), dtype=np.float32)
sin_map = np.zeros((h, w), dtype=np.float32)
cos_map = np.zeros((h, w), dtype=np.float32)
for poly in text_polys:
assert len(poly) == 1
polygon_points = poly[0].reshape(-1, 2)
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
resampled_top_line, resampled_bot_line = self.resample_sidelines(
top_line, bot_line, self.resample_step)
resampled_bot_line = resampled_bot_line[::-1]
center_line = (resampled_top_line + resampled_bot_line) / 2
line_head_shrink_len = np.clip(
(norm(top_line[0] - bot_line[0]) * self.text_comp_ratio),
self.min_width, self.max_width) / 2
line_tail_shrink_len = np.clip(
(norm(top_line[-1] - bot_line[-1]) * self.text_comp_ratio),
self.min_width, self.max_width) / 2
head_shrink_num = int(line_head_shrink_len // self.resample_step)
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
center_line = center_line[head_shrink_num:len(center_line) -
tail_shrink_num]
resampled_top_line = resampled_top_line[
head_shrink_num:len(resampled_top_line) - tail_shrink_num]
resampled_bot_line = resampled_bot_line[
head_shrink_num:len(resampled_bot_line) - tail_shrink_num]
center_lines.append(center_line.astype(np.int32))
self.draw_center_region_maps(resampled_top_line,
resampled_bot_line, center_line,
center_region_mask, top_height_map,
bot_height_map, sin_map, cos_map,
self.center_region_shrink_ratio)
return (center_lines, center_region_mask, top_height_map,
bot_height_map, sin_map, cos_map)
def generate_comp_attribs_from_maps(self, center_lines, center_region_mask,
top_height_map, bot_height_map,
sin_map, cos_map, comp_shrink_ratio):
"""Generate attributes of text components in accordance with text
center lines and geometry attribute maps.
Args:
center_lines (list[ndarray]): The list of text center lines.
center_region_mask (ndarray): The text center region mask.
top_height_map (ndarray): The distance map from each pixel in text
center region to top sideline.
bot_height_map (ndarray): The distance map from each pixel in text
center region to bottom sideline.
sin_map (ndarray): The sin(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
cos_map (ndarray): The cos(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
comp_shrink_ratio (float): The text component shrink ratio.
Returns:
comp_attribs (ndarray): All text components attributes(x, y, h, w,
cos, sin, comp_labels). The comp_labels of two text components
from the same text instance are equal.
"""
assert isinstance(center_lines, list)
assert (center_region_mask.shape == top_height_map.shape ==
bot_height_map.shape == sin_map.shape == cos_map.shape)
assert isinstance(comp_shrink_ratio, float)
center_lines_mask = np.zeros_like(center_region_mask)
cv2.polylines(center_lines_mask, center_lines, 0, (1, ), 1)
center_lines_mask = center_lines_mask * center_region_mask
comp_centers = np.argwhere(center_lines_mask > 0)
y = comp_centers[:, 0]
x = comp_centers[:, 1]
top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
top_mid_x_offset = top_height * cos
top_mid_y_offset = top_height * sin
bot_mid_x_offset = bot_height * cos
bot_mid_y_offset = bot_height * sin
top_mid_pnt = comp_centers + np.hstack(
[top_mid_y_offset, top_mid_x_offset])
bot_mid_pnt = comp_centers - np.hstack(
[bot_mid_y_offset, bot_mid_x_offset])
width = (top_height + bot_height) * self.text_comp_ratio
width = np.clip(width, self.min_width, self.max_width)
top_left = (top_mid_pnt -
np.hstack([width * cos, -width * sin]))[:, ::-1]
top_right = (top_mid_pnt +
np.hstack([width * cos, -width * sin]))[:, ::-1]
bot_right = (bot_mid_pnt +
np.hstack([width * cos, -width * sin]))[:, ::-1]
bot_left = (bot_mid_pnt -
np.hstack([width * cos, -width * sin]))[:, ::-1]
text_comps = np.hstack([top_left, top_right, bot_right, bot_left])
score = center_lines_mask[y, x].reshape((-1, 1))
text_comps = np.hstack([text_comps, score]).astype(np.float32)
# text_comps = la_nms(text_comps, self.text_comp_nms_thr)
if text_comps.shape[0] < 1:
return None
img_h, img_w = center_region_mask.shape
text_comps[:, 0:8:2] = np.clip(text_comps[:, 0:8:2], 0, img_w - 1)
text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1)
comp_centers = np.mean(
text_comps[:, 0:8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
x = comp_centers[:, 0]
y = comp_centers[:, 1]
height = (top_height_map[y, x] + bot_height_map[y, x]).reshape((-1, 1))
width = np.clip(height * self.text_comp_ratio, self.min_width,
self.max_width)
cos = cos_map[y, x].reshape((-1, 1))
sin = sin_map[y, x].reshape((-1, 1))
_, comp_label_mask = cv2.connectedComponents(
center_region_mask.astype(np.uint8), connectivity=8)
comp_labels = comp_label_mask[y, x].reshape((-1, 1))
x = x.reshape((-1, 1))
y = y.reshape((-1, 1))
comp_attribs = np.hstack([x, y, height, width, cos, sin, comp_labels])
return comp_attribs
def generate_rand_comp_attribs(self, comp_num, center_sample_mask):
"""Generate random text components and their attributes to ensure the
the number text components in a text image is larger than k_hop1, which
is the number of one hop neighbors in KNN graph.
Args:
comp_num (int): The number of random text components.
center_sample_mask (ndarray): The text component centers sampling
region mask.
Returns:
rand_comp_attribs (ndarray): The random text components
attributes(x, y, h, w, cos, sin, belong_instance_label=0).
"""
assert isinstance(comp_num, int)
assert comp_num > 0
assert center_sample_mask.ndim == 2
h, w = center_sample_mask.shape
max_rand_half_height = self.max_rand_half_height
min_rand_half_height = self.min_rand_half_height
max_rand_height = max_rand_half_height * 2
max_rand_width = np.clip(max_rand_height * self.text_comp_ratio,
self.min_width, self.max_width)
margin = int(
np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
if 2 * margin + 1 > min(h, w):
assert min(h, w) > (np.sqrt(2) * (self.min_width + 1))
max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1)
min_rand_half_height = max(max_rand_half_height / 4,
self.min_width / 2)
max_rand_height = max_rand_half_height * 2
max_rand_width = np.clip(max_rand_height * self.text_comp_ratio,
self.min_width, self.max_width)
margin = int(
np.sqrt((max_rand_height / 2)**2 +
(max_rand_width / 2)**2)) + 1
inner_center_sample_mask = np.zeros_like(center_sample_mask)
inner_center_sample_mask[margin:h-margin, margin:w-margin] = \
center_sample_mask[margin:h - margin, margin:w - margin]
kernel_size = int(min(max(max_rand_half_height, 7), 17))
inner_center_sample_mask = cv2.erode(
inner_center_sample_mask,
np.ones((kernel_size, kernel_size), np.uint8))
center_candidates = np.argwhere(inner_center_sample_mask > 0)
center_candidate_num = len(center_candidates)
sample_inx = np.random.choice(center_candidate_num, comp_num)
rand_centers = center_candidates[sample_inx]
rand_top_height = np.random.randint(
min_rand_half_height,
max_rand_half_height,
size=(len(rand_centers), 1))
rand_bot_height = np.random.randint(
min_rand_half_height,
max_rand_half_height,
size=(len(rand_centers), 1))
rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
scale = np.sqrt(1.0 / (rand_cos**2 + rand_sin**2 + 1e-8))
rand_cos = rand_cos * scale
rand_sin = rand_sin * scale
height = (rand_top_height + rand_bot_height)
width = np.clip(height * self.text_comp_ratio, self.min_width,
self.max_width)
rand_comp_attribs = np.hstack([
rand_centers[:, ::-1], height, width, rand_cos, rand_sin,
np.zeros_like(rand_sin)
])
return rand_comp_attribs
def jitter_comp_attribs(self, comp_attribs, jitter_level):
"""Jitter text components attributes.
Args:
comp_attribs (ndarray): The text components attributes.
jitter_level (float): The jitter level of text components
attributes.
Returns:
jittered_comp_attribs (ndarray): The jittered text components
attributes(x, y, h, w, cos, sin, belong_instance_label).
"""
assert comp_attribs.shape[1] == 7
assert comp_attribs.shape[0] > 0
assert isinstance(jitter_level, float)
x = comp_attribs[:, 0].reshape((-1, 1))
y = comp_attribs[:, 1].reshape((-1, 1))
h = comp_attribs[:, 2].reshape((-1, 1))
w = comp_attribs[:, 3].reshape((-1, 1))
cos = comp_attribs[:, 4].reshape((-1, 1))
sin = comp_attribs[:, 5].reshape((-1, 1))
belong_label = comp_attribs[:, 6].reshape((-1, 1))
# max jitter offset of (x, y) should be
# ((h * abs(cos) + w * abs(sin)) / 2,
# (h * abs(sin) + w * abs(cos)) / 2)
x += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * (h * np.abs(cos) + w * np.abs(sin)) * jitter_level
y += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * (h * np.abs(sin) + w * np.abs(cos)) * jitter_level
# max jitter offset of (h, w) should be (h, w)
h += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * h * jitter_level
w += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * w * jitter_level
# max jitter offset of (cos, sin) should be (1, 1)
cos += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * 2 * jitter_level
sin += (np.random.random(size=(len(comp_attribs), 1)) -
0.5) * 2 * jitter_level
scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8))
cos = cos * scale
sin = sin * scale
jittered_comp_attribs = np.hstack([x, y, h, w, cos, sin, belong_label])
return jittered_comp_attribs
def generate_comp_attribs(self, center_lines, text_mask,
center_region_mask, top_height_map,
bot_height_map, sin_map, cos_map):
"""Generate text components attributes.
Args:
center_lines (list[ndarray]): The text center lines list.
text_mask (ndarray): The text region mask.
center_region_mask (ndarray): The text center region mask.
top_height_map (ndarray): The distance map from each pixel in text
center region to top sideline.
bot_height_map (ndarray): The distance map from each pixel in text
center region to bottom sideline.
sin_map (ndarray): The sin(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
cos_map (ndarray): The cos(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
Returns:
pad_comp_attribs (ndarray): The padded text components attributes
with a fixed size.
"""
assert isinstance(center_lines, list)
assert (text_mask.shape == center_region_mask.shape ==
top_height_map.shape == bot_height_map.shape == sin_map.shape
== cos_map.shape)
comp_attribs = self.generate_comp_attribs_from_maps(
center_lines, center_region_mask, top_height_map, bot_height_map,
sin_map, cos_map, self.comp_shrink_ratio)
if comp_attribs is not None:
comp_attribs = self.jitter_comp_attribs(comp_attribs,
self.jitter_level)
if comp_attribs.shape[0] < self.min_comp_num:
rand_sample_num = self.min_comp_num - comp_attribs.shape[0]
rand_comp_attribs = self.generate_rand_comp_attribs(
rand_sample_num, 1 - text_mask)
comp_attribs = np.vstack([comp_attribs, rand_comp_attribs])
else:
comp_attribs = self.generate_rand_comp_attribs(
self.min_comp_num, 1 - text_mask)
comp_num = (
np.ones((comp_attribs.shape[0], 1), dtype=np.float32) *
comp_attribs.shape[0])
comp_attribs = np.hstack([comp_num, comp_attribs])
if comp_attribs.shape[0] > self.max_comp_num:
comp_attribs = comp_attribs[:self.max_comp_num, :]
comp_attribs[:, 0] = self.max_comp_num
pad_comp_attribs = np.zeros((self.max_comp_num, comp_attribs.shape[1]))
pad_comp_attribs[:comp_attribs.shape[0], :] = comp_attribs
return pad_comp_attribs
def generate_targets(self, results):
"""Generate the gt targets for DRRG.
Args:
results (dict): The input result dictionary.
Returns:
results (dict): The output result dictionary.
"""
assert isinstance(results, dict)
polygon_masks = results['gt_masks'].masks
polygon_masks_ignore = results['gt_masks_ignore'].masks
h, w, _ = results['img_shape']
gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
(center_lines, gt_center_region_mask, gt_top_height_map,
gt_bot_height_map, gt_sin_map,
gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
polygon_masks)
gt_comp_attribs = self.generate_comp_attribs(center_lines,
gt_text_mask,
gt_center_region_mask,
gt_top_height_map,
gt_bot_height_map,
gt_sin_map, gt_cos_map)
results['mask_fields'].clear() # rm gt_masks encoded by polygons
mapping = {
'gt_text_mask': gt_text_mask,
'gt_center_region_mask': gt_center_region_mask,
'gt_mask': gt_mask,
'gt_top_height_map': gt_top_height_map,
'gt_bot_height_map': gt_bot_height_map,
'gt_sin_map': gt_sin_map,
'gt_cos_map': gt_cos_map
}
for key, value in mapping.items():
value = value if isinstance(value, list) else [value]
results[key] = BitmapMasks(value, h, w)
results['mask_fields'].append(key)
results['gt_comp_attribs'] = gt_comp_attribs
return results

View File

@ -0,0 +1,453 @@
import cv2
import numpy as np
from numpy.linalg import norm
import mmocr.utils.check_argument as check_argument
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
from . import BaseTextDetTargets
@PIPELINES.register_module()
class TextSnakeTargets(BaseTextDetTargets):
"""Generate the ground truth targets of TextSnake: TextSnake: A Flexible
Representation for Detecting Text of Arbitrary Shapes.
[https://arxiv.org/abs/1807.01544]. This was partially adapted from
https://github.com/princewang1994/TextSnake.pytorch.
Args:
orientation_thr (float): The threshold for distinguishing between
head edge and tail edge among the horizontal and vertical edges
of a quadrangle.
"""
def __init__(self,
orientation_thr=2.0,
resample_step=4.0,
center_region_shrink_ratio=0.3):
super().__init__()
self.orientation_thr = orientation_thr
self.resample_step = resample_step
self.center_region_shrink_ratio = center_region_shrink_ratio
def vector_angle(self, vec1, vec2):
if vec1.ndim > 1:
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
else:
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
if vec2.ndim > 1:
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
else:
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
return np.arccos(
np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
def vector_slope(self, vec):
assert len(vec) == 2
return abs(vec[1] / (vec[0] + 1e-8))
def vector_sin(self, vec):
assert len(vec) == 2
return vec[1] / (norm(vec) + 1e-8)
def vector_cos(self, vec):
assert len(vec) == 2
return vec[0] / (norm(vec) + 1e-8)
def find_head_tail(self, points, orientation_thr):
"""Find the head edge and tail edge of a text polygon.
Args:
points (ndarray): The points composing a text polygon.
orientation_thr (float): The threshold for distinguishing between
head edge and tail edge among the horizontal and vertical edges
of a quadrangle.
Returns:
head_inds (list): The indexes of two points composing head edge.
tail_inds (list): The indexes of two points composing tail edge.
"""
assert points.ndim == 2
assert points.shape[0] >= 4
assert points.shape[1] == 2
assert isinstance(orientation_thr, float)
if len(points) > 4:
pad_points = np.vstack([points, points[0]])
edge_vec = pad_points[1:] - pad_points[:-1]
theta_sum = []
for i, edge_vec1 in enumerate(edge_vec):
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
adjacent_edge_vec = edge_vec[adjacent_ind]
temp_theta_sum = np.sum(
self.vector_angle(edge_vec1, adjacent_edge_vec))
theta_sum.append(temp_theta_sum)
theta_sum = np.array(theta_sum)
head_start, tail_start = np.argsort(theta_sum)[::-1][0:2]
if abs(head_start - tail_start) < 2 \
or abs(head_start - tail_start) > 12:
tail_start = (head_start + len(points) // 2) % len(points)
head_end = (head_start + 1) % len(points)
tail_end = (tail_start + 1) % len(points)
if head_end > tail_end:
head_start, tail_start = tail_start, head_start
head_end, tail_end = tail_end, head_end
head_inds = [head_start, head_end]
tail_inds = [tail_start, tail_end]
else:
if self.vector_slope(points[1] - points[0]) + self.vector_slope(
points[3] - points[2]) < self.vector_slope(
points[2] - points[1]) + self.vector_slope(points[0] -
points[3]):
horizontal_edge_inds = [[0, 1], [2, 3]]
vertical_edge_inds = [[3, 0], [1, 2]]
else:
horizontal_edge_inds = [[3, 0], [1, 2]]
vertical_edge_inds = [[0, 1], [2, 3]]
vertical_len_sum = norm(points[vertical_edge_inds[0][0]] -
points[vertical_edge_inds[0][1]]) + norm(
points[vertical_edge_inds[1][0]] -
points[vertical_edge_inds[1][1]])
horizontal_len_sum = norm(
points[horizontal_edge_inds[0][0]] -
points[horizontal_edge_inds[0][1]]) + norm(
points[horizontal_edge_inds[1][0]] -
points[horizontal_edge_inds[1][1]])
if vertical_len_sum > horizontal_len_sum * orientation_thr:
head_inds = horizontal_edge_inds[0]
tail_inds = horizontal_edge_inds[1]
else:
head_inds = vertical_edge_inds[0]
tail_inds = vertical_edge_inds[1]
return head_inds, tail_inds
def reorder_poly_edge(self, points):
"""Get the respective points composing head edge, tail edge, top
sideline and bottom sideline.
Args:
points (ndarray): The points composing a text polygon.
Returns:
head_edge (ndarray): The two points composing the head edge of text
polygon.
tail_edge (ndarray): The two points composing the tail edge of text
polygon.
top_sideline (ndarray): The points composing top curved sideline of
text polygon.
bot_sideline (ndarray): The points composing bottom curved sideline
of text polygon.
"""
assert points.ndim == 2
assert points.shape[0] >= 4
assert points.shape[1] == 2
head_inds, tail_inds = self.find_head_tail(points,
self.orientation_thr)
head_edge, tail_edge = points[head_inds], points[tail_inds]
pad_points = np.vstack([points, points])
if tail_inds[1] < 1:
tail_inds[1] = len(points)
sideline1 = pad_points[head_inds[1]:tail_inds[1]]
sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
sideline_mean_shift = np.mean(
sideline1, axis=0) - np.mean(
sideline2, axis=0)
if sideline_mean_shift[1] > 0:
top_sideline, bot_sideline = sideline2, sideline1
else:
top_sideline, bot_sideline = sideline1, sideline2
return head_edge, tail_edge, top_sideline, bot_sideline
def resample_line(self, line, n):
"""Resample n points on a line.
Args:
line (ndarray): The points composing a line.
n (int): The resampled points number.
Returns:
resampled_line (ndarray): The points composing the resampled line.
"""
assert line.ndim == 2
assert line.shape[0] >= 2
assert line.shape[1] == 2
assert isinstance(n, int)
length_list = [
norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
]
total_length = sum(length_list)
length_cumsum = np.cumsum([0.0] + length_list)
delta_length = total_length / (float(n) + 1e-8)
current_edge_ind = 0
resampled_line = [line[0]]
for i in range(1, n):
current_line_len = i * delta_length
while current_line_len >= length_cumsum[current_edge_ind + 1]:
current_edge_ind += 1
current_edge_end_shift = current_line_len - length_cumsum[
current_edge_ind]
end_shift_ratio = current_edge_end_shift / length_list[
current_edge_ind]
current_point = line[current_edge_ind] + (
line[current_edge_ind + 1] -
line[current_edge_ind]) * end_shift_ratio
resampled_line.append(current_point)
resampled_line.append(line[-1])
resampled_line = np.array(resampled_line)
return resampled_line
def resample_sidelines(self, sideline1, sideline2, resample_step):
"""Resample two sidelines to be of the same points number according to
step size.
Args:
sideline1 (ndarray): The points composing a sideline of a text
polygon.
sideline2 (ndarray): The points composing another sideline of a
text polygon.
resample_step (float): The resampled step size.
Returns:
resampled_line1 (ndarray): The resampled line 1.
resampled_line2 (ndarray): The resampled line 2.
"""
assert sideline1.ndim == sideline1.ndim == 2
assert sideline1.shape[1] == sideline1.shape[1] == 2
assert sideline1.shape[0] >= 2
assert sideline2.shape[0] >= 2
assert isinstance(resample_step, float)
length1 = sum([
norm(sideline1[i + 1] - sideline1[i])
for i in range(len(sideline1) - 1)
])
length2 = sum([
norm(sideline2[i + 1] - sideline2[i])
for i in range(len(sideline2) - 1)
])
total_length = (length1 + length2) / 2
resample_point_num = int(float(total_length) / resample_step)
resampled_line1 = self.resample_line(sideline1, resample_point_num)
resampled_line2 = self.resample_line(sideline2, resample_point_num)
return resampled_line1, resampled_line2
def draw_center_region_maps(self, top_line, bot_line, center_line,
center_region_mask, radius_map, sin_map,
cos_map, region_shrink_ratio):
"""Draw attributes on text center region.
Args:
top_line (ndarray): The points composing top curved sideline of
text polygon.
bot_line (ndarray): The points composing bottom curved sideline
of text polygon.
center_line (ndarray): The points composing the center line of text
instance.
center_region_mask (ndarray): The text center region mask.
radius_map (ndarray): The map where the distance from point to
sidelines will be drawn on for each pixel in text center
region.
sin_map (ndarray): The map where vector_sin(theta) will be drawn
on text center regions. Theta is the angle between tangent
line and vector (1, 0).
cos_map (ndarray): The map where vector_cos(theta) will be drawn on
text center regions. Theta is the angle between tangent line
and vector (1, 0).
region_shrink_ratio (float): The shrink ratio of text center.
"""
assert top_line.shape == bot_line.shape == center_line.shape
assert (center_region_mask.shape == radius_map.shape == sin_map.shape
== cos_map.shape)
assert isinstance(region_shrink_ratio, float)
for i in range(0, len(center_line) - 1):
top_mid_point = (top_line[i] + top_line[i + 1]) / 2
bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
radius = norm(top_mid_point - bot_mid_point) / 2
text_direction = center_line[i + 1] - center_line[i]
sin_theta = self.vector_sin(text_direction)
cos_theta = self.vector_cos(text_direction)
pnt_tl = center_line[i] + (top_line[i] -
center_line[i]) * region_shrink_ratio
pnt_tr = center_line[i + 1] + (
top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
pnt_br = center_line[i + 1] + (
bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
pnt_bl = center_line[i] + (bot_line[i] -
center_line[i]) * region_shrink_ratio
current_center_box = np.vstack([pnt_tl, pnt_tr, pnt_br,
pnt_bl]).astype(np.int32)
cv2.fillPoly(center_region_mask, [current_center_box], color=1)
cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
cv2.fillPoly(radius_map, [current_center_box], color=radius)
def generate_center_mask_attrib_maps(self, img_size, text_polys):
"""Generate text center region mask and geometric attribute maps.
Args:
img_size (tuple): The image size of (height, width).
text_polys (list[list[ndarray]]): The list of text polygons.
Returns:
center_region_mask (ndarray): The text center region mask.
radius_map (ndarray): The distance map from each pixel in text
center region to top sideline.
sin_map (ndarray): The sin(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
cos_map (ndarray): The cos(theta) map where theta is the angle
between vector (top point - bottom point) and vector (1, 0).
"""
assert isinstance(img_size, tuple)
assert check_argument.is_2dlist(text_polys)
h, w = img_size
center_region_mask = np.zeros((h, w), np.uint8)
radius_map = np.zeros((h, w), dtype=np.float32)
sin_map = np.zeros((h, w), dtype=np.float32)
cos_map = np.zeros((h, w), dtype=np.float32)
for poly in text_polys:
assert len(poly) == 1
text_instance = [[poly[0][i], poly[0][i + 1]]
for i in range(0, len(poly[0]), 2)]
polygon_points = np.array(
text_instance, dtype=np.int).reshape(-1, 2)
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
resampled_top_line, resampled_bot_line = self.resample_sidelines(
top_line, bot_line, self.resample_step)
resampled_bot_line = resampled_bot_line[::-1]
center_line = (resampled_top_line + resampled_bot_line) / 2
if self.vector_slope(center_line[-1] - center_line[0]) > 0.9:
if (center_line[-1] - center_line[0])[1] < 0:
center_line = center_line[::-1]
resampled_top_line = resampled_top_line[::-1]
resampled_bot_line = resampled_bot_line[::-1]
else:
if (center_line[-1] - center_line[0])[0] < 0:
center_line = center_line[::-1]
resampled_top_line = resampled_top_line[::-1]
resampled_bot_line = resampled_bot_line[::-1]
line_head_shrink_len = norm(resampled_top_line[0] -
resampled_bot_line[0]) / 4.0
line_tail_shrink_len = norm(resampled_top_line[-1] -
resampled_bot_line[-1]) / 4.0
head_shrink_num = int(line_head_shrink_len // self.resample_step)
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
center_line = center_line[head_shrink_num:len(center_line) -
tail_shrink_num]
resampled_top_line = resampled_top_line[
head_shrink_num:len(resampled_top_line) - tail_shrink_num]
resampled_bot_line = resampled_bot_line[
head_shrink_num:len(resampled_bot_line) - tail_shrink_num]
self.draw_center_region_maps(resampled_top_line,
resampled_bot_line, center_line,
center_region_mask, radius_map,
sin_map, cos_map,
self.center_region_shrink_ratio)
return center_region_mask, radius_map, sin_map, cos_map
def generate_text_region_mask(self, img_size, text_polys):
"""Generate text center region mask and geometry attribute maps.
Args:
img_size (tuple): The image size (height, width).
text_polys (list[list[ndarray]]): The list of text polygons.
Returns:
text_region_mask (ndarray): The text region mask.
"""
assert isinstance(img_size, tuple)
assert check_argument.is_2dlist(text_polys)
h, w = img_size
text_region_mask = np.zeros((h, w), dtype=np.uint8)
for poly in text_polys:
assert len(poly) == 1
text_instance = [[poly[0][i], poly[0][i + 1]]
for i in range(0, len(poly[0]), 2)]
polygon = np.array(text_instance, dtype=np.int).reshape((1, -1, 2))
cv2.fillPoly(text_region_mask, polygon, 1)
return text_region_mask
def generate_targets(self, results):
"""Generate the gt targets for TextSnake.
Args:
results (dict): The input result dictionary.
Returns:
results (dict): The output result dictionary.
"""
assert isinstance(results, dict)
polygon_masks = results['gt_masks'].masks
polygon_masks_ignore = results['gt_masks_ignore'].masks
h, w, _ = results['img_shape']
gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
(gt_center_region_mask, gt_radius_map, gt_sin_map,
gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
polygon_masks)
results['mask_fields'].clear() # rm gt_masks encoded by polygons
mapping = {
'gt_text_mask': gt_text_mask,
'gt_center_region_mask': gt_center_region_mask,
'gt_mask': gt_mask,
'gt_radius_map': gt_radius_map,
'gt_sin_map': gt_sin_map,
'gt_cos_map': gt_cos_map
}
for key, value in mapping.items():
value = value if isinstance(value, list) else [value]
results[key] = BitmapMasks(value, h, w)
results['mask_fields'].append(key)
return results

View File

@ -0,0 +1,193 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.models.builder import HEADS, build_loss
from mmocr.models.textdet.modules import (GCN, LocalGraphs,
ProposalLocalGraphs,
merge_text_comps)
from .head_mixin import HeadMixin
@HEADS.register_module()
class DRRGHead(HeadMixin, nn.Module):
"""The class for DRRG head: Deep Relational Reasoning Graph Network for
Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]
Args:
k_at_hops (tuple(int)): The number of i-hop neighbors,
i = 1, 2, ..., h.
active_connection (int): The number of two hop neighbors deem as
linked to a pivot.
node_geo_feat_dim (int): The dimension of embedded geometric features
of a component.
pooling_scale (float): The spatial scale of RRoI-Aligning.
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
graph_filter_thr (float): The threshold to filter identical local
graphs.
comp_shrink_ratio (float): The shrink ratio of text components.
nms_thr (float): The locality-aware NMS threshold.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_ratio (float): The reciprocal of aspect ratio of text components.
text_region_thr (float): The threshold for text region probability map.
center_region_thr (float): The threshold of text center region
probability map.
center_region_area_thr (int): The threshold of filtering small-size
text center region.
link_thr (float): The threshold for connected components searching.
"""
def __init__(self,
in_channels,
k_at_hops=(8, 4),
active_connection=3,
node_geo_feat_dim=120,
pooling_scale=1.0,
pooling_output_size=(3, 4),
graph_filter_thr=0.75,
comp_shrink_ratio=0.95,
nms_thr=0.25,
min_width=8.0,
max_width=24.0,
comp_ratio=0.65,
text_region_thr=0.6,
center_region_thr=0.4,
center_region_area_thr=100,
link_thr=0.85,
loss=dict(type='DRRGLoss'),
train_cfg=None,
test_cfg=None):
super().__init__()
assert isinstance(in_channels, int)
assert isinstance(k_at_hops, tuple)
assert isinstance(active_connection, int)
assert isinstance(node_geo_feat_dim, int)
assert isinstance(pooling_scale, float)
assert isinstance(pooling_output_size, tuple)
assert isinstance(graph_filter_thr, float)
assert isinstance(comp_shrink_ratio, float)
assert isinstance(nms_thr, float)
assert isinstance(min_width, float)
assert isinstance(max_width, float)
assert isinstance(comp_ratio, float)
assert isinstance(center_region_area_thr, int)
assert isinstance(link_thr, float)
self.in_channels = in_channels
self.out_channels = 6
self.downsample_ratio = 1.0
self.k_at_hops = k_at_hops
self.active_connection = active_connection
self.node_geo_feat_dim = node_geo_feat_dim
self.pooling_scale = pooling_scale
self.pooling_output_size = pooling_output_size
self.graph_filter_thr = graph_filter_thr
self.comp_shrink_ratio = comp_shrink_ratio
self.nms_thr = nms_thr
self.min_width = min_width
self.max_width = max_width
self.comp_ratio = comp_ratio
self.text_region_thr = text_region_thr
self.center_region_thr = center_region_thr
self.center_region_area_thr = center_region_area_thr
self.link_thr = link_thr
self.loss_module = build_loss(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.out_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0)
self.init_weights()
self.graph_train = LocalGraphs(self.k_at_hops, self.active_connection,
self.node_geo_feat_dim,
self.pooling_scale,
self.pooling_output_size,
self.graph_filter_thr)
self.graph_test = ProposalLocalGraphs(
self.k_at_hops, self.active_connection, self.node_geo_feat_dim,
self.pooling_scale, self.pooling_output_size, self.nms_thr,
self.min_width, self.max_width, self.comp_shrink_ratio,
self.comp_ratio, self.text_region_thr, self.center_region_thr,
self.center_region_area_thr)
pool_w, pool_h = self.pooling_output_size
gcn_in_dim = (pool_w * pool_h) * (
self.in_channels + self.out_channels) + self.node_geo_feat_dim
self.gcn = GCN(gcn_in_dim, 32)
def init_weights(self):
normal_init(self.out_conv, mean=0, std=0.01)
def forward(self, inputs, text_comp_feats):
pred_maps = self.out_conv(inputs)
feat_maps = torch.cat([inputs, pred_maps], dim=1)
node_feats, adjacent_matrices, knn_inx, gt_labels = self.graph_train(
feat_maps, np.array(text_comp_feats))
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inx)
return (pred_maps, (gcn_pred, gt_labels))
def single_test(self, feat_maps):
pred_maps = self.out_conv(feat_maps)
feat_maps = torch.cat([feat_maps[0], pred_maps], dim=1)
none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
(node_feats, adjacent_matrix, pivot_inx, knn_inx, local_graph_nodes,
text_comps) = graph_data
if none_flag:
return None, None, None
adjacent_matrix, pivot_inx, knn_inx = map(
lambda x: x.to(feat_maps.device),
(adjacent_matrix, pivot_inx, knn_inx))
gcn_pred = self.gcn_model(node_feats, adjacent_matrix, knn_inx)
pred_labels = F.softmax(gcn_pred, dim=1)
edges = []
scores = []
local_graph_nodes = local_graph_nodes.long().squeeze().cpu().numpy()
graph_num = node_feats.size(0)
for graph_inx in range(graph_num):
pivot = pivot_inx[graph_inx].int().item()
nodes = local_graph_nodes[graph_inx]
for neighbor_inx, neighbor in enumerate(knn_inx[graph_inx]):
neighbor = neighbor.item()
edges.append([nodes[pivot], nodes[neighbor]])
scores.append(pred_labels[graph_inx * (knn_inx.shape[1]) +
neighbor_inx, 1].item())
edges = np.asarray(edges)
scores = np.asarray(scores)
return edges, scores, text_comps
def get_boundary(self, edges, scores, text_comps):
boundaries = []
if edges is not None:
boundaries = merge_text_comps(edges, scores, text_comps,
self.link_thr)
results = dict(boundary_result=boundaries)
return results

View File

@ -0,0 +1,48 @@
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.models.builder import HEADS, build_loss
from . import HeadMixin
@HEADS.register_module()
class TextSnakeHead(HeadMixin, nn.Module):
"""The class for TextSnake head: TextSnake: A Flexible Representation for
Detecting Text of Arbitrary Shapes.
[https://arxiv.org/abs/1807.01544]
"""
def __init__(self,
in_channels,
decoding_type='textsnake',
text_repr_type='poly',
loss=dict(type='TextSnakeLoss'),
train_cfg=None,
test_cfg=None):
super().__init__()
assert isinstance(in_channels, int)
self.in_channels = in_channels
self.out_channels = 5
self.downsample_ratio = 1.0
self.decoding_type = decoding_type
self.text_repr_type = text_repr_type
self.loss_module = build_loss(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.out_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0)
self.init_weights()
def init_weights(self):
normal_init(self.out_conv, mean=0, std=0.01)
def forward(self, inputs):
outputs = self.out_conv(inputs)
return outputs

View File

@ -0,0 +1,38 @@
from mmdet.models.builder import DETECTORS
from . import SingleStageTextDetector, TextDetectorMixin
@DETECTORS.register_module()
class DRRG(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing DRRG text detector: Deep Relational Reasoning
Graph Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]
"""
def forward_train(self, img, img_metas, **kwargs):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details of the values of these keys, see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img)
gt_comp_attribs = kwargs.pop('gt_comp_attribs')
preds = self.bbox_head(x, gt_comp_attribs)
losses = self.bbox_head.loss(preds, **kwargs)
return losses
def simple_test(self, img, img_metas, rescale=False):
x = self.extract_feat(img)
outs = self.bbox_head.single_test(x, img)
boundaries = self.bbox_head.get_boundary(*outs, img_metas, rescale)
return [boundaries]

View File

@ -0,0 +1,23 @@
from mmdet.models.builder import DETECTORS
from . import SingleStageTextDetector, TextDetectorMixin
@DETECTORS.register_module()
class TextSnake(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing TextSnake text detector: TextSnake: A
Flexible Representation for Detecting Text of Arbitrary Shapes.
[https://arxiv.org/abs/1807.01544]
"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None,
show_score=False):
SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
train_cfg, test_cfg, pretrained)
TextDetectorMixin.__init__(self, show_score)

View File

@ -0,0 +1,206 @@
import torch
import torch.nn.functional as F
from torch import nn
from mmdet.core import BitmapMasks
from mmdet.models.builder import LOSSES
from mmocr.utils import check_argument
@LOSSES.register_module()
class DRRGLoss(nn.Module):
"""The class for implementing DRRG loss: Deep Relational Reasoning Graph
Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/1908.05900] This is partially adapted from
https://github.com/GXYM/DRRG.
"""
def __init__(self, ohem_ratio=3.0):
"""Initialization.
Args:
ohem_ratio (float): The negative/positive ratio in OHEM.
"""
super().__init__()
self.ohem_ratio = ohem_ratio
def balance_bce_loss(self, pred, gt, mask):
assert pred.shape == gt.shape == mask.shape
positive = gt * mask
negative = (1 - gt) * mask
positive_count = int(positive.float().sum())
gt = gt.float()
if positive_count > 0:
loss = F.binary_cross_entropy(pred, gt, reduction='none')
positive_loss = torch.sum(loss * positive.float())
negative_loss = loss * negative.float()
negative_count = min(
int(negative.float().sum()),
int(positive_count * self.ohem_ratio))
else:
positive_loss = torch.tensor(0.0)
loss = F.binary_cross_entropy(pred, gt, reduction='none')
negative_loss = loss * negative.float()
negative_count = 100
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
balance_loss = (positive_loss + torch.sum(negative_loss)) / (
float(positive_count + negative_count) + 1e-5)
return balance_loss
def gcn_loss(self, gcn_data):
gcn_pred, gt_labels = gcn_data
gt_labels = gt_labels.view(-1).to(gcn_pred.device)
loss = F.cross_entropy(gcn_pred, gt_labels)
return loss
def bitmasks2tensor(self, bitmasks, target_sz):
"""Convert Bitmasks to tensor.
Args:
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
for one img.
target_sz (tuple(int, int)): The target tensor size HxW.
Returns
results (list[tensor]): The list of kernel tensors. Each
element is for one kernel level.
"""
assert check_argument.is_type_list(bitmasks, BitmapMasks)
assert isinstance(target_sz, tuple)
batch_size = len(bitmasks)
num_masks = len(bitmasks[0])
results = []
for level_inx in range(num_masks):
kernel = []
for batch_inx in range(batch_size):
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
# hxw
mask_sz = mask.shape
# left, right, top, bottom
pad = [
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
]
mask = F.pad(mask, pad, mode='constant', value=0)
kernel.append(mask)
kernel = torch.stack(kernel)
results.append(kernel)
return results
def forward(self, preds, downsample_ratio, gt_text_mask,
gt_center_region_mask, gt_mask, gt_top_height_map,
gt_bot_height_map, gt_sin_map, gt_cos_map):
assert isinstance(preds, tuple)
assert isinstance(downsample_ratio, float)
assert abs(downsample_ratio - 1.0) < 1e-5
assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
assert check_argument.is_type_list(gt_mask, BitmapMasks)
assert check_argument.is_type_list(gt_top_height_map, BitmapMasks)
assert check_argument.is_type_list(gt_bot_height_map, BitmapMasks)
assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
assert check_argument.is_type_list(gt_cos_map, BitmapMasks)
pred_maps, gcn_data = preds
pred_text_region = pred_maps[:, 0, :, :]
pred_center_region = pred_maps[:, 1, :, :]
pred_sin_map = pred_maps[:, 2, :, :]
pred_cos_map = pred_maps[:, 3, :, :]
pred_top_height_map = pred_maps[:, 4, :, :]
pred_bot_height_map = pred_maps[:, 5, :, :]
feature_sz = pred_maps.size()
# bitmask 2 tensor
mapping = {
'gt_text_mask': gt_text_mask,
'gt_center_region_mask': gt_center_region_mask,
'gt_mask': gt_mask,
'gt_top_height_map': gt_top_height_map,
'gt_bot_height_map': gt_bot_height_map,
'gt_sin_map': gt_sin_map,
'gt_cos_map': gt_cos_map
}
gt = {}
for key, value in mapping.items():
gt[key] = value
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
gt[key] = [item.to(pred_maps.device) for item in gt[key]]
scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
pred_sin_map = pred_sin_map * scale
pred_cos_map = pred_cos_map * scale
loss_text = self.balance_bce_loss(
torch.sigmoid(pred_text_region), gt['gt_text_mask'][0],
gt['gt_mask'][0])
text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
negative_text_mask = ((1 - gt['gt_text_mask'][0]) *
gt['gt_mask'][0]).float()
gt_center_region_mask = gt['gt_center_region_mask'][0].float()
loss_center = F.binary_cross_entropy(
torch.sigmoid(pred_center_region),
gt_center_region_mask,
reduction='none')
if int(text_mask.sum()) > 0:
loss_center_positive = torch.sum(
loss_center * text_mask) / torch.sum(text_mask)
else:
loss_center_positive = torch.tensor(0.0)
loss_center_negative = torch.sum(
loss_center * negative_text_mask) / torch.sum(negative_text_mask)
loss_center = loss_center_positive + 0.5 * loss_center_negative
center_mask = (gt['gt_center_region_mask'][0] *
gt['gt_mask'][0]).float()
if int(center_mask.sum()) > 0:
ones = torch.ones_like(
gt['gt_top_height_map'][0], dtype=torch.float)
loss_top = F.smooth_l1_loss(
pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2),
ones,
reduction='none')
loss_bot = F.smooth_l1_loss(
pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2),
ones,
reduction='none')
gt_height = (
gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0])
loss_height = torch.sum(
(torch.log(gt_height + 1) *
(loss_top + loss_bot)) * center_mask) / torch.sum(center_mask)
loss_sin = torch.sum(
F.smooth_l1_loss(
pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
center_mask) / torch.sum(center_mask)
loss_cos = torch.sum(
F.smooth_l1_loss(
pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
center_mask) / torch.sum(center_mask)
else:
loss_height = torch.tensor(0.0)
loss_sin = torch.tensor(0.0)
loss_cos = torch.tensor(0.0)
loss_gcn = self.gcn_loss(gcn_data)
results = dict(
loss_text=loss_text,
loss_center=loss_center,
loss_height=loss_height,
loss_sin=loss_sin,
loss_cos=loss_cos,
loss_gcn=loss_gcn)
return results

View File

@ -0,0 +1,181 @@
import torch
import torch.nn.functional as F
from torch import nn
from mmdet.core import BitmapMasks
from mmdet.models.builder import LOSSES
from mmocr.utils import check_argument
@LOSSES.register_module()
class TextSnakeLoss(nn.Module):
"""The class for implementing TextSnake loss:
TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes
[https://arxiv.org/abs/1807.01544].
This is partially adapted from
https://github.com/princewang1994/TextSnake.pytorch.
"""
def __init__(self, ohem_ratio=3.0):
"""Initialization.
Args:
ohem_ratio (float): The negative/positive ratio in ohem.
"""
super().__init__()
self.ohem_ratio = ohem_ratio
def balanced_bce_loss(self, pred, gt, mask):
assert pred.shape == gt.shape == mask.shape
positive = gt * mask
negative = (1 - gt) * mask
positive_count = int(positive.float().sum())
gt = gt.float()
if positive_count > 0:
loss = F.binary_cross_entropy(pred, gt, reduction='none')
positive_loss = torch.sum(loss * positive.float())
negative_loss = loss * negative.float()
negative_count = min(
int(negative.float().sum()),
int(positive_count * self.ohem_ratio))
else:
positive_loss = torch.tensor(0.0, device=pred.device)
loss = F.binary_cross_entropy(pred, gt, reduction='none')
negative_loss = loss * negative.float()
negative_count = 100
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
balance_loss = (positive_loss + torch.sum(negative_loss)) / (
float(positive_count + negative_count) + 1e-5)
return balance_loss
def bitmasks2tensor(self, bitmasks, target_sz):
"""Convert Bitmasks to tensor.
Args:
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
for one img.
target_sz (tuple(int, int)): The target tensor size HxW.
Returns
results (list[tensor]): The list of kernel tensors. Each
element is for one kernel level.
"""
assert check_argument.is_type_list(bitmasks, BitmapMasks)
assert isinstance(target_sz, tuple)
batch_size = len(bitmasks)
num_masks = len(bitmasks[0])
results = []
for level_inx in range(num_masks):
kernel = []
for batch_inx in range(batch_size):
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
# hxw
mask_sz = mask.shape
# left, right, top, bottom
pad = [
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
]
mask = F.pad(mask, pad, mode='constant', value=0)
kernel.append(mask)
kernel = torch.stack(kernel)
results.append(kernel)
return results
def forward(self, pred_maps, downsample_ratio, gt_text_mask,
gt_center_region_mask, gt_mask, gt_radius_map, gt_sin_map,
gt_cos_map):
assert isinstance(downsample_ratio, float)
assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
assert check_argument.is_type_list(gt_mask, BitmapMasks)
assert check_argument.is_type_list(gt_radius_map, BitmapMasks)
assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
assert check_argument.is_type_list(gt_cos_map, BitmapMasks)
pred_text_region = pred_maps[:, 0, :, :]
pred_center_region = pred_maps[:, 1, :, :]
pred_sin_map = pred_maps[:, 2, :, :]
pred_cos_map = pred_maps[:, 3, :, :]
pred_radius_map = pred_maps[:, 4, :, :]
feature_sz = pred_maps.size()
device = pred_maps.device
# bitmask 2 tensor
mapping = {
'gt_text_mask': gt_text_mask,
'gt_center_region_mask': gt_center_region_mask,
'gt_mask': gt_mask,
'gt_radius_map': gt_radius_map,
'gt_sin_map': gt_sin_map,
'gt_cos_map': gt_cos_map
}
gt = {}
for key, value in mapping.items():
gt[key] = value
if abs(downsample_ratio - 1.0) < 1e-2:
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
else:
gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
if key == 'gt_radius_map':
gt[key] = [item * downsample_ratio for item in gt[key]]
gt[key] = [item.to(device) for item in gt[key]]
scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
pred_sin_map = pred_sin_map * scale
pred_cos_map = pred_cos_map * scale
loss_text = self.balanced_bce_loss(
torch.sigmoid(pred_text_region), gt['gt_text_mask'][0],
gt['gt_mask'][0])
text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
loss_center_map = F.binary_cross_entropy(
torch.sigmoid(pred_center_region),
gt['gt_center_region_mask'][0].float(),
reduction='none')
if int(text_mask.sum()) > 0:
loss_center = torch.sum(
loss_center_map * text_mask) / torch.sum(text_mask)
else:
loss_center = torch.tensor(0.0, device=device)
center_mask = (gt['gt_center_region_mask'][0] *
gt['gt_mask'][0]).float()
if int(center_mask.sum()) > 0:
map_sz = pred_radius_map.size()
ones = torch.ones(map_sz, dtype=torch.float, device=device)
loss_radius = torch.sum(
F.smooth_l1_loss(
pred_radius_map / (gt['gt_radius_map'][0] + 1e-2),
ones,
reduction='none') * center_mask) / torch.sum(center_mask)
loss_sin = torch.sum(
F.smooth_l1_loss(
pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
center_mask) / torch.sum(center_mask)
loss_cos = torch.sum(
F.smooth_l1_loss(
pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
center_mask) / torch.sum(center_mask)
else:
loss_radius = torch.tensor(0.0, device=device)
loss_sin = torch.tensor(0.0, device=device)
loss_cos = torch.tensor(0.0, device=device)
results = dict(
loss_text=loss_text,
loss_center=loss_center,
loss_radius=loss_radius,
loss_sin=loss_sin,
loss_cos=loss_cos)
return results

View File

@ -0,0 +1,6 @@
from .gcn import GCN
from .local_graph import LocalGraphs
from .proposal_local_graph import ProposalLocalGraphs
from .utils import merge_text_comps
__all__ = ['LocalGraphs', 'ProposalLocalGraphs', 'GCN', 'merge_text_comps']

View File

@ -0,0 +1,80 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
class MeanAggregator(nn.Module):
def forward(self, features, A):
x = torch.bmm(A, features)
return x
class GraphConv(nn.Module):
def __init__(self, in_dim, out_dim):
super(GraphConv, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
self.bias = nn.Parameter(torch.FloatTensor(out_dim))
init.xavier_uniform_(self.weight)
init.constant_(self.bias, 0)
self.agg = MeanAggregator()
def forward(self, features, A):
b, n, d = features.shape
assert d == self.in_dim
agg_feats = self.agg(features, A)
cat_feats = torch.cat([features, agg_feats], dim=2)
out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight))
out = F.relu(out + self.bias)
return out
class GCN(nn.Module):
"""Predict linkage between instances. This was from repo
https://github.com/Zhongdao/gcn_clustering: Linkage Based Face Clustering
via Graph Convolution Network.
[https://arxiv.org/abs/1903.11306]
Args:
in_dim(int): The input dimension.
out_dim(int): The output dimension.
"""
def __init__(self, in_dim, out_dim):
super(GCN, self).__init__()
self.bn0 = nn.BatchNorm1d(in_dim, affine=False).float()
self.conv1 = GraphConv(in_dim, 512)
self.conv2 = GraphConv(512, 256)
self.conv3 = GraphConv(256, 128)
self.conv4 = GraphConv(128, 64)
self.classifier = nn.Sequential(
nn.Linear(64, out_dim), nn.PReLU(out_dim), nn.Linear(out_dim, 2))
def forward(self, x, A, one_hop_indexes, train=True):
B, N, D = x.shape
x = x.view(-1, D)
x = self.bn0(x)
x = x.view(B, N, D)
x = self.conv1(x, A)
x = self.conv2(x, A)
x = self.conv3(x, A)
x = self.conv4(x, A)
k1 = one_hop_indexes.size(-1)
dout = x.size(-1)
edge_feat = torch.zeros(B, k1, dout)
for b in range(B):
edge_feat[b, :, :] = x[b, one_hop_indexes[b]]
edge_feat = edge_feat.view(-1, dout).to(x.device)
pred = self.classifier(edge_feat)
# shape: (B*k1)x2
return pred

View File

@ -0,0 +1,307 @@
import numpy as np
import torch
from mmocr.models.utils import RROIAlign
from .utils import (embed_geo_feats, euclidean_distance_matrix,
normalize_adjacent_matrix)
class LocalGraphs:
"""Generate local graphs for GCN to predict which instance a text component
belongs to. This was partially adapted from https://github.com/GXYM/DRRG:
Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]
Args:
k_at_hops (tuple(int)): The number of h-hop neighbors.
active_connection (int): The number of neighbors deem as linked to a
pivot.
node_geo_feat_dim (int): The dimension of embedded geometric features
of a component.
pooling_scale (float): The spatial scale of RRoI-Aligning.
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
local_graph_filter_thr (float): The threshold to filter out identical
local graphs.
"""
def __init__(self, k_at_hops, active_connection, node_geo_feat_dim,
pooling_scale, pooling_output_size, local_graph_filter_thr):
assert isinstance(k_at_hops, tuple)
assert isinstance(active_connection, int)
assert isinstance(node_geo_feat_dim, int)
assert isinstance(pooling_scale, float)
assert isinstance(pooling_output_size, tuple)
assert isinstance(local_graph_filter_thr, float)
self.k_at_hops = k_at_hops
self.local_graph_depth = len(self.k_at_hops)
self.active_connection = active_connection
self.node_geo_feat_dim = node_geo_feat_dim
self.pooling = RROIAlign(pooling_output_size, pooling_scale)
self.local_graph_filter_thr = local_graph_filter_thr
def generate_local_graphs(self, sorted_complete_graph, gt_belong_labels):
"""Generate local graphs for GCN to predict which instance a text
component belongs to.
Args:
sorted_complete_graph (ndarray): The complete graph where nodes are
sorted according to their Euclidean distance.
gt_belong_labels (ndarray): The ground truth labels define which
instance text components (nodes in graphs) belong to.
Returns:
local_graph_node_list (list): The list of local graph neighbors of
pivots.
knn_graph_neighbor_list (list): The list of k nearest neighbors of
pivots.
"""
assert sorted_complete_graph.ndim == 2
assert (sorted_complete_graph.shape[0] ==
sorted_complete_graph.shape[1] == gt_belong_labels.shape[0])
knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1]
local_graph_node_list = list()
knn_graph_neighbor_list = list()
for pivot_inx, knn_graph in enumerate(knn_graphs):
h_hop_neighbor_list = list()
one_hop_neighbors = set(knn_graph[1:])
h_hop_neighbor_list.append(one_hop_neighbors)
for hop_inx in range(1, self.local_graph_depth):
h_hop_neighbor_list.append(set())
for last_hop_neighbor_inx in h_hop_neighbor_list[-2]:
h_hop_neighbor_list[-1].update(
set(sorted_complete_graph[last_hop_neighbor_inx]
[1:self.k_at_hops[hop_inx] + 1]))
hops_neighbor_set = set(
[node for hop in h_hop_neighbor_list for node in hop])
hops_neighbor_list = list(hops_neighbor_set)
hops_neighbor_list.insert(0, pivot_inx)
if pivot_inx < 1:
local_graph_node_list.append(hops_neighbor_list)
knn_graph_neighbor_list.append(one_hop_neighbors)
else:
add_flag = True
for graph_inx in range(len(knn_graph_neighbor_list)):
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
local_graph_nodes = local_graph_node_list[graph_inx]
node_union_num = len(
list(
set(knn_graph_neighbors).union(
set(one_hop_neighbors))))
node_intersect_num = len(
list(
set(knn_graph_neighbors).intersection(
set(one_hop_neighbors))))
one_hop_iou = node_intersect_num / (node_union_num + 1e-8)
if (one_hop_iou > self.local_graph_filter_thr
and pivot_inx in knn_graph_neighbors
and gt_belong_labels[local_graph_nodes[0]]
== gt_belong_labels[pivot_inx]
and gt_belong_labels[local_graph_nodes[0]] != 0):
add_flag = False
break
if add_flag:
local_graph_node_list.append(hops_neighbor_list)
knn_graph_neighbor_list.append(one_hop_neighbors)
return local_graph_node_list, knn_graph_neighbor_list
def generate_gcn_input(self, node_feat_batch, belong_label_batch,
local_graph_node_batch, knn_graph_neighbor_batch,
sorted_complete_graph):
"""Generate graph convolution network input data.
Args:
node_feat_batch (List[Tensor]): The node feature batch.
belong_label_batch (List[ndarray]): The text component belong label
batch.
local_graph_node_batch (List[List[list]]): The local graph
neighbors batch.
knn_graph_neighbor_batch (List[List[set]]): The knn graph neighbor
batch.
sorted_complete_graph (List[ndarray]): The complete graph sorted
according to the Euclidean distance.
Returns:
node_feat_batch_tensor (Tensor): The node features of Graph
Convolutional Network (GCN).
adjacent_mat_batch_tensor (Tensor): The adjacent matrices.
knn_inx_batch_tensor (Tensor): The indices of k nearest neighbors.
gt_linkage_batch_tensor (Tensor): The surpervision signal of GCN
for linkage prediction.
"""
assert isinstance(node_feat_batch, list)
assert isinstance(belong_label_batch, list)
assert isinstance(local_graph_node_batch, list)
assert isinstance(knn_graph_neighbor_batch, list)
assert isinstance(sorted_complete_graph, list)
max_graph_node_num = max([
len(local_graph_nodes)
for local_graph_node_list in local_graph_node_batch
for local_graph_nodes in local_graph_node_list
])
node_feat_batch_list = list()
adjacent_matrix_batch_list = list()
knn_inx_batch_list = list()
gt_linkage_batch_list = list()
for batch_inx, sorted_neighbors in enumerate(sorted_complete_graph):
node_feats = node_feat_batch[batch_inx]
local_graph_list = local_graph_node_batch[batch_inx]
knn_graph_neighbor_list = knn_graph_neighbor_batch[batch_inx]
belong_labels = belong_label_batch[batch_inx]
for graph_inx in range(len(local_graph_list)):
local_graph_nodes = local_graph_list[graph_inx]
local_graph_node_num = len(local_graph_nodes)
pivot_inx = local_graph_nodes[0]
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
node_to_graph_inx = {
j: i
for i, j in enumerate(local_graph_nodes)
}
knn_inx_in_local_graph = torch.tensor(
[node_to_graph_inx[i] for i in knn_graph_neighbors],
dtype=torch.long)
pivot_feats = node_feats[torch.tensor(
pivot_inx, dtype=torch.long)]
normalized_feats = node_feats[torch.tensor(
local_graph_nodes, dtype=torch.long)] - pivot_feats
adjacent_matrix = np.zeros(
(local_graph_node_num, local_graph_node_num))
pad_normalized_feats = torch.cat([
normalized_feats,
torch.zeros(max_graph_node_num - local_graph_node_num,
normalized_feats.shape[1]).to(
node_feats.device)
],
dim=0)
for node in local_graph_nodes:
neighbors = sorted_neighbors[node,
1:self.active_connection + 1]
for neighbor in neighbors:
if neighbor in local_graph_nodes:
adjacent_matrix[node_to_graph_inx[node],
node_to_graph_inx[neighbor]] = 1
adjacent_matrix[node_to_graph_inx[neighbor],
node_to_graph_inx[node]] = 1
adjacent_matrix = normalize_adjacent_matrix(
adjacent_matrix, type='DAD')
adjacent_matrix_tensor = torch.zeros(max_graph_node_num,
max_graph_node_num).to(
node_feats.device)
adjacent_matrix_tensor[:local_graph_node_num, :
local_graph_node_num] = adjacent_matrix
local_graph_labels = torch.from_numpy(
belong_labels[local_graph_nodes]).type(torch.long)
knn_labels = local_graph_labels[knn_inx_in_local_graph]
edge_labels = ((belong_labels[pivot_inx] == knn_labels)
& (belong_labels[pivot_inx] > 0)).long()
node_feat_batch_list.append(pad_normalized_feats)
adjacent_matrix_batch_list.append(adjacent_matrix_tensor)
knn_inx_batch_list.append(knn_inx_in_local_graph)
gt_linkage_batch_list.append(edge_labels)
node_feat_batch_tensor = torch.stack(node_feat_batch_list, 0)
adjacent_mat_batch_tensor = torch.stack(adjacent_matrix_batch_list, 0)
knn_inx_batch_tensor = torch.stack(knn_inx_batch_list, 0)
gt_linkage_batch_tensor = torch.stack(gt_linkage_batch_list, 0)
return (node_feat_batch_tensor, adjacent_mat_batch_tensor,
knn_inx_batch_tensor, gt_linkage_batch_tensor)
def __call__(self, feat_maps, comp_attribs):
"""Generate local graphs.
Args:
feat_maps (Tensor): The feature maps to propose node features in
graph.
comp_attribs (ndarray): The text components attributes.
Returns:
node_feats_batch (Tensor): The node features of Graph Convolutional
Network(GCN).
adjacent_matrices_batch (Tensor): The adjacent matrices.
knn_inx_batch (Tensor): The indices of k nearest neighbors.
gt_linkage_batch (Tensor): The surpervision signal of GCN for
linkage prediction.
"""
assert isinstance(feat_maps, torch.Tensor)
assert comp_attribs.shape[2] == 8
dist_sort_graph_batch_list = []
local_graph_node_batch_list = []
knn_graph_neighbor_batch_list = []
node_feature_batch_list = []
belong_label_batch_list = []
for batch_inx in range(comp_attribs.shape[0]):
comp_num = int(comp_attribs[batch_inx, 0, 0])
comp_geo_attribs = comp_attribs[batch_inx, :comp_num, 1:7]
node_belong_labels = comp_attribs[batch_inx, :comp_num,
7].astype(np.int32)
comp_centers = comp_geo_attribs[:, 0:2]
distance_matrix = euclidean_distance_matrix(
comp_centers, comp_centers)
graph_node_geo_feats = embed_geo_feats(comp_geo_attribs,
self.node_geo_feat_dim)
graph_node_geo_feats = torch.from_numpy(
graph_node_geo_feats).float().to(feat_maps.device)
batch_id = np.zeros(
(comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_inx
text_comps = np.hstack(
(batch_id, comp_geo_attribs.astype(np.float32)))
text_comps = torch.from_numpy(text_comps).float().to(
feat_maps.device)
comp_content_feats = self.pooling(
feat_maps[batch_inx].unsqueeze(0), text_comps)
comp_content_feats = comp_content_feats.view(
comp_content_feats.shape[0], -1).to(feat_maps.device)
node_feats = torch.cat((comp_content_feats, graph_node_geo_feats),
dim=-1)
dist_sort_complete_graph = np.argsort(distance_matrix, axis=1)
(local_graph_nodes,
knn_graph_neighbors) = self.generate_local_graphs(
dist_sort_complete_graph, node_belong_labels)
node_feature_batch_list.append(node_feats)
belong_label_batch_list.append(node_belong_labels)
local_graph_node_batch_list.append(local_graph_nodes)
knn_graph_neighbor_batch_list.append(knn_graph_neighbors)
dist_sort_graph_batch_list.append(dist_sort_complete_graph)
(node_feats_batch, adjacent_matrices_batch, knn_inx_batch,
gt_linkage_batch) = \
self.generate_gcn_input(node_feature_batch_list,
belong_label_batch_list,
local_graph_node_batch_list,
knn_graph_neighbor_batch_list,
dist_sort_graph_batch_list)
return (node_feats_batch, adjacent_matrices_batch, knn_inx_batch,
gt_linkage_batch)

View File

@ -0,0 +1,418 @@
import cv2
import numpy as np
import torch
# from mmocr.models.textdet.postprocess import la_nms
from mmocr.models.utils import RROIAlign
from .utils import (embed_geo_feats, euclidean_distance_matrix,
normalize_adjacent_matrix)
class ProposalLocalGraphs:
"""Propose text components and generate local graphs. This was partially
adapted from https://github.com/GXYM/DRRG: Deep Relational Reasoning Graph
Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]
Args:
k_at_hops (tuple(int)): The number of i-hop neighbors,
i = 1, 2, ..., h.
active_connection (int): The number of two hop neighbors deem as linked
to a pivot.
node_geo_feat_dim (int): The dimension of embedded geometric features
of a component.
pooling_scale (float): The spatial scale of RRoI-Aligning.
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
nms_thr (float): The locality-aware NMS threshold.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_shrink_ratio (float): The shrink ratio of text components.
comp_ratio (float): The reciprocal of aspect ratio of text components.
text_region_thr (float): The threshold for text region probability map.
center_region_thr (float): The threshold for text center region
probability map.
center_region_area_thr (int): The threshold for filtering out
small-size text center region.
"""
def __init__(self, k_at_hops, active_connection, node_geo_feat_dim,
pooling_scale, pooling_output_size, nms_thr, min_width,
max_width, comp_shrink_ratio, comp_ratio, text_region_thr,
center_region_thr, center_region_area_thr):
assert isinstance(k_at_hops, tuple)
assert isinstance(active_connection, int)
assert isinstance(node_geo_feat_dim, int)
assert isinstance(pooling_scale, float)
assert isinstance(pooling_output_size, tuple)
assert isinstance(nms_thr, float)
assert isinstance(min_width, float)
assert isinstance(max_width, float)
assert isinstance(comp_shrink_ratio, float)
assert isinstance(comp_ratio, float)
assert isinstance(text_region_thr, float)
assert isinstance(center_region_thr, float)
assert isinstance(center_region_area_thr, int)
self.k_at_hops = k_at_hops
self.active_connection = active_connection
self.local_graph_depth = len(self.k_at_hops)
self.node_geo_feat_dim = node_geo_feat_dim
self.pooling = RROIAlign(pooling_output_size, pooling_scale)
self.nms_thr = nms_thr
self.min_width = min_width
self.max_width = max_width
self.comp_shrink_ratio = comp_shrink_ratio
self.comp_ratio = comp_ratio
self.text_region_thr = text_region_thr
self.center_region_thr = center_region_thr
self.center_region_area_thr = center_region_area_thr
def fill_hole(self, input_mask):
h, w = input_mask.shape
canvas = np.zeros((h + 2, w + 2), np.uint8)
canvas[1:h + 1, 1:w + 1] = input_mask.copy()
mask = np.zeros((h + 4, w + 4), np.uint8)
cv2.floodFill(canvas, mask, (0, 0), 1)
canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
return (~canvas | input_mask.astype(np.uint8))
def propose_comps(self, top_radius_map, bot_radius_map, sin_map, cos_map,
score_map, min_width, max_width, comp_shrink_ratio,
comp_ratio):
"""Generate text components.
Args:
top_radius_map (ndarray): The predicted distance map from each
pixel in text center region to top sideline.
bot_radius_map (ndarray): The predicted distance map from each
pixel in text center region to bottom sideline.
sin_map (ndarray): The predicted sin(theta) map.
cos_map (ndarray): The predicted cos(theta) map.
score_map (ndarray): The score map for NMS.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_shrink_ratio (float): The shrink ratio of text components.
comp_ratio (float): The reciprocal of aspect ratio of text
components.
Returns:
text_comps (ndarray): The text components.
"""
comp_centers = np.argwhere(score_map > 0)
comp_centers = comp_centers[np.argsort(comp_centers[:, 0])]
y = comp_centers[:, 0]
x = comp_centers[:, 1]
top_radius = top_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
bot_radius = bot_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
top_mid_x_offset = top_radius * cos
top_mid_y_offset = top_radius * sin
bot_mid_x_offset = bot_radius * cos
bot_mid_y_offset = bot_radius * sin
top_mid_pnt = comp_centers + np.hstack(
[top_mid_y_offset, top_mid_x_offset])
bot_mid_pnt = comp_centers - np.hstack(
[bot_mid_y_offset, bot_mid_x_offset])
width = (top_radius + bot_radius) * comp_ratio
width = np.clip(width, min_width, max_width)
top_left = top_mid_pnt - np.hstack([width * cos, -width * sin
])[:, ::-1]
top_right = top_mid_pnt + np.hstack([width * cos, -width * sin
])[:, ::-1]
bot_right = bot_mid_pnt + np.hstack([width * cos, -width * sin
])[:, ::-1]
bot_left = bot_mid_pnt - np.hstack([width * cos, -width * sin
])[:, ::-1]
text_comps = np.hstack([top_left, top_right, bot_right, bot_left])
score = score_map[y, x].reshape((-1, 1))
text_comps = np.hstack([text_comps, score])
return text_comps
def propose_comps_and_attribs(self, text_region_map, center_region_map,
top_radius_map, bot_radius_map, sin_map,
cos_map):
"""Generate text components and attributes.
Args:
text_region_map (ndarray): The predicted text region probability
map.
center_region_map (ndarray): The predicted text center region
probability map.
top_radius_map (ndarray): The predicted distance map from each
pixel in text center region to top sideline.
bot_radius_map (ndarray): The predicted distance map from each
pixel in text center region to bottom sideline.
sin_map (ndarray): The predicted sin(theta) map.
cos_map (ndarray): The predicted cos(theta) map.
Returns:
comp_attribs (ndarray): The text components attributes.
text_comps (ndarray): The text components.
"""
assert (text_region_map.shape == center_region_map.shape ==
top_radius_map.shape == bot_radius_map == sin_map.shape ==
cos_map.shape)
text_mask = text_region_map > self.text_region_thr
center_region_mask = (center_region_map >
self.center_region_thr) * text_mask
scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2))
sin_map, cos_map = sin_map * scale, cos_map * scale
center_region_mask = self.fill_hole(center_region_mask)
center_region_contours, _ = cv2.findContours(
center_region_mask.astype(np.uint8), cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
mask = np.zeros_like(center_region_mask)
comp_list = []
for contour in center_region_contours:
current_center_mask = mask.copy()
cv2.drawContours(current_center_mask, [contour], -1, 1, -1)
if current_center_mask.sum() <= self.center_region_area_thr:
continue
score_map = text_region_map * current_center_mask
text_comp = self.propose_comps(top_radius_map, bot_radius_map,
sin_map, cos_map, score_map,
self.min_width, self.max_width,
self.comp_shrink_ratio,
self.comp_ratio)
# text_comp = la_nms(text_comp.astype('float32'), self.nms_thr)
text_comp_mask = mask.copy()
text_comps_bboxes = text_comp[:, :8].reshape(
(-1, 4, 2)).astype(np.int32)
cv2.drawContours(text_comp_mask, text_comps_bboxes, -1, 1, -1)
if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
continue
comp_list.append(text_comp)
if len(comp_list) <= 0:
return None, None
text_comps = np.vstack(comp_list)
centers = np.mean(
text_comps[:, :8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
x = centers[:, 0]
y = centers[:, 1]
h = top_radius_map[y, x].reshape(
(-1, 1)) + bot_radius_map[y, x].reshape((-1, 1))
w = np.clip(h * self.comp_ratio, self.min_width, self.max_width)
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
x = x.reshape((-1, 1))
y = y.reshape((-1, 1))
comp_attribs = np.hstack([x, y, h, w, cos, sin])
return comp_attribs, text_comps
def generate_local_graphs(self, sorted_complete_graph, node_feats):
"""Generate local graphs and Graph Convolution Network input data.
Args:
sorted_complete_graph (ndarray): The complete graph where nodes are
sorted according to their Euclidean distance.
node_feats (tensor): The graph nodes features.
Returns:
node_feats_tensor (tensor): The graph nodes features.
adjacent_matrix_tensor (tensor): The adjacent matrix of graph.
pivot_inx_tensor (tensor): The pivot indices in local graph.
knn_inx_tensor (tensor): The k nearest neighbor nodes indexes in
local graph.
local_graph_node_tensor (tensor): The indices of nodes in local
graph.
"""
assert sorted_complete_graph.ndim == 2
assert (sorted_complete_graph.shape[0] ==
sorted_complete_graph.shape[1] == node_feats.shape[0])
knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1]
local_graph_node_list = list()
knn_graph_neighbor_list = list()
for pivot_inx, knn_graph in enumerate(knn_graphs):
h_hop_neighbor_list = list()
one_hop_neighbors = set(knn_graph[1:])
h_hop_neighbor_list.append(one_hop_neighbors)
for hop_inx in range(1, self.local_graph_depth):
h_hop_neighbor_list.append(set())
for last_hop_neighbor_inx in h_hop_neighbor_list[-2]:
h_hop_neighbor_list[-1].update(
set(sorted_complete_graph[last_hop_neighbor_inx]
[1:self.k_at_hops[hop_inx] + 1]))
hops_neighbor_set = set(
[node for hop in h_hop_neighbor_list for node in hop])
hops_neighbor_list = list(hops_neighbor_set)
hops_neighbor_list.insert(0, pivot_inx)
local_graph_node_list.append(hops_neighbor_list)
knn_graph_neighbor_list.append(one_hop_neighbors)
max_graph_node_num = max([
len(local_graph_nodes)
for local_graph_nodes in local_graph_node_list
])
node_normalized_feats = list()
adjacent_matrix_list = list()
knn_inx = list()
pivot_graph_inx = list()
local_graph_tensor_list = list()
for graph_inx in range(len(local_graph_node_list)):
local_graph_nodes = local_graph_node_list[graph_inx]
local_graph_node_num = len(local_graph_nodes)
pivot_inx = local_graph_nodes[0]
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
node_to_graph_inx = {j: i for i, j in enumerate(local_graph_nodes)}
pivot_node_inx = torch.tensor([
node_to_graph_inx[pivot_inx],
]).type(torch.long)
knn_inx_in_local_graph = torch.tensor(
[node_to_graph_inx[i] for i in knn_graph_neighbors],
dtype=torch.long)
pivot_feats = node_feats[torch.tensor(pivot_inx, dtype=torch.long)]
normalized_feats = node_feats[torch.tensor(
local_graph_nodes, dtype=torch.long)] - pivot_feats
adjacent_matrix = np.zeros(
(local_graph_node_num, local_graph_node_num))
pad_normalized_feats = torch.cat([
normalized_feats,
torch.zeros(max_graph_node_num - local_graph_node_num,
normalized_feats.shape[1]).to(node_feats.device)
],
dim=0)
for node in local_graph_nodes:
neighbors = sorted_complete_graph[node,
1:self.active_connection + 1]
for neighbor in neighbors:
if neighbor in local_graph_nodes:
adjacent_matrix[node_to_graph_inx[node],
node_to_graph_inx[neighbor]] = 1
adjacent_matrix[node_to_graph_inx[neighbor],
node_to_graph_inx[node]] = 1
adjacent_matrix = normalize_adjacent_matrix(
adjacent_matrix, type='DAD')
adjacent_matrix_tensor = torch.zeros(
max_graph_node_num, max_graph_node_num).to(node_feats.device)
adjacent_matrix_tensor[:local_graph_node_num, :
local_graph_node_num] = adjacent_matrix
local_graph_tensor = torch.tensor(local_graph_nodes)
local_graph_tensor = torch.cat([
local_graph_tensor,
torch.zeros(
max_graph_node_num - local_graph_node_num,
dtype=torch.long)
],
dim=0)
node_normalized_feats.append(pad_normalized_feats)
adjacent_matrix_list.append(adjacent_matrix_tensor)
pivot_graph_inx.append(pivot_node_inx)
knn_inx.append(knn_inx_in_local_graph)
local_graph_tensor_list.append(local_graph_tensor)
node_feats_tensor = torch.stack(node_normalized_feats, 0)
adjacent_matrix_tensor = torch.stack(adjacent_matrix_list, 0)
pivot_inx_tensor = torch.stack(pivot_graph_inx, 0)
knn_inx_tensor = torch.stack(knn_inx, 0)
local_graph_node_tensor = torch.stack(local_graph_tensor_list, 0)
return (node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor,
knn_inx_tensor, local_graph_node_tensor)
def __call__(self, preds, feat_maps):
"""Generate local graphs and Graph Convolution Network input data.
Args:
preds (tensor): The predicted maps.
feat_maps (tensor): The feature maps to extract content features of
text components.
Returns:
node_feats_tensor (tensor): The graph nodes features.
adjacent_matrix_tensor (tensor): The adjacent matrix of graph.
pivot_inx_tensor (tensor): The pivot indices in local graph.
knn_inx_tensor (tensor): The k nearest neighbor nodes indices in
local graph.
local_graph_node_tensor (tensor): The indices of nodes in local
graph.
text_comps (ndarray): The predicted text components.
"""
pred_text_region = torch.sigmoid(preds[0, 0]).data.cpu().numpy()
pred_center_region = torch.sigmoid(preds[0, 1]).data.cpu().numpy()
pred_sin_map = preds[0, 2].data.cpu().numpy()
pred_cos_map = preds[0, 3].data.cpu().numpy()
pred_top_radius_map = preds[0, 4].data.cpu().numpy()
pred_bot_radius_map = preds[0, 5].data.cpu().numpy()
comp_attribs, text_comps = self.propose_comps_and_attribs(
pred_text_region, pred_center_region, pred_top_radius_map,
pred_bot_radius_map, pred_sin_map, pred_cos_map)
if comp_attribs is None:
none_flag = True
return none_flag, (0, 0, 0, 0, 0, 0)
comp_centers = comp_attribs[:, 0:2]
distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
graph_node_geo_feats = embed_geo_feats(comp_attribs,
self.node_geo_feat_dim)
graph_node_geo_feats = torch.from_numpy(
graph_node_geo_feats).float().to(preds.device)
batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32)
text_comps_bboxes = np.hstack(
(batch_id, comp_attribs.astype(np.float32, copy=False)))
text_comps_bboxes = torch.from_numpy(text_comps_bboxes).float().to(
preds.device)
comp_content_feats = self.pooling(feat_maps, text_comps_bboxes)
comp_content_feats = comp_content_feats.view(
comp_content_feats.shape[0], -1).to(preds.device)
node_feats = torch.cat((comp_content_feats, graph_node_geo_feats),
dim=-1)
dist_sort_complete_graph = np.argsort(distance_matrix, axis=1)
(node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor,
knn_inx_tensor, local_graph_node_tensor) = self.generate_local_graphs(
dist_sort_complete_graph, node_feats)
none_flag = False
return none_flag, (node_feats_tensor, adjacent_matrix_tensor,
pivot_inx_tensor, knn_inx_tensor,
local_graph_node_tensor, text_comps)

View File

@ -0,0 +1,354 @@
import functools
import operator
from typing import List
import cv2
import numpy as np
import torch
from numpy.linalg import norm
def normalize_adjacent_matrix(A, type='AD'):
"""Normalize adjacent matrix for GCN.
This was from repo https://github.com/GXYM/DRRG.
"""
if type == 'DAD':
# d is Degree of nodes A=A+I
# L = D^-1/2 A D^-1/2
A = A + np.eye(A.shape[0]) # A=A+I
d = np.sum(A, axis=0)
d_inv = np.power(d, -0.5).flatten()
d_inv[np.isinf(d_inv)] = 0.0
d_inv = np.diag(d_inv)
G = A.dot(d_inv).transpose().dot(d_inv)
G = torch.from_numpy(G)
elif type == 'AD':
A = A + np.eye(A.shape[0]) # A=A+I
A = torch.from_numpy(A)
D = A.sum(1, keepdim=True)
G = A.div(D)
else:
A = A + np.eye(A.shape[0]) # A=A+I
A = torch.from_numpy(A)
D = A.sum(1, keepdim=True)
D = np.diag(D)
G = D - A
return G
def euclidean_distance_matrix(A, B):
"""Calculate the Euclidean distance matrix."""
M = A.shape[0]
N = B.shape[0]
assert A.shape[1] == B.shape[1]
A_dots = (A * A).sum(axis=1).reshape((M, 1)) * np.ones(shape=(1, N))
B_dots = (B * B).sum(axis=1) * np.ones(shape=(M, 1))
D_squared = A_dots + B_dots - 2 * A.dot(B.T)
zero_mask = np.less(D_squared, 0.0)
D_squared[zero_mask] = 0.0
return np.sqrt(D_squared)
def embed_geo_feats(geo_feats, out_dim):
"""Embed geometric features of text components. This was partially adapted
from https://github.com/GXYM/DRRG.
Args:
geo_feats (ndarray): The geometric features of text components.
out_dim (int): The output dimension.
Returns:
embedded_feats (ndarray): The embedded geometric features.
"""
assert isinstance(out_dim, int)
assert out_dim >= geo_feats.shape[1]
comp_num = geo_feats.shape[0]
feat_dim = geo_feats.shape[1]
feat_repeat_times = out_dim // feat_dim
residue_dim = out_dim % feat_dim
if residue_dim > 0:
embed_wave = np.array([
np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
for j in range(feat_repeat_times + 1)
]).reshape((feat_repeat_times + 1, 1, 1))
repeat_feats = np.repeat(
np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0)
residue_feats = np.hstack([
geo_feats[:, 0:residue_dim],
np.zeros((comp_num, feat_dim - residue_dim))
])
repeat_feats = np.stack([repeat_feats, residue_feats], axis=0)
embedded_feats = repeat_feats / embed_wave
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
(comp_num, -1))[:, 0:out_dim]
else:
embed_wave = np.array([
np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
for j in range(feat_repeat_times)
]).reshape((feat_repeat_times, 1, 1))
repeat_feats = np.repeat(
np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0)
embedded_feats = repeat_feats / embed_wave
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
(comp_num, -1))
return embedded_feats
def min_connect_path(list_all: List[list]):
"""This is from https://github.com/GXYM/DRRG."""
list_nodo = list_all.copy()
res: List[List[int]] = []
ept = [0, 0]
def norm2(a, b):
return ((a[0] - b[0])**2 + (a[1] - b[1])**2)**0.5
dict00 = {}
dict11 = {}
ept[0] = list_nodo[0]
ept[1] = list_nodo[0]
list_nodo.remove(list_nodo[0])
while list_nodo:
for i in list_nodo:
length0 = norm2(i, ept[0])
dict00[length0] = [i, ept[0]]
length1 = norm2(ept[1], i)
dict11[length1] = [ept[1], i]
key0 = min(dict00.keys())
key1 = min(dict11.keys())
if key0 <= key1:
ss = dict00[key0][0]
ee = dict00[key0][1]
res.insert(0, [list_all.index(ss), list_all.index(ee)])
list_nodo.remove(ss)
ept[0] = ss
else:
ss = dict11[key1][0]
ee = dict11[key1][1]
res.append([list_all.index(ss), list_all.index(ee)])
list_nodo.remove(ee)
ept[1] = ee
dict00 = {}
dict11 = {}
path = functools.reduce(operator.concat, res)
path = sorted(set(path), key=path.index)
return res, path
def clusters2labels(clusters, node_num):
"""This is from https://github.com/GXYM/DRRG."""
labels = (-1) * np.ones((node_num, ))
for cluster_inx, cluster in enumerate(clusters):
for node in cluster:
labels[node.inx] = cluster_inx
assert np.sum(labels < 0) < 1
return labels
def remove_single(text_comps, pred):
"""Remove isolated single text components.
This is from https://github.com/GXYM/DRRG.
"""
single_flags = np.zeros_like(pred)
pred_labels = np.unique(pred)
for label in pred_labels:
current_label_flag = pred == label
if np.sum(current_label_flag) == 1:
single_flags[np.where(current_label_flag)[0][0]] = 1
remain_inx = [i for i in range(len(pred)) if not single_flags[i]]
remain_inx = np.asarray(remain_inx)
return text_comps[remain_inx, :], pred[remain_inx]
class Node:
def __init__(self, inx):
self.__inx = inx
self.__links = set()
@property
def inx(self):
return self.__inx
@property
def links(self):
return set(self.__links)
def add_link(self, other, score):
self.__links.add(other)
other.__links.add(self)
def connected_components(nodes, score_dict, thr):
"""Connected components searching.
This is from https://github.com/GXYM/DRRG.
"""
result = []
nodes = set(nodes)
while nodes:
node = nodes.pop()
group = {node}
queue = [node]
while queue:
node = queue.pop(0)
if thr is not None:
neighbors = {
linked_neighbor
for linked_neighbor in node.links if score_dict[tuple(
sorted([node.inx, linked_neighbor.inx]))] >= thr
}
else:
neighbors = node.links
neighbors.difference_update(group)
nodes.difference_update(neighbors)
group.update(neighbors)
queue.extend(neighbors)
result.append(group)
return result
def graph_propagation(edges,
scores,
link_thr,
bboxes=None,
dis_thr=50,
pool='avg'):
"""Propagate graph linkage score information.
This is from repo https://github.com/GXYM/DRRG.
"""
edges = np.sort(edges, axis=1)
score_dict = {}
if pool is None:
for i, edge in enumerate(edges):
score_dict[edge[0], edge[1]] = scores[i]
elif pool == 'avg':
for i, edge in enumerate(edges):
if bboxes is not None:
box1 = bboxes[edge[0]][:8].reshape(4, 2)
box2 = bboxes[edge[1]][:8].reshape(4, 2)
center1 = np.mean(box1, axis=0)
center2 = np.mean(box2, axis=0)
dst = norm(center1 - center2)
if dst > dis_thr:
scores[i] = 0
if (edge[0], edge[1]) in score_dict:
score_dict[edge[0], edge[1]] = 0.5 * (
score_dict[edge[0], edge[1]] + scores[i])
else:
score_dict[edge[0], edge[1]] = scores[i]
elif pool == 'max':
for i, edge in enumerate(edges):
if (edge[0], edge[1]) in score_dict:
score_dict[edge[0],
edge[1]] = max(score_dict[edge[0], edge[1]],
scores[i])
else:
score_dict[edge[0], edge[1]] = scores[i]
else:
raise ValueError('Pooling operation not supported')
nodes = np.sort(np.unique(edges.flatten()))
mapping = -1 * np.ones((nodes.max() + 1), dtype=np.int)
mapping[nodes] = np.arange(nodes.shape[0])
link_inx = mapping[edges]
vertex = [Node(node) for node in nodes]
for link, score in zip(link_inx, scores):
vertex[link[0]].add_link(vertex[link[1]], score)
clusters = connected_components(vertex, score_dict, link_thr)
return clusters
def in_contour(cont, point):
x, y = point
return cv2.pointPolygonTest(cont, (x, y), False) > 0
def select_edge(cont, box):
"""This is from repo https://github.com/GXYM/DRRG."""
cont = np.array(cont)
box = box.astype(np.int32)
c1 = np.array(0.5 * (box[0, :] + box[3, :]), dtype=np.int)
c2 = np.array(0.5 * (box[1, :] + box[2, :]), dtype=np.int)
if not in_contour(cont, c1):
return [box[0, :].tolist(), box[3, :].tolist()]
elif not in_contour(cont, c2):
return [box[1, :].tolist(), box[2, :].tolist()]
else:
return None
def comps2boundary(text_comps, final_pred):
"""Propose text components and generate local graphs.
This is from repo https://github.com/GXYM/DRRG.
"""
bbox_contours = list()
for inx in range(0, int(np.max(final_pred)) + 1):
current_instance = np.where(final_pred == inx)
boxes = text_comps[current_instance, :8].reshape(
(-1, 4, 2)).astype(np.int32)
boundary_point = None
if boxes.shape[0] > 1:
centers = np.mean(boxes, axis=1).astype(np.int32).tolist()
paths, routes_path = min_connect_path(centers)
boxes = boxes[routes_path]
top = np.mean(boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
bot = np.mean(boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
edge1 = select_edge(top + bot[::-1], boxes[0])
edge2 = select_edge(top + bot[::-1], boxes[-1])
if edge1 is not None:
top.insert(0, edge1[0])
bot.insert(0, edge1[1])
if edge2 is not None:
top.append(edge2[0])
bot.append(edge2[1])
boundary_point = np.array(top + bot[::-1])
elif boxes.shape[0] == 1:
top = boxes[0, 0:2, :].astype(np.int32).tolist()
bot = boxes[0, 2:4:-1, :].astype(np.int32).tolist()
boundary_point = np.array(top + bot)
if boundary_point is None:
continue
boundary_point = [p for p in boundary_point.flatten().tolist()]
bbox_contours.append(boundary_point)
return bbox_contours
def merge_text_comps(edges, scores, text_comps, link_thr):
"""Merge text components into text instance."""
clusters = graph_propagation(edges, scores, link_thr)
pred_labels = clusters2labels(clusters, text_comps.shape[0])
text_comps, final_pred = remove_single(text_comps, pred_labels)
boundaries = comps2boundary(text_comps, final_pred)
return boundaries

View File

@ -0,0 +1,88 @@
import torch
import torch.nn.functional as F
from mmcv.cnn import xavier_init
from torch import nn
from mmdet.models.builder import NECKS
class UpBlock(nn.Module):
"""Upsample block for DRRG and TextSnake."""
def __init__(self, in_channels, out_channels):
super().__init__()
assert isinstance(in_channels, int)
assert isinstance(out_channels, int)
self.conv1x1 = nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.conv3x3 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.deconv = nn.ConvTranspose2d(
out_channels, out_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x):
x = F.relu(self.conv1x1(x))
x = F.relu(self.conv3x3(x))
x = self.deconv(x)
return x
@NECKS.register_module()
class FPN_UNET(nn.Module):
"""The class for implementing DRRG and TextSnake U-Net-like FPN.
DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape
Text Detection [https://arxiv.org/abs/2003.07493].
TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes
[https://arxiv.org/abs/1807.01544].
"""
def __init__(self, in_channels, out_channels):
super().__init__()
assert len(in_channels) == 4
assert isinstance(out_channels, int)
blocks_out_channels = [out_channels] + [
min(out_channels * 2**i, 256) for i in range(4)
]
blocks_in_channels = [blocks_out_channels[1]] + [
in_channels[i] + blocks_out_channels[i + 2] for i in range(3)
] + [in_channels[3]]
self.up4 = nn.ConvTranspose2d(
blocks_in_channels[4],
blocks_out_channels[4],
kernel_size=4,
stride=2,
padding=1)
self.up_block3 = UpBlock(blocks_in_channels[3], blocks_out_channels[3])
self.up_block2 = UpBlock(blocks_in_channels[2], blocks_out_channels[2])
self.up_block1 = UpBlock(blocks_in_channels[1], blocks_out_channels[1])
self.up_block0 = UpBlock(blocks_in_channels[0], blocks_out_channels[0])
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
xavier_init(m, distribution='uniform')
def forward(self, x):
c2, c3, c4, c5 = x
x = F.relu(self.up4(c5))
x = torch.cat([x, c4], dim=1)
x = F.relu(self.up_block3(x))
x = torch.cat([x, c3], dim=1)
x = F.relu(self.up_block2(x))
x = torch.cat([x, c2], dim=1)
x = F.relu(self.up_block1(x))
x = self.up_block0(x)
# the output should be of the same height and width as backbone input
return x