mirror of https://github.com/open-mmlab/mmocr.git
[Feature] TextRecogPostprocessor
parent
e8f57d6540
commit
a05e3f19c5
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import (backbones, convertors, decoders, dictionary, encoders, fusers,
|
||||
heads, losses, necks, plugins, preprocessor, recognizer)
|
||||
heads, losses, necks, plugins, postprocessor, preprocessor,
|
||||
recognizer)
|
||||
from .backbones import * # NOQA
|
||||
from .convertors import * # NOQA
|
||||
from .decoders import * # NOQA
|
||||
|
@ -11,6 +12,7 @@ from .heads import * # NOQA
|
|||
from .losses import * # NOQA
|
||||
from .necks import * # NOQA
|
||||
from .plugins import * # NOQA
|
||||
from .postprocessor import * # NOQA
|
||||
from .preprocessor import * # NOQA
|
||||
from .recognizer import * # NOQA
|
||||
|
||||
|
@ -18,4 +20,4 @@ __all__ = (
|
|||
backbones.__all__ + convertors.__all__ + decoders.__all__ +
|
||||
encoders.__all__ + heads.__all__ + losses.__all__ + necks.__all__ +
|
||||
preprocessor.__all__ + recognizer.__all__ + fusers.__all__ +
|
||||
plugins.__all__ + dictionary.__all__)
|
||||
plugins.__all__ + dictionary.__all__ + postprocessor.__all__)
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .attn_postprocessor import AttentionPostprocessor
|
||||
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
|
||||
from .ctc_postprocessor import CTCPostProcessor
|
||||
|
||||
__all__ = [
|
||||
'BaseTextRecogPostprocessor', 'AttentionPostprocessor', 'CTCPostProcessor'
|
||||
]
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.registry import MODELS
|
||||
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class AttentionPostprocessor(BaseTextRecogPostprocessor):
|
||||
"""PostProcessor for seq2seq."""
|
||||
|
||||
def get_single_prediction(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
data_sample: Optional[TextRecogDataSample] = None,
|
||||
) -> Tuple[Sequence[int], Sequence[float]]:
|
||||
"""Convert the output of a single image to index and score.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Single image output.
|
||||
data_sample (TextRecogDataSample, optional): Datasample of an
|
||||
image. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tuple(list[int], list[float]): index and score.
|
||||
"""
|
||||
max_value, max_idx = torch.max(output, -1)
|
||||
index, score = [], []
|
||||
output_index = max_idx.cpu().detach().numpy().tolist()
|
||||
output_score = max_value.cpu().detach().numpy().tolist()
|
||||
for char_index, char_score in zip(output_index, output_score):
|
||||
if char_index in self.ignore_indexes:
|
||||
continue
|
||||
if char_index == self.dictionary.end_idx:
|
||||
break
|
||||
index.append(char_index)
|
||||
score.append(char_score)
|
||||
return index, score
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.models.textrecog.dictionary import Dictionary
|
||||
from mmocr.registry import TASK_UTILS
|
||||
|
||||
|
||||
class BaseTextRecogPostprocessor:
|
||||
"""Base text recognition postprocessor.
|
||||
|
||||
Args:
|
||||
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
|
||||
the instance of `Dictionary`.
|
||||
max_seq_len (int): max_seq_len (int): Maximum sequence length. The
|
||||
sequence is usually generated from decoder. Defaults to 40.
|
||||
ignore_chars (list[str]): A list of characters to be ignored from the
|
||||
final results. Postprocessor will skip over these characters when
|
||||
converting raw indexes to characters. Apart from single characters,
|
||||
each item can be one of the following reversed keywords: 'padding',
|
||||
'end' and 'unknown', which refer to their corresponding special
|
||||
tokens in the dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dictionary: Union[Dictionary, Dict],
|
||||
max_seq_len: int = 40,
|
||||
ignore_chars: Sequence[str] = ['padding'],
|
||||
**kwargs) -> None:
|
||||
|
||||
if isinstance(dictionary, dict):
|
||||
self.dictionary = TASK_UTILS.build(dictionary)
|
||||
elif isinstance(dictionary, Dictionary):
|
||||
self.dictionary = dictionary
|
||||
else:
|
||||
raise TypeError(
|
||||
'The type of dictionary should be `Dictionary` or dict, '
|
||||
f'but got {type(dictionary)}')
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
mapping_table = {
|
||||
'padding': self.dictionary.padding_idx,
|
||||
'end': self.dictionary.end_idx,
|
||||
'unknown': self.dictionary.unknown_idx,
|
||||
}
|
||||
if not mmcv.is_list_of(ignore_chars, str):
|
||||
raise TypeError('ignore_chars must be list of str')
|
||||
ignore_indexes = list()
|
||||
for ignore_char in ignore_chars:
|
||||
# TODO add char2id in Dictionary
|
||||
index = mapping_table.get(
|
||||
ignore_char, self.dictionary._char2idx.get(ignore_char, None))
|
||||
if index is None:
|
||||
raise ValueError(f'{ignore_char} is not exist in dictionary')
|
||||
ignore_indexes.append(index)
|
||||
self.ignore_indexes = ignore_indexes
|
||||
|
||||
def get_single_prediction(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
data_sample: Optional[TextRecogDataSample] = None,
|
||||
) -> Tuple[Sequence[int], Sequence[float]]:
|
||||
"""Convert the output of a single image to index and score.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Single image output.
|
||||
data_sample (TextRecogDataSample): Datasample of an image.
|
||||
|
||||
Returns:
|
||||
tuple(list[int], list[float]): index and score.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(
|
||||
self, outputs: torch.Tensor,
|
||||
data_samples: Sequence[TextRecogDataSample]
|
||||
) -> Sequence[TextRecogDataSample]:
|
||||
"""Convert outputs to strings and scores.
|
||||
|
||||
Args:
|
||||
outputs (torch.Tensor): The model outputs in size: N * T * C
|
||||
data_samples (list[TextRecogDataSample]): The list of
|
||||
TextRecogDataSample.
|
||||
|
||||
Returns:
|
||||
list(TextRecogDataSample): The list of TextRecogDataSample. It
|
||||
usually contain ``pred_text`` information.
|
||||
"""
|
||||
batch_size = outputs.size(0)
|
||||
|
||||
for idx in range(batch_size):
|
||||
index, score = self.get_single_prediction(outputs[idx, :, :],
|
||||
data_samples[idx])
|
||||
text = self.dictionary.idx2str(index)
|
||||
data_samples[idx].pred_text.score = score
|
||||
data_samples[idx].pred_text.item = text
|
||||
return data_samples
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.registry import MODELS
|
||||
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
|
||||
|
||||
|
||||
# TODO support beam search
|
||||
@MODELS.register_module()
|
||||
class CTCPostProcessor(BaseTextRecogPostprocessor):
|
||||
"""PostProcessor for CTC."""
|
||||
|
||||
def get_single_prediction(self, output: torch.Tensor,
|
||||
data_sample: TextRecogDataSample
|
||||
) -> Tuple[Sequence[int], Sequence[float]]:
|
||||
"""Convert the output of a single image to index and score.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Single image output.
|
||||
data_sample (TextRecogDataSample): Datasample of an image.
|
||||
|
||||
Returns:
|
||||
tuple(list[int], list[float]): index and score.
|
||||
"""
|
||||
feat_len = output.size(0)
|
||||
max_value, max_idx = torch.max(output, -1)
|
||||
valid_ratio = data_sample.get('valid_ratio', 1)
|
||||
decode_len = min(feat_len, math.ceil(feat_len * valid_ratio))
|
||||
index = []
|
||||
score = []
|
||||
|
||||
prev_idx = self.dictionary.padding_idx
|
||||
for t in range(decode_len):
|
||||
tmp_value = max_idx[t].item()
|
||||
if tmp_value not in (prev_idx, *self.ignore_indexes):
|
||||
index.append(tmp_value)
|
||||
score.append(max_value[t].item())
|
||||
prev_idx = tmp_value
|
||||
return index, score
|
||||
|
||||
def __call__(
|
||||
self, outputs: torch.Tensor,
|
||||
data_samples: Sequence[TextRecogDataSample]
|
||||
) -> Sequence[TextRecogDataSample]:
|
||||
# TODO move to decoder
|
||||
outputs = F.softmax(outputs, dim=2)
|
||||
outputs = outputs.cpu().detach()
|
||||
return super().__call__(outputs, data_samples)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.models.textrecog.dictionary import Dictionary
|
||||
from mmocr.models.textrecog.postprocessor.attn_postprocessor import \
|
||||
AttentionPostprocessor
|
||||
|
||||
|
||||
class TestAttentionPostprocessor(TestCase):
|
||||
|
||||
def _create_dummy_dict_file(
|
||||
self, dict_file,
|
||||
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
|
||||
with open(dict_file, 'w') as f:
|
||||
for char in chars:
|
||||
f.write(char + '\n')
|
||||
|
||||
def test_call(self):
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_gen = Dictionary(
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=True,
|
||||
with_padding=True,
|
||||
with_unknown=False)
|
||||
pred_text = LabelData(valid_ratio=1.0)
|
||||
data_samples = [TextRecogDataSample(pred_text=pred_text)]
|
||||
postprocessor = AttentionPostprocessor(
|
||||
max_seq_len=None, dictionary=dict_gen, ignore_chars=['0'])
|
||||
dict_gen.end_idx = 3
|
||||
# test decode output to index
|
||||
dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 100, 4, 5, 6, 7, 8],
|
||||
[1, 2, 100, 4, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 3, 100, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 3, 100, 5, 6, 7, 8]]])
|
||||
data_samples = postprocessor(dummy_output, data_samples)
|
||||
self.assertEqual(data_samples[0].pred_text.item, '122')
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase, mock
|
||||
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.models.textrecog.dictionary import Dictionary
|
||||
from mmocr.models.textrecog.postprocessor.base_textrecog_postprocessor import \
|
||||
BaseTextRecogPostprocessor
|
||||
|
||||
|
||||
class TestBaseTextRecogPostprocessor(TestCase):
|
||||
|
||||
def _create_dummy_dict_file(
|
||||
self, dict_file,
|
||||
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
|
||||
with open(dict_file, 'w') as f:
|
||||
for char in chars:
|
||||
f.write(char + '\n')
|
||||
|
||||
def test_init(self):
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
# test diction cfg
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
base_postprocessor = BaseTextRecogPostprocessor(dict_cfg)
|
||||
self.assertIsInstance(base_postprocessor.dictionary, Dictionary)
|
||||
self.assertListEqual(base_postprocessor.ignore_indexes,
|
||||
[base_postprocessor.dictionary.padding_idx])
|
||||
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict_cfg, ignore_chars=['1', '2', '3'])
|
||||
|
||||
self.assertListEqual(base_postprocessor.ignore_indexes, [1, 2, 3])
|
||||
|
||||
# test ignore_chars
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'ignore_chars must be list of str'):
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict_cfg, ignore_chars=[1, 2, 3])
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'M is not exist in dictionary'):
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict_cfg, ignore_chars=['M'])
|
||||
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict_cfg, ignore_chars=['1', '2', '3'])
|
||||
# test dictionary is invalid type
|
||||
dict_cfg = ['tmp']
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, ('The type of dictionary should be `Dictionary`'
|
||||
' or dict, '
|
||||
f'but got {type(dict_cfg)}')):
|
||||
base_postprocessor = BaseTextRecogPostprocessor(dict_cfg)
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
@mock.patch(f'{__name__}.BaseTextRecogPostprocessor.get_single_prediction')
|
||||
def test_call(self, mock_get_single_prediction):
|
||||
|
||||
def mock_func(output, data_sample):
|
||||
return [0, 1, 2], [0.8, 0.7, 0.9]
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
mock_get_single_prediction.side_effect = mock_func
|
||||
pred_text = LabelData(valid_ratio=1.0)
|
||||
data_samples = [TextRecogDataSample(pred_text=pred_text)]
|
||||
postprocessor = BaseTextRecogPostprocessor(
|
||||
max_seq_len=None, dictionary=dict_cfg)
|
||||
|
||||
# test decode output to index
|
||||
dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8]]])
|
||||
data_samples = postprocessor(dummy_output, data_samples)
|
||||
self.assertEqual(data_samples[0].pred_text.item, '012')
|
||||
tmp_dir.cleanup()
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.models.textrecog.dictionary import Dictionary
|
||||
from mmocr.models.textrecog.postprocessor.ctc_postprocessor import \
|
||||
CTCPostProcessor
|
||||
|
||||
|
||||
class TestCTCPostProcessor(TestCase):
|
||||
|
||||
def _create_dummy_dict_file(
|
||||
self, dict_file,
|
||||
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
|
||||
with open(dict_file, 'w') as f:
|
||||
for char in chars:
|
||||
f.write(char + '\n')
|
||||
|
||||
def test_get_single_prediction(self):
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_gen = Dictionary(
|
||||
dict_file=dict_file,
|
||||
with_start=False,
|
||||
with_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=False)
|
||||
pred_text = LabelData(valid_ratio=1.0)
|
||||
data_samples = [TextRecogDataSample(pred_text=pred_text)]
|
||||
postprocessor = CTCPostProcessor(max_seq_len=None, dictionary=dict_gen)
|
||||
|
||||
# test decode output to index
|
||||
dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 100, 4, 5, 6, 7, 8],
|
||||
[1, 2, 100, 4, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 3, 100, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 3, 100, 5, 6, 7, 8]]])
|
||||
index, score = postprocessor.get_single_prediction(
|
||||
dummy_output[0], data_samples[0])
|
||||
self.assertListEqual(index, [1, 0, 2, 0, 3, 0, 3])
|
||||
self.assertListEqual(score,
|
||||
[100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0])
|
||||
postprocessor = CTCPostProcessor(
|
||||
max_seq_len=None, dictionary=dict_gen, ignore_chars=['0'])
|
||||
index, score = postprocessor.get_single_prediction(
|
||||
dummy_output[0], data_samples[0])
|
||||
self.assertListEqual(index, [1, 2, 3, 3])
|
||||
self.assertListEqual(score, [100.0, 100.0, 100.0, 100.0])
|
||||
tmp_dir.cleanup()
|
||||
|
||||
def test_call(self):
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_gen = Dictionary(
|
||||
dict_file=dict_file,
|
||||
with_start=False,
|
||||
with_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=False)
|
||||
pred_text = LabelData(valid_ratio=1.0)
|
||||
data_samples = [TextRecogDataSample(pred_text=pred_text)]
|
||||
postprocessor = CTCPostProcessor(max_seq_len=None, dictionary=dict_gen)
|
||||
|
||||
# test decode output to index
|
||||
dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 100, 4, 5, 6, 7, 8],
|
||||
[1, 2, 100, 4, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 3, 100, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 3, 100, 5, 6, 7, 8]]])
|
||||
data_samples = postprocessor(dummy_output, data_samples)
|
||||
self.assertEqual(data_samples[0].pred_text.item, '1020303')
|
Loading…
Reference in New Issue