mirror of https://github.com/open-mmlab/mmocr.git
[PSE] PSE Postprocessor
parent
4a04982806
commit
c0c0f4b565
|
@ -1,68 +1,83 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.ops import contour_expand
|
||||
from mmengine.data import InstanceData
|
||||
|
||||
from mmocr.core import points2boundary
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.registry import MODELS
|
||||
from .base_postprocessor import BasePostprocessor
|
||||
from .pan_postprocessor import PANPostprocessor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PSEPostprocessor(BasePostprocessor):
|
||||
class PSEPostprocessor(PANPostprocessor):
|
||||
"""Decoding predictions of PSENet to instances. This is partially adapted
|
||||
from https://github.com/whai362/PSENet.
|
||||
|
||||
Args:
|
||||
text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
|
||||
Defaults to 'poly'.
|
||||
rescale_fields (list[str]): The bbox/polygon field names to
|
||||
be rescaled. If None, no rescaling will be performed. Defaults to
|
||||
['polygons'].
|
||||
min_kernel_confidence (float): The minimal kernel confidence.
|
||||
min_text_avg_confidence (float): The minimal text average confidence.
|
||||
min_kernel_area (int): The minimal text kernel area.
|
||||
Defaults to 0.5.
|
||||
score_threshold (float): The minimal text average confidence.
|
||||
Defaults to 0.3.
|
||||
min_kernel_area (int): The minimal text kernel area. Defaults to 0.
|
||||
min_text_area (int): The minimal text instance region area.
|
||||
Defaults to 16.
|
||||
downsample_ratio (float): Downsample ratio. Defaults to 0.25.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
text_repr_type='poly',
|
||||
min_kernel_confidence=0.5,
|
||||
min_text_avg_confidence=0.85,
|
||||
min_kernel_area=0,
|
||||
min_text_area=16,
|
||||
**kwargs):
|
||||
super().__init__(text_repr_type)
|
||||
|
||||
assert 0 <= min_kernel_confidence <= 1
|
||||
assert 0 <= min_text_avg_confidence <= 1
|
||||
assert isinstance(min_kernel_area, int)
|
||||
assert isinstance(min_text_area, int)
|
||||
|
||||
self.min_kernel_confidence = min_kernel_confidence
|
||||
self.min_text_avg_confidence = min_text_avg_confidence
|
||||
text_repr_type: str = 'poly',
|
||||
rescale_fields: List[str] = ['polygons'],
|
||||
min_kernel_confidence: float = 0.5,
|
||||
score_threshold: float = 0.3,
|
||||
min_kernel_area: int = 0,
|
||||
min_text_area: int = 16,
|
||||
downsample_ratio: float = 0.25) -> None:
|
||||
super().__init__(
|
||||
text_repr_type=text_repr_type,
|
||||
rescale_fields=rescale_fields,
|
||||
min_kernel_confidence=min_kernel_confidence,
|
||||
score_threshold=score_threshold,
|
||||
min_text_area=min_text_area,
|
||||
downsample_ratio=downsample_ratio)
|
||||
self.min_kernel_area = min_kernel_area
|
||||
self.min_text_area = min_text_area
|
||||
|
||||
def __call__(self, preds):
|
||||
def get_text_instances(self, pred_results: torch.Tensor,
|
||||
data_sample: TextDetDataSample,
|
||||
**kwargs) -> TextDetDataSample:
|
||||
"""
|
||||
Args:
|
||||
preds (Tensor): Prediction map with shape :math:`(C, H, W)`.
|
||||
pred_result (torch.Tensor): Prediction results of an image which
|
||||
is a tensor of shape :math:`(N, H, W)`.
|
||||
data_sample (TextDetDataSample): Datasample of an image.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: The instance boundary and its confidence.
|
||||
TextDetDataSample: A new DataSample with predictions filled in.
|
||||
Polygons and results are saved in
|
||||
``TextDetDataSample.pred_instances.polygons``. The confidence
|
||||
scores are saved in ``TextDetDataSample.pred_instances.scores``.
|
||||
"""
|
||||
assert preds.dim() == 3
|
||||
assert pred_results.dim() == 3
|
||||
|
||||
preds = torch.sigmoid(preds) # text confidence
|
||||
pred_results = torch.sigmoid(pred_results) # text confidence
|
||||
|
||||
score = preds[0, :, :]
|
||||
masks = preds > self.min_kernel_confidence
|
||||
masks = pred_results > self.min_kernel_confidence
|
||||
text_mask = masks[0, :, :]
|
||||
kernel_masks = masks[0:, :, :] * text_mask
|
||||
|
||||
score = score.data.cpu().numpy().astype(np.float32)
|
||||
|
||||
kernel_masks = kernel_masks.data.cpu().numpy().astype(np.uint8)
|
||||
|
||||
score = pred_results[0, :, :]
|
||||
score = score.data.cpu().numpy().astype(np.float32)
|
||||
|
||||
region_num, labels = cv2.connectedComponents(
|
||||
kernel_masks[-1], connectivity=4)
|
||||
|
||||
|
@ -70,19 +85,25 @@ class PSEPostprocessor(BasePostprocessor):
|
|||
region_num)
|
||||
labels = np.array(labels)
|
||||
label_num = np.max(labels)
|
||||
boundaries = []
|
||||
|
||||
polygons = []
|
||||
scores = []
|
||||
for i in range(1, label_num + 1):
|
||||
points = np.array(np.where(labels == i)).transpose((1, 0))[:, ::-1]
|
||||
area = points.shape[0]
|
||||
score_instance = np.mean(score[labels == i])
|
||||
if not self.is_valid_instance(area, score_instance,
|
||||
self.min_text_area,
|
||||
self.min_text_avg_confidence):
|
||||
if not (area >= self.min_text_area
|
||||
or score_instance > self.score_threshold):
|
||||
continue
|
||||
|
||||
vertices_confidence = points2boundary(points, self.text_repr_type,
|
||||
score_instance)
|
||||
if vertices_confidence is not None:
|
||||
boundaries.append(vertices_confidence)
|
||||
polygon = self._points2boundary(points)
|
||||
if polygon:
|
||||
polygons.append(polygon)
|
||||
scores.append(score_instance)
|
||||
|
||||
return boundaries
|
||||
pred_instances = InstanceData()
|
||||
pred_instances.polygons = polygons
|
||||
pred_instances.scores = torch.FloatTensor(scores)
|
||||
data_sample.pred_instances = pred_instances
|
||||
|
||||
return data_sample
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from nose_parameterized import parameterized
|
||||
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.models.textdet.postprocessors import PSEPostprocessor
|
||||
|
||||
|
||||
class TestPSEPostprocessor(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([('poly'), ('quad')])
|
||||
def test_get_text_instances(self, text_repr_type):
|
||||
postprocessor = PSEPostprocessor(text_repr_type=text_repr_type)
|
||||
pred_result = torch.rand(6, 4, 5)
|
||||
data_sample = TextDetDataSample(metainfo=dict(scale_factor=(0.5, 1)))
|
||||
results = postprocessor.get_text_instances(pred_result, data_sample)
|
||||
self.assertIn('polygons', results.pred_instances)
|
||||
self.assertIn('scores', results.pred_instances)
|
||||
|
||||
postprocessor = PSEPostprocessor(
|
||||
score_threshold=1,
|
||||
min_kernel_confidence=1,
|
||||
text_repr_type=text_repr_type)
|
||||
pred_result = torch.rand(6, 4, 5) * 0.8
|
||||
results = postprocessor.get_text_instances(pred_result, data_sample)
|
||||
self.assertEqual(results.pred_instances.polygons, [])
|
||||
self.assertTrue(
|
||||
(results.pred_instances.scores == torch.FloatTensor([])).all())
|
Loading…
Reference in New Issue