mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[DBNet] Add DBPostProcessor
This commit is contained in:
parent
cd3d173b18
commit
7a66a84b64
@ -1,58 +1,81 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmengine import InstanceData
|
||||||
|
from shapely.geometry import Polygon
|
||||||
|
|
||||||
from mmocr.core import points2boundary
|
from mmocr.core import TextDetDataSample
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from .base_postprocessor import BasePostprocessor
|
from mmocr.utils import offset_polygon
|
||||||
from .utils import box_score_fast, unclip
|
from .base_postprocessor import BaseTextDetPostProcessor
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class DBPostprocessor(BasePostprocessor):
|
class DBPostprocessor(BaseTextDetPostProcessor):
|
||||||
"""Decoding predictions of DbNet to instances. This is partially adapted
|
"""Decoding predictions of DbNet to instances. This is partially adapted
|
||||||
from https://github.com/MhLiao/DB.
|
from https://github.com/MhLiao/DB.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
|
text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
|
||||||
mask_thr (float): The mask threshold value for binarization.
|
Defaults to 'poly'.
|
||||||
|
rescale_fields (list[str]): The bbox/polygon field names to
|
||||||
|
be rescaled. If None, no rescaling will be performed. Defaults to
|
||||||
|
['polygons'].
|
||||||
|
mask_thr (float): The mask threshold value for binarization. Defaults
|
||||||
|
to 0.3.
|
||||||
min_text_score (float): The threshold value for converting binary map
|
min_text_score (float): The threshold value for converting binary map
|
||||||
to shrink text regions.
|
to shrink text regions. Defaults to 0.3.
|
||||||
min_text_width (int): The minimum width of boundary polygon/box
|
min_text_width (int): The minimum width of boundary polygon/box
|
||||||
predicted.
|
predicted. Defaults to 5.
|
||||||
unclip_ratio (float): The unclip ratio for text regions dilation.
|
unclip_ratio (float): The unclip ratio for text regions dilation.
|
||||||
epsilon_ratio (float): The epsilon ratio for approximation accuracy.
|
Defaults to 1.5.
|
||||||
max_candidates (int): The maximum candidate number.
|
max_candidates (int): The maximum candidate number. Defaults to 3000.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
text_repr_type='poly',
|
text_repr_type: str = 'poly',
|
||||||
mask_thr=0.3,
|
rescale_fields: Sequence[str] = ['polygons'],
|
||||||
min_text_score=0.3,
|
mask_thr: float = 0.3,
|
||||||
min_text_width=5,
|
min_text_score: float = 0.3,
|
||||||
unclip_ratio=1.5,
|
min_text_width: int = 5,
|
||||||
epsilon_ratio=0.01,
|
unclip_ratio: float = 1.5,
|
||||||
max_candidates=3000,
|
max_candidates: int = 3000,
|
||||||
**kwargs):
|
**kwargs) -> None:
|
||||||
super().__init__(text_repr_type)
|
super().__init__(
|
||||||
|
text_repr_type=text_repr_type,
|
||||||
|
rescale_fields=rescale_fields,
|
||||||
|
**kwargs)
|
||||||
self.mask_thr = mask_thr
|
self.mask_thr = mask_thr
|
||||||
self.min_text_score = min_text_score
|
self.min_text_score = min_text_score
|
||||||
self.min_text_width = min_text_width
|
self.min_text_width = min_text_width
|
||||||
self.unclip_ratio = unclip_ratio
|
self.unclip_ratio = unclip_ratio
|
||||||
self.epsilon_ratio = epsilon_ratio
|
|
||||||
self.max_candidates = max_candidates
|
self.max_candidates = max_candidates
|
||||||
|
|
||||||
def __call__(self, preds):
|
def get_text_instances(self, pred_results: dict,
|
||||||
"""
|
data_sample: TextDetDataSample
|
||||||
|
) -> TextDetDataSample:
|
||||||
|
"""Get text instance predictions of one image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
preds (Tensor): Prediction map with shape :math:`(C, H, W)`.
|
pred_result (dict): Prediction results of an image containing the
|
||||||
|
``prob_map``, which is a tensor of shape :math:`(N, H, W)`.
|
||||||
|
data_sample (TextDetDataSample): Datasample of an image.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[list[float]]: The predicted text boundaries.
|
TextDetDataSample: A new DataSample with predictions filled in.
|
||||||
|
Polygons and results are saved in
|
||||||
|
``TextDetDataSample.pred_instances.polygons``. The confidence
|
||||||
|
scores are saved in ``TextDetDataSample.pred_instances.scores``.
|
||||||
"""
|
"""
|
||||||
assert preds.dim() == 3
|
|
||||||
|
|
||||||
prob_map = preds[0, :, :]
|
data_sample.pred_instances = InstanceData()
|
||||||
|
data_sample.pred_instances.polygons = []
|
||||||
|
data_sample.pred_instances.scores = []
|
||||||
|
|
||||||
|
prob_map = pred_results['prob_map']
|
||||||
text_mask = prob_map > self.mask_thr
|
text_mask = prob_map > self.mask_thr
|
||||||
|
|
||||||
score_map = prob_map.data.cpu().numpy().astype(np.float32)
|
score_map = prob_map.data.cpu().numpy().astype(np.float32)
|
||||||
@ -61,34 +84,80 @@ class DBPostprocessor(BasePostprocessor):
|
|||||||
contours, _ = cv2.findContours((text_mask * 255).astype(np.uint8),
|
contours, _ = cv2.findContours((text_mask * 255).astype(np.uint8),
|
||||||
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
boundaries = []
|
|
||||||
for i, poly in enumerate(contours):
|
for i, poly in enumerate(contours):
|
||||||
if i > self.max_candidates:
|
if i > self.max_candidates:
|
||||||
break
|
break
|
||||||
epsilon = self.epsilon_ratio * cv2.arcLength(poly, True)
|
epsilon = 0.01 * cv2.arcLength(poly, True)
|
||||||
approx = cv2.approxPolyDP(poly, epsilon, True)
|
approx = cv2.approxPolyDP(poly, epsilon, True)
|
||||||
points = approx.reshape((-1, 2))
|
poly_pts = approx.reshape((-1, 2))
|
||||||
if points.shape[0] < 4:
|
if poly_pts.shape[0] < 4:
|
||||||
continue
|
continue
|
||||||
score = box_score_fast(score_map, points)
|
score = self._get_bbox_score(score_map, poly_pts)
|
||||||
if score < self.min_text_score:
|
if score < self.min_text_score:
|
||||||
continue
|
continue
|
||||||
poly = unclip(points, unclip_ratio=self.unclip_ratio)
|
poly = self._unclip(poly_pts)
|
||||||
if len(poly) == 0 or isinstance(poly[0], list):
|
# If the result polygon does not exist, or it is split into
|
||||||
|
# multiple polygons, skip it.
|
||||||
|
if len(poly) == 0 or isinstance(poly, list):
|
||||||
continue
|
continue
|
||||||
poly = poly.reshape(-1, 2)
|
poly = poly.reshape(-1, 2)
|
||||||
|
|
||||||
if self.text_repr_type == 'quad':
|
if self.text_repr_type == 'quad':
|
||||||
poly = points2boundary(poly, self.text_repr_type, score,
|
rect = cv2.minAreaRect(poly)
|
||||||
self.min_text_width)
|
vertices = cv2.boxPoints(rect)
|
||||||
|
poly = vertices.flatten() if min(
|
||||||
|
rect[1]) >= self.min_text_width else []
|
||||||
elif self.text_repr_type == 'poly':
|
elif self.text_repr_type == 'poly':
|
||||||
poly = poly.flatten().tolist()
|
poly = poly.flatten()
|
||||||
if score is not None:
|
|
||||||
poly = poly + [score]
|
|
||||||
if len(poly) < 8:
|
|
||||||
poly = None
|
|
||||||
|
|
||||||
if poly is not None:
|
if len(poly) < 8:
|
||||||
boundaries.append(poly)
|
poly = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
return boundaries
|
if len(poly) > 0:
|
||||||
|
data_sample.pred_instances.polygons.append(poly)
|
||||||
|
data_sample.pred_instances.scores.append(
|
||||||
|
torch.FloatTensor([score]))
|
||||||
|
|
||||||
|
return data_sample
|
||||||
|
|
||||||
|
def _get_bbox_score(self, score_map: np.ndarray,
|
||||||
|
poly_pts: np.ndarray) -> float:
|
||||||
|
"""Compute the average score over the area of the bounding box of the
|
||||||
|
polygon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score_map (np.ndarray): The score map.
|
||||||
|
poly_pts (np.ndarray): The polygon points.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The average score.
|
||||||
|
"""
|
||||||
|
h, w = score_map.shape[:2]
|
||||||
|
poly_pts = poly_pts.copy()
|
||||||
|
xmin = np.clip(
|
||||||
|
np.floor(poly_pts[:, 0].min()).astype(np.int32), 0, w - 1)
|
||||||
|
xmax = np.clip(
|
||||||
|
np.ceil(poly_pts[:, 0].max()).astype(np.int32), 0, w - 1)
|
||||||
|
ymin = np.clip(
|
||||||
|
np.floor(poly_pts[:, 1].min()).astype(np.int32), 0, h - 1)
|
||||||
|
ymax = np.clip(
|
||||||
|
np.ceil(poly_pts[:, 1].max()).astype(np.int32), 0, h - 1)
|
||||||
|
|
||||||
|
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||||
|
poly_pts[:, 0] = poly_pts[:, 0] - xmin
|
||||||
|
poly_pts[:, 1] = poly_pts[:, 1] - ymin
|
||||||
|
cv2.fillPoly(mask, poly_pts.reshape(1, -1, 2).astype(np.int32), 1)
|
||||||
|
return cv2.mean(score_map[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||||
|
|
||||||
|
def _unclip(self, poly_pts: np.ndarray) -> np.ndarray:
|
||||||
|
"""Unclip a polygon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
poly_pts (np.ndarray): The polygon points.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: The expanded polygon points.
|
||||||
|
"""
|
||||||
|
poly = Polygon(poly_pts)
|
||||||
|
distance = poly.area * self.unclip_ratio / poly.length
|
||||||
|
return offset_polygon(poly_pts, distance)
|
||||||
|
@ -5,6 +5,7 @@ interrogate
|
|||||||
isort
|
isort
|
||||||
# Note: used for kwarray.group_items, this may be ported to mmcv in the future.
|
# Note: used for kwarray.group_items, this may be ported to mmcv in the future.
|
||||||
kwarray
|
kwarray
|
||||||
|
parameterized
|
||||||
pytest
|
pytest
|
||||||
pytest-cov
|
pytest-cov
|
||||||
pytest-runner
|
pytest-runner
|
||||||
|
@ -0,0 +1,42 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmengine import InstanceData
|
||||||
|
from nose_parameterized import parameterized
|
||||||
|
|
||||||
|
from mmocr.core import TextDetDataSample
|
||||||
|
from mmocr.models.textdet.postprocessors import DBPostprocessor
|
||||||
|
|
||||||
|
|
||||||
|
class TestDBPostProcessor(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_get_bbox_score(self):
|
||||||
|
postprocessor = DBPostprocessor()
|
||||||
|
score_map = np.arange(0, 1, step=0.05).reshape(4, 5)
|
||||||
|
poly_pts = np.array(((0, 0), (0, 1), (1, 1), (1, 0)))
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
postprocessor._get_bbox_score(score_map, poly_pts), 0.15)
|
||||||
|
|
||||||
|
@parameterized.expand([('poly'), ('quad')])
|
||||||
|
def test_get_text_instances(self, text_repr_type):
|
||||||
|
|
||||||
|
postprocessor = DBPostprocessor(text_repr_type=text_repr_type)
|
||||||
|
pred_result = dict(prob_map=torch.rand(4, 5))
|
||||||
|
data_sample = TextDetDataSample(
|
||||||
|
metainfo=dict(scale_factor=(0.5, 1)),
|
||||||
|
gt_instances=InstanceData(polygons=[
|
||||||
|
np.array([0, 0, 0, 1, 2, 1, 2, 0]),
|
||||||
|
np.array([1, 1, 1, 2, 3, 2, 3, 1])
|
||||||
|
]))
|
||||||
|
results = postprocessor.get_text_instances(pred_result, data_sample)
|
||||||
|
self.assertIn('polygons', results.pred_instances)
|
||||||
|
self.assertIn('scores', results.pred_instances)
|
||||||
|
|
||||||
|
postprocessor = DBPostprocessor(
|
||||||
|
min_text_score=1, text_repr_type=text_repr_type)
|
||||||
|
pred_result = dict(prob_map=torch.rand(4, 5) * 0.8)
|
||||||
|
results = postprocessor.get_text_instances(pred_result, data_sample)
|
||||||
|
self.assertEqual(results.pred_instances.polygons, [])
|
||||||
|
self.assertEqual(results.pred_instances.scores, [])
|
Loading…
x
Reference in New Issue
Block a user