mirror of https://github.com/open-mmlab/mmocr.git
204 lines
8.1 KiB
Python
204 lines
8.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from functools import partial
|
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|
|
|
import mmengine
|
|
import numpy as np
|
|
from torch import Tensor
|
|
|
|
from mmocr.structures import TextDetDataSample
|
|
from mmocr.utils import boundary_iou, rescale_polygons
|
|
|
|
|
|
class BaseTextDetPostProcessor:
|
|
"""Base postprocessor for text detection models.
|
|
|
|
Args:
|
|
text_repr_type (str): The boundary encoding type, 'poly' or 'quad'.
|
|
Defaults to 'poly'.
|
|
rescale_fields (list[str], optional): The bbox/polygon field names to
|
|
be rescaled. If None, no rescaling will be performed.
|
|
train_cfg (dict, optional): The parameters to be passed to
|
|
``self.get_text_instances`` in training. Defaults to None.
|
|
test_cfg (dict, optional): The parameters to be passed to
|
|
``self.get_text_instances`` in testing. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
text_repr_type: str = 'poly',
|
|
rescale_fields: Optional[Sequence[str]] = None,
|
|
train_cfg: Optional[Dict] = None,
|
|
test_cfg: Optional[Dict] = None) -> None:
|
|
assert text_repr_type in ['poly', 'quad']
|
|
assert rescale_fields is None or isinstance(rescale_fields, list)
|
|
assert train_cfg is None or isinstance(train_cfg, dict)
|
|
assert test_cfg is None or isinstance(test_cfg, dict)
|
|
self.text_repr_type = text_repr_type
|
|
self.rescale_fields = rescale_fields
|
|
self.train_cfg = train_cfg
|
|
self.test_cfg = test_cfg
|
|
|
|
def __call__(self,
|
|
pred_results: dict,
|
|
data_samples: Sequence[TextDetDataSample],
|
|
training: bool = False) -> Sequence[TextDetDataSample]:
|
|
"""Postprocess pred_results according to metainfos in data_samples.
|
|
|
|
Args:
|
|
pred_results (dict): The prediction results stored in a dictionary.
|
|
Usually each item to be post-processed is expected to be a
|
|
batched tensor.
|
|
data_samples (list[TextDetDataSample]): Batch of data_samples,
|
|
each corresponding to a prediction result.
|
|
training (bool): Whether the model is in training mode. Defaults to
|
|
False.
|
|
|
|
Returns:
|
|
list[TextDetDataSample]: Batch of post-processed datasamples.
|
|
"""
|
|
cfg = self.train_cfg if training else self.test_cfg
|
|
if cfg is None:
|
|
cfg = {}
|
|
pred_results = self.split_results(pred_results)
|
|
process_single = partial(self._process_single, **cfg)
|
|
results = list(map(process_single, pred_results, data_samples))
|
|
|
|
return results
|
|
|
|
def _process_single(self, pred_result: dict,
|
|
data_sample: TextDetDataSample,
|
|
**kwargs) -> TextDetDataSample:
|
|
"""Process prediction results from one image.
|
|
|
|
Args:
|
|
pred_result (dict): Prediction results of an image.
|
|
data_sample (TextDetDataSample): Datasample of an image.
|
|
"""
|
|
|
|
results = self.get_text_instances(pred_result, data_sample, **kwargs)
|
|
|
|
if self.rescale_fields and len(self.rescale_fields) > 0:
|
|
assert isinstance(self.rescale_fields, list)
|
|
assert set(self.rescale_fields).issubset(
|
|
set(results.pred_instances.keys()))
|
|
results = self.rescale(results, data_sample.scale_factor)
|
|
return results
|
|
|
|
def rescale(self, results: TextDetDataSample,
|
|
scale_factor: Sequence[int]) -> TextDetDataSample:
|
|
"""Rescale results in ``results.pred_instances`` according to
|
|
``scale_factor``, whose keys are defined in ``self.rescale_fields``.
|
|
Usually used to rescale bboxes and/or polygons.
|
|
|
|
Args:
|
|
results (TextDetDataSample): The post-processed prediction results.
|
|
scale_factor (tuple(int)): (w_scale, h_scale)
|
|
|
|
Returns:
|
|
TextDetDataSample: Prediction results with rescaled results.
|
|
"""
|
|
scale_factor = np.asarray(scale_factor)
|
|
for key in self.rescale_fields:
|
|
results.pred_instances[key] = rescale_polygons(
|
|
results.pred_instances[key], scale_factor, mode='div')
|
|
return results
|
|
|
|
def get_text_instances(self, pred_results: dict,
|
|
data_sample: TextDetDataSample,
|
|
**kwargs) -> TextDetDataSample:
|
|
"""Get text instance predictions of one image.
|
|
|
|
Args:
|
|
pred_result (dict): Prediction results of an image.
|
|
data_sample (TextDetDataSample): Datasample of an image.
|
|
**kwargs: Other parameters. Configurable via ``__init__.train_cfg``
|
|
and ``__init__.test_cfg``.
|
|
|
|
Returns:
|
|
TextDetDataSample: A new DataSample with predictions filled in.
|
|
The polygon/bbox results are usually saved in
|
|
``TextDetDataSample.pred_instances.polygons`` or
|
|
``TextDetDataSample.pred_instances.bboxes``. The confidence scores
|
|
are saved in ``TextDetDataSample.pred_instances.scores``.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def split_results(
|
|
self, pred_results: Union[Tensor, List[Tensor]]
|
|
) -> Union[List[Tensor], List[List[Tensor]]]:
|
|
"""Split batched tensor(s) along the first dimension pack split tensors
|
|
into a list.
|
|
|
|
Args:
|
|
pred_results (tensor or list[tensor]): Raw result tensor(s) from
|
|
detection head. Each tensor usually has the shape of (N, ...)
|
|
|
|
Returns:
|
|
list[tensor] or list[list[tensor]]: N tensors if ``pred_results``
|
|
is a tensor, or a list of N lists of tensors if
|
|
``pred_results`` is a list of tensors.
|
|
"""
|
|
assert isinstance(pred_results, Tensor) or mmengine.is_seq_of(
|
|
pred_results, Tensor)
|
|
|
|
if mmengine.is_seq_of(pred_results, Tensor):
|
|
for i in range(1, len(pred_results)):
|
|
assert pred_results[0].shape[0] == pred_results[i].shape[0], \
|
|
'The first dimension of all tensors should be the same'
|
|
|
|
batch_num = len(pred_results) if isinstance(pred_results, Tensor) else\
|
|
len(pred_results[0])
|
|
results = []
|
|
for i in range(batch_num):
|
|
if isinstance(pred_results, Tensor):
|
|
results.append(pred_results[i])
|
|
else:
|
|
results.append([])
|
|
for tensor in pred_results:
|
|
results[i].append(tensor[i])
|
|
return results
|
|
|
|
def poly_nms(self, polygons: List[np.ndarray], scores: List[float],
|
|
threshold: float) -> Tuple[List[np.ndarray], List[float]]:
|
|
"""Non-maximum suppression for text detection.
|
|
|
|
Args:
|
|
polygons (list[ndarray]): List of polygons.
|
|
scores (list[float]): List of scores.
|
|
threshold (float): Threshold for NMS.
|
|
|
|
Returns:
|
|
tuple(keep_polys, keep_scores):
|
|
|
|
- keep_polys (list[ndarray]): List of preserved polygons after NMS.
|
|
- keep_scores (list[float]): List of preserved scores after NMS.
|
|
"""
|
|
assert isinstance(polygons, list)
|
|
assert isinstance(scores, list)
|
|
assert len(polygons) == len(scores)
|
|
|
|
polygons = [
|
|
np.hstack((polygon, score))
|
|
for polygon, score in zip(polygons, scores)
|
|
]
|
|
polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
|
|
keep_polys = []
|
|
keep_scores = []
|
|
index = [i for i in range(len(polygons))]
|
|
|
|
while len(index) > 0:
|
|
keep_polys.append(polygons[index[-1]][:-1].tolist())
|
|
keep_scores.append(polygons[index[-1]][-1])
|
|
A = polygons[index[-1]][:-1]
|
|
index = np.delete(index, -1)
|
|
|
|
iou_list = np.zeros((len(index), ))
|
|
for i in range(len(index)):
|
|
B = polygons[index[i]][:-1]
|
|
|
|
iou_list[i] = boundary_iou(A, B, 1)
|
|
remove_index = np.where(iou_list > threshold)
|
|
index = np.delete(index, remove_index)
|
|
|
|
return keep_polys, keep_scores
|