mirror of https://github.com/open-mmlab/mmocr.git
Merge pull request #2 from HolyCrap96/feature/textsnake_drrg
[feature]: add textsnake_drrgpull/2/head
commit
3ed6aaa4e4
|
@ -0,0 +1,566 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from numpy.linalg import norm
|
||||
|
||||
import mmocr.utils.check_argument as check_argument
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
# from mmocr.models.textdet import la_nms
|
||||
from .textsnake_targets import TextSnakeTargets
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class DRRGTargets(TextSnakeTargets):
|
||||
"""Generate the ground truth targets of DRRG: Deep Relational Reasoning
|
||||
Graph Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]. This was partially adapted from
|
||||
https://github.com/GXYM/DRRG.
|
||||
|
||||
Args:
|
||||
orientation_thr (float): The threshold for distinguishing between
|
||||
head edge and tail edge among the horizontal and vertical edges
|
||||
of a quadrangle.
|
||||
resample_step (float): The step size for resampling the text center
|
||||
line (TCL). Better not exceed half of the minimum width
|
||||
of the text component.
|
||||
min_comp_num (int): The minimum number of text components, which
|
||||
should be k_hop1 + 1 on graph.
|
||||
max_comp_num (int): The maximum number of text components.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
center_region_shrink_ratio (float): The shrink ratio of text center
|
||||
region.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
text_comp_ratio (float): The reciprocal of aspect ratio of text
|
||||
components.
|
||||
min_rand_half_height(float): The minimum half-height of random text
|
||||
components.
|
||||
max_rand_half_height (float): The maximum half-height of random
|
||||
text components.
|
||||
jitter_level (float): The jitter level of text components geometric
|
||||
features.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
orientation_thr=2.0,
|
||||
resample_step=8.0,
|
||||
min_comp_num=9,
|
||||
max_comp_num=600,
|
||||
min_width=8.0,
|
||||
max_width=24.0,
|
||||
center_region_shrink_ratio=0.3,
|
||||
comp_shrink_ratio=1.0,
|
||||
text_comp_ratio=0.65,
|
||||
text_comp_nms_thr=0.25,
|
||||
min_rand_half_height=6.0,
|
||||
max_rand_half_height=24.0,
|
||||
jitter_level=0.2):
|
||||
|
||||
super().__init__()
|
||||
self.orientation_thr = orientation_thr
|
||||
self.resample_step = resample_step
|
||||
self.max_comp_num = max_comp_num
|
||||
self.min_comp_num = min_comp_num
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.center_region_shrink_ratio = center_region_shrink_ratio
|
||||
self.comp_shrink_ratio = comp_shrink_ratio
|
||||
self.text_comp_ratio = text_comp_ratio
|
||||
self.text_comp_nms_thr = text_comp_nms_thr
|
||||
self.min_rand_half_height = min_rand_half_height
|
||||
self.max_rand_half_height = max_rand_half_height
|
||||
self.jitter_level = jitter_level
|
||||
|
||||
def dist_point2line(self, pnt, line):
|
||||
|
||||
assert isinstance(line, tuple)
|
||||
pnt1, pnt2 = line
|
||||
d = abs(np.cross(pnt2 - pnt1, pnt - pnt1)) / (norm(pnt2 - pnt1) + 1e-8)
|
||||
return d
|
||||
|
||||
def draw_center_region_maps(self, top_line, bot_line, center_line,
|
||||
center_region_mask, top_height_map,
|
||||
bot_height_map, sin_map, cos_map,
|
||||
region_shrink_ratio):
|
||||
"""Draw attributes on text center region.
|
||||
|
||||
Args:
|
||||
top_line (ndarray): The points composing top curved sideline of
|
||||
text polygon.
|
||||
bot_line (ndarray): The points composing bottom curved sideline
|
||||
of text polygon.
|
||||
center_line (ndarray): The points composing the center line of text
|
||||
instance.
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
top_height_map (ndarray): The map on which the distance from point
|
||||
to top sideline will be drawn for each pixel in text center
|
||||
region.
|
||||
bot_height_map (ndarray): The map on which the distance from point
|
||||
to bottom sideline will be drawn for each pixel in text center
|
||||
region.
|
||||
sin_map (ndarray): The map of vector_sin(top_point -bot_point)
|
||||
that will be drawn on text center region.
|
||||
cos_map (ndarray): The map of vector_cos(top_point -bot_point)
|
||||
will be drawn on text center region.
|
||||
region_shrink_ratio (float): The shrink ratio of text center.
|
||||
"""
|
||||
|
||||
assert top_line.shape == bot_line.shape == center_line.shape
|
||||
assert (center_region_mask.shape == top_height_map.shape ==
|
||||
bot_height_map.shape == sin_map.shape == cos_map.shape)
|
||||
assert isinstance(region_shrink_ratio, float)
|
||||
|
||||
h, w = center_region_mask.shape
|
||||
for i in range(0, len(center_line) - 1):
|
||||
|
||||
top_mid_point = (top_line[i] + top_line[i + 1]) / 2
|
||||
bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
|
||||
|
||||
sin_theta = self.vector_sin(top_mid_point - bot_mid_point)
|
||||
cos_theta = self.vector_cos(top_mid_point - bot_mid_point)
|
||||
|
||||
pnt_tl = center_line[i] + (top_line[i] -
|
||||
center_line[i]) * region_shrink_ratio
|
||||
pnt_tr = center_line[i + 1] + (
|
||||
top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
|
||||
pnt_br = center_line[i + 1] + (
|
||||
bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
|
||||
pnt_bl = center_line[i] + (bot_line[i] -
|
||||
center_line[i]) * region_shrink_ratio
|
||||
current_center_box = np.vstack([pnt_tl, pnt_tr, pnt_br,
|
||||
pnt_bl]).astype(np.int32)
|
||||
|
||||
cv2.fillPoly(center_region_mask, [current_center_box], color=1)
|
||||
cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
|
||||
cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
|
||||
|
||||
# x,y order
|
||||
current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0,
|
||||
w - 1)
|
||||
current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0,
|
||||
h - 1)
|
||||
min_coord = np.min(current_center_box, axis=0).astype(np.int32)
|
||||
max_coord = np.max(current_center_box, axis=0).astype(np.int32)
|
||||
current_center_box = current_center_box - min_coord
|
||||
sz = (max_coord - min_coord + 1)
|
||||
|
||||
center_box_mask = np.zeros((sz[1], sz[0]), dtype=np.uint8)
|
||||
cv2.fillPoly(center_box_mask, [current_center_box], color=1)
|
||||
|
||||
inx = np.argwhere(center_box_mask > 0)
|
||||
inx = inx + (min_coord[1], min_coord[0]) # y, x order
|
||||
inx_xy = np.fliplr(inx)
|
||||
top_height_map[(inx[:, 0], inx[:, 1])] = self.dist_point2line(
|
||||
inx_xy, (top_line[i], top_line[i + 1]))
|
||||
bot_height_map[(inx[:, 0], inx[:, 1])] = self.dist_point2line(
|
||||
inx_xy, (bot_line[i], bot_line[i + 1]))
|
||||
|
||||
def generate_center_mask_attrib_maps(self, img_size, text_polys):
|
||||
"""Generate text center region mask and geometry attribute maps.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size of (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
center_lines (list): The list of text center lines.
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
top_height_map (ndarray): The distance map from each pixel in text
|
||||
center region to top sideline.
|
||||
bot_height_map (ndarray): The distance map from each pixel in text
|
||||
center region to bottom sideline.
|
||||
sin_map (ndarray): The sin(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
cos_map (ndarray): The cos(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
|
||||
center_lines = []
|
||||
center_region_mask = np.zeros((h, w), np.uint8)
|
||||
top_height_map = np.zeros((h, w), dtype=np.float32)
|
||||
bot_height_map = np.zeros((h, w), dtype=np.float32)
|
||||
sin_map = np.zeros((h, w), dtype=np.float32)
|
||||
cos_map = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
for poly in text_polys:
|
||||
assert len(poly) == 1
|
||||
polygon_points = poly[0].reshape(-1, 2)
|
||||
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
|
||||
resampled_top_line, resampled_bot_line = self.resample_sidelines(
|
||||
top_line, bot_line, self.resample_step)
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
center_line = (resampled_top_line + resampled_bot_line) / 2
|
||||
|
||||
line_head_shrink_len = np.clip(
|
||||
(norm(top_line[0] - bot_line[0]) * self.text_comp_ratio),
|
||||
self.min_width, self.max_width) / 2
|
||||
line_tail_shrink_len = np.clip(
|
||||
(norm(top_line[-1] - bot_line[-1]) * self.text_comp_ratio),
|
||||
self.min_width, self.max_width) / 2
|
||||
head_shrink_num = int(line_head_shrink_len // self.resample_step)
|
||||
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
|
||||
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
|
||||
center_line = center_line[head_shrink_num:len(center_line) -
|
||||
tail_shrink_num]
|
||||
resampled_top_line = resampled_top_line[
|
||||
head_shrink_num:len(resampled_top_line) - tail_shrink_num]
|
||||
resampled_bot_line = resampled_bot_line[
|
||||
head_shrink_num:len(resampled_bot_line) - tail_shrink_num]
|
||||
center_lines.append(center_line.astype(np.int32))
|
||||
|
||||
self.draw_center_region_maps(resampled_top_line,
|
||||
resampled_bot_line, center_line,
|
||||
center_region_mask, top_height_map,
|
||||
bot_height_map, sin_map, cos_map,
|
||||
self.center_region_shrink_ratio)
|
||||
|
||||
return (center_lines, center_region_mask, top_height_map,
|
||||
bot_height_map, sin_map, cos_map)
|
||||
|
||||
def generate_comp_attribs_from_maps(self, center_lines, center_region_mask,
|
||||
top_height_map, bot_height_map,
|
||||
sin_map, cos_map, comp_shrink_ratio):
|
||||
"""Generate attributes of text components in accordance with text
|
||||
center lines and geometry attribute maps.
|
||||
|
||||
Args:
|
||||
center_lines (list[ndarray]): The list of text center lines.
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
top_height_map (ndarray): The distance map from each pixel in text
|
||||
center region to top sideline.
|
||||
bot_height_map (ndarray): The distance map from each pixel in text
|
||||
center region to bottom sideline.
|
||||
sin_map (ndarray): The sin(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
cos_map (ndarray): The cos(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
comp_shrink_ratio (float): The text component shrink ratio.
|
||||
|
||||
Returns:
|
||||
comp_attribs (ndarray): All text components attributes(x, y, h, w,
|
||||
cos, sin, comp_labels). The comp_labels of two text components
|
||||
from the same text instance are equal.
|
||||
"""
|
||||
|
||||
assert isinstance(center_lines, list)
|
||||
assert (center_region_mask.shape == top_height_map.shape ==
|
||||
bot_height_map.shape == sin_map.shape == cos_map.shape)
|
||||
assert isinstance(comp_shrink_ratio, float)
|
||||
|
||||
center_lines_mask = np.zeros_like(center_region_mask)
|
||||
cv2.polylines(center_lines_mask, center_lines, 0, (1, ), 1)
|
||||
center_lines_mask = center_lines_mask * center_region_mask
|
||||
comp_centers = np.argwhere(center_lines_mask > 0)
|
||||
|
||||
y = comp_centers[:, 0]
|
||||
x = comp_centers[:, 1]
|
||||
|
||||
top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
|
||||
bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
|
||||
top_mid_x_offset = top_height * cos
|
||||
top_mid_y_offset = top_height * sin
|
||||
bot_mid_x_offset = bot_height * cos
|
||||
bot_mid_y_offset = bot_height * sin
|
||||
|
||||
top_mid_pnt = comp_centers + np.hstack(
|
||||
[top_mid_y_offset, top_mid_x_offset])
|
||||
bot_mid_pnt = comp_centers - np.hstack(
|
||||
[bot_mid_y_offset, bot_mid_x_offset])
|
||||
|
||||
width = (top_height + bot_height) * self.text_comp_ratio
|
||||
width = np.clip(width, self.min_width, self.max_width)
|
||||
|
||||
top_left = (top_mid_pnt -
|
||||
np.hstack([width * cos, -width * sin]))[:, ::-1]
|
||||
top_right = (top_mid_pnt +
|
||||
np.hstack([width * cos, -width * sin]))[:, ::-1]
|
||||
bot_right = (bot_mid_pnt +
|
||||
np.hstack([width * cos, -width * sin]))[:, ::-1]
|
||||
bot_left = (bot_mid_pnt -
|
||||
np.hstack([width * cos, -width * sin]))[:, ::-1]
|
||||
text_comps = np.hstack([top_left, top_right, bot_right, bot_left])
|
||||
|
||||
score = center_lines_mask[y, x].reshape((-1, 1))
|
||||
text_comps = np.hstack([text_comps, score]).astype(np.float32)
|
||||
# text_comps = la_nms(text_comps, self.text_comp_nms_thr)
|
||||
|
||||
if text_comps.shape[0] < 1:
|
||||
return None
|
||||
|
||||
img_h, img_w = center_region_mask.shape
|
||||
text_comps[:, 0:8:2] = np.clip(text_comps[:, 0:8:2], 0, img_w - 1)
|
||||
text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1)
|
||||
|
||||
comp_centers = np.mean(
|
||||
text_comps[:, 0:8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
|
||||
x = comp_centers[:, 0]
|
||||
y = comp_centers[:, 1]
|
||||
|
||||
height = (top_height_map[y, x] + bot_height_map[y, x]).reshape((-1, 1))
|
||||
width = np.clip(height * self.text_comp_ratio, self.min_width,
|
||||
self.max_width)
|
||||
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
|
||||
_, comp_label_mask = cv2.connectedComponents(
|
||||
center_region_mask.astype(np.uint8), connectivity=8)
|
||||
comp_labels = comp_label_mask[y, x].reshape((-1, 1))
|
||||
|
||||
x = x.reshape((-1, 1))
|
||||
y = y.reshape((-1, 1))
|
||||
comp_attribs = np.hstack([x, y, height, width, cos, sin, comp_labels])
|
||||
|
||||
return comp_attribs
|
||||
|
||||
def generate_rand_comp_attribs(self, comp_num, center_sample_mask):
|
||||
"""Generate random text components and their attributes to ensure the
|
||||
the number text components in a text image is larger than k_hop1, which
|
||||
is the number of one hop neighbors in KNN graph.
|
||||
|
||||
Args:
|
||||
comp_num (int): The number of random text components.
|
||||
center_sample_mask (ndarray): The text component centers sampling
|
||||
region mask.
|
||||
|
||||
Returns:
|
||||
rand_comp_attribs (ndarray): The random text components
|
||||
attributes(x, y, h, w, cos, sin, belong_instance_label=0).
|
||||
"""
|
||||
|
||||
assert isinstance(comp_num, int)
|
||||
assert comp_num > 0
|
||||
assert center_sample_mask.ndim == 2
|
||||
|
||||
h, w = center_sample_mask.shape
|
||||
|
||||
max_rand_half_height = self.max_rand_half_height
|
||||
min_rand_half_height = self.min_rand_half_height
|
||||
max_rand_height = max_rand_half_height * 2
|
||||
max_rand_width = np.clip(max_rand_height * self.text_comp_ratio,
|
||||
self.min_width, self.max_width)
|
||||
margin = int(
|
||||
np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
|
||||
|
||||
if 2 * margin + 1 > min(h, w):
|
||||
|
||||
assert min(h, w) > (np.sqrt(2) * (self.min_width + 1))
|
||||
max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1)
|
||||
min_rand_half_height = max(max_rand_half_height / 4,
|
||||
self.min_width / 2)
|
||||
|
||||
max_rand_height = max_rand_half_height * 2
|
||||
max_rand_width = np.clip(max_rand_height * self.text_comp_ratio,
|
||||
self.min_width, self.max_width)
|
||||
margin = int(
|
||||
np.sqrt((max_rand_height / 2)**2 +
|
||||
(max_rand_width / 2)**2)) + 1
|
||||
|
||||
inner_center_sample_mask = np.zeros_like(center_sample_mask)
|
||||
inner_center_sample_mask[margin:h-margin, margin:w-margin] = \
|
||||
center_sample_mask[margin:h - margin, margin:w - margin]
|
||||
kernel_size = int(min(max(max_rand_half_height, 7), 17))
|
||||
inner_center_sample_mask = cv2.erode(
|
||||
inner_center_sample_mask,
|
||||
np.ones((kernel_size, kernel_size), np.uint8))
|
||||
|
||||
center_candidates = np.argwhere(inner_center_sample_mask > 0)
|
||||
center_candidate_num = len(center_candidates)
|
||||
sample_inx = np.random.choice(center_candidate_num, comp_num)
|
||||
rand_centers = center_candidates[sample_inx]
|
||||
|
||||
rand_top_height = np.random.randint(
|
||||
min_rand_half_height,
|
||||
max_rand_half_height,
|
||||
size=(len(rand_centers), 1))
|
||||
rand_bot_height = np.random.randint(
|
||||
min_rand_half_height,
|
||||
max_rand_half_height,
|
||||
size=(len(rand_centers), 1))
|
||||
|
||||
rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
|
||||
rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
|
||||
scale = np.sqrt(1.0 / (rand_cos**2 + rand_sin**2 + 1e-8))
|
||||
rand_cos = rand_cos * scale
|
||||
rand_sin = rand_sin * scale
|
||||
|
||||
height = (rand_top_height + rand_bot_height)
|
||||
width = np.clip(height * self.text_comp_ratio, self.min_width,
|
||||
self.max_width)
|
||||
|
||||
rand_comp_attribs = np.hstack([
|
||||
rand_centers[:, ::-1], height, width, rand_cos, rand_sin,
|
||||
np.zeros_like(rand_sin)
|
||||
])
|
||||
|
||||
return rand_comp_attribs
|
||||
|
||||
def jitter_comp_attribs(self, comp_attribs, jitter_level):
|
||||
"""Jitter text components attributes.
|
||||
|
||||
Args:
|
||||
comp_attribs (ndarray): The text components attributes.
|
||||
jitter_level (float): The jitter level of text components
|
||||
attributes.
|
||||
|
||||
Returns:
|
||||
jittered_comp_attribs (ndarray): The jittered text components
|
||||
attributes(x, y, h, w, cos, sin, belong_instance_label).
|
||||
"""
|
||||
|
||||
assert comp_attribs.shape[1] == 7
|
||||
assert comp_attribs.shape[0] > 0
|
||||
assert isinstance(jitter_level, float)
|
||||
|
||||
x = comp_attribs[:, 0].reshape((-1, 1))
|
||||
y = comp_attribs[:, 1].reshape((-1, 1))
|
||||
h = comp_attribs[:, 2].reshape((-1, 1))
|
||||
w = comp_attribs[:, 3].reshape((-1, 1))
|
||||
cos = comp_attribs[:, 4].reshape((-1, 1))
|
||||
sin = comp_attribs[:, 5].reshape((-1, 1))
|
||||
belong_label = comp_attribs[:, 6].reshape((-1, 1))
|
||||
|
||||
# max jitter offset of (x, y) should be
|
||||
# ((h * abs(cos) + w * abs(sin)) / 2,
|
||||
# (h * abs(sin) + w * abs(cos)) / 2)
|
||||
x += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * (h * np.abs(cos) + w * np.abs(sin)) * jitter_level
|
||||
y += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * (h * np.abs(sin) + w * np.abs(cos)) * jitter_level
|
||||
|
||||
# max jitter offset of (h, w) should be (h, w)
|
||||
h += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * h * jitter_level
|
||||
w += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * w * jitter_level
|
||||
|
||||
# max jitter offset of (cos, sin) should be (1, 1)
|
||||
cos += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * 2 * jitter_level
|
||||
sin += (np.random.random(size=(len(comp_attribs), 1)) -
|
||||
0.5) * 2 * jitter_level
|
||||
|
||||
scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8))
|
||||
cos = cos * scale
|
||||
sin = sin * scale
|
||||
|
||||
jittered_comp_attribs = np.hstack([x, y, h, w, cos, sin, belong_label])
|
||||
|
||||
return jittered_comp_attribs
|
||||
|
||||
def generate_comp_attribs(self, center_lines, text_mask,
|
||||
center_region_mask, top_height_map,
|
||||
bot_height_map, sin_map, cos_map):
|
||||
"""Generate text components attributes.
|
||||
|
||||
Args:
|
||||
center_lines (list[ndarray]): The text center lines list.
|
||||
text_mask (ndarray): The text region mask.
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
top_height_map (ndarray): The distance map from each pixel in text
|
||||
center region to top sideline.
|
||||
bot_height_map (ndarray): The distance map from each pixel in text
|
||||
center region to bottom sideline.
|
||||
sin_map (ndarray): The sin(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
cos_map (ndarray): The cos(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
|
||||
Returns:
|
||||
pad_comp_attribs (ndarray): The padded text components attributes
|
||||
with a fixed size.
|
||||
"""
|
||||
|
||||
assert isinstance(center_lines, list)
|
||||
assert (text_mask.shape == center_region_mask.shape ==
|
||||
top_height_map.shape == bot_height_map.shape == sin_map.shape
|
||||
== cos_map.shape)
|
||||
|
||||
comp_attribs = self.generate_comp_attribs_from_maps(
|
||||
center_lines, center_region_mask, top_height_map, bot_height_map,
|
||||
sin_map, cos_map, self.comp_shrink_ratio)
|
||||
|
||||
if comp_attribs is not None:
|
||||
comp_attribs = self.jitter_comp_attribs(comp_attribs,
|
||||
self.jitter_level)
|
||||
|
||||
if comp_attribs.shape[0] < self.min_comp_num:
|
||||
rand_sample_num = self.min_comp_num - comp_attribs.shape[0]
|
||||
rand_comp_attribs = self.generate_rand_comp_attribs(
|
||||
rand_sample_num, 1 - text_mask)
|
||||
comp_attribs = np.vstack([comp_attribs, rand_comp_attribs])
|
||||
else:
|
||||
comp_attribs = self.generate_rand_comp_attribs(
|
||||
self.min_comp_num, 1 - text_mask)
|
||||
|
||||
comp_num = (
|
||||
np.ones((comp_attribs.shape[0], 1), dtype=np.float32) *
|
||||
comp_attribs.shape[0])
|
||||
comp_attribs = np.hstack([comp_num, comp_attribs])
|
||||
|
||||
if comp_attribs.shape[0] > self.max_comp_num:
|
||||
comp_attribs = comp_attribs[:self.max_comp_num, :]
|
||||
comp_attribs[:, 0] = self.max_comp_num
|
||||
|
||||
pad_comp_attribs = np.zeros((self.max_comp_num, comp_attribs.shape[1]))
|
||||
pad_comp_attribs[:comp_attribs.shape[0], :] = comp_attribs
|
||||
|
||||
return pad_comp_attribs
|
||||
|
||||
def generate_targets(self, results):
|
||||
"""Generate the gt targets for DRRG.
|
||||
|
||||
Args:
|
||||
results (dict): The input result dictionary.
|
||||
|
||||
Returns:
|
||||
results (dict): The output result dictionary.
|
||||
"""
|
||||
|
||||
assert isinstance(results, dict)
|
||||
|
||||
polygon_masks = results['gt_masks'].masks
|
||||
polygon_masks_ignore = results['gt_masks_ignore'].masks
|
||||
|
||||
h, w, _ = results['img_shape']
|
||||
|
||||
gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
|
||||
gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
|
||||
(center_lines, gt_center_region_mask, gt_top_height_map,
|
||||
gt_bot_height_map, gt_sin_map,
|
||||
gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
|
||||
polygon_masks)
|
||||
|
||||
gt_comp_attribs = self.generate_comp_attribs(center_lines,
|
||||
gt_text_mask,
|
||||
gt_center_region_mask,
|
||||
gt_top_height_map,
|
||||
gt_bot_height_map,
|
||||
gt_sin_map, gt_cos_map)
|
||||
|
||||
results['mask_fields'].clear() # rm gt_masks encoded by polygons
|
||||
mapping = {
|
||||
'gt_text_mask': gt_text_mask,
|
||||
'gt_center_region_mask': gt_center_region_mask,
|
||||
'gt_mask': gt_mask,
|
||||
'gt_top_height_map': gt_top_height_map,
|
||||
'gt_bot_height_map': gt_bot_height_map,
|
||||
'gt_sin_map': gt_sin_map,
|
||||
'gt_cos_map': gt_cos_map
|
||||
}
|
||||
for key, value in mapping.items():
|
||||
value = value if isinstance(value, list) else [value]
|
||||
results[key] = BitmapMasks(value, h, w)
|
||||
results['mask_fields'].append(key)
|
||||
|
||||
results['gt_comp_attribs'] = gt_comp_attribs
|
||||
return results
|
|
@ -0,0 +1,453 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from numpy.linalg import norm
|
||||
|
||||
import mmocr.utils.check_argument as check_argument
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from . import BaseTextDetTargets
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class TextSnakeTargets(BaseTextDetTargets):
|
||||
"""Generate the ground truth targets of TextSnake: TextSnake: A Flexible
|
||||
Representation for Detecting Text of Arbitrary Shapes.
|
||||
|
||||
[https://arxiv.org/abs/1807.01544]. This was partially adapted from
|
||||
https://github.com/princewang1994/TextSnake.pytorch.
|
||||
|
||||
Args:
|
||||
orientation_thr (float): The threshold for distinguishing between
|
||||
head edge and tail edge among the horizontal and vertical edges
|
||||
of a quadrangle.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
orientation_thr=2.0,
|
||||
resample_step=4.0,
|
||||
center_region_shrink_ratio=0.3):
|
||||
|
||||
super().__init__()
|
||||
self.orientation_thr = orientation_thr
|
||||
self.resample_step = resample_step
|
||||
self.center_region_shrink_ratio = center_region_shrink_ratio
|
||||
|
||||
def vector_angle(self, vec1, vec2):
|
||||
if vec1.ndim > 1:
|
||||
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
|
||||
else:
|
||||
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
|
||||
if vec2.ndim > 1:
|
||||
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
|
||||
else:
|
||||
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
|
||||
return np.arccos(
|
||||
np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
|
||||
|
||||
def vector_slope(self, vec):
|
||||
assert len(vec) == 2
|
||||
return abs(vec[1] / (vec[0] + 1e-8))
|
||||
|
||||
def vector_sin(self, vec):
|
||||
assert len(vec) == 2
|
||||
return vec[1] / (norm(vec) + 1e-8)
|
||||
|
||||
def vector_cos(self, vec):
|
||||
assert len(vec) == 2
|
||||
return vec[0] / (norm(vec) + 1e-8)
|
||||
|
||||
def find_head_tail(self, points, orientation_thr):
|
||||
"""Find the head edge and tail edge of a text polygon.
|
||||
|
||||
Args:
|
||||
points (ndarray): The points composing a text polygon.
|
||||
orientation_thr (float): The threshold for distinguishing between
|
||||
head edge and tail edge among the horizontal and vertical edges
|
||||
of a quadrangle.
|
||||
|
||||
Returns:
|
||||
head_inds (list): The indexes of two points composing head edge.
|
||||
tail_inds (list): The indexes of two points composing tail edge.
|
||||
"""
|
||||
|
||||
assert points.ndim == 2
|
||||
assert points.shape[0] >= 4
|
||||
assert points.shape[1] == 2
|
||||
assert isinstance(orientation_thr, float)
|
||||
|
||||
if len(points) > 4:
|
||||
pad_points = np.vstack([points, points[0]])
|
||||
edge_vec = pad_points[1:] - pad_points[:-1]
|
||||
|
||||
theta_sum = []
|
||||
|
||||
for i, edge_vec1 in enumerate(edge_vec):
|
||||
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
|
||||
adjacent_edge_vec = edge_vec[adjacent_ind]
|
||||
temp_theta_sum = np.sum(
|
||||
self.vector_angle(edge_vec1, adjacent_edge_vec))
|
||||
theta_sum.append(temp_theta_sum)
|
||||
theta_sum = np.array(theta_sum)
|
||||
head_start, tail_start = np.argsort(theta_sum)[::-1][0:2]
|
||||
|
||||
if abs(head_start - tail_start) < 2 \
|
||||
or abs(head_start - tail_start) > 12:
|
||||
tail_start = (head_start + len(points) // 2) % len(points)
|
||||
head_end = (head_start + 1) % len(points)
|
||||
tail_end = (tail_start + 1) % len(points)
|
||||
|
||||
if head_end > tail_end:
|
||||
head_start, tail_start = tail_start, head_start
|
||||
head_end, tail_end = tail_end, head_end
|
||||
head_inds = [head_start, head_end]
|
||||
tail_inds = [tail_start, tail_end]
|
||||
else:
|
||||
if self.vector_slope(points[1] - points[0]) + self.vector_slope(
|
||||
points[3] - points[2]) < self.vector_slope(
|
||||
points[2] - points[1]) + self.vector_slope(points[0] -
|
||||
points[3]):
|
||||
horizontal_edge_inds = [[0, 1], [2, 3]]
|
||||
vertical_edge_inds = [[3, 0], [1, 2]]
|
||||
else:
|
||||
horizontal_edge_inds = [[3, 0], [1, 2]]
|
||||
vertical_edge_inds = [[0, 1], [2, 3]]
|
||||
|
||||
vertical_len_sum = norm(points[vertical_edge_inds[0][0]] -
|
||||
points[vertical_edge_inds[0][1]]) + norm(
|
||||
points[vertical_edge_inds[1][0]] -
|
||||
points[vertical_edge_inds[1][1]])
|
||||
horizontal_len_sum = norm(
|
||||
points[horizontal_edge_inds[0][0]] -
|
||||
points[horizontal_edge_inds[0][1]]) + norm(
|
||||
points[horizontal_edge_inds[1][0]] -
|
||||
points[horizontal_edge_inds[1][1]])
|
||||
|
||||
if vertical_len_sum > horizontal_len_sum * orientation_thr:
|
||||
head_inds = horizontal_edge_inds[0]
|
||||
tail_inds = horizontal_edge_inds[1]
|
||||
else:
|
||||
head_inds = vertical_edge_inds[0]
|
||||
tail_inds = vertical_edge_inds[1]
|
||||
|
||||
return head_inds, tail_inds
|
||||
|
||||
def reorder_poly_edge(self, points):
|
||||
"""Get the respective points composing head edge, tail edge, top
|
||||
sideline and bottom sideline.
|
||||
|
||||
Args:
|
||||
points (ndarray): The points composing a text polygon.
|
||||
|
||||
Returns:
|
||||
head_edge (ndarray): The two points composing the head edge of text
|
||||
polygon.
|
||||
tail_edge (ndarray): The two points composing the tail edge of text
|
||||
polygon.
|
||||
top_sideline (ndarray): The points composing top curved sideline of
|
||||
text polygon.
|
||||
bot_sideline (ndarray): The points composing bottom curved sideline
|
||||
of text polygon.
|
||||
"""
|
||||
|
||||
assert points.ndim == 2
|
||||
assert points.shape[0] >= 4
|
||||
assert points.shape[1] == 2
|
||||
|
||||
head_inds, tail_inds = self.find_head_tail(points,
|
||||
self.orientation_thr)
|
||||
head_edge, tail_edge = points[head_inds], points[tail_inds]
|
||||
|
||||
pad_points = np.vstack([points, points])
|
||||
if tail_inds[1] < 1:
|
||||
tail_inds[1] = len(points)
|
||||
sideline1 = pad_points[head_inds[1]:tail_inds[1]]
|
||||
sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
|
||||
sideline_mean_shift = np.mean(
|
||||
sideline1, axis=0) - np.mean(
|
||||
sideline2, axis=0)
|
||||
|
||||
if sideline_mean_shift[1] > 0:
|
||||
top_sideline, bot_sideline = sideline2, sideline1
|
||||
else:
|
||||
top_sideline, bot_sideline = sideline1, sideline2
|
||||
|
||||
return head_edge, tail_edge, top_sideline, bot_sideline
|
||||
|
||||
def resample_line(self, line, n):
|
||||
"""Resample n points on a line.
|
||||
|
||||
Args:
|
||||
line (ndarray): The points composing a line.
|
||||
n (int): The resampled points number.
|
||||
|
||||
Returns:
|
||||
resampled_line (ndarray): The points composing the resampled line.
|
||||
"""
|
||||
|
||||
assert line.ndim == 2
|
||||
assert line.shape[0] >= 2
|
||||
assert line.shape[1] == 2
|
||||
assert isinstance(n, int)
|
||||
|
||||
length_list = [
|
||||
norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
|
||||
]
|
||||
total_length = sum(length_list)
|
||||
length_cumsum = np.cumsum([0.0] + length_list)
|
||||
delta_length = total_length / (float(n) + 1e-8)
|
||||
|
||||
current_edge_ind = 0
|
||||
resampled_line = [line[0]]
|
||||
|
||||
for i in range(1, n):
|
||||
current_line_len = i * delta_length
|
||||
|
||||
while current_line_len >= length_cumsum[current_edge_ind + 1]:
|
||||
current_edge_ind += 1
|
||||
current_edge_end_shift = current_line_len - length_cumsum[
|
||||
current_edge_ind]
|
||||
end_shift_ratio = current_edge_end_shift / length_list[
|
||||
current_edge_ind]
|
||||
current_point = line[current_edge_ind] + (
|
||||
line[current_edge_ind + 1] -
|
||||
line[current_edge_ind]) * end_shift_ratio
|
||||
resampled_line.append(current_point)
|
||||
|
||||
resampled_line.append(line[-1])
|
||||
resampled_line = np.array(resampled_line)
|
||||
|
||||
return resampled_line
|
||||
|
||||
def resample_sidelines(self, sideline1, sideline2, resample_step):
|
||||
"""Resample two sidelines to be of the same points number according to
|
||||
step size.
|
||||
|
||||
Args:
|
||||
sideline1 (ndarray): The points composing a sideline of a text
|
||||
polygon.
|
||||
sideline2 (ndarray): The points composing another sideline of a
|
||||
text polygon.
|
||||
resample_step (float): The resampled step size.
|
||||
|
||||
Returns:
|
||||
resampled_line1 (ndarray): The resampled line 1.
|
||||
resampled_line2 (ndarray): The resampled line 2.
|
||||
"""
|
||||
|
||||
assert sideline1.ndim == sideline1.ndim == 2
|
||||
assert sideline1.shape[1] == sideline1.shape[1] == 2
|
||||
assert sideline1.shape[0] >= 2
|
||||
assert sideline2.shape[0] >= 2
|
||||
assert isinstance(resample_step, float)
|
||||
|
||||
length1 = sum([
|
||||
norm(sideline1[i + 1] - sideline1[i])
|
||||
for i in range(len(sideline1) - 1)
|
||||
])
|
||||
length2 = sum([
|
||||
norm(sideline2[i + 1] - sideline2[i])
|
||||
for i in range(len(sideline2) - 1)
|
||||
])
|
||||
|
||||
total_length = (length1 + length2) / 2
|
||||
resample_point_num = int(float(total_length) / resample_step)
|
||||
|
||||
resampled_line1 = self.resample_line(sideline1, resample_point_num)
|
||||
resampled_line2 = self.resample_line(sideline2, resample_point_num)
|
||||
|
||||
return resampled_line1, resampled_line2
|
||||
|
||||
def draw_center_region_maps(self, top_line, bot_line, center_line,
|
||||
center_region_mask, radius_map, sin_map,
|
||||
cos_map, region_shrink_ratio):
|
||||
"""Draw attributes on text center region.
|
||||
|
||||
Args:
|
||||
top_line (ndarray): The points composing top curved sideline of
|
||||
text polygon.
|
||||
bot_line (ndarray): The points composing bottom curved sideline
|
||||
of text polygon.
|
||||
center_line (ndarray): The points composing the center line of text
|
||||
instance.
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
radius_map (ndarray): The map where the distance from point to
|
||||
sidelines will be drawn on for each pixel in text center
|
||||
region.
|
||||
sin_map (ndarray): The map where vector_sin(theta) will be drawn
|
||||
on text center regions. Theta is the angle between tangent
|
||||
line and vector (1, 0).
|
||||
cos_map (ndarray): The map where vector_cos(theta) will be drawn on
|
||||
text center regions. Theta is the angle between tangent line
|
||||
and vector (1, 0).
|
||||
region_shrink_ratio (float): The shrink ratio of text center.
|
||||
"""
|
||||
|
||||
assert top_line.shape == bot_line.shape == center_line.shape
|
||||
assert (center_region_mask.shape == radius_map.shape == sin_map.shape
|
||||
== cos_map.shape)
|
||||
assert isinstance(region_shrink_ratio, float)
|
||||
for i in range(0, len(center_line) - 1):
|
||||
|
||||
top_mid_point = (top_line[i] + top_line[i + 1]) / 2
|
||||
bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
|
||||
radius = norm(top_mid_point - bot_mid_point) / 2
|
||||
|
||||
text_direction = center_line[i + 1] - center_line[i]
|
||||
sin_theta = self.vector_sin(text_direction)
|
||||
cos_theta = self.vector_cos(text_direction)
|
||||
|
||||
pnt_tl = center_line[i] + (top_line[i] -
|
||||
center_line[i]) * region_shrink_ratio
|
||||
pnt_tr = center_line[i + 1] + (
|
||||
top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
|
||||
pnt_br = center_line[i + 1] + (
|
||||
bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
|
||||
pnt_bl = center_line[i] + (bot_line[i] -
|
||||
center_line[i]) * region_shrink_ratio
|
||||
current_center_box = np.vstack([pnt_tl, pnt_tr, pnt_br,
|
||||
pnt_bl]).astype(np.int32)
|
||||
|
||||
cv2.fillPoly(center_region_mask, [current_center_box], color=1)
|
||||
cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
|
||||
cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
|
||||
cv2.fillPoly(radius_map, [current_center_box], color=radius)
|
||||
|
||||
def generate_center_mask_attrib_maps(self, img_size, text_polys):
|
||||
"""Generate text center region mask and geometric attribute maps.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size of (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
radius_map (ndarray): The distance map from each pixel in text
|
||||
center region to top sideline.
|
||||
sin_map (ndarray): The sin(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
cos_map (ndarray): The cos(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
|
||||
center_region_mask = np.zeros((h, w), np.uint8)
|
||||
radius_map = np.zeros((h, w), dtype=np.float32)
|
||||
sin_map = np.zeros((h, w), dtype=np.float32)
|
||||
cos_map = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
for poly in text_polys:
|
||||
assert len(poly) == 1
|
||||
text_instance = [[poly[0][i], poly[0][i + 1]]
|
||||
for i in range(0, len(poly[0]), 2)]
|
||||
polygon_points = np.array(
|
||||
text_instance, dtype=np.int).reshape(-1, 2)
|
||||
|
||||
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
|
||||
resampled_top_line, resampled_bot_line = self.resample_sidelines(
|
||||
top_line, bot_line, self.resample_step)
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
center_line = (resampled_top_line + resampled_bot_line) / 2
|
||||
|
||||
if self.vector_slope(center_line[-1] - center_line[0]) > 0.9:
|
||||
if (center_line[-1] - center_line[0])[1] < 0:
|
||||
center_line = center_line[::-1]
|
||||
resampled_top_line = resampled_top_line[::-1]
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
else:
|
||||
if (center_line[-1] - center_line[0])[0] < 0:
|
||||
center_line = center_line[::-1]
|
||||
resampled_top_line = resampled_top_line[::-1]
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
|
||||
line_head_shrink_len = norm(resampled_top_line[0] -
|
||||
resampled_bot_line[0]) / 4.0
|
||||
line_tail_shrink_len = norm(resampled_top_line[-1] -
|
||||
resampled_bot_line[-1]) / 4.0
|
||||
head_shrink_num = int(line_head_shrink_len // self.resample_step)
|
||||
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
|
||||
|
||||
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
|
||||
center_line = center_line[head_shrink_num:len(center_line) -
|
||||
tail_shrink_num]
|
||||
resampled_top_line = resampled_top_line[
|
||||
head_shrink_num:len(resampled_top_line) - tail_shrink_num]
|
||||
resampled_bot_line = resampled_bot_line[
|
||||
head_shrink_num:len(resampled_bot_line) - tail_shrink_num]
|
||||
|
||||
self.draw_center_region_maps(resampled_top_line,
|
||||
resampled_bot_line, center_line,
|
||||
center_region_mask, radius_map,
|
||||
sin_map, cos_map,
|
||||
self.center_region_shrink_ratio)
|
||||
|
||||
return center_region_mask, radius_map, sin_map, cos_map
|
||||
|
||||
def generate_text_region_mask(self, img_size, text_polys):
|
||||
"""Generate text center region mask and geometry attribute maps.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
text_region_mask (ndarray): The text region mask.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
text_region_mask = np.zeros((h, w), dtype=np.uint8)
|
||||
|
||||
for poly in text_polys:
|
||||
assert len(poly) == 1
|
||||
text_instance = [[poly[0][i], poly[0][i + 1]]
|
||||
for i in range(0, len(poly[0]), 2)]
|
||||
polygon = np.array(text_instance, dtype=np.int).reshape((1, -1, 2))
|
||||
cv2.fillPoly(text_region_mask, polygon, 1)
|
||||
|
||||
return text_region_mask
|
||||
|
||||
def generate_targets(self, results):
|
||||
"""Generate the gt targets for TextSnake.
|
||||
|
||||
Args:
|
||||
results (dict): The input result dictionary.
|
||||
|
||||
Returns:
|
||||
results (dict): The output result dictionary.
|
||||
"""
|
||||
|
||||
assert isinstance(results, dict)
|
||||
|
||||
polygon_masks = results['gt_masks'].masks
|
||||
polygon_masks_ignore = results['gt_masks_ignore'].masks
|
||||
|
||||
h, w, _ = results['img_shape']
|
||||
|
||||
gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
|
||||
gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
|
||||
|
||||
(gt_center_region_mask, gt_radius_map, gt_sin_map,
|
||||
gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
|
||||
polygon_masks)
|
||||
|
||||
results['mask_fields'].clear() # rm gt_masks encoded by polygons
|
||||
mapping = {
|
||||
'gt_text_mask': gt_text_mask,
|
||||
'gt_center_region_mask': gt_center_region_mask,
|
||||
'gt_mask': gt_mask,
|
||||
'gt_radius_map': gt_radius_map,
|
||||
'gt_sin_map': gt_sin_map,
|
||||
'gt_cos_map': gt_cos_map
|
||||
}
|
||||
for key, value in mapping.items():
|
||||
value = value if isinstance(value, list) else [value]
|
||||
results[key] = BitmapMasks(value, h, w)
|
||||
results['mask_fields'].append(key)
|
||||
|
||||
return results
|
|
@ -0,0 +1,193 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import normal_init
|
||||
|
||||
from mmdet.models.builder import HEADS, build_loss
|
||||
from mmocr.models.textdet.modules import (GCN, LocalGraphs,
|
||||
ProposalLocalGraphs,
|
||||
merge_text_comps)
|
||||
from .head_mixin import HeadMixin
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class DRRGHead(HeadMixin, nn.Module):
|
||||
"""The class for DRRG head: Deep Relational Reasoning Graph Network for
|
||||
Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]
|
||||
|
||||
Args:
|
||||
k_at_hops (tuple(int)): The number of i-hop neighbors,
|
||||
i = 1, 2, ..., h.
|
||||
active_connection (int): The number of two hop neighbors deem as
|
||||
linked to a pivot.
|
||||
node_geo_feat_dim (int): The dimension of embedded geometric features
|
||||
of a component.
|
||||
pooling_scale (float): The spatial scale of RRoI-Aligning.
|
||||
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
|
||||
graph_filter_thr (float): The threshold to filter identical local
|
||||
graphs.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
nms_thr (float): The locality-aware NMS threshold.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
comp_ratio (float): The reciprocal of aspect ratio of text components.
|
||||
text_region_thr (float): The threshold for text region probability map.
|
||||
center_region_thr (float): The threshold of text center region
|
||||
probability map.
|
||||
center_region_area_thr (int): The threshold of filtering small-size
|
||||
text center region.
|
||||
link_thr (float): The threshold for connected components searching.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
k_at_hops=(8, 4),
|
||||
active_connection=3,
|
||||
node_geo_feat_dim=120,
|
||||
pooling_scale=1.0,
|
||||
pooling_output_size=(3, 4),
|
||||
graph_filter_thr=0.75,
|
||||
comp_shrink_ratio=0.95,
|
||||
nms_thr=0.25,
|
||||
min_width=8.0,
|
||||
max_width=24.0,
|
||||
comp_ratio=0.65,
|
||||
text_region_thr=0.6,
|
||||
center_region_thr=0.4,
|
||||
center_region_area_thr=100,
|
||||
link_thr=0.85,
|
||||
loss=dict(type='DRRGLoss'),
|
||||
train_cfg=None,
|
||||
test_cfg=None):
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(k_at_hops, tuple)
|
||||
assert isinstance(active_connection, int)
|
||||
assert isinstance(node_geo_feat_dim, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert isinstance(pooling_output_size, tuple)
|
||||
assert isinstance(graph_filter_thr, float)
|
||||
assert isinstance(comp_shrink_ratio, float)
|
||||
assert isinstance(nms_thr, float)
|
||||
assert isinstance(min_width, float)
|
||||
assert isinstance(max_width, float)
|
||||
assert isinstance(comp_ratio, float)
|
||||
assert isinstance(center_region_area_thr, int)
|
||||
assert isinstance(link_thr, float)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = 6
|
||||
self.downsample_ratio = 1.0
|
||||
self.k_at_hops = k_at_hops
|
||||
self.active_connection = active_connection
|
||||
self.node_geo_feat_dim = node_geo_feat_dim
|
||||
self.pooling_scale = pooling_scale
|
||||
self.pooling_output_size = pooling_output_size
|
||||
self.graph_filter_thr = graph_filter_thr
|
||||
self.comp_shrink_ratio = comp_shrink_ratio
|
||||
self.nms_thr = nms_thr
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.comp_ratio = comp_ratio
|
||||
self.text_region_thr = text_region_thr
|
||||
self.center_region_thr = center_region_thr
|
||||
self.center_region_area_thr = center_region_area_thr
|
||||
self.link_thr = link_thr
|
||||
self.loss_module = build_loss(loss)
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
self.out_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.init_weights()
|
||||
|
||||
self.graph_train = LocalGraphs(self.k_at_hops, self.active_connection,
|
||||
self.node_geo_feat_dim,
|
||||
self.pooling_scale,
|
||||
self.pooling_output_size,
|
||||
self.graph_filter_thr)
|
||||
|
||||
self.graph_test = ProposalLocalGraphs(
|
||||
self.k_at_hops, self.active_connection, self.node_geo_feat_dim,
|
||||
self.pooling_scale, self.pooling_output_size, self.nms_thr,
|
||||
self.min_width, self.max_width, self.comp_shrink_ratio,
|
||||
self.comp_ratio, self.text_region_thr, self.center_region_thr,
|
||||
self.center_region_area_thr)
|
||||
|
||||
pool_w, pool_h = self.pooling_output_size
|
||||
gcn_in_dim = (pool_w * pool_h) * (
|
||||
self.in_channels + self.out_channels) + self.node_geo_feat_dim
|
||||
self.gcn = GCN(gcn_in_dim, 32)
|
||||
|
||||
def init_weights(self):
|
||||
normal_init(self.out_conv, mean=0, std=0.01)
|
||||
|
||||
def forward(self, inputs, text_comp_feats):
|
||||
|
||||
pred_maps = self.out_conv(inputs)
|
||||
|
||||
feat_maps = torch.cat([inputs, pred_maps], dim=1)
|
||||
node_feats, adjacent_matrices, knn_inx, gt_labels = self.graph_train(
|
||||
feat_maps, np.array(text_comp_feats))
|
||||
|
||||
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inx)
|
||||
|
||||
return (pred_maps, (gcn_pred, gt_labels))
|
||||
|
||||
def single_test(self, feat_maps):
|
||||
|
||||
pred_maps = self.out_conv(feat_maps)
|
||||
feat_maps = torch.cat([feat_maps[0], pred_maps], dim=1)
|
||||
|
||||
none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
|
||||
|
||||
(node_feats, adjacent_matrix, pivot_inx, knn_inx, local_graph_nodes,
|
||||
text_comps) = graph_data
|
||||
|
||||
if none_flag:
|
||||
return None, None, None
|
||||
|
||||
adjacent_matrix, pivot_inx, knn_inx = map(
|
||||
lambda x: x.to(feat_maps.device),
|
||||
(adjacent_matrix, pivot_inx, knn_inx))
|
||||
gcn_pred = self.gcn_model(node_feats, adjacent_matrix, knn_inx)
|
||||
|
||||
pred_labels = F.softmax(gcn_pred, dim=1)
|
||||
|
||||
edges = []
|
||||
scores = []
|
||||
local_graph_nodes = local_graph_nodes.long().squeeze().cpu().numpy()
|
||||
graph_num = node_feats.size(0)
|
||||
|
||||
for graph_inx in range(graph_num):
|
||||
pivot = pivot_inx[graph_inx].int().item()
|
||||
nodes = local_graph_nodes[graph_inx]
|
||||
for neighbor_inx, neighbor in enumerate(knn_inx[graph_inx]):
|
||||
neighbor = neighbor.item()
|
||||
edges.append([nodes[pivot], nodes[neighbor]])
|
||||
scores.append(pred_labels[graph_inx * (knn_inx.shape[1]) +
|
||||
neighbor_inx, 1].item())
|
||||
|
||||
edges = np.asarray(edges)
|
||||
scores = np.asarray(scores)
|
||||
|
||||
return edges, scores, text_comps
|
||||
|
||||
def get_boundary(self, edges, scores, text_comps):
|
||||
|
||||
boundaries = []
|
||||
if edges is not None:
|
||||
boundaries = merge_text_comps(edges, scores, text_comps,
|
||||
self.link_thr)
|
||||
|
||||
results = dict(boundary_result=boundaries)
|
||||
|
||||
return results
|
|
@ -0,0 +1,48 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import normal_init
|
||||
|
||||
from mmdet.models.builder import HEADS, build_loss
|
||||
from . import HeadMixin
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class TextSnakeHead(HeadMixin, nn.Module):
|
||||
"""The class for TextSnake head: TextSnake: A Flexible Representation for
|
||||
Detecting Text of Arbitrary Shapes.
|
||||
|
||||
[https://arxiv.org/abs/1807.01544]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
decoding_type='textsnake',
|
||||
text_repr_type='poly',
|
||||
loss=dict(type='TextSnakeLoss'),
|
||||
train_cfg=None,
|
||||
test_cfg=None):
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(in_channels, int)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = 5
|
||||
self.downsample_ratio = 1.0
|
||||
self.decoding_type = decoding_type
|
||||
self.text_repr_type = text_repr_type
|
||||
self.loss_module = build_loss(loss)
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
self.out_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
normal_init(self.out_conv, mean=0, std=0.01)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.out_conv(inputs)
|
||||
return outputs
|
|
@ -0,0 +1,38 @@
|
|||
from mmdet.models.builder import DETECTORS
|
||||
from . import SingleStageTextDetector, TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class DRRG(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for implementing DRRG text detector: Deep Relational Reasoning
|
||||
Graph Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]
|
||||
"""
|
||||
|
||||
def forward_train(self, img, img_metas, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
img_metas (list[dict]): A list of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details of the values of these keys, see
|
||||
:class:`mmdet.datasets.pipelines.Collect`.
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
x = self.extract_feat(img)
|
||||
gt_comp_attribs = kwargs.pop('gt_comp_attribs')
|
||||
preds = self.bbox_head(x, gt_comp_attribs)
|
||||
losses = self.bbox_head.loss(preds, **kwargs)
|
||||
return losses
|
||||
|
||||
def simple_test(self, img, img_metas, rescale=False):
|
||||
|
||||
x = self.extract_feat(img)
|
||||
outs = self.bbox_head.single_test(x, img)
|
||||
boundaries = self.bbox_head.get_boundary(*outs, img_metas, rescale)
|
||||
|
||||
return [boundaries]
|
|
@ -0,0 +1,23 @@
|
|||
from mmdet.models.builder import DETECTORS
|
||||
from . import SingleStageTextDetector, TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class TextSnake(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for implementing TextSnake text detector: TextSnake: A
|
||||
Flexible Representation for Detecting Text of Arbitrary Shapes.
|
||||
|
||||
[https://arxiv.org/abs/1807.01544]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
neck,
|
||||
bbox_head,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None,
|
||||
show_score=False):
|
||||
SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
|
||||
train_cfg, test_cfg, pretrained)
|
||||
TextDetectorMixin.__init__(self, show_score)
|
|
@ -0,0 +1,206 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.models.builder import LOSSES
|
||||
from mmocr.utils import check_argument
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class DRRGLoss(nn.Module):
|
||||
"""The class for implementing DRRG loss: Deep Relational Reasoning Graph
|
||||
Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/1908.05900] This is partially adapted from
|
||||
https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
|
||||
def __init__(self, ohem_ratio=3.0):
|
||||
"""Initialization.
|
||||
|
||||
Args:
|
||||
ohem_ratio (float): The negative/positive ratio in OHEM.
|
||||
"""
|
||||
super().__init__()
|
||||
self.ohem_ratio = ohem_ratio
|
||||
|
||||
def balance_bce_loss(self, pred, gt, mask):
|
||||
|
||||
assert pred.shape == gt.shape == mask.shape
|
||||
positive = gt * mask
|
||||
negative = (1 - gt) * mask
|
||||
positive_count = int(positive.float().sum())
|
||||
gt = gt.float()
|
||||
if positive_count > 0:
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction='none')
|
||||
positive_loss = torch.sum(loss * positive.float())
|
||||
negative_loss = loss * negative.float()
|
||||
negative_count = min(
|
||||
int(negative.float().sum()),
|
||||
int(positive_count * self.ohem_ratio))
|
||||
else:
|
||||
positive_loss = torch.tensor(0.0)
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction='none')
|
||||
negative_loss = loss * negative.float()
|
||||
negative_count = 100
|
||||
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
|
||||
|
||||
balance_loss = (positive_loss + torch.sum(negative_loss)) / (
|
||||
float(positive_count + negative_count) + 1e-5)
|
||||
|
||||
return balance_loss
|
||||
|
||||
def gcn_loss(self, gcn_data):
|
||||
|
||||
gcn_pred, gt_labels = gcn_data
|
||||
gt_labels = gt_labels.view(-1).to(gcn_pred.device)
|
||||
loss = F.cross_entropy(gcn_pred, gt_labels)
|
||||
|
||||
return loss
|
||||
|
||||
def bitmasks2tensor(self, bitmasks, target_sz):
|
||||
"""Convert Bitmasks to tensor.
|
||||
|
||||
Args:
|
||||
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
|
||||
for one img.
|
||||
target_sz (tuple(int, int)): The target tensor size HxW.
|
||||
|
||||
Returns
|
||||
results (list[tensor]): The list of kernel tensors. Each
|
||||
element is for one kernel level.
|
||||
"""
|
||||
assert check_argument.is_type_list(bitmasks, BitmapMasks)
|
||||
assert isinstance(target_sz, tuple)
|
||||
|
||||
batch_size = len(bitmasks)
|
||||
num_masks = len(bitmasks[0])
|
||||
|
||||
results = []
|
||||
|
||||
for level_inx in range(num_masks):
|
||||
kernel = []
|
||||
for batch_inx in range(batch_size):
|
||||
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
|
||||
# hxw
|
||||
mask_sz = mask.shape
|
||||
# left, right, top, bottom
|
||||
pad = [
|
||||
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
|
||||
]
|
||||
mask = F.pad(mask, pad, mode='constant', value=0)
|
||||
kernel.append(mask)
|
||||
kernel = torch.stack(kernel)
|
||||
results.append(kernel)
|
||||
|
||||
return results
|
||||
|
||||
def forward(self, preds, downsample_ratio, gt_text_mask,
|
||||
gt_center_region_mask, gt_mask, gt_top_height_map,
|
||||
gt_bot_height_map, gt_sin_map, gt_cos_map):
|
||||
|
||||
assert isinstance(preds, tuple)
|
||||
assert isinstance(downsample_ratio, float)
|
||||
assert abs(downsample_ratio - 1.0) < 1e-5
|
||||
assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_top_height_map, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_bot_height_map, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_cos_map, BitmapMasks)
|
||||
|
||||
pred_maps, gcn_data = preds
|
||||
pred_text_region = pred_maps[:, 0, :, :]
|
||||
pred_center_region = pred_maps[:, 1, :, :]
|
||||
pred_sin_map = pred_maps[:, 2, :, :]
|
||||
pred_cos_map = pred_maps[:, 3, :, :]
|
||||
pred_top_height_map = pred_maps[:, 4, :, :]
|
||||
pred_bot_height_map = pred_maps[:, 5, :, :]
|
||||
feature_sz = pred_maps.size()
|
||||
|
||||
# bitmask 2 tensor
|
||||
mapping = {
|
||||
'gt_text_mask': gt_text_mask,
|
||||
'gt_center_region_mask': gt_center_region_mask,
|
||||
'gt_mask': gt_mask,
|
||||
'gt_top_height_map': gt_top_height_map,
|
||||
'gt_bot_height_map': gt_bot_height_map,
|
||||
'gt_sin_map': gt_sin_map,
|
||||
'gt_cos_map': gt_cos_map
|
||||
}
|
||||
gt = {}
|
||||
for key, value in mapping.items():
|
||||
gt[key] = value
|
||||
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
|
||||
gt[key] = [item.to(pred_maps.device) for item in gt[key]]
|
||||
|
||||
scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
|
||||
pred_sin_map = pred_sin_map * scale
|
||||
pred_cos_map = pred_cos_map * scale
|
||||
|
||||
loss_text = self.balance_bce_loss(
|
||||
torch.sigmoid(pred_text_region), gt['gt_text_mask'][0],
|
||||
gt['gt_mask'][0])
|
||||
|
||||
text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
|
||||
negative_text_mask = ((1 - gt['gt_text_mask'][0]) *
|
||||
gt['gt_mask'][0]).float()
|
||||
gt_center_region_mask = gt['gt_center_region_mask'][0].float()
|
||||
loss_center = F.binary_cross_entropy(
|
||||
torch.sigmoid(pred_center_region),
|
||||
gt_center_region_mask,
|
||||
reduction='none')
|
||||
if int(text_mask.sum()) > 0:
|
||||
loss_center_positive = torch.sum(
|
||||
loss_center * text_mask) / torch.sum(text_mask)
|
||||
else:
|
||||
loss_center_positive = torch.tensor(0.0)
|
||||
loss_center_negative = torch.sum(
|
||||
loss_center * negative_text_mask) / torch.sum(negative_text_mask)
|
||||
loss_center = loss_center_positive + 0.5 * loss_center_negative
|
||||
|
||||
center_mask = (gt['gt_center_region_mask'][0] *
|
||||
gt['gt_mask'][0]).float()
|
||||
if int(center_mask.sum()) > 0:
|
||||
ones = torch.ones_like(
|
||||
gt['gt_top_height_map'][0], dtype=torch.float)
|
||||
loss_top = F.smooth_l1_loss(
|
||||
pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2),
|
||||
ones,
|
||||
reduction='none')
|
||||
loss_bot = F.smooth_l1_loss(
|
||||
pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2),
|
||||
ones,
|
||||
reduction='none')
|
||||
gt_height = (
|
||||
gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0])
|
||||
loss_height = torch.sum(
|
||||
(torch.log(gt_height + 1) *
|
||||
(loss_top + loss_bot)) * center_mask) / torch.sum(center_mask)
|
||||
|
||||
loss_sin = torch.sum(
|
||||
F.smooth_l1_loss(
|
||||
pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
|
||||
center_mask) / torch.sum(center_mask)
|
||||
loss_cos = torch.sum(
|
||||
F.smooth_l1_loss(
|
||||
pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
|
||||
center_mask) / torch.sum(center_mask)
|
||||
else:
|
||||
loss_height = torch.tensor(0.0)
|
||||
loss_sin = torch.tensor(0.0)
|
||||
loss_cos = torch.tensor(0.0)
|
||||
|
||||
loss_gcn = self.gcn_loss(gcn_data)
|
||||
|
||||
results = dict(
|
||||
loss_text=loss_text,
|
||||
loss_center=loss_center,
|
||||
loss_height=loss_height,
|
||||
loss_sin=loss_sin,
|
||||
loss_cos=loss_cos,
|
||||
loss_gcn=loss_gcn)
|
||||
|
||||
return results
|
|
@ -0,0 +1,181 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.models.builder import LOSSES
|
||||
from mmocr.utils import check_argument
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class TextSnakeLoss(nn.Module):
|
||||
"""The class for implementing TextSnake loss:
|
||||
TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes
|
||||
[https://arxiv.org/abs/1807.01544].
|
||||
This is partially adapted from
|
||||
https://github.com/princewang1994/TextSnake.pytorch.
|
||||
"""
|
||||
|
||||
def __init__(self, ohem_ratio=3.0):
|
||||
"""Initialization.
|
||||
|
||||
Args:
|
||||
ohem_ratio (float): The negative/positive ratio in ohem.
|
||||
"""
|
||||
super().__init__()
|
||||
self.ohem_ratio = ohem_ratio
|
||||
|
||||
def balanced_bce_loss(self, pred, gt, mask):
|
||||
|
||||
assert pred.shape == gt.shape == mask.shape
|
||||
positive = gt * mask
|
||||
negative = (1 - gt) * mask
|
||||
positive_count = int(positive.float().sum())
|
||||
gt = gt.float()
|
||||
if positive_count > 0:
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction='none')
|
||||
positive_loss = torch.sum(loss * positive.float())
|
||||
negative_loss = loss * negative.float()
|
||||
negative_count = min(
|
||||
int(negative.float().sum()),
|
||||
int(positive_count * self.ohem_ratio))
|
||||
else:
|
||||
positive_loss = torch.tensor(0.0, device=pred.device)
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction='none')
|
||||
negative_loss = loss * negative.float()
|
||||
negative_count = 100
|
||||
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
|
||||
|
||||
balance_loss = (positive_loss + torch.sum(negative_loss)) / (
|
||||
float(positive_count + negative_count) + 1e-5)
|
||||
|
||||
return balance_loss
|
||||
|
||||
def bitmasks2tensor(self, bitmasks, target_sz):
|
||||
"""Convert Bitmasks to tensor.
|
||||
|
||||
Args:
|
||||
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
|
||||
for one img.
|
||||
target_sz (tuple(int, int)): The target tensor size HxW.
|
||||
|
||||
Returns
|
||||
results (list[tensor]): The list of kernel tensors. Each
|
||||
element is for one kernel level.
|
||||
"""
|
||||
assert check_argument.is_type_list(bitmasks, BitmapMasks)
|
||||
assert isinstance(target_sz, tuple)
|
||||
|
||||
batch_size = len(bitmasks)
|
||||
num_masks = len(bitmasks[0])
|
||||
|
||||
results = []
|
||||
|
||||
for level_inx in range(num_masks):
|
||||
kernel = []
|
||||
for batch_inx in range(batch_size):
|
||||
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
|
||||
# hxw
|
||||
mask_sz = mask.shape
|
||||
# left, right, top, bottom
|
||||
pad = [
|
||||
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
|
||||
]
|
||||
mask = F.pad(mask, pad, mode='constant', value=0)
|
||||
kernel.append(mask)
|
||||
kernel = torch.stack(kernel)
|
||||
results.append(kernel)
|
||||
|
||||
return results
|
||||
|
||||
def forward(self, pred_maps, downsample_ratio, gt_text_mask,
|
||||
gt_center_region_mask, gt_mask, gt_radius_map, gt_sin_map,
|
||||
gt_cos_map):
|
||||
|
||||
assert isinstance(downsample_ratio, float)
|
||||
assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_radius_map, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_cos_map, BitmapMasks)
|
||||
|
||||
pred_text_region = pred_maps[:, 0, :, :]
|
||||
pred_center_region = pred_maps[:, 1, :, :]
|
||||
pred_sin_map = pred_maps[:, 2, :, :]
|
||||
pred_cos_map = pred_maps[:, 3, :, :]
|
||||
pred_radius_map = pred_maps[:, 4, :, :]
|
||||
feature_sz = pred_maps.size()
|
||||
device = pred_maps.device
|
||||
|
||||
# bitmask 2 tensor
|
||||
mapping = {
|
||||
'gt_text_mask': gt_text_mask,
|
||||
'gt_center_region_mask': gt_center_region_mask,
|
||||
'gt_mask': gt_mask,
|
||||
'gt_radius_map': gt_radius_map,
|
||||
'gt_sin_map': gt_sin_map,
|
||||
'gt_cos_map': gt_cos_map
|
||||
}
|
||||
gt = {}
|
||||
for key, value in mapping.items():
|
||||
gt[key] = value
|
||||
if abs(downsample_ratio - 1.0) < 1e-2:
|
||||
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
|
||||
else:
|
||||
gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
|
||||
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
|
||||
if key == 'gt_radius_map':
|
||||
gt[key] = [item * downsample_ratio for item in gt[key]]
|
||||
gt[key] = [item.to(device) for item in gt[key]]
|
||||
|
||||
scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
|
||||
pred_sin_map = pred_sin_map * scale
|
||||
pred_cos_map = pred_cos_map * scale
|
||||
|
||||
loss_text = self.balanced_bce_loss(
|
||||
torch.sigmoid(pred_text_region), gt['gt_text_mask'][0],
|
||||
gt['gt_mask'][0])
|
||||
|
||||
text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
|
||||
loss_center_map = F.binary_cross_entropy(
|
||||
torch.sigmoid(pred_center_region),
|
||||
gt['gt_center_region_mask'][0].float(),
|
||||
reduction='none')
|
||||
if int(text_mask.sum()) > 0:
|
||||
loss_center = torch.sum(
|
||||
loss_center_map * text_mask) / torch.sum(text_mask)
|
||||
else:
|
||||
loss_center = torch.tensor(0.0, device=device)
|
||||
|
||||
center_mask = (gt['gt_center_region_mask'][0] *
|
||||
gt['gt_mask'][0]).float()
|
||||
if int(center_mask.sum()) > 0:
|
||||
map_sz = pred_radius_map.size()
|
||||
ones = torch.ones(map_sz, dtype=torch.float, device=device)
|
||||
loss_radius = torch.sum(
|
||||
F.smooth_l1_loss(
|
||||
pred_radius_map / (gt['gt_radius_map'][0] + 1e-2),
|
||||
ones,
|
||||
reduction='none') * center_mask) / torch.sum(center_mask)
|
||||
loss_sin = torch.sum(
|
||||
F.smooth_l1_loss(
|
||||
pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
|
||||
center_mask) / torch.sum(center_mask)
|
||||
loss_cos = torch.sum(
|
||||
F.smooth_l1_loss(
|
||||
pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
|
||||
center_mask) / torch.sum(center_mask)
|
||||
else:
|
||||
loss_radius = torch.tensor(0.0, device=device)
|
||||
loss_sin = torch.tensor(0.0, device=device)
|
||||
loss_cos = torch.tensor(0.0, device=device)
|
||||
|
||||
results = dict(
|
||||
loss_text=loss_text,
|
||||
loss_center=loss_center,
|
||||
loss_radius=loss_radius,
|
||||
loss_sin=loss_sin,
|
||||
loss_cos=loss_cos)
|
||||
|
||||
return results
|
|
@ -0,0 +1,6 @@
|
|||
from .gcn import GCN
|
||||
from .local_graph import LocalGraphs
|
||||
from .proposal_local_graph import ProposalLocalGraphs
|
||||
from .utils import merge_text_comps
|
||||
|
||||
__all__ = ['LocalGraphs', 'ProposalLocalGraphs', 'GCN', 'merge_text_comps']
|
|
@ -0,0 +1,80 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
|
||||
|
||||
class MeanAggregator(nn.Module):
|
||||
|
||||
def forward(self, features, A):
|
||||
x = torch.bmm(A, features)
|
||||
return x
|
||||
|
||||
|
||||
class GraphConv(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super(GraphConv, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
|
||||
self.bias = nn.Parameter(torch.FloatTensor(out_dim))
|
||||
init.xavier_uniform_(self.weight)
|
||||
init.constant_(self.bias, 0)
|
||||
self.agg = MeanAggregator()
|
||||
|
||||
def forward(self, features, A):
|
||||
b, n, d = features.shape
|
||||
assert d == self.in_dim
|
||||
agg_feats = self.agg(features, A)
|
||||
cat_feats = torch.cat([features, agg_feats], dim=2)
|
||||
out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight))
|
||||
out = F.relu(out + self.bias)
|
||||
return out
|
||||
|
||||
|
||||
class GCN(nn.Module):
|
||||
"""Predict linkage between instances. This was from repo
|
||||
https://github.com/Zhongdao/gcn_clustering: Linkage Based Face Clustering
|
||||
via Graph Convolution Network.
|
||||
|
||||
[https://arxiv.org/abs/1903.11306]
|
||||
|
||||
Args:
|
||||
in_dim(int): The input dimension.
|
||||
out_dim(int): The output dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super(GCN, self).__init__()
|
||||
self.bn0 = nn.BatchNorm1d(in_dim, affine=False).float()
|
||||
self.conv1 = GraphConv(in_dim, 512)
|
||||
self.conv2 = GraphConv(512, 256)
|
||||
self.conv3 = GraphConv(256, 128)
|
||||
self.conv4 = GraphConv(128, 64)
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(64, out_dim), nn.PReLU(out_dim), nn.Linear(out_dim, 2))
|
||||
|
||||
def forward(self, x, A, one_hop_indexes, train=True):
|
||||
|
||||
B, N, D = x.shape
|
||||
|
||||
x = x.view(-1, D)
|
||||
x = self.bn0(x)
|
||||
x = x.view(B, N, D)
|
||||
|
||||
x = self.conv1(x, A)
|
||||
x = self.conv2(x, A)
|
||||
x = self.conv3(x, A)
|
||||
x = self.conv4(x, A)
|
||||
k1 = one_hop_indexes.size(-1)
|
||||
dout = x.size(-1)
|
||||
edge_feat = torch.zeros(B, k1, dout)
|
||||
for b in range(B):
|
||||
edge_feat[b, :, :] = x[b, one_hop_indexes[b]]
|
||||
edge_feat = edge_feat.view(-1, dout).to(x.device)
|
||||
pred = self.classifier(edge_feat)
|
||||
|
||||
# shape: (B*k1)x2
|
||||
return pred
|
|
@ -0,0 +1,307 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmocr.models.utils import RROIAlign
|
||||
from .utils import (embed_geo_feats, euclidean_distance_matrix,
|
||||
normalize_adjacent_matrix)
|
||||
|
||||
|
||||
class LocalGraphs:
|
||||
"""Generate local graphs for GCN to predict which instance a text component
|
||||
belongs to. This was partially adapted from https://github.com/GXYM/DRRG:
|
||||
Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]
|
||||
|
||||
Args:
|
||||
k_at_hops (tuple(int)): The number of h-hop neighbors.
|
||||
active_connection (int): The number of neighbors deem as linked to a
|
||||
pivot.
|
||||
node_geo_feat_dim (int): The dimension of embedded geometric features
|
||||
of a component.
|
||||
pooling_scale (float): The spatial scale of RRoI-Aligning.
|
||||
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
|
||||
local_graph_filter_thr (float): The threshold to filter out identical
|
||||
local graphs.
|
||||
"""
|
||||
|
||||
def __init__(self, k_at_hops, active_connection, node_geo_feat_dim,
|
||||
pooling_scale, pooling_output_size, local_graph_filter_thr):
|
||||
|
||||
assert isinstance(k_at_hops, tuple)
|
||||
assert isinstance(active_connection, int)
|
||||
assert isinstance(node_geo_feat_dim, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert isinstance(pooling_output_size, tuple)
|
||||
assert isinstance(local_graph_filter_thr, float)
|
||||
|
||||
self.k_at_hops = k_at_hops
|
||||
self.local_graph_depth = len(self.k_at_hops)
|
||||
self.active_connection = active_connection
|
||||
self.node_geo_feat_dim = node_geo_feat_dim
|
||||
self.pooling = RROIAlign(pooling_output_size, pooling_scale)
|
||||
self.local_graph_filter_thr = local_graph_filter_thr
|
||||
|
||||
def generate_local_graphs(self, sorted_complete_graph, gt_belong_labels):
|
||||
"""Generate local graphs for GCN to predict which instance a text
|
||||
component belongs to.
|
||||
|
||||
Args:
|
||||
sorted_complete_graph (ndarray): The complete graph where nodes are
|
||||
sorted according to their Euclidean distance.
|
||||
gt_belong_labels (ndarray): The ground truth labels define which
|
||||
instance text components (nodes in graphs) belong to.
|
||||
|
||||
Returns:
|
||||
local_graph_node_list (list): The list of local graph neighbors of
|
||||
pivots.
|
||||
knn_graph_neighbor_list (list): The list of k nearest neighbors of
|
||||
pivots.
|
||||
"""
|
||||
|
||||
assert sorted_complete_graph.ndim == 2
|
||||
assert (sorted_complete_graph.shape[0] ==
|
||||
sorted_complete_graph.shape[1] == gt_belong_labels.shape[0])
|
||||
|
||||
knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1]
|
||||
local_graph_node_list = list()
|
||||
knn_graph_neighbor_list = list()
|
||||
for pivot_inx, knn_graph in enumerate(knn_graphs):
|
||||
|
||||
h_hop_neighbor_list = list()
|
||||
one_hop_neighbors = set(knn_graph[1:])
|
||||
h_hop_neighbor_list.append(one_hop_neighbors)
|
||||
|
||||
for hop_inx in range(1, self.local_graph_depth):
|
||||
h_hop_neighbor_list.append(set())
|
||||
for last_hop_neighbor_inx in h_hop_neighbor_list[-2]:
|
||||
h_hop_neighbor_list[-1].update(
|
||||
set(sorted_complete_graph[last_hop_neighbor_inx]
|
||||
[1:self.k_at_hops[hop_inx] + 1]))
|
||||
|
||||
hops_neighbor_set = set(
|
||||
[node for hop in h_hop_neighbor_list for node in hop])
|
||||
hops_neighbor_list = list(hops_neighbor_set)
|
||||
hops_neighbor_list.insert(0, pivot_inx)
|
||||
|
||||
if pivot_inx < 1:
|
||||
local_graph_node_list.append(hops_neighbor_list)
|
||||
knn_graph_neighbor_list.append(one_hop_neighbors)
|
||||
else:
|
||||
add_flag = True
|
||||
for graph_inx in range(len(knn_graph_neighbor_list)):
|
||||
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
|
||||
local_graph_nodes = local_graph_node_list[graph_inx]
|
||||
|
||||
node_union_num = len(
|
||||
list(
|
||||
set(knn_graph_neighbors).union(
|
||||
set(one_hop_neighbors))))
|
||||
node_intersect_num = len(
|
||||
list(
|
||||
set(knn_graph_neighbors).intersection(
|
||||
set(one_hop_neighbors))))
|
||||
one_hop_iou = node_intersect_num / (node_union_num + 1e-8)
|
||||
|
||||
if (one_hop_iou > self.local_graph_filter_thr
|
||||
and pivot_inx in knn_graph_neighbors
|
||||
and gt_belong_labels[local_graph_nodes[0]]
|
||||
== gt_belong_labels[pivot_inx]
|
||||
and gt_belong_labels[local_graph_nodes[0]] != 0):
|
||||
add_flag = False
|
||||
break
|
||||
if add_flag:
|
||||
local_graph_node_list.append(hops_neighbor_list)
|
||||
knn_graph_neighbor_list.append(one_hop_neighbors)
|
||||
|
||||
return local_graph_node_list, knn_graph_neighbor_list
|
||||
|
||||
def generate_gcn_input(self, node_feat_batch, belong_label_batch,
|
||||
local_graph_node_batch, knn_graph_neighbor_batch,
|
||||
sorted_complete_graph):
|
||||
"""Generate graph convolution network input data.
|
||||
|
||||
Args:
|
||||
node_feat_batch (List[Tensor]): The node feature batch.
|
||||
belong_label_batch (List[ndarray]): The text component belong label
|
||||
batch.
|
||||
local_graph_node_batch (List[List[list]]): The local graph
|
||||
neighbors batch.
|
||||
knn_graph_neighbor_batch (List[List[set]]): The knn graph neighbor
|
||||
batch.
|
||||
sorted_complete_graph (List[ndarray]): The complete graph sorted
|
||||
according to the Euclidean distance.
|
||||
|
||||
Returns:
|
||||
node_feat_batch_tensor (Tensor): The node features of Graph
|
||||
Convolutional Network (GCN).
|
||||
adjacent_mat_batch_tensor (Tensor): The adjacent matrices.
|
||||
knn_inx_batch_tensor (Tensor): The indices of k nearest neighbors.
|
||||
gt_linkage_batch_tensor (Tensor): The surpervision signal of GCN
|
||||
for linkage prediction.
|
||||
"""
|
||||
|
||||
assert isinstance(node_feat_batch, list)
|
||||
assert isinstance(belong_label_batch, list)
|
||||
assert isinstance(local_graph_node_batch, list)
|
||||
assert isinstance(knn_graph_neighbor_batch, list)
|
||||
assert isinstance(sorted_complete_graph, list)
|
||||
|
||||
max_graph_node_num = max([
|
||||
len(local_graph_nodes)
|
||||
for local_graph_node_list in local_graph_node_batch
|
||||
for local_graph_nodes in local_graph_node_list
|
||||
])
|
||||
|
||||
node_feat_batch_list = list()
|
||||
adjacent_matrix_batch_list = list()
|
||||
knn_inx_batch_list = list()
|
||||
gt_linkage_batch_list = list()
|
||||
|
||||
for batch_inx, sorted_neighbors in enumerate(sorted_complete_graph):
|
||||
node_feats = node_feat_batch[batch_inx]
|
||||
local_graph_list = local_graph_node_batch[batch_inx]
|
||||
knn_graph_neighbor_list = knn_graph_neighbor_batch[batch_inx]
|
||||
belong_labels = belong_label_batch[batch_inx]
|
||||
|
||||
for graph_inx in range(len(local_graph_list)):
|
||||
local_graph_nodes = local_graph_list[graph_inx]
|
||||
local_graph_node_num = len(local_graph_nodes)
|
||||
pivot_inx = local_graph_nodes[0]
|
||||
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
|
||||
node_to_graph_inx = {
|
||||
j: i
|
||||
for i, j in enumerate(local_graph_nodes)
|
||||
}
|
||||
|
||||
knn_inx_in_local_graph = torch.tensor(
|
||||
[node_to_graph_inx[i] for i in knn_graph_neighbors],
|
||||
dtype=torch.long)
|
||||
pivot_feats = node_feats[torch.tensor(
|
||||
pivot_inx, dtype=torch.long)]
|
||||
normalized_feats = node_feats[torch.tensor(
|
||||
local_graph_nodes, dtype=torch.long)] - pivot_feats
|
||||
|
||||
adjacent_matrix = np.zeros(
|
||||
(local_graph_node_num, local_graph_node_num))
|
||||
pad_normalized_feats = torch.cat([
|
||||
normalized_feats,
|
||||
torch.zeros(max_graph_node_num - local_graph_node_num,
|
||||
normalized_feats.shape[1]).to(
|
||||
node_feats.device)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
for node in local_graph_nodes:
|
||||
neighbors = sorted_neighbors[node,
|
||||
1:self.active_connection + 1]
|
||||
for neighbor in neighbors:
|
||||
if neighbor in local_graph_nodes:
|
||||
adjacent_matrix[node_to_graph_inx[node],
|
||||
node_to_graph_inx[neighbor]] = 1
|
||||
adjacent_matrix[node_to_graph_inx[neighbor],
|
||||
node_to_graph_inx[node]] = 1
|
||||
|
||||
adjacent_matrix = normalize_adjacent_matrix(
|
||||
adjacent_matrix, type='DAD')
|
||||
adjacent_matrix_tensor = torch.zeros(max_graph_node_num,
|
||||
max_graph_node_num).to(
|
||||
node_feats.device)
|
||||
adjacent_matrix_tensor[:local_graph_node_num, :
|
||||
local_graph_node_num] = adjacent_matrix
|
||||
|
||||
local_graph_labels = torch.from_numpy(
|
||||
belong_labels[local_graph_nodes]).type(torch.long)
|
||||
knn_labels = local_graph_labels[knn_inx_in_local_graph]
|
||||
edge_labels = ((belong_labels[pivot_inx] == knn_labels)
|
||||
& (belong_labels[pivot_inx] > 0)).long()
|
||||
|
||||
node_feat_batch_list.append(pad_normalized_feats)
|
||||
adjacent_matrix_batch_list.append(adjacent_matrix_tensor)
|
||||
knn_inx_batch_list.append(knn_inx_in_local_graph)
|
||||
gt_linkage_batch_list.append(edge_labels)
|
||||
|
||||
node_feat_batch_tensor = torch.stack(node_feat_batch_list, 0)
|
||||
adjacent_mat_batch_tensor = torch.stack(adjacent_matrix_batch_list, 0)
|
||||
knn_inx_batch_tensor = torch.stack(knn_inx_batch_list, 0)
|
||||
gt_linkage_batch_tensor = torch.stack(gt_linkage_batch_list, 0)
|
||||
|
||||
return (node_feat_batch_tensor, adjacent_mat_batch_tensor,
|
||||
knn_inx_batch_tensor, gt_linkage_batch_tensor)
|
||||
|
||||
def __call__(self, feat_maps, comp_attribs):
|
||||
"""Generate local graphs.
|
||||
|
||||
Args:
|
||||
feat_maps (Tensor): The feature maps to propose node features in
|
||||
graph.
|
||||
comp_attribs (ndarray): The text components attributes.
|
||||
|
||||
Returns:
|
||||
node_feats_batch (Tensor): The node features of Graph Convolutional
|
||||
Network(GCN).
|
||||
adjacent_matrices_batch (Tensor): The adjacent matrices.
|
||||
knn_inx_batch (Tensor): The indices of k nearest neighbors.
|
||||
gt_linkage_batch (Tensor): The surpervision signal of GCN for
|
||||
linkage prediction.
|
||||
"""
|
||||
|
||||
assert isinstance(feat_maps, torch.Tensor)
|
||||
assert comp_attribs.shape[2] == 8
|
||||
|
||||
dist_sort_graph_batch_list = []
|
||||
local_graph_node_batch_list = []
|
||||
knn_graph_neighbor_batch_list = []
|
||||
node_feature_batch_list = []
|
||||
belong_label_batch_list = []
|
||||
|
||||
for batch_inx in range(comp_attribs.shape[0]):
|
||||
comp_num = int(comp_attribs[batch_inx, 0, 0])
|
||||
comp_geo_attribs = comp_attribs[batch_inx, :comp_num, 1:7]
|
||||
node_belong_labels = comp_attribs[batch_inx, :comp_num,
|
||||
7].astype(np.int32)
|
||||
|
||||
comp_centers = comp_geo_attribs[:, 0:2]
|
||||
distance_matrix = euclidean_distance_matrix(
|
||||
comp_centers, comp_centers)
|
||||
|
||||
graph_node_geo_feats = embed_geo_feats(comp_geo_attribs,
|
||||
self.node_geo_feat_dim)
|
||||
graph_node_geo_feats = torch.from_numpy(
|
||||
graph_node_geo_feats).float().to(feat_maps.device)
|
||||
|
||||
batch_id = np.zeros(
|
||||
(comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_inx
|
||||
text_comps = np.hstack(
|
||||
(batch_id, comp_geo_attribs.astype(np.float32)))
|
||||
text_comps = torch.from_numpy(text_comps).float().to(
|
||||
feat_maps.device)
|
||||
|
||||
comp_content_feats = self.pooling(
|
||||
feat_maps[batch_inx].unsqueeze(0), text_comps)
|
||||
comp_content_feats = comp_content_feats.view(
|
||||
comp_content_feats.shape[0], -1).to(feat_maps.device)
|
||||
node_feats = torch.cat((comp_content_feats, graph_node_geo_feats),
|
||||
dim=-1)
|
||||
|
||||
dist_sort_complete_graph = np.argsort(distance_matrix, axis=1)
|
||||
(local_graph_nodes,
|
||||
knn_graph_neighbors) = self.generate_local_graphs(
|
||||
dist_sort_complete_graph, node_belong_labels)
|
||||
|
||||
node_feature_batch_list.append(node_feats)
|
||||
belong_label_batch_list.append(node_belong_labels)
|
||||
local_graph_node_batch_list.append(local_graph_nodes)
|
||||
knn_graph_neighbor_batch_list.append(knn_graph_neighbors)
|
||||
dist_sort_graph_batch_list.append(dist_sort_complete_graph)
|
||||
|
||||
(node_feats_batch, adjacent_matrices_batch, knn_inx_batch,
|
||||
gt_linkage_batch) = \
|
||||
self.generate_gcn_input(node_feature_batch_list,
|
||||
belong_label_batch_list,
|
||||
local_graph_node_batch_list,
|
||||
knn_graph_neighbor_batch_list,
|
||||
dist_sort_graph_batch_list)
|
||||
|
||||
return (node_feats_batch, adjacent_matrices_batch, knn_inx_batch,
|
||||
gt_linkage_batch)
|
|
@ -0,0 +1,418 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# from mmocr.models.textdet.postprocess import la_nms
|
||||
from mmocr.models.utils import RROIAlign
|
||||
from .utils import (embed_geo_feats, euclidean_distance_matrix,
|
||||
normalize_adjacent_matrix)
|
||||
|
||||
|
||||
class ProposalLocalGraphs:
|
||||
"""Propose text components and generate local graphs. This was partially
|
||||
adapted from https://github.com/GXYM/DRRG: Deep Relational Reasoning Graph
|
||||
Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]
|
||||
|
||||
Args:
|
||||
k_at_hops (tuple(int)): The number of i-hop neighbors,
|
||||
i = 1, 2, ..., h.
|
||||
active_connection (int): The number of two hop neighbors deem as linked
|
||||
to a pivot.
|
||||
node_geo_feat_dim (int): The dimension of embedded geometric features
|
||||
of a component.
|
||||
pooling_scale (float): The spatial scale of RRoI-Aligning.
|
||||
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
|
||||
nms_thr (float): The locality-aware NMS threshold.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
comp_ratio (float): The reciprocal of aspect ratio of text components.
|
||||
text_region_thr (float): The threshold for text region probability map.
|
||||
center_region_thr (float): The threshold for text center region
|
||||
probability map.
|
||||
center_region_area_thr (int): The threshold for filtering out
|
||||
small-size text center region.
|
||||
"""
|
||||
|
||||
def __init__(self, k_at_hops, active_connection, node_geo_feat_dim,
|
||||
pooling_scale, pooling_output_size, nms_thr, min_width,
|
||||
max_width, comp_shrink_ratio, comp_ratio, text_region_thr,
|
||||
center_region_thr, center_region_area_thr):
|
||||
|
||||
assert isinstance(k_at_hops, tuple)
|
||||
assert isinstance(active_connection, int)
|
||||
assert isinstance(node_geo_feat_dim, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert isinstance(pooling_output_size, tuple)
|
||||
assert isinstance(nms_thr, float)
|
||||
assert isinstance(min_width, float)
|
||||
assert isinstance(max_width, float)
|
||||
assert isinstance(comp_shrink_ratio, float)
|
||||
assert isinstance(comp_ratio, float)
|
||||
assert isinstance(text_region_thr, float)
|
||||
assert isinstance(center_region_thr, float)
|
||||
assert isinstance(center_region_area_thr, int)
|
||||
|
||||
self.k_at_hops = k_at_hops
|
||||
self.active_connection = active_connection
|
||||
self.local_graph_depth = len(self.k_at_hops)
|
||||
self.node_geo_feat_dim = node_geo_feat_dim
|
||||
self.pooling = RROIAlign(pooling_output_size, pooling_scale)
|
||||
self.nms_thr = nms_thr
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.comp_shrink_ratio = comp_shrink_ratio
|
||||
self.comp_ratio = comp_ratio
|
||||
self.text_region_thr = text_region_thr
|
||||
self.center_region_thr = center_region_thr
|
||||
self.center_region_area_thr = center_region_area_thr
|
||||
|
||||
def fill_hole(self, input_mask):
|
||||
h, w = input_mask.shape
|
||||
canvas = np.zeros((h + 2, w + 2), np.uint8)
|
||||
canvas[1:h + 1, 1:w + 1] = input_mask.copy()
|
||||
|
||||
mask = np.zeros((h + 4, w + 4), np.uint8)
|
||||
|
||||
cv2.floodFill(canvas, mask, (0, 0), 1)
|
||||
canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
|
||||
|
||||
return (~canvas | input_mask.astype(np.uint8))
|
||||
|
||||
def propose_comps(self, top_radius_map, bot_radius_map, sin_map, cos_map,
|
||||
score_map, min_width, max_width, comp_shrink_ratio,
|
||||
comp_ratio):
|
||||
"""Generate text components.
|
||||
|
||||
Args:
|
||||
top_radius_map (ndarray): The predicted distance map from each
|
||||
pixel in text center region to top sideline.
|
||||
bot_radius_map (ndarray): The predicted distance map from each
|
||||
pixel in text center region to bottom sideline.
|
||||
sin_map (ndarray): The predicted sin(theta) map.
|
||||
cos_map (ndarray): The predicted cos(theta) map.
|
||||
score_map (ndarray): The score map for NMS.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
comp_ratio (float): The reciprocal of aspect ratio of text
|
||||
components.
|
||||
|
||||
Returns:
|
||||
text_comps (ndarray): The text components.
|
||||
"""
|
||||
|
||||
comp_centers = np.argwhere(score_map > 0)
|
||||
comp_centers = comp_centers[np.argsort(comp_centers[:, 0])]
|
||||
y = comp_centers[:, 0]
|
||||
x = comp_centers[:, 1]
|
||||
|
||||
top_radius = top_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
|
||||
bot_radius = bot_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
|
||||
top_mid_x_offset = top_radius * cos
|
||||
top_mid_y_offset = top_radius * sin
|
||||
bot_mid_x_offset = bot_radius * cos
|
||||
bot_mid_y_offset = bot_radius * sin
|
||||
|
||||
top_mid_pnt = comp_centers + np.hstack(
|
||||
[top_mid_y_offset, top_mid_x_offset])
|
||||
bot_mid_pnt = comp_centers - np.hstack(
|
||||
[bot_mid_y_offset, bot_mid_x_offset])
|
||||
|
||||
width = (top_radius + bot_radius) * comp_ratio
|
||||
width = np.clip(width, min_width, max_width)
|
||||
|
||||
top_left = top_mid_pnt - np.hstack([width * cos, -width * sin
|
||||
])[:, ::-1]
|
||||
top_right = top_mid_pnt + np.hstack([width * cos, -width * sin
|
||||
])[:, ::-1]
|
||||
bot_right = bot_mid_pnt + np.hstack([width * cos, -width * sin
|
||||
])[:, ::-1]
|
||||
bot_left = bot_mid_pnt - np.hstack([width * cos, -width * sin
|
||||
])[:, ::-1]
|
||||
|
||||
text_comps = np.hstack([top_left, top_right, bot_right, bot_left])
|
||||
score = score_map[y, x].reshape((-1, 1))
|
||||
text_comps = np.hstack([text_comps, score])
|
||||
|
||||
return text_comps
|
||||
|
||||
def propose_comps_and_attribs(self, text_region_map, center_region_map,
|
||||
top_radius_map, bot_radius_map, sin_map,
|
||||
cos_map):
|
||||
"""Generate text components and attributes.
|
||||
|
||||
Args:
|
||||
text_region_map (ndarray): The predicted text region probability
|
||||
map.
|
||||
center_region_map (ndarray): The predicted text center region
|
||||
probability map.
|
||||
top_radius_map (ndarray): The predicted distance map from each
|
||||
pixel in text center region to top sideline.
|
||||
bot_radius_map (ndarray): The predicted distance map from each
|
||||
pixel in text center region to bottom sideline.
|
||||
sin_map (ndarray): The predicted sin(theta) map.
|
||||
cos_map (ndarray): The predicted cos(theta) map.
|
||||
|
||||
Returns:
|
||||
comp_attribs (ndarray): The text components attributes.
|
||||
text_comps (ndarray): The text components.
|
||||
"""
|
||||
|
||||
assert (text_region_map.shape == center_region_map.shape ==
|
||||
top_radius_map.shape == bot_radius_map == sin_map.shape ==
|
||||
cos_map.shape)
|
||||
text_mask = text_region_map > self.text_region_thr
|
||||
center_region_mask = (center_region_map >
|
||||
self.center_region_thr) * text_mask
|
||||
|
||||
scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2))
|
||||
sin_map, cos_map = sin_map * scale, cos_map * scale
|
||||
|
||||
center_region_mask = self.fill_hole(center_region_mask)
|
||||
center_region_contours, _ = cv2.findContours(
|
||||
center_region_mask.astype(np.uint8), cv2.RETR_TREE,
|
||||
cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
mask = np.zeros_like(center_region_mask)
|
||||
comp_list = []
|
||||
for contour in center_region_contours:
|
||||
current_center_mask = mask.copy()
|
||||
cv2.drawContours(current_center_mask, [contour], -1, 1, -1)
|
||||
if current_center_mask.sum() <= self.center_region_area_thr:
|
||||
continue
|
||||
score_map = text_region_map * current_center_mask
|
||||
|
||||
text_comp = self.propose_comps(top_radius_map, bot_radius_map,
|
||||
sin_map, cos_map, score_map,
|
||||
self.min_width, self.max_width,
|
||||
self.comp_shrink_ratio,
|
||||
self.comp_ratio)
|
||||
|
||||
# text_comp = la_nms(text_comp.astype('float32'), self.nms_thr)
|
||||
|
||||
text_comp_mask = mask.copy()
|
||||
text_comps_bboxes = text_comp[:, :8].reshape(
|
||||
(-1, 4, 2)).astype(np.int32)
|
||||
|
||||
cv2.drawContours(text_comp_mask, text_comps_bboxes, -1, 1, -1)
|
||||
if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
|
||||
continue
|
||||
|
||||
comp_list.append(text_comp)
|
||||
|
||||
if len(comp_list) <= 0:
|
||||
return None, None
|
||||
|
||||
text_comps = np.vstack(comp_list)
|
||||
|
||||
centers = np.mean(
|
||||
text_comps[:, :8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
|
||||
|
||||
x = centers[:, 0]
|
||||
y = centers[:, 1]
|
||||
|
||||
h = top_radius_map[y, x].reshape(
|
||||
(-1, 1)) + bot_radius_map[y, x].reshape((-1, 1))
|
||||
w = np.clip(h * self.comp_ratio, self.min_width, self.max_width)
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
x = x.reshape((-1, 1))
|
||||
y = y.reshape((-1, 1))
|
||||
comp_attribs = np.hstack([x, y, h, w, cos, sin])
|
||||
|
||||
return comp_attribs, text_comps
|
||||
|
||||
def generate_local_graphs(self, sorted_complete_graph, node_feats):
|
||||
"""Generate local graphs and Graph Convolution Network input data.
|
||||
|
||||
Args:
|
||||
sorted_complete_graph (ndarray): The complete graph where nodes are
|
||||
sorted according to their Euclidean distance.
|
||||
node_feats (tensor): The graph nodes features.
|
||||
|
||||
Returns:
|
||||
node_feats_tensor (tensor): The graph nodes features.
|
||||
adjacent_matrix_tensor (tensor): The adjacent matrix of graph.
|
||||
pivot_inx_tensor (tensor): The pivot indices in local graph.
|
||||
knn_inx_tensor (tensor): The k nearest neighbor nodes indexes in
|
||||
local graph.
|
||||
local_graph_node_tensor (tensor): The indices of nodes in local
|
||||
graph.
|
||||
"""
|
||||
|
||||
assert sorted_complete_graph.ndim == 2
|
||||
assert (sorted_complete_graph.shape[0] ==
|
||||
sorted_complete_graph.shape[1] == node_feats.shape[0])
|
||||
|
||||
knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1]
|
||||
local_graph_node_list = list()
|
||||
knn_graph_neighbor_list = list()
|
||||
for pivot_inx, knn_graph in enumerate(knn_graphs):
|
||||
|
||||
h_hop_neighbor_list = list()
|
||||
one_hop_neighbors = set(knn_graph[1:])
|
||||
h_hop_neighbor_list.append(one_hop_neighbors)
|
||||
|
||||
for hop_inx in range(1, self.local_graph_depth):
|
||||
h_hop_neighbor_list.append(set())
|
||||
for last_hop_neighbor_inx in h_hop_neighbor_list[-2]:
|
||||
h_hop_neighbor_list[-1].update(
|
||||
set(sorted_complete_graph[last_hop_neighbor_inx]
|
||||
[1:self.k_at_hops[hop_inx] + 1]))
|
||||
|
||||
hops_neighbor_set = set(
|
||||
[node for hop in h_hop_neighbor_list for node in hop])
|
||||
hops_neighbor_list = list(hops_neighbor_set)
|
||||
hops_neighbor_list.insert(0, pivot_inx)
|
||||
|
||||
local_graph_node_list.append(hops_neighbor_list)
|
||||
knn_graph_neighbor_list.append(one_hop_neighbors)
|
||||
|
||||
max_graph_node_num = max([
|
||||
len(local_graph_nodes)
|
||||
for local_graph_nodes in local_graph_node_list
|
||||
])
|
||||
|
||||
node_normalized_feats = list()
|
||||
adjacent_matrix_list = list()
|
||||
knn_inx = list()
|
||||
pivot_graph_inx = list()
|
||||
local_graph_tensor_list = list()
|
||||
|
||||
for graph_inx in range(len(local_graph_node_list)):
|
||||
|
||||
local_graph_nodes = local_graph_node_list[graph_inx]
|
||||
local_graph_node_num = len(local_graph_nodes)
|
||||
pivot_inx = local_graph_nodes[0]
|
||||
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
|
||||
node_to_graph_inx = {j: i for i, j in enumerate(local_graph_nodes)}
|
||||
|
||||
pivot_node_inx = torch.tensor([
|
||||
node_to_graph_inx[pivot_inx],
|
||||
]).type(torch.long)
|
||||
knn_inx_in_local_graph = torch.tensor(
|
||||
[node_to_graph_inx[i] for i in knn_graph_neighbors],
|
||||
dtype=torch.long)
|
||||
pivot_feats = node_feats[torch.tensor(pivot_inx, dtype=torch.long)]
|
||||
normalized_feats = node_feats[torch.tensor(
|
||||
local_graph_nodes, dtype=torch.long)] - pivot_feats
|
||||
|
||||
adjacent_matrix = np.zeros(
|
||||
(local_graph_node_num, local_graph_node_num))
|
||||
pad_normalized_feats = torch.cat([
|
||||
normalized_feats,
|
||||
torch.zeros(max_graph_node_num - local_graph_node_num,
|
||||
normalized_feats.shape[1]).to(node_feats.device)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
for node in local_graph_nodes:
|
||||
neighbors = sorted_complete_graph[node,
|
||||
1:self.active_connection + 1]
|
||||
for neighbor in neighbors:
|
||||
if neighbor in local_graph_nodes:
|
||||
adjacent_matrix[node_to_graph_inx[node],
|
||||
node_to_graph_inx[neighbor]] = 1
|
||||
adjacent_matrix[node_to_graph_inx[neighbor],
|
||||
node_to_graph_inx[node]] = 1
|
||||
|
||||
adjacent_matrix = normalize_adjacent_matrix(
|
||||
adjacent_matrix, type='DAD')
|
||||
adjacent_matrix_tensor = torch.zeros(
|
||||
max_graph_node_num, max_graph_node_num).to(node_feats.device)
|
||||
adjacent_matrix_tensor[:local_graph_node_num, :
|
||||
local_graph_node_num] = adjacent_matrix
|
||||
|
||||
local_graph_tensor = torch.tensor(local_graph_nodes)
|
||||
local_graph_tensor = torch.cat([
|
||||
local_graph_tensor,
|
||||
torch.zeros(
|
||||
max_graph_node_num - local_graph_node_num,
|
||||
dtype=torch.long)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
node_normalized_feats.append(pad_normalized_feats)
|
||||
adjacent_matrix_list.append(adjacent_matrix_tensor)
|
||||
pivot_graph_inx.append(pivot_node_inx)
|
||||
knn_inx.append(knn_inx_in_local_graph)
|
||||
local_graph_tensor_list.append(local_graph_tensor)
|
||||
|
||||
node_feats_tensor = torch.stack(node_normalized_feats, 0)
|
||||
adjacent_matrix_tensor = torch.stack(adjacent_matrix_list, 0)
|
||||
pivot_inx_tensor = torch.stack(pivot_graph_inx, 0)
|
||||
knn_inx_tensor = torch.stack(knn_inx, 0)
|
||||
local_graph_node_tensor = torch.stack(local_graph_tensor_list, 0)
|
||||
|
||||
return (node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor,
|
||||
knn_inx_tensor, local_graph_node_tensor)
|
||||
|
||||
def __call__(self, preds, feat_maps):
|
||||
"""Generate local graphs and Graph Convolution Network input data.
|
||||
|
||||
Args:
|
||||
preds (tensor): The predicted maps.
|
||||
feat_maps (tensor): The feature maps to extract content features of
|
||||
text components.
|
||||
|
||||
Returns:
|
||||
node_feats_tensor (tensor): The graph nodes features.
|
||||
adjacent_matrix_tensor (tensor): The adjacent matrix of graph.
|
||||
pivot_inx_tensor (tensor): The pivot indices in local graph.
|
||||
knn_inx_tensor (tensor): The k nearest neighbor nodes indices in
|
||||
local graph.
|
||||
local_graph_node_tensor (tensor): The indices of nodes in local
|
||||
graph.
|
||||
text_comps (ndarray): The predicted text components.
|
||||
"""
|
||||
|
||||
pred_text_region = torch.sigmoid(preds[0, 0]).data.cpu().numpy()
|
||||
pred_center_region = torch.sigmoid(preds[0, 1]).data.cpu().numpy()
|
||||
pred_sin_map = preds[0, 2].data.cpu().numpy()
|
||||
pred_cos_map = preds[0, 3].data.cpu().numpy()
|
||||
pred_top_radius_map = preds[0, 4].data.cpu().numpy()
|
||||
pred_bot_radius_map = preds[0, 5].data.cpu().numpy()
|
||||
|
||||
comp_attribs, text_comps = self.propose_comps_and_attribs(
|
||||
pred_text_region, pred_center_region, pred_top_radius_map,
|
||||
pred_bot_radius_map, pred_sin_map, pred_cos_map)
|
||||
|
||||
if comp_attribs is None:
|
||||
none_flag = True
|
||||
return none_flag, (0, 0, 0, 0, 0, 0)
|
||||
|
||||
comp_centers = comp_attribs[:, 0:2]
|
||||
distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
|
||||
|
||||
graph_node_geo_feats = embed_geo_feats(comp_attribs,
|
||||
self.node_geo_feat_dim)
|
||||
graph_node_geo_feats = torch.from_numpy(
|
||||
graph_node_geo_feats).float().to(preds.device)
|
||||
|
||||
batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32)
|
||||
text_comps_bboxes = np.hstack(
|
||||
(batch_id, comp_attribs.astype(np.float32, copy=False)))
|
||||
text_comps_bboxes = torch.from_numpy(text_comps_bboxes).float().to(
|
||||
preds.device)
|
||||
|
||||
comp_content_feats = self.pooling(feat_maps, text_comps_bboxes)
|
||||
comp_content_feats = comp_content_feats.view(
|
||||
comp_content_feats.shape[0], -1).to(preds.device)
|
||||
node_feats = torch.cat((comp_content_feats, graph_node_geo_feats),
|
||||
dim=-1)
|
||||
|
||||
dist_sort_complete_graph = np.argsort(distance_matrix, axis=1)
|
||||
(node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor,
|
||||
knn_inx_tensor, local_graph_node_tensor) = self.generate_local_graphs(
|
||||
dist_sort_complete_graph, node_feats)
|
||||
|
||||
none_flag = False
|
||||
return none_flag, (node_feats_tensor, adjacent_matrix_tensor,
|
||||
pivot_inx_tensor, knn_inx_tensor,
|
||||
local_graph_node_tensor, text_comps)
|
|
@ -0,0 +1,354 @@
|
|||
import functools
|
||||
import operator
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.linalg import norm
|
||||
|
||||
|
||||
def normalize_adjacent_matrix(A, type='AD'):
|
||||
"""Normalize adjacent matrix for GCN.
|
||||
|
||||
This was from repo https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
if type == 'DAD':
|
||||
# d is Degree of nodes A=A+I
|
||||
# L = D^-1/2 A D^-1/2
|
||||
A = A + np.eye(A.shape[0]) # A=A+I
|
||||
d = np.sum(A, axis=0)
|
||||
d_inv = np.power(d, -0.5).flatten()
|
||||
d_inv[np.isinf(d_inv)] = 0.0
|
||||
d_inv = np.diag(d_inv)
|
||||
G = A.dot(d_inv).transpose().dot(d_inv)
|
||||
G = torch.from_numpy(G)
|
||||
elif type == 'AD':
|
||||
A = A + np.eye(A.shape[0]) # A=A+I
|
||||
A = torch.from_numpy(A)
|
||||
D = A.sum(1, keepdim=True)
|
||||
G = A.div(D)
|
||||
else:
|
||||
A = A + np.eye(A.shape[0]) # A=A+I
|
||||
A = torch.from_numpy(A)
|
||||
D = A.sum(1, keepdim=True)
|
||||
D = np.diag(D)
|
||||
G = D - A
|
||||
return G
|
||||
|
||||
|
||||
def euclidean_distance_matrix(A, B):
|
||||
"""Calculate the Euclidean distance matrix."""
|
||||
|
||||
M = A.shape[0]
|
||||
N = B.shape[0]
|
||||
|
||||
assert A.shape[1] == B.shape[1]
|
||||
|
||||
A_dots = (A * A).sum(axis=1).reshape((M, 1)) * np.ones(shape=(1, N))
|
||||
B_dots = (B * B).sum(axis=1) * np.ones(shape=(M, 1))
|
||||
D_squared = A_dots + B_dots - 2 * A.dot(B.T)
|
||||
|
||||
zero_mask = np.less(D_squared, 0.0)
|
||||
D_squared[zero_mask] = 0.0
|
||||
return np.sqrt(D_squared)
|
||||
|
||||
|
||||
def embed_geo_feats(geo_feats, out_dim):
|
||||
"""Embed geometric features of text components. This was partially adapted
|
||||
from https://github.com/GXYM/DRRG.
|
||||
|
||||
Args:
|
||||
geo_feats (ndarray): The geometric features of text components.
|
||||
out_dim (int): The output dimension.
|
||||
|
||||
Returns:
|
||||
embedded_feats (ndarray): The embedded geometric features.
|
||||
"""
|
||||
assert isinstance(out_dim, int)
|
||||
assert out_dim >= geo_feats.shape[1]
|
||||
comp_num = geo_feats.shape[0]
|
||||
feat_dim = geo_feats.shape[1]
|
||||
feat_repeat_times = out_dim // feat_dim
|
||||
residue_dim = out_dim % feat_dim
|
||||
|
||||
if residue_dim > 0:
|
||||
embed_wave = np.array([
|
||||
np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
|
||||
for j in range(feat_repeat_times + 1)
|
||||
]).reshape((feat_repeat_times + 1, 1, 1))
|
||||
repeat_feats = np.repeat(
|
||||
np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0)
|
||||
residue_feats = np.hstack([
|
||||
geo_feats[:, 0:residue_dim],
|
||||
np.zeros((comp_num, feat_dim - residue_dim))
|
||||
])
|
||||
repeat_feats = np.stack([repeat_feats, residue_feats], axis=0)
|
||||
embedded_feats = repeat_feats / embed_wave
|
||||
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
|
||||
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
|
||||
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
|
||||
(comp_num, -1))[:, 0:out_dim]
|
||||
else:
|
||||
embed_wave = np.array([
|
||||
np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
|
||||
for j in range(feat_repeat_times)
|
||||
]).reshape((feat_repeat_times, 1, 1))
|
||||
repeat_feats = np.repeat(
|
||||
np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0)
|
||||
embedded_feats = repeat_feats / embed_wave
|
||||
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
|
||||
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
|
||||
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
|
||||
(comp_num, -1))
|
||||
|
||||
return embedded_feats
|
||||
|
||||
|
||||
def min_connect_path(list_all: List[list]):
|
||||
"""This is from https://github.com/GXYM/DRRG."""
|
||||
|
||||
list_nodo = list_all.copy()
|
||||
res: List[List[int]] = []
|
||||
ept = [0, 0]
|
||||
|
||||
def norm2(a, b):
|
||||
return ((a[0] - b[0])**2 + (a[1] - b[1])**2)**0.5
|
||||
|
||||
dict00 = {}
|
||||
dict11 = {}
|
||||
ept[0] = list_nodo[0]
|
||||
ept[1] = list_nodo[0]
|
||||
list_nodo.remove(list_nodo[0])
|
||||
while list_nodo:
|
||||
for i in list_nodo:
|
||||
length0 = norm2(i, ept[0])
|
||||
dict00[length0] = [i, ept[0]]
|
||||
length1 = norm2(ept[1], i)
|
||||
dict11[length1] = [ept[1], i]
|
||||
key0 = min(dict00.keys())
|
||||
key1 = min(dict11.keys())
|
||||
|
||||
if key0 <= key1:
|
||||
ss = dict00[key0][0]
|
||||
ee = dict00[key0][1]
|
||||
res.insert(0, [list_all.index(ss), list_all.index(ee)])
|
||||
list_nodo.remove(ss)
|
||||
ept[0] = ss
|
||||
else:
|
||||
ss = dict11[key1][0]
|
||||
ee = dict11[key1][1]
|
||||
res.append([list_all.index(ss), list_all.index(ee)])
|
||||
list_nodo.remove(ee)
|
||||
ept[1] = ee
|
||||
|
||||
dict00 = {}
|
||||
dict11 = {}
|
||||
|
||||
path = functools.reduce(operator.concat, res)
|
||||
path = sorted(set(path), key=path.index)
|
||||
|
||||
return res, path
|
||||
|
||||
|
||||
def clusters2labels(clusters, node_num):
|
||||
"""This is from https://github.com/GXYM/DRRG."""
|
||||
labels = (-1) * np.ones((node_num, ))
|
||||
for cluster_inx, cluster in enumerate(clusters):
|
||||
for node in cluster:
|
||||
labels[node.inx] = cluster_inx
|
||||
assert np.sum(labels < 0) < 1
|
||||
return labels
|
||||
|
||||
|
||||
def remove_single(text_comps, pred):
|
||||
"""Remove isolated single text components.
|
||||
|
||||
This is from https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
single_flags = np.zeros_like(pred)
|
||||
pred_labels = np.unique(pred)
|
||||
for label in pred_labels:
|
||||
current_label_flag = pred == label
|
||||
if np.sum(current_label_flag) == 1:
|
||||
single_flags[np.where(current_label_flag)[0][0]] = 1
|
||||
remain_inx = [i for i in range(len(pred)) if not single_flags[i]]
|
||||
remain_inx = np.asarray(remain_inx)
|
||||
return text_comps[remain_inx, :], pred[remain_inx]
|
||||
|
||||
|
||||
class Node:
|
||||
|
||||
def __init__(self, inx):
|
||||
self.__inx = inx
|
||||
self.__links = set()
|
||||
|
||||
@property
|
||||
def inx(self):
|
||||
return self.__inx
|
||||
|
||||
@property
|
||||
def links(self):
|
||||
return set(self.__links)
|
||||
|
||||
def add_link(self, other, score):
|
||||
self.__links.add(other)
|
||||
other.__links.add(self)
|
||||
|
||||
|
||||
def connected_components(nodes, score_dict, thr):
|
||||
"""Connected components searching.
|
||||
|
||||
This is from https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
|
||||
result = []
|
||||
nodes = set(nodes)
|
||||
while nodes:
|
||||
node = nodes.pop()
|
||||
group = {node}
|
||||
queue = [node]
|
||||
while queue:
|
||||
node = queue.pop(0)
|
||||
if thr is not None:
|
||||
neighbors = {
|
||||
linked_neighbor
|
||||
for linked_neighbor in node.links if score_dict[tuple(
|
||||
sorted([node.inx, linked_neighbor.inx]))] >= thr
|
||||
}
|
||||
else:
|
||||
neighbors = node.links
|
||||
neighbors.difference_update(group)
|
||||
nodes.difference_update(neighbors)
|
||||
group.update(neighbors)
|
||||
queue.extend(neighbors)
|
||||
result.append(group)
|
||||
return result
|
||||
|
||||
|
||||
def graph_propagation(edges,
|
||||
scores,
|
||||
link_thr,
|
||||
bboxes=None,
|
||||
dis_thr=50,
|
||||
pool='avg'):
|
||||
"""Propagate graph linkage score information.
|
||||
|
||||
This is from repo https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
edges = np.sort(edges, axis=1)
|
||||
|
||||
score_dict = {}
|
||||
if pool is None:
|
||||
for i, edge in enumerate(edges):
|
||||
score_dict[edge[0], edge[1]] = scores[i]
|
||||
elif pool == 'avg':
|
||||
for i, edge in enumerate(edges):
|
||||
if bboxes is not None:
|
||||
box1 = bboxes[edge[0]][:8].reshape(4, 2)
|
||||
box2 = bboxes[edge[1]][:8].reshape(4, 2)
|
||||
center1 = np.mean(box1, axis=0)
|
||||
center2 = np.mean(box2, axis=0)
|
||||
dst = norm(center1 - center2)
|
||||
if dst > dis_thr:
|
||||
scores[i] = 0
|
||||
if (edge[0], edge[1]) in score_dict:
|
||||
score_dict[edge[0], edge[1]] = 0.5 * (
|
||||
score_dict[edge[0], edge[1]] + scores[i])
|
||||
else:
|
||||
score_dict[edge[0], edge[1]] = scores[i]
|
||||
|
||||
elif pool == 'max':
|
||||
for i, edge in enumerate(edges):
|
||||
if (edge[0], edge[1]) in score_dict:
|
||||
score_dict[edge[0],
|
||||
edge[1]] = max(score_dict[edge[0], edge[1]],
|
||||
scores[i])
|
||||
else:
|
||||
score_dict[edge[0], edge[1]] = scores[i]
|
||||
else:
|
||||
raise ValueError('Pooling operation not supported')
|
||||
|
||||
nodes = np.sort(np.unique(edges.flatten()))
|
||||
mapping = -1 * np.ones((nodes.max() + 1), dtype=np.int)
|
||||
mapping[nodes] = np.arange(nodes.shape[0])
|
||||
link_inx = mapping[edges]
|
||||
vertex = [Node(node) for node in nodes]
|
||||
for link, score in zip(link_inx, scores):
|
||||
vertex[link[0]].add_link(vertex[link[1]], score)
|
||||
|
||||
clusters = connected_components(vertex, score_dict, link_thr)
|
||||
|
||||
return clusters
|
||||
|
||||
|
||||
def in_contour(cont, point):
|
||||
x, y = point
|
||||
return cv2.pointPolygonTest(cont, (x, y), False) > 0
|
||||
|
||||
|
||||
def select_edge(cont, box):
|
||||
"""This is from repo https://github.com/GXYM/DRRG."""
|
||||
cont = np.array(cont)
|
||||
box = box.astype(np.int32)
|
||||
c1 = np.array(0.5 * (box[0, :] + box[3, :]), dtype=np.int)
|
||||
c2 = np.array(0.5 * (box[1, :] + box[2, :]), dtype=np.int)
|
||||
|
||||
if not in_contour(cont, c1):
|
||||
return [box[0, :].tolist(), box[3, :].tolist()]
|
||||
elif not in_contour(cont, c2):
|
||||
return [box[1, :].tolist(), box[2, :].tolist()]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def comps2boundary(text_comps, final_pred):
|
||||
"""Propose text components and generate local graphs.
|
||||
|
||||
This is from repo https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
bbox_contours = list()
|
||||
for inx in range(0, int(np.max(final_pred)) + 1):
|
||||
current_instance = np.where(final_pred == inx)
|
||||
boxes = text_comps[current_instance, :8].reshape(
|
||||
(-1, 4, 2)).astype(np.int32)
|
||||
|
||||
boundary_point = None
|
||||
if boxes.shape[0] > 1:
|
||||
centers = np.mean(boxes, axis=1).astype(np.int32).tolist()
|
||||
paths, routes_path = min_connect_path(centers)
|
||||
boxes = boxes[routes_path]
|
||||
top = np.mean(boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
|
||||
bot = np.mean(boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
|
||||
edge1 = select_edge(top + bot[::-1], boxes[0])
|
||||
edge2 = select_edge(top + bot[::-1], boxes[-1])
|
||||
if edge1 is not None:
|
||||
top.insert(0, edge1[0])
|
||||
bot.insert(0, edge1[1])
|
||||
if edge2 is not None:
|
||||
top.append(edge2[0])
|
||||
bot.append(edge2[1])
|
||||
boundary_point = np.array(top + bot[::-1])
|
||||
|
||||
elif boxes.shape[0] == 1:
|
||||
top = boxes[0, 0:2, :].astype(np.int32).tolist()
|
||||
bot = boxes[0, 2:4:-1, :].astype(np.int32).tolist()
|
||||
boundary_point = np.array(top + bot)
|
||||
|
||||
if boundary_point is None:
|
||||
continue
|
||||
|
||||
boundary_point = [p for p in boundary_point.flatten().tolist()]
|
||||
bbox_contours.append(boundary_point)
|
||||
|
||||
return bbox_contours
|
||||
|
||||
|
||||
def merge_text_comps(edges, scores, text_comps, link_thr):
|
||||
"""Merge text components into text instance."""
|
||||
clusters = graph_propagation(edges, scores, link_thr)
|
||||
pred_labels = clusters2labels(clusters, text_comps.shape[0])
|
||||
text_comps, final_pred = remove_single(text_comps, pred_labels)
|
||||
boundaries = comps2boundary(text_comps, final_pred)
|
||||
|
||||
return boundaries
|
|
@ -0,0 +1,88 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import xavier_init
|
||||
from torch import nn
|
||||
|
||||
from mmdet.models.builder import NECKS
|
||||
|
||||
|
||||
class UpBlock(nn.Module):
|
||||
"""Upsample block for DRRG and TextSnake."""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(out_channels, int)
|
||||
|
||||
self.conv1x1 = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.conv3x3 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.deconv = nn.ConvTranspose2d(
|
||||
out_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1x1(x))
|
||||
x = F.relu(self.conv3x3(x))
|
||||
x = self.deconv(x)
|
||||
return x
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
class FPN_UNET(nn.Module):
|
||||
"""The class for implementing DRRG and TextSnake U-Net-like FPN.
|
||||
|
||||
DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape
|
||||
Text Detection [https://arxiv.org/abs/2003.07493].
|
||||
TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes
|
||||
[https://arxiv.org/abs/1807.01544].
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
assert len(in_channels) == 4
|
||||
assert isinstance(out_channels, int)
|
||||
|
||||
blocks_out_channels = [out_channels] + [
|
||||
min(out_channels * 2**i, 256) for i in range(4)
|
||||
]
|
||||
blocks_in_channels = [blocks_out_channels[1]] + [
|
||||
in_channels[i] + blocks_out_channels[i + 2] for i in range(3)
|
||||
] + [in_channels[3]]
|
||||
|
||||
self.up4 = nn.ConvTranspose2d(
|
||||
blocks_in_channels[4],
|
||||
blocks_out_channels[4],
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.up_block3 = UpBlock(blocks_in_channels[3], blocks_out_channels[3])
|
||||
self.up_block2 = UpBlock(blocks_in_channels[2], blocks_out_channels[2])
|
||||
self.up_block1 = UpBlock(blocks_in_channels[1], blocks_out_channels[1])
|
||||
self.up_block0 = UpBlock(blocks_in_channels[0], blocks_out_channels[0])
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
xavier_init(m, distribution='uniform')
|
||||
|
||||
def forward(self, x):
|
||||
c2, c3, c4, c5 = x
|
||||
|
||||
x = F.relu(self.up4(c5))
|
||||
|
||||
x = torch.cat([x, c4], dim=1)
|
||||
x = F.relu(self.up_block3(x))
|
||||
|
||||
x = torch.cat([x, c3], dim=1)
|
||||
x = F.relu(self.up_block2(x))
|
||||
|
||||
x = torch.cat([x, c2], dim=1)
|
||||
x = F.relu(self.up_block1(x))
|
||||
|
||||
x = self.up_block0(x)
|
||||
# the output should be of the same height and width as backbone input
|
||||
return x
|
Loading…
Reference in New Issue