mirror of https://github.com/open-mmlab/mmocr.git
[DBNet] Add DBPostProcessor
parent
cd3d173b18
commit
7a66a84b64
mmocr/models/textdet/postprocessors
requirements
tests/test_models/test_textdet/test_postprocessors
|
@ -1,58 +1,81 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import InstanceData
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from mmocr.core import points2boundary
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.registry import MODELS
|
||||
from .base_postprocessor import BasePostprocessor
|
||||
from .utils import box_score_fast, unclip
|
||||
from mmocr.utils import offset_polygon
|
||||
from .base_postprocessor import BaseTextDetPostProcessor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DBPostprocessor(BasePostprocessor):
|
||||
class DBPostprocessor(BaseTextDetPostProcessor):
|
||||
"""Decoding predictions of DbNet to instances. This is partially adapted
|
||||
from https://github.com/MhLiao/DB.
|
||||
|
||||
Args:
|
||||
text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
|
||||
mask_thr (float): The mask threshold value for binarization.
|
||||
Defaults to 'poly'.
|
||||
rescale_fields (list[str]): The bbox/polygon field names to
|
||||
be rescaled. If None, no rescaling will be performed. Defaults to
|
||||
['polygons'].
|
||||
mask_thr (float): The mask threshold value for binarization. Defaults
|
||||
to 0.3.
|
||||
min_text_score (float): The threshold value for converting binary map
|
||||
to shrink text regions.
|
||||
to shrink text regions. Defaults to 0.3.
|
||||
min_text_width (int): The minimum width of boundary polygon/box
|
||||
predicted.
|
||||
predicted. Defaults to 5.
|
||||
unclip_ratio (float): The unclip ratio for text regions dilation.
|
||||
epsilon_ratio (float): The epsilon ratio for approximation accuracy.
|
||||
max_candidates (int): The maximum candidate number.
|
||||
Defaults to 1.5.
|
||||
max_candidates (int): The maximum candidate number. Defaults to 3000.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
text_repr_type='poly',
|
||||
mask_thr=0.3,
|
||||
min_text_score=0.3,
|
||||
min_text_width=5,
|
||||
unclip_ratio=1.5,
|
||||
epsilon_ratio=0.01,
|
||||
max_candidates=3000,
|
||||
**kwargs):
|
||||
super().__init__(text_repr_type)
|
||||
text_repr_type: str = 'poly',
|
||||
rescale_fields: Sequence[str] = ['polygons'],
|
||||
mask_thr: float = 0.3,
|
||||
min_text_score: float = 0.3,
|
||||
min_text_width: int = 5,
|
||||
unclip_ratio: float = 1.5,
|
||||
max_candidates: int = 3000,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
text_repr_type=text_repr_type,
|
||||
rescale_fields=rescale_fields,
|
||||
**kwargs)
|
||||
self.mask_thr = mask_thr
|
||||
self.min_text_score = min_text_score
|
||||
self.min_text_width = min_text_width
|
||||
self.unclip_ratio = unclip_ratio
|
||||
self.epsilon_ratio = epsilon_ratio
|
||||
self.max_candidates = max_candidates
|
||||
|
||||
def __call__(self, preds):
|
||||
"""
|
||||
def get_text_instances(self, pred_results: dict,
|
||||
data_sample: TextDetDataSample
|
||||
) -> TextDetDataSample:
|
||||
"""Get text instance predictions of one image.
|
||||
|
||||
Args:
|
||||
preds (Tensor): Prediction map with shape :math:`(C, H, W)`.
|
||||
pred_result (dict): Prediction results of an image containing the
|
||||
``prob_map``, which is a tensor of shape :math:`(N, H, W)`.
|
||||
data_sample (TextDetDataSample): Datasample of an image.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: The predicted text boundaries.
|
||||
TextDetDataSample: A new DataSample with predictions filled in.
|
||||
Polygons and results are saved in
|
||||
``TextDetDataSample.pred_instances.polygons``. The confidence
|
||||
scores are saved in ``TextDetDataSample.pred_instances.scores``.
|
||||
"""
|
||||
assert preds.dim() == 3
|
||||
|
||||
prob_map = preds[0, :, :]
|
||||
data_sample.pred_instances = InstanceData()
|
||||
data_sample.pred_instances.polygons = []
|
||||
data_sample.pred_instances.scores = []
|
||||
|
||||
prob_map = pred_results['prob_map']
|
||||
text_mask = prob_map > self.mask_thr
|
||||
|
||||
score_map = prob_map.data.cpu().numpy().astype(np.float32)
|
||||
|
@ -61,34 +84,80 @@ class DBPostprocessor(BasePostprocessor):
|
|||
contours, _ = cv2.findContours((text_mask * 255).astype(np.uint8),
|
||||
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
boundaries = []
|
||||
for i, poly in enumerate(contours):
|
||||
if i > self.max_candidates:
|
||||
break
|
||||
epsilon = self.epsilon_ratio * cv2.arcLength(poly, True)
|
||||
epsilon = 0.01 * cv2.arcLength(poly, True)
|
||||
approx = cv2.approxPolyDP(poly, epsilon, True)
|
||||
points = approx.reshape((-1, 2))
|
||||
if points.shape[0] < 4:
|
||||
poly_pts = approx.reshape((-1, 2))
|
||||
if poly_pts.shape[0] < 4:
|
||||
continue
|
||||
score = box_score_fast(score_map, points)
|
||||
score = self._get_bbox_score(score_map, poly_pts)
|
||||
if score < self.min_text_score:
|
||||
continue
|
||||
poly = unclip(points, unclip_ratio=self.unclip_ratio)
|
||||
if len(poly) == 0 or isinstance(poly[0], list):
|
||||
poly = self._unclip(poly_pts)
|
||||
# If the result polygon does not exist, or it is split into
|
||||
# multiple polygons, skip it.
|
||||
if len(poly) == 0 or isinstance(poly, list):
|
||||
continue
|
||||
poly = poly.reshape(-1, 2)
|
||||
|
||||
if self.text_repr_type == 'quad':
|
||||
poly = points2boundary(poly, self.text_repr_type, score,
|
||||
self.min_text_width)
|
||||
rect = cv2.minAreaRect(poly)
|
||||
vertices = cv2.boxPoints(rect)
|
||||
poly = vertices.flatten() if min(
|
||||
rect[1]) >= self.min_text_width else []
|
||||
elif self.text_repr_type == 'poly':
|
||||
poly = poly.flatten().tolist()
|
||||
if score is not None:
|
||||
poly = poly + [score]
|
||||
if len(poly) < 8:
|
||||
poly = None
|
||||
poly = poly.flatten()
|
||||
|
||||
if poly is not None:
|
||||
boundaries.append(poly)
|
||||
if len(poly) < 8:
|
||||
poly = np.array([], dtype=np.float32)
|
||||
|
||||
return boundaries
|
||||
if len(poly) > 0:
|
||||
data_sample.pred_instances.polygons.append(poly)
|
||||
data_sample.pred_instances.scores.append(
|
||||
torch.FloatTensor([score]))
|
||||
|
||||
return data_sample
|
||||
|
||||
def _get_bbox_score(self, score_map: np.ndarray,
|
||||
poly_pts: np.ndarray) -> float:
|
||||
"""Compute the average score over the area of the bounding box of the
|
||||
polygon.
|
||||
|
||||
Args:
|
||||
score_map (np.ndarray): The score map.
|
||||
poly_pts (np.ndarray): The polygon points.
|
||||
|
||||
Returns:
|
||||
float: The average score.
|
||||
"""
|
||||
h, w = score_map.shape[:2]
|
||||
poly_pts = poly_pts.copy()
|
||||
xmin = np.clip(
|
||||
np.floor(poly_pts[:, 0].min()).astype(np.int32), 0, w - 1)
|
||||
xmax = np.clip(
|
||||
np.ceil(poly_pts[:, 0].max()).astype(np.int32), 0, w - 1)
|
||||
ymin = np.clip(
|
||||
np.floor(poly_pts[:, 1].min()).astype(np.int32), 0, h - 1)
|
||||
ymax = np.clip(
|
||||
np.ceil(poly_pts[:, 1].max()).astype(np.int32), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
poly_pts[:, 0] = poly_pts[:, 0] - xmin
|
||||
poly_pts[:, 1] = poly_pts[:, 1] - ymin
|
||||
cv2.fillPoly(mask, poly_pts.reshape(1, -1, 2).astype(np.int32), 1)
|
||||
return cv2.mean(score_map[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def _unclip(self, poly_pts: np.ndarray) -> np.ndarray:
|
||||
"""Unclip a polygon.
|
||||
|
||||
Args:
|
||||
poly_pts (np.ndarray): The polygon points.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The expanded polygon points.
|
||||
"""
|
||||
poly = Polygon(poly_pts)
|
||||
distance = poly.area * self.unclip_ratio / poly.length
|
||||
return offset_polygon(poly_pts, distance)
|
||||
|
|
|
@ -5,6 +5,7 @@ interrogate
|
|||
isort
|
||||
# Note: used for kwarray.group_items, this may be ported to mmcv in the future.
|
||||
kwarray
|
||||
parameterized
|
||||
pytest
|
||||
pytest-cov
|
||||
pytest-runner
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import InstanceData
|
||||
from nose_parameterized import parameterized
|
||||
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.models.textdet.postprocessors import DBPostprocessor
|
||||
|
||||
|
||||
class TestDBPostProcessor(unittest.TestCase):
|
||||
|
||||
def test_get_bbox_score(self):
|
||||
postprocessor = DBPostprocessor()
|
||||
score_map = np.arange(0, 1, step=0.05).reshape(4, 5)
|
||||
poly_pts = np.array(((0, 0), (0, 1), (1, 1), (1, 0)))
|
||||
self.assertAlmostEqual(
|
||||
postprocessor._get_bbox_score(score_map, poly_pts), 0.15)
|
||||
|
||||
@parameterized.expand([('poly'), ('quad')])
|
||||
def test_get_text_instances(self, text_repr_type):
|
||||
|
||||
postprocessor = DBPostprocessor(text_repr_type=text_repr_type)
|
||||
pred_result = dict(prob_map=torch.rand(4, 5))
|
||||
data_sample = TextDetDataSample(
|
||||
metainfo=dict(scale_factor=(0.5, 1)),
|
||||
gt_instances=InstanceData(polygons=[
|
||||
np.array([0, 0, 0, 1, 2, 1, 2, 0]),
|
||||
np.array([1, 1, 1, 2, 3, 2, 3, 1])
|
||||
]))
|
||||
results = postprocessor.get_text_instances(pred_result, data_sample)
|
||||
self.assertIn('polygons', results.pred_instances)
|
||||
self.assertIn('scores', results.pred_instances)
|
||||
|
||||
postprocessor = DBPostprocessor(
|
||||
min_text_score=1, text_repr_type=text_repr_type)
|
||||
pred_result = dict(prob_map=torch.rand(4, 5) * 0.8)
|
||||
results = postprocessor.get_text_instances(pred_result, data_sample)
|
||||
self.assertEqual(results.pred_instances.polygons, [])
|
||||
self.assertEqual(results.pred_instances.scores, [])
|
Loading…
Reference in New Issue