[Feature] TextRecogPostprocessor

pull/1178/head
liukuikun 2022-05-20 06:17:13 +00:00 committed by gaotongxiao
parent e8f57d6540
commit a05e3f19c5
8 changed files with 439 additions and 2 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import (backbones, convertors, decoders, dictionary, encoders, fusers,
heads, losses, necks, plugins, preprocessor, recognizer)
heads, losses, necks, plugins, postprocessor, preprocessor,
recognizer)
from .backbones import * # NOQA
from .convertors import * # NOQA
from .decoders import * # NOQA
@ -11,6 +12,7 @@ from .heads import * # NOQA
from .losses import * # NOQA
from .necks import * # NOQA
from .plugins import * # NOQA
from .postprocessor import * # NOQA
from .preprocessor import * # NOQA
from .recognizer import * # NOQA
@ -18,4 +20,4 @@ __all__ = (
backbones.__all__ + convertors.__all__ + decoders.__all__ +
encoders.__all__ + heads.__all__ + losses.__all__ + necks.__all__ +
preprocessor.__all__ + recognizer.__all__ + fusers.__all__ +
plugins.__all__ + dictionary.__all__)
plugins.__all__ + dictionary.__all__ + postprocessor.__all__)

View File

@ -0,0 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .attn_postprocessor import AttentionPostprocessor
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
from .ctc_postprocessor import CTCPostProcessor
__all__ = [
'BaseTextRecogPostprocessor', 'AttentionPostprocessor', 'CTCPostProcessor'
]

View File

@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple
import torch
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.registry import MODELS
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
@MODELS.register_module()
class AttentionPostprocessor(BaseTextRecogPostprocessor):
"""PostProcessor for seq2seq."""
def get_single_prediction(
self,
output: torch.Tensor,
data_sample: Optional[TextRecogDataSample] = None,
) -> Tuple[Sequence[int], Sequence[float]]:
"""Convert the output of a single image to index and score.
Args:
output (torch.Tensor): Single image output.
data_sample (TextRecogDataSample, optional): Datasample of an
image. Defaults to None.
Returns:
tuple(list[int], list[float]): index and score.
"""
max_value, max_idx = torch.max(output, -1)
index, score = [], []
output_index = max_idx.cpu().detach().numpy().tolist()
output_score = max_value.cpu().detach().numpy().tolist()
for char_index, char_score in zip(output_index, output_score):
if char_index in self.ignore_indexes:
continue
if char_index == self.dictionary.end_idx:
break
index.append(char_index)
score.append(char_score)
return index, score

View File

@ -0,0 +1,100 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence, Tuple, Union
import mmcv
import torch
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.models.textrecog.dictionary import Dictionary
from mmocr.registry import TASK_UTILS
class BaseTextRecogPostprocessor:
"""Base text recognition postprocessor.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
max_seq_len (int): max_seq_len (int): Maximum sequence length. The
sequence is usually generated from decoder. Defaults to 40.
ignore_chars (list[str]): A list of characters to be ignored from the
final results. Postprocessor will skip over these characters when
converting raw indexes to characters. Apart from single characters,
each item can be one of the following reversed keywords: 'padding',
'end' and 'unknown', which refer to their corresponding special
tokens in the dictionary.
"""
def __init__(self,
dictionary: Union[Dictionary, Dict],
max_seq_len: int = 40,
ignore_chars: Sequence[str] = ['padding'],
**kwargs) -> None:
if isinstance(dictionary, dict):
self.dictionary = TASK_UTILS.build(dictionary)
elif isinstance(dictionary, Dictionary):
self.dictionary = dictionary
else:
raise TypeError(
'The type of dictionary should be `Dictionary` or dict, '
f'but got {type(dictionary)}')
self.max_seq_len = max_seq_len
mapping_table = {
'padding': self.dictionary.padding_idx,
'end': self.dictionary.end_idx,
'unknown': self.dictionary.unknown_idx,
}
if not mmcv.is_list_of(ignore_chars, str):
raise TypeError('ignore_chars must be list of str')
ignore_indexes = list()
for ignore_char in ignore_chars:
# TODO add char2id in Dictionary
index = mapping_table.get(
ignore_char, self.dictionary._char2idx.get(ignore_char, None))
if index is None:
raise ValueError(f'{ignore_char} is not exist in dictionary')
ignore_indexes.append(index)
self.ignore_indexes = ignore_indexes
def get_single_prediction(
self,
output: torch.Tensor,
data_sample: Optional[TextRecogDataSample] = None,
) -> Tuple[Sequence[int], Sequence[float]]:
"""Convert the output of a single image to index and score.
Args:
output (torch.Tensor): Single image output.
data_sample (TextRecogDataSample): Datasample of an image.
Returns:
tuple(list[int], list[float]): index and score.
"""
raise NotImplementedError
def __call__(
self, outputs: torch.Tensor,
data_samples: Sequence[TextRecogDataSample]
) -> Sequence[TextRecogDataSample]:
"""Convert outputs to strings and scores.
Args:
outputs (torch.Tensor): The model outputs in size: N * T * C
data_samples (list[TextRecogDataSample]): The list of
TextRecogDataSample.
Returns:
list(TextRecogDataSample): The list of TextRecogDataSample. It
usually contain ``pred_text`` information.
"""
batch_size = outputs.size(0)
for idx in range(batch_size):
index, score = self.get_single_prediction(outputs[idx, :, :],
data_samples[idx])
text = self.dictionary.idx2str(index)
data_samples[idx].pred_text.score = score
data_samples[idx].pred_text.item = text
return data_samples

View File

@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Sequence, Tuple
import torch
import torch.nn.functional as F
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.registry import MODELS
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
# TODO support beam search
@MODELS.register_module()
class CTCPostProcessor(BaseTextRecogPostprocessor):
"""PostProcessor for CTC."""
def get_single_prediction(self, output: torch.Tensor,
data_sample: TextRecogDataSample
) -> Tuple[Sequence[int], Sequence[float]]:
"""Convert the output of a single image to index and score.
Args:
output (torch.Tensor): Single image output.
data_sample (TextRecogDataSample): Datasample of an image.
Returns:
tuple(list[int], list[float]): index and score.
"""
feat_len = output.size(0)
max_value, max_idx = torch.max(output, -1)
valid_ratio = data_sample.get('valid_ratio', 1)
decode_len = min(feat_len, math.ceil(feat_len * valid_ratio))
index = []
score = []
prev_idx = self.dictionary.padding_idx
for t in range(decode_len):
tmp_value = max_idx[t].item()
if tmp_value not in (prev_idx, *self.ignore_indexes):
index.append(tmp_value)
score.append(max_value[t].item())
prev_idx = tmp_value
return index, score
def __call__(
self, outputs: torch.Tensor,
data_samples: Sequence[TextRecogDataSample]
) -> Sequence[TextRecogDataSample]:
# TODO move to decoder
outputs = F.softmax(outputs, dim=2)
outputs = outputs.cpu().detach()
return super().__call__(outputs, data_samples)

View File

@ -0,0 +1,51 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
import torch
from mmengine.data import LabelData
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.models.textrecog.dictionary import Dictionary
from mmocr.models.textrecog.postprocessor.attn_postprocessor import \
AttentionPostprocessor
class TestAttentionPostprocessor(TestCase):
def _create_dummy_dict_file(
self, dict_file,
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
with open(dict_file, 'w') as f:
for char in chars:
f.write(char + '\n')
def test_call(self):
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_gen = Dictionary(
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=False)
pred_text = LabelData(valid_ratio=1.0)
data_samples = [TextRecogDataSample(pred_text=pred_text)]
postprocessor = AttentionPostprocessor(
max_seq_len=None, dictionary=dict_gen, ignore_chars=['0'])
dict_gen.end_idx = 3
# test decode output to index
dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 100, 4, 5, 6, 7, 8],
[1, 2, 100, 4, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 100, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 100, 5, 6, 7, 8]]])
data_samples = postprocessor(dummy_output, data_samples)
self.assertEqual(data_samples[0].pred_text.item, '122')

View File

@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase, mock
import torch
from mmengine.data import LabelData
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.models.textrecog.dictionary import Dictionary
from mmocr.models.textrecog.postprocessor.base_textrecog_postprocessor import \
BaseTextRecogPostprocessor
class TestBaseTextRecogPostprocessor(TestCase):
def _create_dummy_dict_file(
self, dict_file,
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
with open(dict_file, 'w') as f:
for char in chars:
f.write(char + '\n')
def test_init(self):
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
# test diction cfg
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True)
base_postprocessor = BaseTextRecogPostprocessor(dict_cfg)
self.assertIsInstance(base_postprocessor.dictionary, Dictionary)
self.assertListEqual(base_postprocessor.ignore_indexes,
[base_postprocessor.dictionary.padding_idx])
base_postprocessor = BaseTextRecogPostprocessor(
dict_cfg, ignore_chars=['1', '2', '3'])
self.assertListEqual(base_postprocessor.ignore_indexes, [1, 2, 3])
# test ignore_chars
with self.assertRaisesRegex(TypeError,
'ignore_chars must be list of str'):
base_postprocessor = BaseTextRecogPostprocessor(
dict_cfg, ignore_chars=[1, 2, 3])
with self.assertRaisesRegex(ValueError,
'M is not exist in dictionary'):
base_postprocessor = BaseTextRecogPostprocessor(
dict_cfg, ignore_chars=['M'])
base_postprocessor = BaseTextRecogPostprocessor(
dict_cfg, ignore_chars=['1', '2', '3'])
# test dictionary is invalid type
dict_cfg = ['tmp']
with self.assertRaisesRegex(
TypeError, ('The type of dictionary should be `Dictionary`'
' or dict, '
f'but got {type(dict_cfg)}')):
base_postprocessor = BaseTextRecogPostprocessor(dict_cfg)
tmp_dir.cleanup()
@mock.patch(f'{__name__}.BaseTextRecogPostprocessor.get_single_prediction')
def test_call(self, mock_get_single_prediction):
def mock_func(output, data_sample):
return [0, 1, 2], [0.8, 0.7, 0.9]
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True)
mock_get_single_prediction.side_effect = mock_func
pred_text = LabelData(valid_ratio=1.0)
data_samples = [TextRecogDataSample(pred_text=pred_text)]
postprocessor = BaseTextRecogPostprocessor(
max_seq_len=None, dictionary=dict_cfg)
# test decode output to index
dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8]]])
data_samples = postprocessor(dummy_output, data_samples)
self.assertEqual(data_samples[0].pred_text.item, '012')
tmp_dir.cleanup()

View File

@ -0,0 +1,86 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
import torch
from mmengine.data import LabelData
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.models.textrecog.dictionary import Dictionary
from mmocr.models.textrecog.postprocessor.ctc_postprocessor import \
CTCPostProcessor
class TestCTCPostProcessor(TestCase):
def _create_dummy_dict_file(
self, dict_file,
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
with open(dict_file, 'w') as f:
for char in chars:
f.write(char + '\n')
def test_get_single_prediction(self):
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_gen = Dictionary(
dict_file=dict_file,
with_start=False,
with_end=False,
with_padding=True,
with_unknown=False)
pred_text = LabelData(valid_ratio=1.0)
data_samples = [TextRecogDataSample(pred_text=pred_text)]
postprocessor = CTCPostProcessor(max_seq_len=None, dictionary=dict_gen)
# test decode output to index
dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 100, 4, 5, 6, 7, 8],
[1, 2, 100, 4, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 100, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 100, 5, 6, 7, 8]]])
index, score = postprocessor.get_single_prediction(
dummy_output[0], data_samples[0])
self.assertListEqual(index, [1, 0, 2, 0, 3, 0, 3])
self.assertListEqual(score,
[100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0])
postprocessor = CTCPostProcessor(
max_seq_len=None, dictionary=dict_gen, ignore_chars=['0'])
index, score = postprocessor.get_single_prediction(
dummy_output[0], data_samples[0])
self.assertListEqual(index, [1, 2, 3, 3])
self.assertListEqual(score, [100.0, 100.0, 100.0, 100.0])
tmp_dir.cleanup()
def test_call(self):
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_gen = Dictionary(
dict_file=dict_file,
with_start=False,
with_end=False,
with_padding=True,
with_unknown=False)
pred_text = LabelData(valid_ratio=1.0)
data_samples = [TextRecogDataSample(pred_text=pred_text)]
postprocessor = CTCPostProcessor(max_seq_len=None, dictionary=dict_gen)
# test decode output to index
dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 100, 4, 5, 6, 7, 8],
[1, 2, 100, 4, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 100, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 100, 5, 6, 7, 8]]])
data_samples = postprocessor(dummy_output, data_samples)
self.assertEqual(data_samples[0].pred_text.item, '1020303')