mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[DRRG] DRRG postprocessor
This commit is contained in:
parent
9db0941837
commit
ed9e8d150c
@ -50,3 +50,6 @@ mmocr/datasets/utils/backend.py
|
||||
mmocr/datasets/utils/loader.py
|
||||
# It will be removed after TTA refactor
|
||||
mmocr/datasets/pipelines/test_time_aug.py
|
||||
|
||||
# Cover it by tests seems like an impossible mission
|
||||
mmocr/models/textdet/postprocessors/drrg_postprocessor.py
|
||||
|
@ -1,41 +1,447 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
import operator
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import InstanceData
|
||||
from numpy import ndarray
|
||||
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.registry import MODELS
|
||||
from .base_postprocessor import BasePostprocessor
|
||||
from .utils import (clusters2labels, comps2boundaries, connected_components,
|
||||
graph_propagation, remove_single)
|
||||
from .base_postprocessor import BaseTextDetPostProcessor
|
||||
|
||||
|
||||
class Node:
|
||||
"""A simple graph node.
|
||||
|
||||
Args:
|
||||
ind (int): The index of the node.
|
||||
"""
|
||||
|
||||
def __init__(self, ind: int) -> None:
|
||||
self.__ind = ind
|
||||
self.__links = set()
|
||||
|
||||
@property
|
||||
def ind(self) -> int:
|
||||
"""Current node index."""
|
||||
return self.__ind
|
||||
|
||||
@property
|
||||
def links(self) -> set:
|
||||
"""A set of links."""
|
||||
return set(self.__links)
|
||||
|
||||
def add_link(self, link_node: 'Node') -> None:
|
||||
"""Add a link to the node.
|
||||
|
||||
Args:
|
||||
link_node (Node): The link node.
|
||||
"""
|
||||
self.__links.add(link_node)
|
||||
link_node.__links.add(self)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DRRGPostprocessor(BasePostprocessor):
|
||||
class DRRGPostprocessor(BaseTextDetPostProcessor):
|
||||
"""Merge text components and construct boundaries of text instances.
|
||||
|
||||
Args:
|
||||
link_thr (float): The edge score threshold.
|
||||
link_thr (float): The edge score threshold. Defaults to 0.8.
|
||||
edge_len_thr (int or float): The edge length threshold. Defaults to 50.
|
||||
rescale_fields (list[str]): The bbox/polygon field names to
|
||||
be rescaled. If None, no rescaling will be performed. Defaults to
|
||||
[polygons'].
|
||||
"""
|
||||
|
||||
def __init__(self, link_thr, **kwargs):
|
||||
def __init__(self,
|
||||
link_thr: float = 0.8,
|
||||
edge_len_thr: Union[int, float] = 50.,
|
||||
rescale_fields=['polygons'],
|
||||
**kwargs) -> None:
|
||||
super().__init__(rescale_fields=rescale_fields)
|
||||
assert isinstance(link_thr, float)
|
||||
assert isinstance(edge_len_thr, (int, float))
|
||||
self.link_thr = link_thr
|
||||
self.edge_len_thr = edge_len_thr
|
||||
|
||||
def __call__(self, edges, scores, text_comps):
|
||||
def get_text_instances(self, pred_results: Tuple[ndarray, ndarray,
|
||||
ndarray],
|
||||
data_sample: TextDetDataSample
|
||||
) -> TextDetDataSample:
|
||||
"""Get text instance predictions of one image.
|
||||
|
||||
Args:
|
||||
pred_result (tuple(ndarray, ndarray, ndarray)): Prediction results
|
||||
edge, score and text_comps. Each of shape
|
||||
:math:`(N_{edges}, 2)`, :math:`(N_{edges},)` and
|
||||
:math:`(M, 9)`, respectively.
|
||||
data_sample (TextDetDataSample): Datasample of an image.
|
||||
|
||||
Returns:
|
||||
TextDetDataSample: The original dataSample with predictions filled
|
||||
in. Polygons and results are saved in
|
||||
``TextDetDataSample.pred_instances.polygons``. The confidence
|
||||
scores are saved in ``TextDetDataSample.pred_instances.scores``.
|
||||
"""
|
||||
|
||||
data_sample.pred_instances = InstanceData()
|
||||
polys = []
|
||||
scores = []
|
||||
|
||||
pred_edges, pred_scores, text_comps = pred_results
|
||||
|
||||
if pred_edges is not None:
|
||||
assert len(pred_edges) == len(pred_scores)
|
||||
assert text_comps.ndim == 2
|
||||
assert text_comps.shape[1] == 9
|
||||
|
||||
vertices, score_dict = self._graph_propagation(
|
||||
pred_edges, pred_scores, text_comps)
|
||||
clusters = self._connected_components(vertices, score_dict)
|
||||
pred_labels = self._clusters2labels(clusters, text_comps.shape[0])
|
||||
text_comps, pred_labels = self._remove_single(
|
||||
text_comps, pred_labels)
|
||||
polys, scores = self._comps2polys(text_comps, pred_labels)
|
||||
|
||||
data_sample.pred_instances.polygons = polys
|
||||
data_sample.pred_instances.scores = torch.FloatTensor(scores)
|
||||
|
||||
return data_sample
|
||||
|
||||
def split_results(self, pred_results: Tuple[ndarray, ndarray,
|
||||
ndarray]) -> List[Tuple]:
|
||||
"""Split batched elements in pred_results along the first dimension
|
||||
into ``batch_num`` sub-elements and regather them into a list of dicts.
|
||||
|
||||
However, DRRG only outputs one batch at inference time, so this
|
||||
function is a no-op.
|
||||
"""
|
||||
return [pred_results]
|
||||
|
||||
def _graph_propagation(self, edges: ndarray, scores: ndarray,
|
||||
text_comps: ndarray) -> Tuple[List[Node], Dict]:
|
||||
"""Propagate edge score information and construct graph. This code was
|
||||
partially adapted from https://github.com/GXYM/DRRG licensed under the
|
||||
MIT license.
|
||||
|
||||
Args:
|
||||
edges (ndarray): The edge array of shape N * 2, each row is a node
|
||||
index pair that makes up an edge in graph.
|
||||
scores (ndarray): The edge score array of shape (N,).
|
||||
scores (ndarray): The edge score array.
|
||||
text_comps (ndarray): The text components.
|
||||
|
||||
Returns:
|
||||
List[list[float]]: The predicted boundaries of text instances.
|
||||
tuple(vertices, score_dict):
|
||||
|
||||
- vertices (list[Node]): The Nodes in graph.
|
||||
- score_dict (dict): The edge score dict.
|
||||
"""
|
||||
assert len(edges) == len(scores)
|
||||
assert edges.ndim == 2
|
||||
assert edges.shape[1] == 2
|
||||
assert edges.shape[0] == scores.shape[0]
|
||||
assert text_comps.ndim == 2
|
||||
assert text_comps.shape[1] == 9
|
||||
|
||||
vertices, score_dict = graph_propagation(edges, scores, text_comps)
|
||||
clusters = connected_components(vertices, score_dict, self.link_thr)
|
||||
pred_labels = clusters2labels(clusters, text_comps.shape[0])
|
||||
text_comps, pred_labels = remove_single(text_comps, pred_labels)
|
||||
boundaries = comps2boundaries(text_comps, pred_labels)
|
||||
edges = np.sort(edges, axis=1)
|
||||
score_dict = {}
|
||||
for i, edge in enumerate(edges):
|
||||
if text_comps is not None:
|
||||
box1 = text_comps[edge[0], :8].reshape(4, 2)
|
||||
box2 = text_comps[edge[1], :8].reshape(4, 2)
|
||||
center1 = np.mean(box1, axis=0)
|
||||
center2 = np.mean(box2, axis=0)
|
||||
distance = np.linalg.norm(center1 - center2)
|
||||
if distance > self.edge_len_thr:
|
||||
scores[i] = 0
|
||||
if (edge[0], edge[1]) in score_dict:
|
||||
score_dict[edge[0], edge[1]] = 0.5 * (
|
||||
score_dict[edge[0], edge[1]] + scores[i])
|
||||
else:
|
||||
score_dict[edge[0], edge[1]] = scores[i]
|
||||
|
||||
return boundaries
|
||||
nodes = np.sort(np.unique(edges.flatten()))
|
||||
mapping = -1 * np.ones((np.max(nodes) + 1), dtype=np.int)
|
||||
mapping[nodes] = np.arange(nodes.shape[0])
|
||||
order_inds = mapping[edges]
|
||||
vertices = [Node(node) for node in nodes]
|
||||
for ind in order_inds:
|
||||
vertices[ind[0]].add_link(vertices[ind[1]])
|
||||
|
||||
return vertices, score_dict
|
||||
|
||||
def _connected_components(self, nodes: List[Node],
|
||||
score_dict: Dict) -> List[List[Node]]:
|
||||
"""Conventional connected components searching. This code was partially
|
||||
adapted from https://github.com/GXYM/DRRG licensed under the MIT
|
||||
license.
|
||||
|
||||
Args:
|
||||
nodes (list[Node]): The list of Node objects.
|
||||
score_dict (dict): The edge score dict.
|
||||
|
||||
Returns:
|
||||
List[list[Node]]: The clustered Node objects.
|
||||
"""
|
||||
assert isinstance(nodes, list)
|
||||
assert all([isinstance(node, Node) for node in nodes])
|
||||
assert isinstance(score_dict, dict)
|
||||
|
||||
clusters = []
|
||||
nodes = set(nodes)
|
||||
while nodes:
|
||||
node = nodes.pop()
|
||||
cluster = {node}
|
||||
node_queue = [node]
|
||||
while node_queue:
|
||||
node = node_queue.pop(0)
|
||||
neighbors = {
|
||||
neighbor
|
||||
for neighbor in node.links if score_dict[tuple(
|
||||
sorted([node.ind, neighbor.ind]))] >= self.link_thr
|
||||
}
|
||||
neighbors.difference_update(cluster)
|
||||
nodes.difference_update(neighbors)
|
||||
cluster.update(neighbors)
|
||||
node_queue.extend(neighbors)
|
||||
clusters.append(list(cluster))
|
||||
return clusters
|
||||
|
||||
def _clusters2labels(self, clusters: List[List[Node]],
|
||||
num_nodes: int) -> ndarray:
|
||||
"""Convert clusters of Node to text component labels. This code was
|
||||
partially adapted from https://github.com/GXYM/DRRG licensed under the
|
||||
MIT license.
|
||||
|
||||
Args:
|
||||
clusters (List[list[Node]]): The clusters of Node objects.
|
||||
num_nodes (int): The total node number of graphs in an image.
|
||||
|
||||
Returns:
|
||||
ndarray: The node label array.
|
||||
"""
|
||||
assert isinstance(clusters, list)
|
||||
assert all([isinstance(cluster, list) for cluster in clusters])
|
||||
assert all([
|
||||
isinstance(node, Node) for cluster in clusters for node in cluster
|
||||
])
|
||||
assert isinstance(num_nodes, int)
|
||||
|
||||
node_labels = np.zeros(num_nodes)
|
||||
for cluster_ind, cluster in enumerate(clusters):
|
||||
for node in cluster:
|
||||
node_labels[node.ind] = cluster_ind
|
||||
return node_labels
|
||||
|
||||
def _remove_single(self, text_comps: ndarray,
|
||||
comp_pred_labels: ndarray) -> Tuple[ndarray, ndarray]:
|
||||
"""Remove isolated text components. This code was partially adapted
|
||||
from https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
text_comps (ndarray): The text components.
|
||||
comp_pred_labels (ndarray): The clustering labels of text
|
||||
components.
|
||||
|
||||
Returns:
|
||||
tuple(filtered_text_comps, comp_pred_labels):
|
||||
|
||||
- filtered_text_comps (ndarray): The text components with isolated
|
||||
ones removed.
|
||||
- comp_pred_labels (ndarray): The clustering labels with labels of
|
||||
isolated text components removed.
|
||||
"""
|
||||
assert text_comps.ndim == 2
|
||||
assert text_comps.shape[0] == comp_pred_labels.shape[0]
|
||||
|
||||
single_flags = np.zeros_like(comp_pred_labels)
|
||||
pred_labels = np.unique(comp_pred_labels)
|
||||
for label in pred_labels:
|
||||
current_label_flag = (comp_pred_labels == label)
|
||||
if np.sum(current_label_flag) == 1:
|
||||
single_flags[np.where(current_label_flag)[0][0]] = 1
|
||||
keep_ind = [
|
||||
i for i in range(len(comp_pred_labels)) if not single_flags[i]
|
||||
]
|
||||
filtered_text_comps = text_comps[keep_ind, :]
|
||||
filtered_labels = comp_pred_labels[keep_ind]
|
||||
|
||||
return filtered_text_comps, filtered_labels
|
||||
|
||||
def _comps2polys(self, text_comps: ndarray, comp_pred_labels: ndarray
|
||||
) -> Tuple[List[ndarray], List[float]]:
|
||||
"""Construct text instance boundaries from clustered text components.
|
||||
This code was partially adapted from https://github.com/GXYM/DRRG
|
||||
licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
text_comps (ndarray): The text components.
|
||||
comp_pred_labels (ndarray): The clustering labels of text
|
||||
components.
|
||||
|
||||
Returns:
|
||||
tuple(boundaries, scores):
|
||||
|
||||
- boundaries (list[ndarray]): The predicted boundaries of text
|
||||
instances.
|
||||
- scores (list[float]): The boundary scores.
|
||||
"""
|
||||
assert text_comps.ndim == 2
|
||||
assert len(text_comps) == len(comp_pred_labels)
|
||||
boundaries = []
|
||||
scores = []
|
||||
if len(text_comps) < 1:
|
||||
return boundaries, scores
|
||||
for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1):
|
||||
cluster_comp_inds = np.where(comp_pred_labels == cluster_ind)
|
||||
text_comp_boxes = text_comps[cluster_comp_inds, :8].reshape(
|
||||
(-1, 4, 2)).astype(np.int32)
|
||||
score = np.mean(text_comps[cluster_comp_inds, -1])
|
||||
|
||||
if text_comp_boxes.shape[0] < 1:
|
||||
continue
|
||||
|
||||
elif text_comp_boxes.shape[0] > 1:
|
||||
centers = np.mean(
|
||||
text_comp_boxes, axis=1).astype(np.int32).tolist()
|
||||
shortest_path = self._min_connect_path(centers)
|
||||
text_comp_boxes = text_comp_boxes[shortest_path]
|
||||
top_line = np.mean(
|
||||
text_comp_boxes[:, 0:2, :],
|
||||
axis=1).astype(np.int32).tolist()
|
||||
bot_line = np.mean(
|
||||
text_comp_boxes[:, 2:4, :],
|
||||
axis=1).astype(np.int32).tolist()
|
||||
top_line, bot_line = self._fix_corner(top_line, bot_line,
|
||||
text_comp_boxes[0],
|
||||
text_comp_boxes[-1])
|
||||
boundary_points = top_line + bot_line[::-1]
|
||||
|
||||
else:
|
||||
top_line = text_comp_boxes[0, 0:2, :].astype(np.int32).tolist()
|
||||
bot_line = text_comp_boxes[0, 2:4:-1, :].astype(
|
||||
np.int32).tolist()
|
||||
boundary_points = top_line + bot_line
|
||||
|
||||
boundary = [p for coord in boundary_points for p in coord]
|
||||
boundaries.append(np.array(boundary, dtype=np.float32))
|
||||
scores.append(score)
|
||||
|
||||
return boundaries, scores
|
||||
|
||||
def _norm2(self, point1: List[int], point2: List[int]) -> float:
|
||||
"""Calculate the norm of two points."""
|
||||
return ((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)**0.5
|
||||
|
||||
def _min_connect_path(self, points: List[List[int]]) -> List[List[int]]:
|
||||
"""Find the shortest path to traverse all points. This code was
|
||||
partially adapted from https://github.com/GXYM/DRRG licensed under the
|
||||
MIT license.
|
||||
|
||||
Args:
|
||||
points(List[list[int]]): The point sequence
|
||||
[[x0, y0], [x1, y1], ...].
|
||||
|
||||
Returns:
|
||||
List[list[int]]: The shortest index path.
|
||||
"""
|
||||
assert isinstance(points, list)
|
||||
assert all([isinstance(point, list) for point in points])
|
||||
assert all(
|
||||
[isinstance(coord, int) for point in points for coord in point])
|
||||
|
||||
points_queue = points.copy()
|
||||
shortest_path = []
|
||||
current_edge = [[], []]
|
||||
|
||||
edge_dict0 = {}
|
||||
edge_dict1 = {}
|
||||
current_edge[0] = points_queue[0]
|
||||
current_edge[1] = points_queue[0]
|
||||
points_queue.remove(points_queue[0])
|
||||
while points_queue:
|
||||
for point in points_queue:
|
||||
length0 = self._norm2(point, current_edge[0])
|
||||
edge_dict0[length0] = [point, current_edge[0]]
|
||||
length1 = self._norm2(current_edge[1], point)
|
||||
edge_dict1[length1] = [current_edge[1], point]
|
||||
key0 = min(edge_dict0.keys())
|
||||
key1 = min(edge_dict1.keys())
|
||||
|
||||
if key0 <= key1:
|
||||
start = edge_dict0[key0][0]
|
||||
end = edge_dict0[key0][1]
|
||||
shortest_path.insert(0,
|
||||
[points.index(start),
|
||||
points.index(end)])
|
||||
points_queue.remove(start)
|
||||
current_edge[0] = start
|
||||
else:
|
||||
start = edge_dict1[key1][0]
|
||||
end = edge_dict1[key1][1]
|
||||
shortest_path.append([points.index(start), points.index(end)])
|
||||
points_queue.remove(end)
|
||||
current_edge[1] = end
|
||||
|
||||
edge_dict0 = {}
|
||||
edge_dict1 = {}
|
||||
|
||||
shortest_path = functools.reduce(operator.concat, shortest_path)
|
||||
shortest_path = sorted(set(shortest_path), key=shortest_path.index)
|
||||
|
||||
return shortest_path
|
||||
|
||||
def _in_contour(self, contour: ndarray, point: ndarray) -> bool:
|
||||
"""Whether a point is in a contour."""
|
||||
x, y = point
|
||||
return cv2.pointPolygonTest(contour, (int(x), int(y)), False) > 0.5
|
||||
|
||||
def _fix_corner(self, top_line: List[List[int]], btm_line: List[List[int]],
|
||||
start_box: ndarray, end_box: ndarray
|
||||
) -> Tuple[List[List[int]], List[List[int]]]:
|
||||
"""Add corner points to predicted side lines. This code was partially
|
||||
adapted from https://github.com/GXYM/DRRG licensed under the MIT
|
||||
license.
|
||||
|
||||
Args:
|
||||
top_line (List[list[int]]): The predicted top sidelines of text
|
||||
instance.
|
||||
btm_line (List[list[int]]): The predicted bottom sidelines of text
|
||||
instance.
|
||||
start_box (ndarray): The first text component box.
|
||||
end_box (ndarray): The last text component box.
|
||||
|
||||
Returns:
|
||||
tuple(top_line, bot_line):
|
||||
|
||||
- top_line (List[list[int]]): The top sidelines with corner point
|
||||
added.
|
||||
- bot_line (List[list[int]]): The bottom sidelines with corner
|
||||
point added.
|
||||
"""
|
||||
assert isinstance(top_line, list)
|
||||
assert all(isinstance(point, list) for point in top_line)
|
||||
assert isinstance(btm_line, list)
|
||||
assert all(isinstance(point, list) for point in btm_line)
|
||||
assert start_box.shape == end_box.shape == (4, 2)
|
||||
|
||||
contour = np.array(top_line + btm_line[::-1])
|
||||
start_left_mid = (start_box[0] + start_box[3]) / 2
|
||||
start_right_mid = (start_box[1] + start_box[2]) / 2
|
||||
end_left_mid = (end_box[0] + end_box[3]) / 2
|
||||
end_right_mid = (end_box[1] + end_box[2]) / 2
|
||||
if not self._in_contour(contour, start_left_mid):
|
||||
top_line.insert(0, start_box[0].tolist())
|
||||
btm_line.insert(0, start_box[3].tolist())
|
||||
elif not self._in_contour(contour, start_right_mid):
|
||||
top_line.insert(0, start_box[1].tolist())
|
||||
btm_line.insert(0, start_box[2].tolist())
|
||||
if not self._in_contour(contour, end_left_mid):
|
||||
top_line.append(end_box[0].tolist())
|
||||
btm_line.append(end_box[3].tolist())
|
||||
elif not self._in_contour(contour, end_right_mid):
|
||||
top_line.append(end_box[1].tolist())
|
||||
btm_line.append(end_box[2].tolist())
|
||||
return top_line, btm_line
|
||||
|
@ -0,0 +1,29 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import InstanceData
|
||||
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.models.textdet.postprocessors import DRRGPostprocessor
|
||||
|
||||
|
||||
class TestDRRGPostProcessor(unittest.TestCase):
|
||||
|
||||
def test_call(self):
|
||||
|
||||
postprocessor = DRRGPostprocessor()
|
||||
pred_results = (np.random.randint(0, 2, (10, 2)), np.random.rand(10),
|
||||
np.random.rand(2, 9))
|
||||
data_sample = TextDetDataSample(
|
||||
metainfo=dict(scale_factor=(0.5, 1)),
|
||||
gt_instances=InstanceData(polygons=[
|
||||
np.array([0, 0, 0, 1, 2, 1, 2, 0]),
|
||||
np.array([1, 1, 1, 2, 3, 2, 3, 1])
|
||||
]))
|
||||
result = postprocessor(pred_results, [data_sample])[0]
|
||||
self.assertIn('polygons', result.pred_instances)
|
||||
self.assertIn('scores', result.pred_instances)
|
||||
self.assertTrue(
|
||||
isinstance(result.pred_instances['scores'], torch.FloatTensor))
|
Loading…
x
Reference in New Issue
Block a user