[DBNet] Add DBPostProcessor

pull/1178/head
gaotongxiao 2022-05-30 08:32:04 +00:00
parent cd3d173b18
commit 7a66a84b64
4 changed files with 154 additions and 42 deletions
mmocr/models/textdet/postprocessors
requirements
tests/test_models/test_textdet/test_postprocessors

View File

@ -1,58 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import cv2
import numpy as np
import torch
from mmengine import InstanceData
from shapely.geometry import Polygon
from mmocr.core import points2boundary
from mmocr.core import TextDetDataSample
from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor
from .utils import box_score_fast, unclip
from mmocr.utils import offset_polygon
from .base_postprocessor import BaseTextDetPostProcessor
@MODELS.register_module()
class DBPostprocessor(BasePostprocessor):
class DBPostprocessor(BaseTextDetPostProcessor):
"""Decoding predictions of DbNet to instances. This is partially adapted
from https://github.com/MhLiao/DB.
Args:
text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
mask_thr (float): The mask threshold value for binarization.
Defaults to 'poly'.
rescale_fields (list[str]): The bbox/polygon field names to
be rescaled. If None, no rescaling will be performed. Defaults to
['polygons'].
mask_thr (float): The mask threshold value for binarization. Defaults
to 0.3.
min_text_score (float): The threshold value for converting binary map
to shrink text regions.
to shrink text regions. Defaults to 0.3.
min_text_width (int): The minimum width of boundary polygon/box
predicted.
predicted. Defaults to 5.
unclip_ratio (float): The unclip ratio for text regions dilation.
epsilon_ratio (float): The epsilon ratio for approximation accuracy.
max_candidates (int): The maximum candidate number.
Defaults to 1.5.
max_candidates (int): The maximum candidate number. Defaults to 3000.
"""
def __init__(self,
text_repr_type='poly',
mask_thr=0.3,
min_text_score=0.3,
min_text_width=5,
unclip_ratio=1.5,
epsilon_ratio=0.01,
max_candidates=3000,
**kwargs):
super().__init__(text_repr_type)
text_repr_type: str = 'poly',
rescale_fields: Sequence[str] = ['polygons'],
mask_thr: float = 0.3,
min_text_score: float = 0.3,
min_text_width: int = 5,
unclip_ratio: float = 1.5,
max_candidates: int = 3000,
**kwargs) -> None:
super().__init__(
text_repr_type=text_repr_type,
rescale_fields=rescale_fields,
**kwargs)
self.mask_thr = mask_thr
self.min_text_score = min_text_score
self.min_text_width = min_text_width
self.unclip_ratio = unclip_ratio
self.epsilon_ratio = epsilon_ratio
self.max_candidates = max_candidates
def __call__(self, preds):
"""
def get_text_instances(self, pred_results: dict,
data_sample: TextDetDataSample
) -> TextDetDataSample:
"""Get text instance predictions of one image.
Args:
preds (Tensor): Prediction map with shape :math:`(C, H, W)`.
pred_result (dict): Prediction results of an image containing the
``prob_map``, which is a tensor of shape :math:`(N, H, W)`.
data_sample (TextDetDataSample): Datasample of an image.
Returns:
list[list[float]]: The predicted text boundaries.
TextDetDataSample: A new DataSample with predictions filled in.
Polygons and results are saved in
``TextDetDataSample.pred_instances.polygons``. The confidence
scores are saved in ``TextDetDataSample.pred_instances.scores``.
"""
assert preds.dim() == 3
prob_map = preds[0, :, :]
data_sample.pred_instances = InstanceData()
data_sample.pred_instances.polygons = []
data_sample.pred_instances.scores = []
prob_map = pred_results['prob_map']
text_mask = prob_map > self.mask_thr
score_map = prob_map.data.cpu().numpy().astype(np.float32)
@ -61,34 +84,80 @@ class DBPostprocessor(BasePostprocessor):
contours, _ = cv2.findContours((text_mask * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
boundaries = []
for i, poly in enumerate(contours):
if i > self.max_candidates:
break
epsilon = self.epsilon_ratio * cv2.arcLength(poly, True)
epsilon = 0.01 * cv2.arcLength(poly, True)
approx = cv2.approxPolyDP(poly, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
poly_pts = approx.reshape((-1, 2))
if poly_pts.shape[0] < 4:
continue
score = box_score_fast(score_map, points)
score = self._get_bbox_score(score_map, poly_pts)
if score < self.min_text_score:
continue
poly = unclip(points, unclip_ratio=self.unclip_ratio)
if len(poly) == 0 or isinstance(poly[0], list):
poly = self._unclip(poly_pts)
# If the result polygon does not exist, or it is split into
# multiple polygons, skip it.
if len(poly) == 0 or isinstance(poly, list):
continue
poly = poly.reshape(-1, 2)
if self.text_repr_type == 'quad':
poly = points2boundary(poly, self.text_repr_type, score,
self.min_text_width)
rect = cv2.minAreaRect(poly)
vertices = cv2.boxPoints(rect)
poly = vertices.flatten() if min(
rect[1]) >= self.min_text_width else []
elif self.text_repr_type == 'poly':
poly = poly.flatten().tolist()
if score is not None:
poly = poly + [score]
if len(poly) < 8:
poly = None
poly = poly.flatten()
if poly is not None:
boundaries.append(poly)
if len(poly) < 8:
poly = np.array([], dtype=np.float32)
return boundaries
if len(poly) > 0:
data_sample.pred_instances.polygons.append(poly)
data_sample.pred_instances.scores.append(
torch.FloatTensor([score]))
return data_sample
def _get_bbox_score(self, score_map: np.ndarray,
poly_pts: np.ndarray) -> float:
"""Compute the average score over the area of the bounding box of the
polygon.
Args:
score_map (np.ndarray): The score map.
poly_pts (np.ndarray): The polygon points.
Returns:
float: The average score.
"""
h, w = score_map.shape[:2]
poly_pts = poly_pts.copy()
xmin = np.clip(
np.floor(poly_pts[:, 0].min()).astype(np.int32), 0, w - 1)
xmax = np.clip(
np.ceil(poly_pts[:, 0].max()).astype(np.int32), 0, w - 1)
ymin = np.clip(
np.floor(poly_pts[:, 1].min()).astype(np.int32), 0, h - 1)
ymax = np.clip(
np.ceil(poly_pts[:, 1].max()).astype(np.int32), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
poly_pts[:, 0] = poly_pts[:, 0] - xmin
poly_pts[:, 1] = poly_pts[:, 1] - ymin
cv2.fillPoly(mask, poly_pts.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(score_map[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def _unclip(self, poly_pts: np.ndarray) -> np.ndarray:
"""Unclip a polygon.
Args:
poly_pts (np.ndarray): The polygon points.
Returns:
np.ndarray: The expanded polygon points.
"""
poly = Polygon(poly_pts)
distance = poly.area * self.unclip_ratio / poly.length
return offset_polygon(poly_pts, distance)

View File

@ -5,6 +5,7 @@ interrogate
isort
# Note: used for kwarray.group_items, this may be ported to mmcv in the future.
kwarray
parameterized
pytest
pytest-cov
pytest-runner

View File

@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import numpy as np
import torch
from mmengine import InstanceData
from nose_parameterized import parameterized
from mmocr.core import TextDetDataSample
from mmocr.models.textdet.postprocessors import DBPostprocessor
class TestDBPostProcessor(unittest.TestCase):
def test_get_bbox_score(self):
postprocessor = DBPostprocessor()
score_map = np.arange(0, 1, step=0.05).reshape(4, 5)
poly_pts = np.array(((0, 0), (0, 1), (1, 1), (1, 0)))
self.assertAlmostEqual(
postprocessor._get_bbox_score(score_map, poly_pts), 0.15)
@parameterized.expand([('poly'), ('quad')])
def test_get_text_instances(self, text_repr_type):
postprocessor = DBPostprocessor(text_repr_type=text_repr_type)
pred_result = dict(prob_map=torch.rand(4, 5))
data_sample = TextDetDataSample(
metainfo=dict(scale_factor=(0.5, 1)),
gt_instances=InstanceData(polygons=[
np.array([0, 0, 0, 1, 2, 1, 2, 0]),
np.array([1, 1, 1, 2, 3, 2, 3, 1])
]))
results = postprocessor.get_text_instances(pred_result, data_sample)
self.assertIn('polygons', results.pred_instances)
self.assertIn('scores', results.pred_instances)
postprocessor = DBPostprocessor(
min_text_score=1, text_repr_type=text_repr_type)
pred_result = dict(prob_map=torch.rand(4, 5) * 0.8)
results = postprocessor.get_text_instances(pred_result, data_sample)
self.assertEqual(results.pred_instances.polygons, [])
self.assertEqual(results.pred_instances.scores, [])