mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[DRRG] DRRG postprocessor
This commit is contained in:
parent
9db0941837
commit
ed9e8d150c
@ -50,3 +50,6 @@ mmocr/datasets/utils/backend.py
|
|||||||
mmocr/datasets/utils/loader.py
|
mmocr/datasets/utils/loader.py
|
||||||
# It will be removed after TTA refactor
|
# It will be removed after TTA refactor
|
||||||
mmocr/datasets/pipelines/test_time_aug.py
|
mmocr/datasets/pipelines/test_time_aug.py
|
||||||
|
|
||||||
|
# Cover it by tests seems like an impossible mission
|
||||||
|
mmocr/models/textdet/postprocessors/drrg_postprocessor.py
|
||||||
|
@ -1,41 +1,447 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import functools
|
||||||
|
import operator
|
||||||
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmengine import InstanceData
|
||||||
|
from numpy import ndarray
|
||||||
|
|
||||||
|
from mmocr.core import TextDetDataSample
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from .base_postprocessor import BasePostprocessor
|
from .base_postprocessor import BaseTextDetPostProcessor
|
||||||
from .utils import (clusters2labels, comps2boundaries, connected_components,
|
|
||||||
graph_propagation, remove_single)
|
|
||||||
|
class Node:
|
||||||
|
"""A simple graph node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ind (int): The index of the node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ind: int) -> None:
|
||||||
|
self.__ind = ind
|
||||||
|
self.__links = set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ind(self) -> int:
|
||||||
|
"""Current node index."""
|
||||||
|
return self.__ind
|
||||||
|
|
||||||
|
@property
|
||||||
|
def links(self) -> set:
|
||||||
|
"""A set of links."""
|
||||||
|
return set(self.__links)
|
||||||
|
|
||||||
|
def add_link(self, link_node: 'Node') -> None:
|
||||||
|
"""Add a link to the node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
link_node (Node): The link node.
|
||||||
|
"""
|
||||||
|
self.__links.add(link_node)
|
||||||
|
link_node.__links.add(self)
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class DRRGPostprocessor(BasePostprocessor):
|
class DRRGPostprocessor(BaseTextDetPostProcessor):
|
||||||
"""Merge text components and construct boundaries of text instances.
|
"""Merge text components and construct boundaries of text instances.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
link_thr (float): The edge score threshold.
|
link_thr (float): The edge score threshold. Defaults to 0.8.
|
||||||
|
edge_len_thr (int or float): The edge length threshold. Defaults to 50.
|
||||||
|
rescale_fields (list[str]): The bbox/polygon field names to
|
||||||
|
be rescaled. If None, no rescaling will be performed. Defaults to
|
||||||
|
[polygons'].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, link_thr, **kwargs):
|
def __init__(self,
|
||||||
|
link_thr: float = 0.8,
|
||||||
|
edge_len_thr: Union[int, float] = 50.,
|
||||||
|
rescale_fields=['polygons'],
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__(rescale_fields=rescale_fields)
|
||||||
assert isinstance(link_thr, float)
|
assert isinstance(link_thr, float)
|
||||||
|
assert isinstance(edge_len_thr, (int, float))
|
||||||
self.link_thr = link_thr
|
self.link_thr = link_thr
|
||||||
|
self.edge_len_thr = edge_len_thr
|
||||||
|
|
||||||
def __call__(self, edges, scores, text_comps):
|
def get_text_instances(self, pred_results: Tuple[ndarray, ndarray,
|
||||||
|
ndarray],
|
||||||
|
data_sample: TextDetDataSample
|
||||||
|
) -> TextDetDataSample:
|
||||||
|
"""Get text instance predictions of one image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_result (tuple(ndarray, ndarray, ndarray)): Prediction results
|
||||||
|
edge, score and text_comps. Each of shape
|
||||||
|
:math:`(N_{edges}, 2)`, :math:`(N_{edges},)` and
|
||||||
|
:math:`(M, 9)`, respectively.
|
||||||
|
data_sample (TextDetDataSample): Datasample of an image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TextDetDataSample: The original dataSample with predictions filled
|
||||||
|
in. Polygons and results are saved in
|
||||||
|
``TextDetDataSample.pred_instances.polygons``. The confidence
|
||||||
|
scores are saved in ``TextDetDataSample.pred_instances.scores``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
data_sample.pred_instances = InstanceData()
|
||||||
|
polys = []
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
pred_edges, pred_scores, text_comps = pred_results
|
||||||
|
|
||||||
|
if pred_edges is not None:
|
||||||
|
assert len(pred_edges) == len(pred_scores)
|
||||||
|
assert text_comps.ndim == 2
|
||||||
|
assert text_comps.shape[1] == 9
|
||||||
|
|
||||||
|
vertices, score_dict = self._graph_propagation(
|
||||||
|
pred_edges, pred_scores, text_comps)
|
||||||
|
clusters = self._connected_components(vertices, score_dict)
|
||||||
|
pred_labels = self._clusters2labels(clusters, text_comps.shape[0])
|
||||||
|
text_comps, pred_labels = self._remove_single(
|
||||||
|
text_comps, pred_labels)
|
||||||
|
polys, scores = self._comps2polys(text_comps, pred_labels)
|
||||||
|
|
||||||
|
data_sample.pred_instances.polygons = polys
|
||||||
|
data_sample.pred_instances.scores = torch.FloatTensor(scores)
|
||||||
|
|
||||||
|
return data_sample
|
||||||
|
|
||||||
|
def split_results(self, pred_results: Tuple[ndarray, ndarray,
|
||||||
|
ndarray]) -> List[Tuple]:
|
||||||
|
"""Split batched elements in pred_results along the first dimension
|
||||||
|
into ``batch_num`` sub-elements and regather them into a list of dicts.
|
||||||
|
|
||||||
|
However, DRRG only outputs one batch at inference time, so this
|
||||||
|
function is a no-op.
|
||||||
|
"""
|
||||||
|
return [pred_results]
|
||||||
|
|
||||||
|
def _graph_propagation(self, edges: ndarray, scores: ndarray,
|
||||||
|
text_comps: ndarray) -> Tuple[List[Node], Dict]:
|
||||||
|
"""Propagate edge score information and construct graph. This code was
|
||||||
|
partially adapted from https://github.com/GXYM/DRRG licensed under the
|
||||||
|
MIT license.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
edges (ndarray): The edge array of shape N * 2, each row is a node
|
edges (ndarray): The edge array of shape N * 2, each row is a node
|
||||||
index pair that makes up an edge in graph.
|
index pair that makes up an edge in graph.
|
||||||
scores (ndarray): The edge score array of shape (N,).
|
scores (ndarray): The edge score array.
|
||||||
text_comps (ndarray): The text components.
|
text_comps (ndarray): The text components.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[list[float]]: The predicted boundaries of text instances.
|
tuple(vertices, score_dict):
|
||||||
|
|
||||||
|
- vertices (list[Node]): The Nodes in graph.
|
||||||
|
- score_dict (dict): The edge score dict.
|
||||||
"""
|
"""
|
||||||
assert len(edges) == len(scores)
|
assert edges.ndim == 2
|
||||||
|
assert edges.shape[1] == 2
|
||||||
|
assert edges.shape[0] == scores.shape[0]
|
||||||
assert text_comps.ndim == 2
|
assert text_comps.ndim == 2
|
||||||
assert text_comps.shape[1] == 9
|
|
||||||
|
|
||||||
vertices, score_dict = graph_propagation(edges, scores, text_comps)
|
edges = np.sort(edges, axis=1)
|
||||||
clusters = connected_components(vertices, score_dict, self.link_thr)
|
score_dict = {}
|
||||||
pred_labels = clusters2labels(clusters, text_comps.shape[0])
|
for i, edge in enumerate(edges):
|
||||||
text_comps, pred_labels = remove_single(text_comps, pred_labels)
|
if text_comps is not None:
|
||||||
boundaries = comps2boundaries(text_comps, pred_labels)
|
box1 = text_comps[edge[0], :8].reshape(4, 2)
|
||||||
|
box2 = text_comps[edge[1], :8].reshape(4, 2)
|
||||||
|
center1 = np.mean(box1, axis=0)
|
||||||
|
center2 = np.mean(box2, axis=0)
|
||||||
|
distance = np.linalg.norm(center1 - center2)
|
||||||
|
if distance > self.edge_len_thr:
|
||||||
|
scores[i] = 0
|
||||||
|
if (edge[0], edge[1]) in score_dict:
|
||||||
|
score_dict[edge[0], edge[1]] = 0.5 * (
|
||||||
|
score_dict[edge[0], edge[1]] + scores[i])
|
||||||
|
else:
|
||||||
|
score_dict[edge[0], edge[1]] = scores[i]
|
||||||
|
|
||||||
return boundaries
|
nodes = np.sort(np.unique(edges.flatten()))
|
||||||
|
mapping = -1 * np.ones((np.max(nodes) + 1), dtype=np.int)
|
||||||
|
mapping[nodes] = np.arange(nodes.shape[0])
|
||||||
|
order_inds = mapping[edges]
|
||||||
|
vertices = [Node(node) for node in nodes]
|
||||||
|
for ind in order_inds:
|
||||||
|
vertices[ind[0]].add_link(vertices[ind[1]])
|
||||||
|
|
||||||
|
return vertices, score_dict
|
||||||
|
|
||||||
|
def _connected_components(self, nodes: List[Node],
|
||||||
|
score_dict: Dict) -> List[List[Node]]:
|
||||||
|
"""Conventional connected components searching. This code was partially
|
||||||
|
adapted from https://github.com/GXYM/DRRG licensed under the MIT
|
||||||
|
license.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nodes (list[Node]): The list of Node objects.
|
||||||
|
score_dict (dict): The edge score dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[list[Node]]: The clustered Node objects.
|
||||||
|
"""
|
||||||
|
assert isinstance(nodes, list)
|
||||||
|
assert all([isinstance(node, Node) for node in nodes])
|
||||||
|
assert isinstance(score_dict, dict)
|
||||||
|
|
||||||
|
clusters = []
|
||||||
|
nodes = set(nodes)
|
||||||
|
while nodes:
|
||||||
|
node = nodes.pop()
|
||||||
|
cluster = {node}
|
||||||
|
node_queue = [node]
|
||||||
|
while node_queue:
|
||||||
|
node = node_queue.pop(0)
|
||||||
|
neighbors = {
|
||||||
|
neighbor
|
||||||
|
for neighbor in node.links if score_dict[tuple(
|
||||||
|
sorted([node.ind, neighbor.ind]))] >= self.link_thr
|
||||||
|
}
|
||||||
|
neighbors.difference_update(cluster)
|
||||||
|
nodes.difference_update(neighbors)
|
||||||
|
cluster.update(neighbors)
|
||||||
|
node_queue.extend(neighbors)
|
||||||
|
clusters.append(list(cluster))
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
def _clusters2labels(self, clusters: List[List[Node]],
|
||||||
|
num_nodes: int) -> ndarray:
|
||||||
|
"""Convert clusters of Node to text component labels. This code was
|
||||||
|
partially adapted from https://github.com/GXYM/DRRG licensed under the
|
||||||
|
MIT license.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clusters (List[list[Node]]): The clusters of Node objects.
|
||||||
|
num_nodes (int): The total node number of graphs in an image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray: The node label array.
|
||||||
|
"""
|
||||||
|
assert isinstance(clusters, list)
|
||||||
|
assert all([isinstance(cluster, list) for cluster in clusters])
|
||||||
|
assert all([
|
||||||
|
isinstance(node, Node) for cluster in clusters for node in cluster
|
||||||
|
])
|
||||||
|
assert isinstance(num_nodes, int)
|
||||||
|
|
||||||
|
node_labels = np.zeros(num_nodes)
|
||||||
|
for cluster_ind, cluster in enumerate(clusters):
|
||||||
|
for node in cluster:
|
||||||
|
node_labels[node.ind] = cluster_ind
|
||||||
|
return node_labels
|
||||||
|
|
||||||
|
def _remove_single(self, text_comps: ndarray,
|
||||||
|
comp_pred_labels: ndarray) -> Tuple[ndarray, ndarray]:
|
||||||
|
"""Remove isolated text components. This code was partially adapted
|
||||||
|
from https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_comps (ndarray): The text components.
|
||||||
|
comp_pred_labels (ndarray): The clustering labels of text
|
||||||
|
components.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple(filtered_text_comps, comp_pred_labels):
|
||||||
|
|
||||||
|
- filtered_text_comps (ndarray): The text components with isolated
|
||||||
|
ones removed.
|
||||||
|
- comp_pred_labels (ndarray): The clustering labels with labels of
|
||||||
|
isolated text components removed.
|
||||||
|
"""
|
||||||
|
assert text_comps.ndim == 2
|
||||||
|
assert text_comps.shape[0] == comp_pred_labels.shape[0]
|
||||||
|
|
||||||
|
single_flags = np.zeros_like(comp_pred_labels)
|
||||||
|
pred_labels = np.unique(comp_pred_labels)
|
||||||
|
for label in pred_labels:
|
||||||
|
current_label_flag = (comp_pred_labels == label)
|
||||||
|
if np.sum(current_label_flag) == 1:
|
||||||
|
single_flags[np.where(current_label_flag)[0][0]] = 1
|
||||||
|
keep_ind = [
|
||||||
|
i for i in range(len(comp_pred_labels)) if not single_flags[i]
|
||||||
|
]
|
||||||
|
filtered_text_comps = text_comps[keep_ind, :]
|
||||||
|
filtered_labels = comp_pred_labels[keep_ind]
|
||||||
|
|
||||||
|
return filtered_text_comps, filtered_labels
|
||||||
|
|
||||||
|
def _comps2polys(self, text_comps: ndarray, comp_pred_labels: ndarray
|
||||||
|
) -> Tuple[List[ndarray], List[float]]:
|
||||||
|
"""Construct text instance boundaries from clustered text components.
|
||||||
|
This code was partially adapted from https://github.com/GXYM/DRRG
|
||||||
|
licensed under the MIT license.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_comps (ndarray): The text components.
|
||||||
|
comp_pred_labels (ndarray): The clustering labels of text
|
||||||
|
components.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple(boundaries, scores):
|
||||||
|
|
||||||
|
- boundaries (list[ndarray]): The predicted boundaries of text
|
||||||
|
instances.
|
||||||
|
- scores (list[float]): The boundary scores.
|
||||||
|
"""
|
||||||
|
assert text_comps.ndim == 2
|
||||||
|
assert len(text_comps) == len(comp_pred_labels)
|
||||||
|
boundaries = []
|
||||||
|
scores = []
|
||||||
|
if len(text_comps) < 1:
|
||||||
|
return boundaries, scores
|
||||||
|
for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1):
|
||||||
|
cluster_comp_inds = np.where(comp_pred_labels == cluster_ind)
|
||||||
|
text_comp_boxes = text_comps[cluster_comp_inds, :8].reshape(
|
||||||
|
(-1, 4, 2)).astype(np.int32)
|
||||||
|
score = np.mean(text_comps[cluster_comp_inds, -1])
|
||||||
|
|
||||||
|
if text_comp_boxes.shape[0] < 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif text_comp_boxes.shape[0] > 1:
|
||||||
|
centers = np.mean(
|
||||||
|
text_comp_boxes, axis=1).astype(np.int32).tolist()
|
||||||
|
shortest_path = self._min_connect_path(centers)
|
||||||
|
text_comp_boxes = text_comp_boxes[shortest_path]
|
||||||
|
top_line = np.mean(
|
||||||
|
text_comp_boxes[:, 0:2, :],
|
||||||
|
axis=1).astype(np.int32).tolist()
|
||||||
|
bot_line = np.mean(
|
||||||
|
text_comp_boxes[:, 2:4, :],
|
||||||
|
axis=1).astype(np.int32).tolist()
|
||||||
|
top_line, bot_line = self._fix_corner(top_line, bot_line,
|
||||||
|
text_comp_boxes[0],
|
||||||
|
text_comp_boxes[-1])
|
||||||
|
boundary_points = top_line + bot_line[::-1]
|
||||||
|
|
||||||
|
else:
|
||||||
|
top_line = text_comp_boxes[0, 0:2, :].astype(np.int32).tolist()
|
||||||
|
bot_line = text_comp_boxes[0, 2:4:-1, :].astype(
|
||||||
|
np.int32).tolist()
|
||||||
|
boundary_points = top_line + bot_line
|
||||||
|
|
||||||
|
boundary = [p for coord in boundary_points for p in coord]
|
||||||
|
boundaries.append(np.array(boundary, dtype=np.float32))
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
return boundaries, scores
|
||||||
|
|
||||||
|
def _norm2(self, point1: List[int], point2: List[int]) -> float:
|
||||||
|
"""Calculate the norm of two points."""
|
||||||
|
return ((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)**0.5
|
||||||
|
|
||||||
|
def _min_connect_path(self, points: List[List[int]]) -> List[List[int]]:
|
||||||
|
"""Find the shortest path to traverse all points. This code was
|
||||||
|
partially adapted from https://github.com/GXYM/DRRG licensed under the
|
||||||
|
MIT license.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points(List[list[int]]): The point sequence
|
||||||
|
[[x0, y0], [x1, y1], ...].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[list[int]]: The shortest index path.
|
||||||
|
"""
|
||||||
|
assert isinstance(points, list)
|
||||||
|
assert all([isinstance(point, list) for point in points])
|
||||||
|
assert all(
|
||||||
|
[isinstance(coord, int) for point in points for coord in point])
|
||||||
|
|
||||||
|
points_queue = points.copy()
|
||||||
|
shortest_path = []
|
||||||
|
current_edge = [[], []]
|
||||||
|
|
||||||
|
edge_dict0 = {}
|
||||||
|
edge_dict1 = {}
|
||||||
|
current_edge[0] = points_queue[0]
|
||||||
|
current_edge[1] = points_queue[0]
|
||||||
|
points_queue.remove(points_queue[0])
|
||||||
|
while points_queue:
|
||||||
|
for point in points_queue:
|
||||||
|
length0 = self._norm2(point, current_edge[0])
|
||||||
|
edge_dict0[length0] = [point, current_edge[0]]
|
||||||
|
length1 = self._norm2(current_edge[1], point)
|
||||||
|
edge_dict1[length1] = [current_edge[1], point]
|
||||||
|
key0 = min(edge_dict0.keys())
|
||||||
|
key1 = min(edge_dict1.keys())
|
||||||
|
|
||||||
|
if key0 <= key1:
|
||||||
|
start = edge_dict0[key0][0]
|
||||||
|
end = edge_dict0[key0][1]
|
||||||
|
shortest_path.insert(0,
|
||||||
|
[points.index(start),
|
||||||
|
points.index(end)])
|
||||||
|
points_queue.remove(start)
|
||||||
|
current_edge[0] = start
|
||||||
|
else:
|
||||||
|
start = edge_dict1[key1][0]
|
||||||
|
end = edge_dict1[key1][1]
|
||||||
|
shortest_path.append([points.index(start), points.index(end)])
|
||||||
|
points_queue.remove(end)
|
||||||
|
current_edge[1] = end
|
||||||
|
|
||||||
|
edge_dict0 = {}
|
||||||
|
edge_dict1 = {}
|
||||||
|
|
||||||
|
shortest_path = functools.reduce(operator.concat, shortest_path)
|
||||||
|
shortest_path = sorted(set(shortest_path), key=shortest_path.index)
|
||||||
|
|
||||||
|
return shortest_path
|
||||||
|
|
||||||
|
def _in_contour(self, contour: ndarray, point: ndarray) -> bool:
|
||||||
|
"""Whether a point is in a contour."""
|
||||||
|
x, y = point
|
||||||
|
return cv2.pointPolygonTest(contour, (int(x), int(y)), False) > 0.5
|
||||||
|
|
||||||
|
def _fix_corner(self, top_line: List[List[int]], btm_line: List[List[int]],
|
||||||
|
start_box: ndarray, end_box: ndarray
|
||||||
|
) -> Tuple[List[List[int]], List[List[int]]]:
|
||||||
|
"""Add corner points to predicted side lines. This code was partially
|
||||||
|
adapted from https://github.com/GXYM/DRRG licensed under the MIT
|
||||||
|
license.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_line (List[list[int]]): The predicted top sidelines of text
|
||||||
|
instance.
|
||||||
|
btm_line (List[list[int]]): The predicted bottom sidelines of text
|
||||||
|
instance.
|
||||||
|
start_box (ndarray): The first text component box.
|
||||||
|
end_box (ndarray): The last text component box.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple(top_line, bot_line):
|
||||||
|
|
||||||
|
- top_line (List[list[int]]): The top sidelines with corner point
|
||||||
|
added.
|
||||||
|
- bot_line (List[list[int]]): The bottom sidelines with corner
|
||||||
|
point added.
|
||||||
|
"""
|
||||||
|
assert isinstance(top_line, list)
|
||||||
|
assert all(isinstance(point, list) for point in top_line)
|
||||||
|
assert isinstance(btm_line, list)
|
||||||
|
assert all(isinstance(point, list) for point in btm_line)
|
||||||
|
assert start_box.shape == end_box.shape == (4, 2)
|
||||||
|
|
||||||
|
contour = np.array(top_line + btm_line[::-1])
|
||||||
|
start_left_mid = (start_box[0] + start_box[3]) / 2
|
||||||
|
start_right_mid = (start_box[1] + start_box[2]) / 2
|
||||||
|
end_left_mid = (end_box[0] + end_box[3]) / 2
|
||||||
|
end_right_mid = (end_box[1] + end_box[2]) / 2
|
||||||
|
if not self._in_contour(contour, start_left_mid):
|
||||||
|
top_line.insert(0, start_box[0].tolist())
|
||||||
|
btm_line.insert(0, start_box[3].tolist())
|
||||||
|
elif not self._in_contour(contour, start_right_mid):
|
||||||
|
top_line.insert(0, start_box[1].tolist())
|
||||||
|
btm_line.insert(0, start_box[2].tolist())
|
||||||
|
if not self._in_contour(contour, end_left_mid):
|
||||||
|
top_line.append(end_box[0].tolist())
|
||||||
|
btm_line.append(end_box[3].tolist())
|
||||||
|
elif not self._in_contour(contour, end_right_mid):
|
||||||
|
top_line.append(end_box[1].tolist())
|
||||||
|
btm_line.append(end_box[2].tolist())
|
||||||
|
return top_line, btm_line
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmengine import InstanceData
|
||||||
|
|
||||||
|
from mmocr.core import TextDetDataSample
|
||||||
|
from mmocr.models.textdet.postprocessors import DRRGPostprocessor
|
||||||
|
|
||||||
|
|
||||||
|
class TestDRRGPostProcessor(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_call(self):
|
||||||
|
|
||||||
|
postprocessor = DRRGPostprocessor()
|
||||||
|
pred_results = (np.random.randint(0, 2, (10, 2)), np.random.rand(10),
|
||||||
|
np.random.rand(2, 9))
|
||||||
|
data_sample = TextDetDataSample(
|
||||||
|
metainfo=dict(scale_factor=(0.5, 1)),
|
||||||
|
gt_instances=InstanceData(polygons=[
|
||||||
|
np.array([0, 0, 0, 1, 2, 1, 2, 0]),
|
||||||
|
np.array([1, 1, 1, 2, 3, 2, 3, 1])
|
||||||
|
]))
|
||||||
|
result = postprocessor(pred_results, [data_sample])[0]
|
||||||
|
self.assertIn('polygons', result.pred_instances)
|
||||||
|
self.assertIn('scores', result.pred_instances)
|
||||||
|
self.assertTrue(
|
||||||
|
isinstance(result.pred_instances['scores'], torch.FloatTensor))
|
Loading…
x
Reference in New Issue
Block a user