diff --git a/mmocr/models/textrecog/dictionary/dictionary.py b/mmocr/models/textrecog/dictionary/dictionary.py index cb3916a0..d16dc875 100644 --- a/mmocr/models/textrecog/dictionary/dictionary.py +++ b/mmocr/models/textrecog/dictionary/dictionary.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -import re -from typing import List, Optional, Sequence +from typing import List, Sequence from mmocr.registry import TASK_UTILS from mmocr.utils import list_from_file @@ -50,8 +49,7 @@ class Dictionary: end_token: str = '', start_end_token: str = '', padding_token: str = '', - unknown_token: Optional[str] = '', - **kwargs) -> None: + unknown_token: str = '') -> None: self.with_start = with_start self.with_end = with_end self.same_start_end = same_start_end @@ -75,10 +73,8 @@ class Dictionary: self._dict.append(line) self._char2idx = {char: idx for idx, char in enumerate(self._dict)} - self._contain_uppercase = len(re.findall('[A-Z]', ''.join( - self.dict))) > 0 - self._update_dict(**kwargs) + self._update_dict() assert len(set(self._dict)) == len(self._dict), \ 'Invalid dictionary: Has duplicated characters.' @@ -90,16 +86,10 @@ class Dictionary: @property def dict(self) -> list: - """list: The list of all character to recognize, which Special tokens - are counted.""" + """list: Returns a list of characters to recognize, where special + tokens are counted.""" return self._dict - @property - def contain_uppercase(self) -> bool: - """bool: Whether all the English characters in dict file are in lowercase. - """ - return self._contain_uppercase - def char2idx(self, char: str, strict: bool = True) -> int: """Convert a character to an index via ``Dictionary.dict``. @@ -158,10 +148,12 @@ class Dictionary: assert isinstance(index, (list, tuple)) string = '' for i in index: + assert i < len(self._dict), f'Index: {i} out of range! Index ' \ + f'must be less than {len(self._dict)}' string += self._dict[i] return string - def _update_dict(self, **kwargs): + def _update_dict(self): """Update the dict with tokens according to parameters.""" # BOS/EOS self.start_idx = None diff --git a/tests/test_models/test_kie/test_heads/test_sdmgr_head.py b/tests/test_models/test_kie/test_heads/test_sdmgr_head.py index 7f83d57d..6c1efc3d 100644 --- a/tests/test_models/test_kie/test_heads/test_sdmgr_head.py +++ b/tests/test_models/test_kie/test_heads/test_sdmgr_head.py @@ -40,6 +40,7 @@ class TestSDMGRHead(TestCase): self.assertEqual(edge_cls.shape, torch.Size([4, 2])) # When input image is None + del (dict_cfg['type']) head = SDMGRHead(dictionary=Dictionary(**dict_cfg)) node_cls, edge_cls = head(None, [data_sample]) self.assertEqual(node_cls.shape, torch.Size([2, 26])) diff --git a/tests/test_models/test_textrecog/test_dictionary/test_dictionary.py b/tests/test_models/test_textrecog/test_dictionary/test_dictionary.py index 03827d20..df3a022d 100644 --- a/tests/test_models/test_textrecog/test_dictionary/test_dictionary.py +++ b/tests/test_models/test_textrecog/test_dictionary/test_dictionary.py @@ -89,17 +89,6 @@ class TestDictionary(TestCase): dict_gen = Dictionary(dict_file=dict_file) assert dict_gen.num_classes == 36 - def test_contain_uppercase(self): - with tempfile.TemporaryDirectory() as tmp_dir: - # create dummy data - dict_file = osp.join(tmp_dir, 'fake_chars.txt') - create_dummy_dict_file(dict_file) - dict_gen = Dictionary(dict_file=dict_file) - assert dict_gen.contain_uppercase is False - create_dummy_dict_file(dict_file, chars='abcdABCD') - dict_gen = Dictionary(dict_file=dict_file) - assert dict_gen.contain_uppercase is True - def test_char2idx(self): with tempfile.TemporaryDirectory() as tmp_dir: @@ -150,3 +139,5 @@ class TestDictionary(TestCase): self.assertEqual(dict_gen.idx2str([0, 1, 2, 3, 4]), '01234') with self.assertRaises(AssertionError): dict_gen.idx2str('01234') + with self.assertRaises(AssertionError): + dict_gen.idx2str([40]) diff --git a/tests/test_models/test_textrecog/test_postprocessors/test_base_textrecog_postprocessor.py b/tests/test_models/test_textrecog/test_postprocessors/test_base_textrecog_postprocessor.py index 0426f0a2..a4918b46 100644 --- a/tests/test_models/test_textrecog/test_postprocessors/test_base_textrecog_postprocessor.py +++ b/tests/test_models/test_textrecog/test_postprocessors/test_base_textrecog_postprocessor.py @@ -87,7 +87,7 @@ class TestBaseTextRecogPostprocessor(TestCase): type='Dictionary', dict_file=dict_file, with_unknown=True, - unkown_token=None), + unknown_token=None), ignore_chars=['M']) with self.assertWarnsRegex(Warning,