389 lines
16 KiB
Python
389 lines
16 KiB
Python
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
This code is refer from:
|
|
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/local_graph.py
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import paddle.nn as nn
|
|
from ppocr.ext_op import RoIAlignRotated
|
|
|
|
|
|
def normalize_adjacent_matrix(A):
|
|
assert A.ndim == 2
|
|
assert A.shape[0] == A.shape[1]
|
|
|
|
A = A + np.eye(A.shape[0])
|
|
d = np.sum(A, axis=0)
|
|
d = np.clip(d, 0, None)
|
|
d_inv = np.power(d, -0.5).flatten()
|
|
d_inv[np.isinf(d_inv)] = 0.0
|
|
d_inv = np.diag(d_inv)
|
|
G = A.dot(d_inv).transpose().dot(d_inv)
|
|
return G
|
|
|
|
|
|
def euclidean_distance_matrix(A, B):
|
|
"""Calculate the Euclidean distance matrix.
|
|
|
|
Args:
|
|
A (ndarray): The point sequence.
|
|
B (ndarray): The point sequence with the same dimensions as A.
|
|
|
|
returns:
|
|
D (ndarray): The Euclidean distance matrix.
|
|
"""
|
|
assert A.ndim == 2
|
|
assert B.ndim == 2
|
|
assert A.shape[1] == B.shape[1]
|
|
|
|
m = A.shape[0]
|
|
n = B.shape[0]
|
|
|
|
A_dots = (A * A).sum(axis=1).reshape((m, 1)) * np.ones(shape=(1, n))
|
|
B_dots = (B * B).sum(axis=1) * np.ones(shape=(m, 1))
|
|
D_squared = A_dots + B_dots - 2 * A.dot(B.T)
|
|
|
|
zero_mask = np.less(D_squared, 0.0)
|
|
D_squared[zero_mask] = 0.0
|
|
D = np.sqrt(D_squared)
|
|
return D
|
|
|
|
|
|
def feature_embedding(input_feats, out_feat_len):
|
|
"""Embed features. This code was partially adapted from
|
|
https://github.com/GXYM/DRRG licensed under the MIT license.
|
|
|
|
Args:
|
|
input_feats (ndarray): The input features of shape (N, d), where N is
|
|
the number of nodes in graph, d is the input feature vector length.
|
|
out_feat_len (int): The length of output feature vector.
|
|
|
|
Returns:
|
|
embedded_feats (ndarray): The embedded features.
|
|
"""
|
|
assert input_feats.ndim == 2
|
|
assert isinstance(out_feat_len, int)
|
|
assert out_feat_len >= input_feats.shape[1]
|
|
|
|
num_nodes = input_feats.shape[0]
|
|
feat_dim = input_feats.shape[1]
|
|
feat_repeat_times = out_feat_len // feat_dim
|
|
residue_dim = out_feat_len % feat_dim
|
|
|
|
if residue_dim > 0:
|
|
embed_wave = np.array([
|
|
np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
|
|
for j in range(feat_repeat_times + 1)
|
|
]).reshape((feat_repeat_times + 1, 1, 1))
|
|
repeat_feats = np.repeat(
|
|
np.expand_dims(
|
|
input_feats, axis=0), feat_repeat_times, axis=0)
|
|
residue_feats = np.hstack([
|
|
input_feats[:, 0:residue_dim], np.zeros(
|
|
(num_nodes, feat_dim - residue_dim))
|
|
])
|
|
residue_feats = np.expand_dims(residue_feats, axis=0)
|
|
repeat_feats = np.concatenate([repeat_feats, residue_feats], axis=0)
|
|
embedded_feats = repeat_feats / embed_wave
|
|
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
|
|
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
|
|
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
|
|
(num_nodes, -1))[:, 0:out_feat_len]
|
|
else:
|
|
embed_wave = np.array([
|
|
np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
|
|
for j in range(feat_repeat_times)
|
|
]).reshape((feat_repeat_times, 1, 1))
|
|
repeat_feats = np.repeat(
|
|
np.expand_dims(
|
|
input_feats, axis=0), feat_repeat_times, axis=0)
|
|
embedded_feats = repeat_feats / embed_wave
|
|
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
|
|
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
|
|
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
|
|
(num_nodes, -1)).astype(np.float32)
|
|
|
|
return embedded_feats
|
|
|
|
|
|
class LocalGraphs:
|
|
def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
|
|
pooling_scale, pooling_output_size, local_graph_thr):
|
|
|
|
assert len(k_at_hops) == 2
|
|
assert all(isinstance(n, int) for n in k_at_hops)
|
|
assert isinstance(num_adjacent_linkages, int)
|
|
assert isinstance(node_geo_feat_len, int)
|
|
assert isinstance(pooling_scale, float)
|
|
assert all(isinstance(n, int) for n in pooling_output_size)
|
|
assert isinstance(local_graph_thr, float)
|
|
|
|
self.k_at_hops = k_at_hops
|
|
self.num_adjacent_linkages = num_adjacent_linkages
|
|
self.node_geo_feat_dim = node_geo_feat_len
|
|
self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
|
|
self.local_graph_thr = local_graph_thr
|
|
|
|
def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels):
|
|
"""Generate local graphs for GCN to predict which instance a text
|
|
component belongs to.
|
|
|
|
Args:
|
|
sorted_dist_inds (ndarray): The complete graph node indices, which
|
|
is sorted according to the Euclidean distance.
|
|
gt_comp_labels(ndarray): The ground truth labels define the
|
|
instance to which the text components (nodes in graphs) belong.
|
|
|
|
Returns:
|
|
pivot_local_graphs(list[list[int]]): The list of local graph
|
|
neighbor indices of pivots.
|
|
pivot_knns(list[list[int]]): The list of k-nearest neighbor indices
|
|
of pivots.
|
|
"""
|
|
|
|
assert sorted_dist_inds.ndim == 2
|
|
assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
|
|
gt_comp_labels.shape[0])
|
|
|
|
knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
|
|
pivot_local_graphs = []
|
|
pivot_knns = []
|
|
for pivot_ind, knn in enumerate(knn_graph):
|
|
|
|
local_graph_neighbors = set(knn)
|
|
|
|
for neighbor_ind in knn:
|
|
local_graph_neighbors.update(
|
|
set(sorted_dist_inds[neighbor_ind, 1:self.k_at_hops[1] +
|
|
1]))
|
|
|
|
local_graph_neighbors.discard(pivot_ind)
|
|
pivot_local_graph = list(local_graph_neighbors)
|
|
pivot_local_graph.insert(0, pivot_ind)
|
|
pivot_knn = [pivot_ind] + list(knn)
|
|
|
|
if pivot_ind < 1:
|
|
pivot_local_graphs.append(pivot_local_graph)
|
|
pivot_knns.append(pivot_knn)
|
|
else:
|
|
add_flag = True
|
|
for graph_ind, added_knn in enumerate(pivot_knns):
|
|
added_pivot_ind = added_knn[0]
|
|
added_local_graph = pivot_local_graphs[graph_ind]
|
|
|
|
union = len(
|
|
set(pivot_local_graph[1:]).union(
|
|
set(added_local_graph[1:])))
|
|
intersect = len(
|
|
set(pivot_local_graph[1:]).intersection(
|
|
set(added_local_graph[1:])))
|
|
local_graph_iou = intersect / (union + 1e-8)
|
|
|
|
if (local_graph_iou > self.local_graph_thr and
|
|
pivot_ind in added_knn and
|
|
gt_comp_labels[added_pivot_ind] ==
|
|
gt_comp_labels[pivot_ind] and
|
|
gt_comp_labels[pivot_ind] != 0):
|
|
add_flag = False
|
|
break
|
|
if add_flag:
|
|
pivot_local_graphs.append(pivot_local_graph)
|
|
pivot_knns.append(pivot_knn)
|
|
|
|
return pivot_local_graphs, pivot_knns
|
|
|
|
def generate_gcn_input(self, node_feat_batch, node_label_batch,
|
|
local_graph_batch, knn_batch, sorted_dist_ind_batch):
|
|
"""Generate graph convolution network input data.
|
|
|
|
Args:
|
|
node_feat_batch (List[Tensor]): The batched graph node features.
|
|
node_label_batch (List[ndarray]): The batched text component
|
|
labels.
|
|
local_graph_batch (List[List[list[int]]]): The local graph node
|
|
indices of image batch.
|
|
knn_batch (List[List[list[int]]]): The knn graph node indices of
|
|
image batch.
|
|
sorted_dist_ind_batch (list[ndarray]): The node indices sorted
|
|
according to the Euclidean distance.
|
|
|
|
Returns:
|
|
local_graphs_node_feat (Tensor): The node features of graph.
|
|
adjacent_matrices (Tensor): The adjacent matrices of local graphs.
|
|
pivots_knn_inds (Tensor): The k-nearest neighbor indices in
|
|
local graph.
|
|
gt_linkage (Tensor): The surpervision signal of GCN for linkage
|
|
prediction.
|
|
"""
|
|
assert isinstance(node_feat_batch, list)
|
|
assert isinstance(node_label_batch, list)
|
|
assert isinstance(local_graph_batch, list)
|
|
assert isinstance(knn_batch, list)
|
|
assert isinstance(sorted_dist_ind_batch, list)
|
|
|
|
num_max_nodes = max([
|
|
len(pivot_local_graph)
|
|
for pivot_local_graphs in local_graph_batch
|
|
for pivot_local_graph in pivot_local_graphs
|
|
])
|
|
|
|
local_graphs_node_feat = []
|
|
adjacent_matrices = []
|
|
pivots_knn_inds = []
|
|
pivots_gt_linkage = []
|
|
|
|
for batch_ind, sorted_dist_inds in enumerate(sorted_dist_ind_batch):
|
|
node_feats = node_feat_batch[batch_ind]
|
|
pivot_local_graphs = local_graph_batch[batch_ind]
|
|
pivot_knns = knn_batch[batch_ind]
|
|
node_labels = node_label_batch[batch_ind]
|
|
|
|
for graph_ind, pivot_knn in enumerate(pivot_knns):
|
|
pivot_local_graph = pivot_local_graphs[graph_ind]
|
|
num_nodes = len(pivot_local_graph)
|
|
pivot_ind = pivot_local_graph[0]
|
|
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
|
|
|
|
knn_inds = paddle.to_tensor(
|
|
[node2ind_map[i] for i in pivot_knn[1:]])
|
|
pivot_feats = node_feats[pivot_ind]
|
|
normalized_feats = node_feats[paddle.to_tensor(
|
|
pivot_local_graph)] - pivot_feats
|
|
|
|
adjacent_matrix = np.zeros(
|
|
(num_nodes, num_nodes), dtype=np.float32)
|
|
for node in pivot_local_graph:
|
|
neighbors = sorted_dist_inds[node, 1:
|
|
self.num_adjacent_linkages + 1]
|
|
for neighbor in neighbors:
|
|
if neighbor in pivot_local_graph:
|
|
|
|
adjacent_matrix[node2ind_map[node], node2ind_map[
|
|
neighbor]] = 1
|
|
adjacent_matrix[node2ind_map[neighbor],
|
|
node2ind_map[node]] = 1
|
|
|
|
adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
|
|
pad_adjacent_matrix = paddle.zeros(
|
|
(num_max_nodes, num_max_nodes))
|
|
pad_adjacent_matrix[:num_nodes, :num_nodes] = paddle.cast(
|
|
paddle.to_tensor(adjacent_matrix), 'float32')
|
|
|
|
pad_normalized_feats = paddle.concat(
|
|
[
|
|
normalized_feats, paddle.zeros(
|
|
(num_max_nodes - num_nodes,
|
|
normalized_feats.shape[1]))
|
|
],
|
|
axis=0)
|
|
local_graph_labels = node_labels[pivot_local_graph]
|
|
knn_labels = local_graph_labels[knn_inds.numpy()]
|
|
link_labels = ((node_labels[pivot_ind] == knn_labels) &
|
|
(node_labels[pivot_ind] > 0)).astype(np.int64)
|
|
link_labels = paddle.to_tensor(link_labels)
|
|
|
|
local_graphs_node_feat.append(pad_normalized_feats)
|
|
adjacent_matrices.append(pad_adjacent_matrix)
|
|
pivots_knn_inds.append(knn_inds)
|
|
pivots_gt_linkage.append(link_labels)
|
|
|
|
local_graphs_node_feat = paddle.stack(local_graphs_node_feat, 0)
|
|
adjacent_matrices = paddle.stack(adjacent_matrices, 0)
|
|
pivots_knn_inds = paddle.stack(pivots_knn_inds, 0)
|
|
pivots_gt_linkage = paddle.stack(pivots_gt_linkage, 0)
|
|
|
|
return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
|
|
pivots_gt_linkage)
|
|
|
|
def __call__(self, feat_maps, comp_attribs):
|
|
"""Generate local graphs as GCN input.
|
|
|
|
Args:
|
|
feat_maps (Tensor): The feature maps to extract the content
|
|
features of text components.
|
|
comp_attribs (ndarray): The text component attributes.
|
|
|
|
Returns:
|
|
local_graphs_node_feat (Tensor): The node features of graph.
|
|
adjacent_matrices (Tensor): The adjacent matrices of local graphs.
|
|
pivots_knn_inds (Tensor): The k-nearest neighbor indices in local
|
|
graph.
|
|
gt_linkage (Tensor): The surpervision signal of GCN for linkage
|
|
prediction.
|
|
"""
|
|
|
|
assert isinstance(feat_maps, paddle.Tensor)
|
|
assert comp_attribs.ndim == 3
|
|
assert comp_attribs.shape[2] == 8
|
|
|
|
sorted_dist_inds_batch = []
|
|
local_graph_batch = []
|
|
knn_batch = []
|
|
node_feat_batch = []
|
|
node_label_batch = []
|
|
|
|
for batch_ind in range(comp_attribs.shape[0]):
|
|
num_comps = int(comp_attribs[batch_ind, 0, 0])
|
|
comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7]
|
|
node_labels = comp_attribs[batch_ind, :num_comps, 7].astype(
|
|
np.int32)
|
|
|
|
comp_centers = comp_geo_attribs[:, 0:2]
|
|
distance_matrix = euclidean_distance_matrix(comp_centers,
|
|
comp_centers)
|
|
|
|
batch_id = np.zeros(
|
|
(comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind
|
|
comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1)
|
|
angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign(
|
|
comp_geo_attribs[:, -1])
|
|
angle = angle.reshape((-1, 1))
|
|
rotated_rois = np.hstack(
|
|
[batch_id, comp_geo_attribs[:, :-2], angle])
|
|
rois = paddle.to_tensor(rotated_rois)
|
|
content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0),
|
|
rois)
|
|
|
|
content_feats = content_feats.reshape([content_feats.shape[0], -1])
|
|
geo_feats = feature_embedding(comp_geo_attribs,
|
|
self.node_geo_feat_dim)
|
|
geo_feats = paddle.to_tensor(geo_feats)
|
|
node_feats = paddle.concat([content_feats, geo_feats], axis=-1)
|
|
|
|
sorted_dist_inds = np.argsort(distance_matrix, axis=1)
|
|
pivot_local_graphs, pivot_knns = self.generate_local_graphs(
|
|
sorted_dist_inds, node_labels)
|
|
|
|
node_feat_batch.append(node_feats)
|
|
node_label_batch.append(node_labels)
|
|
local_graph_batch.append(pivot_local_graphs)
|
|
knn_batch.append(pivot_knns)
|
|
sorted_dist_inds_batch.append(sorted_dist_inds)
|
|
|
|
(node_feats, adjacent_matrices, knn_inds, gt_linkage) = \
|
|
self.generate_gcn_input(node_feat_batch,
|
|
node_label_batch,
|
|
local_graph_batch,
|
|
knn_batch,
|
|
sorted_dist_inds_batch)
|
|
|
|
return node_feats, adjacent_matrices, knn_inds, gt_linkage
|