From c920edfb3a97c45fd3548b60e7708d8445eed1a0 Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Thu, 12 May 2022 12:01:03 +0000 Subject: [PATCH] [Refactor] split labelconverter to Dictionary --- mmocr/models/textrecog/__init__.py | 7 +- mmocr/models/textrecog/dictionary/__init__.py | 5 + .../models/textrecog/dictionary/dictionary.py | 169 ++++++++++++++++++ .../test_dictionary/test_dictionary.py | 132 ++++++++++++++ 4 files changed, 310 insertions(+), 3 deletions(-) create mode 100644 mmocr/models/textrecog/dictionary/__init__.py create mode 100644 mmocr/models/textrecog/dictionary/dictionary.py create mode 100644 tests/test_models/test_dictionary/test_dictionary.py diff --git a/mmocr/models/textrecog/__init__.py b/mmocr/models/textrecog/__init__.py index 9a813067..5a84aa5c 100644 --- a/mmocr/models/textrecog/__init__.py +++ b/mmocr/models/textrecog/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import (backbones, convertors, decoders, encoders, fusers, heads, - losses, necks, plugins, preprocessor, recognizer) +from . import (backbones, convertors, decoders, dictionary, encoders, fusers, + heads, losses, necks, plugins, preprocessor, recognizer) from .backbones import * # NOQA from .convertors import * # NOQA from .decoders import * # NOQA +from .dictionary import * # NOQA from .encoders import * # NOQA from .fusers import * # NOQA from .heads import * # NOQA @@ -17,4 +18,4 @@ __all__ = ( backbones.__all__ + convertors.__all__ + decoders.__all__ + encoders.__all__ + heads.__all__ + losses.__all__ + necks.__all__ + preprocessor.__all__ + recognizer.__all__ + fusers.__all__ + - plugins.__all__) + plugins.__all__ + dictionary.__all__) diff --git a/mmocr/models/textrecog/dictionary/__init__.py b/mmocr/models/textrecog/dictionary/__init__.py new file mode 100644 index 00000000..9ad0ab30 --- /dev/null +++ b/mmocr/models/textrecog/dictionary/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .dictionary import Dictionary + +__all__ = ['Dictionary'] diff --git a/mmocr/models/textrecog/dictionary/dictionary.py b/mmocr/models/textrecog/dictionary/dictionary.py new file mode 100644 index 00000000..9b642296 --- /dev/null +++ b/mmocr/models/textrecog/dictionary/dictionary.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import List, Sequence + +from mmocr.registry import TASK_UTILS +from mmocr.utils import list_from_file + + +@TASK_UTILS.register_module() +class Dictionary: + """The class generates a dictionary for recognition. It pre-defines four + special tokens: ``start_token``, ``end_token``, ``pad_token``, and + ``unknown_token``, which will be sequentially placed at the end of the + dictionary when their corresponding flags are True. + + Args: + dict_file (str): The path of Character dict file which a single + character must occupies a line. + with_start (bool): The flag to control whether to include the start + token. Defaults to False. + with_end (bool): The flag to control whether to include the end token. + Defaults to False. + same_start_end (bool): The flag to control whether the start token and + end token are the same. It only works when both ``with_start`` and + ``with_end`` are True. Defaults to False. + with_padding (bool):The padding token may represent more than a + padding. It can also represent tokens like the blank token in CTC + or the background token in SegOCR. Defaults to False. + with_unknown (bool): The flag to control whether to include the + unknown token. Defaults to False. + start_token (str): The start token as a string. Defaults to ''. + end_token (str): The end token as a string. Defaults to ''. + start_end_token (str): The start/end token as a string. if start and + end is the same. Defaults to ''. + padding_token (str): The padding token as a string. + Defaults to ''. + unknown_token (str): The unknown token as a string. + Defaults to ''. + """ + + def __init__(self, + dict_file: str, + with_start: bool = False, + with_end: bool = False, + same_start_end: bool = False, + with_padding: bool = False, + with_unknown: bool = False, + start_token: str = '', + end_token: str = '', + start_end_token: str = '', + padding_token: str = '', + unknown_token: str = '', + **kwargs) -> None: + self.with_start = with_start + self.with_end = with_end + self.same_start_end = same_start_end + self.with_padding = with_padding + self.with_unknown = with_unknown + self.start_end_token = start_end_token + self.start_token = start_token + self.end_token = end_token + self.padding_token = padding_token + self.unknown_token = unknown_token + + assert isinstance(dict_file, str) + self._dict = [] + for line_num, line in enumerate(list_from_file(dict_file)): + line = line.strip('\r\n') + if len(line) > 1: + raise ValueError('Expect each line has 0 or 1 character, ' + f'got {len(line)} characters ' + f'at line {line_num + 1}') + if line != '': + self._dict.append(line) + + self._char2idx = {char: idx for idx, char in enumerate(self._dict)} + self._contain_uppercase = len(re.findall('[A-Z]', ''.join( + self.dict))) > 0 + + self._update_dict(**kwargs) + assert len(set(self._dict)) == len(self._dict), \ + 'Invalid dictionary: Has duplicated characters.' + + @property + def num_classes(self) -> int: + """int: Number of output classes. Special tokens are counted. + """ + return len(self._dict) + + @property + def dict(self) -> list: + """list: The list of all character to recognize, which Special tokens + are counted.""" + return self._dict + + @property + def contain_uppercase(self) -> bool: + """bool: Whether all the English characters in dict file are in lowercase. + """ + return self._contain_uppercase + + def str2idx(self, string: str) -> List: + """Convert a string to a list of indexes via ``Dictionary.dict`` + + Args: + string (str): The string to convert to indexes. + + Return: + list: The list of indexes of the string. + """ + idx = list() + for s in string: + char_idx = self._char2idx.get(s, self.unknown_idx) + if char_idx is None: + raise Exception(f'Chararcter: {s} not in dict,' + f' please check gt_label and use' + f' custom dict file,' + f' or set "with_unknown=True"') + idx.append(char_idx) + return idx + + def idx2str(self, index: Sequence[int]) -> str: + """Convert a list of index to string. + + Args: + index (list[int]): The list of indexes to convert to string. + + Return: + str: The converted string. + """ + assert isinstance(index, (list, tuple)) + string = '' + for i in index: + string += self._dict[i] + return string + + def _update_dict(self, **kwargs): + """Update the dict with tokens according to parameters.""" + # BOS/EOS + self.start_idx = None + self.end_idx = None + if self.with_start and self.with_end and self.same_start_end: + self._dict.append(self.start_end_token) + self.start_idx = len(self._dict) - 1 + self.end_idx = self.start_idx + else: + if self.with_start: + self._dict.append(self.start_token) + self.start_idx = len(self._dict) - 1 + if self.with_end: + self._dict.append(self.end_token) + self.end_idx = len(self._dict) - 1 + + # padding + self.padding_idx = None + if self.with_padding: + self._dict.append(self.padding_token) + self.padding_idx = len(self._dict) - 1 + + # unknown + self.unknown_idx = None + if self.with_unknown: + self._dict.append(self.unknown_token) + self.unknown_idx = len(self._dict) - 1 + + # update char2idx + self._char2idx = {} + for idx, char in enumerate(self._dict): + self._char2idx[char] = idx diff --git a/tests/test_models/test_dictionary/test_dictionary.py b/tests/test_models/test_dictionary/test_dictionary.py new file mode 100644 index 00000000..937bab02 --- /dev/null +++ b/tests/test_models/test_dictionary/test_dictionary.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +from unittest import TestCase + +from mmocr.models.textrecog import Dictionary + + +class TestDictionary(TestCase): + + def _create_dummy_dict_file( + self, dict_file, + chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA + with open(dict_file, 'w') as f: + for char in chars: + f.write(char + '\n') + + def test_init(self): + tmp_dir = tempfile.TemporaryDirectory() + + # create dummy data + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + # with start + dict_gen = Dictionary( + dict_file=dict_file, + with_start=True, + with_end=True, + same_start_end=False, + with_padding=True, + with_unknown=True) + self.assertEqual(dict_gen.num_classes, 40) + self.assertListEqual( + dict_gen.dict, + list('0123456789abcdefghijklmnopqrstuvwxyz') + + ['', '', '', '']) + dict_gen = Dictionary( + dict_file=dict_file, + with_start=True, + with_end=True, + same_start_end=True, + with_padding=True, + with_unknown=True) + assert dict_gen.num_classes == 39 + assert dict_gen.dict == list('0123456789abcdefghijklmnopqrstuvwxyz' + ) + ['', '', ''] + self.assertEqual(dict_gen.num_classes, 39) + self.assertListEqual( + dict_gen.dict, + list('0123456789abcdefghijklmnopqrstuvwxyz') + + ['', '', '']) + dict_gen = Dictionary( + dict_file=dict_file, + with_start=True, + with_end=True, + same_start_end=False, + with_padding=True, + with_unknown=True, + start_token='', + end_token='', + padding_token='', + unknown_token='') + assert dict_gen.num_classes == 40 + assert dict_gen.dict[-4:] == ['', '', '', ''] + self.assertEqual(dict_gen.num_classes, 40) + self.assertListEqual(dict_gen.dict[-4:], + ['', '', '', '']) + dict_gen = Dictionary( + dict_file=dict_file, + with_start=True, + with_end=True, + same_start_end=True, + with_padding=True, + with_unknown=True, + start_end_token='') + self.assertEqual(dict_gen.num_classes, 39) + self.assertListEqual(dict_gen.dict[-3:], ['', '', '']) + # test len(line) > 1 + self._create_dummy_dict_file(dict_file, chars=['12', '3', '4']) + with self.assertRaises(ValueError): + dict_gen = Dictionary(dict_file=dict_file) + + # test duplicated dict + self._create_dummy_dict_file(dict_file, chars=['1', '1', '2']) + with self.assertRaises(AssertionError): + dict_gen = Dictionary(dict_file=dict_file) + + tmp_dir.cleanup() + + def test_num_classes(self): + with tempfile.TemporaryDirectory() as tmp_dir: + # create dummy data + dict_file = osp.join(tmp_dir, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + dict_gen = Dictionary(dict_file=dict_file) + assert dict_gen.num_classes == 36 + + def test_contain_uppercase(self): + with tempfile.TemporaryDirectory() as tmp_dir: + # create dummy data + dict_file = osp.join(tmp_dir, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + dict_gen = Dictionary(dict_file=dict_file) + assert dict_gen.contain_uppercase is False + self._create_dummy_dict_file(dict_file, chars='abcdABCD') + dict_gen = Dictionary(dict_file=dict_file) + assert dict_gen.contain_uppercase is True + + def test_str2idx(self): + with tempfile.TemporaryDirectory() as tmp_dir: + + # create dummy data + dict_file = osp.join(tmp_dir, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + dict_gen = Dictionary(dict_file=dict_file) + self.assertEqual(dict_gen.str2idx('01234'), [0, 1, 2, 3, 4]) + with self.assertRaises(Exception): + dict_gen.str2idx('H') + + dict_gen = Dictionary(dict_file=dict_file, with_unknown=True) + self.assertListEqual(dict_gen.str2idx('H'), [dict_gen.unknown_idx]) + + def test_idx2str(self): + with tempfile.TemporaryDirectory() as tmp_dir: + + # create dummy data + dict_file = osp.join(tmp_dir, 'fake_chars.txt') + self._create_dummy_dict_file(dict_file) + dict_gen = Dictionary(dict_file=dict_file) + self.assertEqual(dict_gen.idx2str([0, 1, 2, 3, 4]), '01234') + with self.assertRaises(AssertionError): + dict_gen.idx2str('01234')