[PSE] PSE Postprocessor

pull/1178/head
wangxinyu 2022-06-21 09:25:55 +00:00 committed by gaotongxiao
parent 4a04982806
commit c0c0f4b565
2 changed files with 91 additions and 40 deletions

View File

@ -1,68 +1,83 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import cv2
import numpy as np
import torch
from mmcv.ops import contour_expand
from mmengine.data import InstanceData
from mmocr.core import points2boundary
from mmocr.core import TextDetDataSample
from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor
from .pan_postprocessor import PANPostprocessor
@MODELS.register_module()
class PSEPostprocessor(BasePostprocessor):
class PSEPostprocessor(PANPostprocessor):
"""Decoding predictions of PSENet to instances. This is partially adapted
from https://github.com/whai362/PSENet.
Args:
text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
Defaults to 'poly'.
rescale_fields (list[str]): The bbox/polygon field names to
be rescaled. If None, no rescaling will be performed. Defaults to
['polygons'].
min_kernel_confidence (float): The minimal kernel confidence.
min_text_avg_confidence (float): The minimal text average confidence.
min_kernel_area (int): The minimal text kernel area.
Defaults to 0.5.
score_threshold (float): The minimal text average confidence.
Defaults to 0.3.
min_kernel_area (int): The minimal text kernel area. Defaults to 0.
min_text_area (int): The minimal text instance region area.
Defaults to 16.
downsample_ratio (float): Downsample ratio. Defaults to 0.25.
"""
def __init__(self,
text_repr_type='poly',
min_kernel_confidence=0.5,
min_text_avg_confidence=0.85,
min_kernel_area=0,
min_text_area=16,
**kwargs):
super().__init__(text_repr_type)
assert 0 <= min_kernel_confidence <= 1
assert 0 <= min_text_avg_confidence <= 1
assert isinstance(min_kernel_area, int)
assert isinstance(min_text_area, int)
self.min_kernel_confidence = min_kernel_confidence
self.min_text_avg_confidence = min_text_avg_confidence
text_repr_type: str = 'poly',
rescale_fields: List[str] = ['polygons'],
min_kernel_confidence: float = 0.5,
score_threshold: float = 0.3,
min_kernel_area: int = 0,
min_text_area: int = 16,
downsample_ratio: float = 0.25) -> None:
super().__init__(
text_repr_type=text_repr_type,
rescale_fields=rescale_fields,
min_kernel_confidence=min_kernel_confidence,
score_threshold=score_threshold,
min_text_area=min_text_area,
downsample_ratio=downsample_ratio)
self.min_kernel_area = min_kernel_area
self.min_text_area = min_text_area
def __call__(self, preds):
def get_text_instances(self, pred_results: torch.Tensor,
data_sample: TextDetDataSample,
**kwargs) -> TextDetDataSample:
"""
Args:
preds (Tensor): Prediction map with shape :math:`(C, H, W)`.
pred_result (torch.Tensor): Prediction results of an image which
is a tensor of shape :math:`(N, H, W)`.
data_sample (TextDetDataSample): Datasample of an image.
Returns:
list[list[float]]: The instance boundary and its confidence.
TextDetDataSample: A new DataSample with predictions filled in.
Polygons and results are saved in
``TextDetDataSample.pred_instances.polygons``. The confidence
scores are saved in ``TextDetDataSample.pred_instances.scores``.
"""
assert preds.dim() == 3
assert pred_results.dim() == 3
preds = torch.sigmoid(preds) # text confidence
pred_results = torch.sigmoid(pred_results) # text confidence
score = preds[0, :, :]
masks = preds > self.min_kernel_confidence
masks = pred_results > self.min_kernel_confidence
text_mask = masks[0, :, :]
kernel_masks = masks[0:, :, :] * text_mask
score = score.data.cpu().numpy().astype(np.float32)
kernel_masks = kernel_masks.data.cpu().numpy().astype(np.uint8)
score = pred_results[0, :, :]
score = score.data.cpu().numpy().astype(np.float32)
region_num, labels = cv2.connectedComponents(
kernel_masks[-1], connectivity=4)
@ -70,19 +85,25 @@ class PSEPostprocessor(BasePostprocessor):
region_num)
labels = np.array(labels)
label_num = np.max(labels)
boundaries = []
polygons = []
scores = []
for i in range(1, label_num + 1):
points = np.array(np.where(labels == i)).transpose((1, 0))[:, ::-1]
area = points.shape[0]
score_instance = np.mean(score[labels == i])
if not self.is_valid_instance(area, score_instance,
self.min_text_area,
self.min_text_avg_confidence):
if not (area >= self.min_text_area
or score_instance > self.score_threshold):
continue
vertices_confidence = points2boundary(points, self.text_repr_type,
score_instance)
if vertices_confidence is not None:
boundaries.append(vertices_confidence)
polygon = self._points2boundary(points)
if polygon:
polygons.append(polygon)
scores.append(score_instance)
return boundaries
pred_instances = InstanceData()
pred_instances.polygons = polygons
pred_instances.scores = torch.FloatTensor(scores)
data_sample.pred_instances = pred_instances
return data_sample

View File

@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
from nose_parameterized import parameterized
from mmocr.core import TextDetDataSample
from mmocr.models.textdet.postprocessors import PSEPostprocessor
class TestPSEPostprocessor(unittest.TestCase):
@parameterized.expand([('poly'), ('quad')])
def test_get_text_instances(self, text_repr_type):
postprocessor = PSEPostprocessor(text_repr_type=text_repr_type)
pred_result = torch.rand(6, 4, 5)
data_sample = TextDetDataSample(metainfo=dict(scale_factor=(0.5, 1)))
results = postprocessor.get_text_instances(pred_result, data_sample)
self.assertIn('polygons', results.pred_instances)
self.assertIn('scores', results.pred_instances)
postprocessor = PSEPostprocessor(
score_threshold=1,
min_kernel_confidence=1,
text_repr_type=text_repr_type)
pred_result = torch.rand(6, 4, 5) * 0.8
results = postprocessor.get_text_instances(pred_result, data_sample)
self.assertEqual(results.pred_instances.polygons, [])
self.assertTrue(
(results.pred_instances.scores == torch.FloatTensor([])).all())