diff --git a/mmocr/models/textrecog/__init__.py b/mmocr/models/textrecog/__init__.py index 5a84aa5c..cd96c031 100644 --- a/mmocr/models/textrecog/__init__.py +++ b/mmocr/models/textrecog/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from . import (backbones, convertors, decoders, dictionary, encoders, fusers, - heads, losses, necks, plugins, preprocessor, recognizer) + heads, losses, necks, plugins, postprocessor, preprocessor, + recognizer) from .backbones import * # NOQA from .convertors import * # NOQA from .decoders import * # NOQA @@ -11,6 +12,7 @@ from .heads import * # NOQA from .losses import * # NOQA from .necks import * # NOQA from .plugins import * # NOQA +from .postprocessor import * # NOQA from .preprocessor import * # NOQA from .recognizer import * # NOQA @@ -18,4 +20,4 @@ __all__ = ( backbones.__all__ + convertors.__all__ + decoders.__all__ + encoders.__all__ + heads.__all__ + losses.__all__ + necks.__all__ + preprocessor.__all__ + recognizer.__all__ + fusers.__all__ + - plugins.__all__ + dictionary.__all__) + plugins.__all__ + dictionary.__all__ + postprocessor.__all__) diff --git a/mmocr/models/textrecog/postprocessor/__init__.py b/mmocr/models/textrecog/postprocessor/__init__.py new file mode 100644 index 00000000..8b725ab3 --- /dev/null +++ b/mmocr/models/textrecog/postprocessor/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .attn_postprocessor import AttentionPostprocessor +from .base_textrecog_postprocessor import BaseTextRecogPostprocessor +from .ctc_postprocessor import CTCPostProcessor + +__all__ = [ + 'BaseTextRecogPostprocessor', 'AttentionPostprocessor', 'CTCPostProcessor' +] diff --git a/mmocr/models/textrecog/postprocessor/attn_postprocessor.py b/mmocr/models/textrecog/postprocessor/attn_postprocessor.py new file mode 100644 index 00000000..c5d96b5c --- /dev/null +++ b/mmocr/models/textrecog/postprocessor/attn_postprocessor.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple + +import torch + +from mmocr.core.data_structures import TextRecogDataSample +from mmocr.registry import MODELS +from .base_textrecog_postprocessor import BaseTextRecogPostprocessor + + +@MODELS.register_module() +class AttentionPostprocessor(BaseTextRecogPostprocessor): + """PostProcessor for seq2seq.""" + + def get_single_prediction( + self, + output: torch.Tensor, + data_sample: Optional[TextRecogDataSample] = None, + ) -> Tuple[Sequence[int], Sequence[float]]: + """Convert the output of a single image to index and score. + + Args: + output (torch.Tensor): Single image output. + data_sample (TextRecogDataSample, optional): Datasample of an + image. Defaults to None. + + Returns: + tuple(list[int], list[float]): index and score. + """ + max_value, max_idx = torch.max(output, -1) + index, score = [], [] + output_index = max_idx.cpu().detach().numpy().tolist() + output_score = max_value.cpu().detach().numpy().tolist() + for char_index, char_score in zip(output_index, output_score): + if char_index in self.ignore_indexes: + continue + if char_index == self.dictionary.end_idx: + break + index.append(char_index) + score.append(char_score) + return index, score diff --git a/mmocr/models/textrecog/postprocessor/base_textrecog_postprocessor.py b/mmocr/models/textrecog/postprocessor/base_textrecog_postprocessor.py new file mode 100644 index 00000000..311812e3 --- /dev/null +++ b/mmocr/models/textrecog/postprocessor/base_textrecog_postprocessor.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Sequence, Tuple, Union + +import mmcv +import torch + +from mmocr.core.data_structures import TextRecogDataSample +from mmocr.models.textrecog.dictionary import Dictionary +from mmocr.registry import TASK_UTILS + + +class BaseTextRecogPostprocessor: + """Base text recognition postprocessor. + + Args: + dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or + the instance of `Dictionary`. + max_seq_len (int): max_seq_len (int): Maximum sequence length. The + sequence is usually generated from decoder. Defaults to 40. + ignore_chars (list[str]): A list of characters to be ignored from the + final results. Postprocessor will skip over these characters when + converting raw indexes to characters. Apart from single characters, + each item can be one of the following reversed keywords: 'padding', + 'end' and 'unknown', which refer to their corresponding special + tokens in the dictionary. + """ + + def __init__(self, + dictionary: Union[Dictionary, Dict], + max_seq_len: int = 40, + ignore_chars: Sequence[str] = ['padding'], + **kwargs) -> None: + + if isinstance(dictionary, dict): + self.dictionary = TASK_UTILS.build(dictionary) + elif isinstance(dictionary, Dictionary): + self.dictionary = dictionary + else: + raise TypeError( + 'The type of dictionary should be `Dictionary` or dict, ' + f'but got {type(dictionary)}') + self.max_seq_len = max_seq_len + + mapping_table = { + 'padding': self.dictionary.padding_idx, + 'end': self.dictionary.end_idx, + 'unknown': self.dictionary.unknown_idx, + } + if not mmcv.is_list_of(ignore_chars, str): + raise TypeError('ignore_chars must be list of str') + ignore_indexes = list() + for ignore_char in ignore_chars: + # TODO add char2id in Dictionary + index = mapping_table.get( + ignore_char, self.dictionary._char2idx.get(ignore_char, None)) + if index is None: + raise ValueError(f'{ignore_char} is not exist in dictionary') + ignore_indexes.append(index) + self.ignore_indexes = ignore_indexes + + def get_single_prediction( + self, + output: torch.Tensor, + data_sample: Optional[TextRecogDataSample] = None, + ) -> Tuple[Sequence[int], Sequence[float]]: + """Convert the output of a single image to index and score. + + Args: + output (torch.Tensor): Single image output. + data_sample (TextRecogDataSample): Datasample of an image. + + Returns: + tuple(list[int], list[float]): index and score. + """ + raise NotImplementedError + + def __call__( + self, outputs: torch.Tensor, + data_samples: Sequence[TextRecogDataSample] + ) -> Sequence[TextRecogDataSample]: + """Convert outputs to strings and scores. + + Args: + outputs (torch.Tensor): The model outputs in size: N * T * C + data_samples (list[TextRecogDataSample]): The list of + TextRecogDataSample. + + Returns: + list(TextRecogDataSample): The list of TextRecogDataSample. It + usually contain ``pred_text`` information. + """ + batch_size = outputs.size(0) + + for idx in range(batch_size): + index, score = self.get_single_prediction(outputs[idx, :, :], + data_samples[idx]) + text = self.dictionary.idx2str(index) + data_samples[idx].pred_text.score = score + data_samples[idx].pred_text.item = text + return data_samples diff --git a/mmocr/models/textrecog/postprocessor/ctc_postprocessor.py b/mmocr/models/textrecog/postprocessor/ctc_postprocessor.py new file mode 100644 index 00000000..e68c8d79 --- /dev/null +++ b/mmocr/models/textrecog/postprocessor/ctc_postprocessor.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence, Tuple + +import torch +import torch.nn.functional as F + +from mmocr.core.data_structures import TextRecogDataSample +from mmocr.registry import MODELS +from .base_textrecog_postprocessor import BaseTextRecogPostprocessor + + +# TODO support beam search +@MODELS.register_module() +class CTCPostProcessor(BaseTextRecogPostprocessor): + """PostProcessor for CTC.""" + + def get_single_prediction(self, output: torch.Tensor, + data_sample: TextRecogDataSample + ) -> Tuple[Sequence[int], Sequence[float]]: + """Convert the output of a single image to index and score. + + Args: + output (torch.Tensor): Single image output. + data_sample (TextRecogDataSample): Datasample of an image. + + Returns: + tuple(list[int], list[float]): index and score. + """ + feat_len = output.size(0) + max_value, max_idx = torch.max(output, -1) + valid_ratio = data_sample.get('valid_ratio', 1) + decode_len = min(feat_len, math.ceil(feat_len * valid_ratio)) + index = [] + score = [] + + prev_idx = self.dictionary.padding_idx + for t in range(decode_len): + tmp_value = max_idx[t].item() + if tmp_value not in (prev_idx, *self.ignore_indexes): + index.append(tmp_value) + score.append(max_value[t].item()) + prev_idx = tmp_value + return index, score + + def __call__( + self, outputs: torch.Tensor, + data_samples: Sequence[TextRecogDataSample] + ) -> Sequence[TextRecogDataSample]: + # TODO move to decoder + outputs = F.softmax(outputs, dim=2) + outputs = outputs.cpu().detach() + return super().__call__(outputs, data_samples) diff --git a/tests/test_models/test_textrecog/test_postprocessor/test_attn_postprocessor.py b/tests/test_models/test_textrecog/test_postprocessor/test_attn_postprocessor.py new file mode 100644 index 00000000..71d322f3 --- /dev/null +++ b/tests/test_models/test_textrecog/test_postprocessor/test_attn_postprocessor.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import os.path as osp +import tempfile +from unittest import TestCase + +import torch +from mmengine.data import LabelData + +from mmocr.core.data_structures import TextRecogDataSample +from mmocr.models.textrecog.dictionary import Dictionary +from mmocr.models.textrecog.postprocessor.attn_postprocessor import \ + AttentionPostprocessor + + +class TestAttentionPostprocessor(TestCase): + + def _create_dummy_dict_file( + self, dict_file, + chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA + with open(dict_file, 'w') as f: + for char in chars: + f.write(char + '\n') + + def test_call(self): + tmp_dir = tempfile.TemporaryDirectory() + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + dict_gen = Dictionary( + dict_file=dict_file, + with_start=True, + with_end=True, + same_start_end=True, + with_padding=True, + with_unknown=False) + pred_text = LabelData(valid_ratio=1.0) + data_samples = [TextRecogDataSample(pred_text=pred_text)] + postprocessor = AttentionPostprocessor( + max_seq_len=None, dictionary=dict_gen, ignore_chars=['0']) + dict_gen.end_idx = 3 + # test decode output to index + dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 100, 4, 5, 6, 7, 8], + [1, 2, 100, 4, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 100, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 100, 5, 6, 7, 8]]]) + data_samples = postprocessor(dummy_output, data_samples) + self.assertEqual(data_samples[0].pred_text.item, '122') diff --git a/tests/test_models/test_textrecog/test_postprocessor/test_base_textrecog_postprocessor.py b/tests/test_models/test_textrecog/test_postprocessor/test_base_textrecog_postprocessor.py new file mode 100644 index 00000000..0f09a1aa --- /dev/null +++ b/tests/test_models/test_textrecog/test_postprocessor/test_base_textrecog_postprocessor.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +from unittest import TestCase, mock + +import torch +from mmengine.data import LabelData + +from mmocr.core.data_structures import TextRecogDataSample +from mmocr.models.textrecog.dictionary import Dictionary +from mmocr.models.textrecog.postprocessor.base_textrecog_postprocessor import \ + BaseTextRecogPostprocessor + + +class TestBaseTextRecogPostprocessor(TestCase): + + def _create_dummy_dict_file( + self, dict_file, + chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA + with open(dict_file, 'w') as f: + for char in chars: + f.write(char + '\n') + + def test_init(self): + tmp_dir = tempfile.TemporaryDirectory() + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + # test diction cfg + dict_cfg = dict( + type='Dictionary', + dict_file=dict_file, + with_start=True, + with_end=True, + same_start_end=False, + with_padding=True, + with_unknown=True) + base_postprocessor = BaseTextRecogPostprocessor(dict_cfg) + self.assertIsInstance(base_postprocessor.dictionary, Dictionary) + self.assertListEqual(base_postprocessor.ignore_indexes, + [base_postprocessor.dictionary.padding_idx]) + + base_postprocessor = BaseTextRecogPostprocessor( + dict_cfg, ignore_chars=['1', '2', '3']) + + self.assertListEqual(base_postprocessor.ignore_indexes, [1, 2, 3]) + + # test ignore_chars + with self.assertRaisesRegex(TypeError, + 'ignore_chars must be list of str'): + base_postprocessor = BaseTextRecogPostprocessor( + dict_cfg, ignore_chars=[1, 2, 3]) + with self.assertRaisesRegex(ValueError, + 'M is not exist in dictionary'): + base_postprocessor = BaseTextRecogPostprocessor( + dict_cfg, ignore_chars=['M']) + + base_postprocessor = BaseTextRecogPostprocessor( + dict_cfg, ignore_chars=['1', '2', '3']) + # test dictionary is invalid type + dict_cfg = ['tmp'] + with self.assertRaisesRegex( + TypeError, ('The type of dictionary should be `Dictionary`' + ' or dict, ' + f'but got {type(dict_cfg)}')): + base_postprocessor = BaseTextRecogPostprocessor(dict_cfg) + + tmp_dir.cleanup() + + @mock.patch(f'{__name__}.BaseTextRecogPostprocessor.get_single_prediction') + def test_call(self, mock_get_single_prediction): + + def mock_func(output, data_sample): + return [0, 1, 2], [0.8, 0.7, 0.9] + + tmp_dir = tempfile.TemporaryDirectory() + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + dict_cfg = dict( + type='Dictionary', + dict_file=dict_file, + with_start=True, + with_end=True, + same_start_end=False, + with_padding=True, + with_unknown=True) + mock_get_single_prediction.side_effect = mock_func + pred_text = LabelData(valid_ratio=1.0) + data_samples = [TextRecogDataSample(pred_text=pred_text)] + postprocessor = BaseTextRecogPostprocessor( + max_seq_len=None, dictionary=dict_cfg) + + # test decode output to index + dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8]]]) + data_samples = postprocessor(dummy_output, data_samples) + self.assertEqual(data_samples[0].pred_text.item, '012') + tmp_dir.cleanup() diff --git a/tests/test_models/test_textrecog/test_postprocessor/test_ctc_postprocessor.py b/tests/test_models/test_textrecog/test_postprocessor/test_ctc_postprocessor.py new file mode 100644 index 00000000..f6e6e765 --- /dev/null +++ b/tests/test_models/test_textrecog/test_postprocessor/test_ctc_postprocessor.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import os.path as osp +import tempfile +from unittest import TestCase + +import torch +from mmengine.data import LabelData + +from mmocr.core.data_structures import TextRecogDataSample +from mmocr.models.textrecog.dictionary import Dictionary +from mmocr.models.textrecog.postprocessor.ctc_postprocessor import \ + CTCPostProcessor + + +class TestCTCPostProcessor(TestCase): + + def _create_dummy_dict_file( + self, dict_file, + chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA + with open(dict_file, 'w') as f: + for char in chars: + f.write(char + '\n') + + def test_get_single_prediction(self): + + tmp_dir = tempfile.TemporaryDirectory() + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + dict_gen = Dictionary( + dict_file=dict_file, + with_start=False, + with_end=False, + with_padding=True, + with_unknown=False) + pred_text = LabelData(valid_ratio=1.0) + data_samples = [TextRecogDataSample(pred_text=pred_text)] + postprocessor = CTCPostProcessor(max_seq_len=None, dictionary=dict_gen) + + # test decode output to index + dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 100, 4, 5, 6, 7, 8], + [1, 2, 100, 4, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 100, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 100, 5, 6, 7, 8]]]) + index, score = postprocessor.get_single_prediction( + dummy_output[0], data_samples[0]) + self.assertListEqual(index, [1, 0, 2, 0, 3, 0, 3]) + self.assertListEqual(score, + [100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0]) + postprocessor = CTCPostProcessor( + max_seq_len=None, dictionary=dict_gen, ignore_chars=['0']) + index, score = postprocessor.get_single_prediction( + dummy_output[0], data_samples[0]) + self.assertListEqual(index, [1, 2, 3, 3]) + self.assertListEqual(score, [100.0, 100.0, 100.0, 100.0]) + tmp_dir.cleanup() + + def test_call(self): + tmp_dir = tempfile.TemporaryDirectory() + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + dict_gen = Dictionary( + dict_file=dict_file, + with_start=False, + with_end=False, + with_padding=True, + with_unknown=False) + pred_text = LabelData(valid_ratio=1.0) + data_samples = [TextRecogDataSample(pred_text=pred_text)] + postprocessor = CTCPostProcessor(max_seq_len=None, dictionary=dict_gen) + + # test decode output to index + dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 100, 4, 5, 6, 7, 8], + [1, 2, 100, 4, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 100, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 100, 5, 6, 7, 8]]]) + data_samples = postprocessor(dummy_output, data_samples) + self.assertEqual(data_samples[0].pred_text.item, '1020303')