diff --git a/mmocr/models/textrecog/dictionary/dictionary.py b/mmocr/models/textrecog/dictionary/dictionary.py index 745138b5..cb3916a0 100644 --- a/mmocr/models/textrecog/dictionary/dictionary.py +++ b/mmocr/models/textrecog/dictionary/dictionary.py @@ -100,6 +100,30 @@ class Dictionary: """ return self._contain_uppercase + def char2idx(self, char: str, strict: bool = True) -> int: + """Convert a character to an index via ``Dictionary.dict``. + + Args: + char (str): The character to convert to index. + strict (bool): The flag to control whether to raise an exception + when the character is not in the dictionary. Defaults to True. + + Return: + int: The index of the character. + """ + char_idx = self._char2idx.get(char, None) + if char_idx is None: + if self.with_unknown: + return self.unknown_idx + elif not strict: + return None + else: + raise Exception(f'Chararcter: {char} not in dict,' + ' please check gt_label and use' + ' custom dict file,' + ' or set "with_unknown=True"') + return char_idx + def str2idx(self, string: str) -> List: """Convert a string to a list of indexes via ``Dictionary.dict``. @@ -111,7 +135,7 @@ class Dictionary: """ idx = list() for s in string: - char_idx = self._char2idx.get(s, self.unknown_idx) + char_idx = self.char2idx(s) if char_idx is None: if self.with_unknown: continue diff --git a/mmocr/models/textrecog/module_losses/ce_module_loss.py b/mmocr/models/textrecog/module_losses/ce_module_loss.py index 18ae6ded..12c2f0a7 100644 --- a/mmocr/models/textrecog/module_losses/ce_module_loss.py +++ b/mmocr/models/textrecog/module_losses/ce_module_loss.py @@ -28,6 +28,14 @@ class CEModuleLoss(BaseRecogModuleLoss): - lower: Convert gt texts into lowercase characters. Usually, it only works for English characters. Defaults to 'unchanged'. + pad_with (str): The padding strategy for ``gt_text.padded_indexes``. + Defaults to 'auto'. Options are: + - 'auto': Use dictionary.padding_idx to pad gt texts, or + dictionary.end_idx if dictionary.padding_idx + is None. + - 'padding': Always use dictionary.padding_idx to pad gt texts. + - 'end': Always use dictionary.end_idx to pad gt texts. + - 'none': Do not pad gt texts. ignore_char (int or str): Specifies a target value that is ignored and does not contribute to the input gradient. ignore_char can be int or str. If int, it is the index of @@ -54,6 +62,7 @@ class CEModuleLoss(BaseRecogModuleLoss): dictionary: Union[Dict, Dictionary], max_seq_len: int = 40, letter_case: str = 'unchanged', + pad_with: str = 'auto', ignore_char: Union[int, str] = 'padding', flatten: bool = False, reduction: str = 'none', @@ -61,7 +70,8 @@ class CEModuleLoss(BaseRecogModuleLoss): super().__init__( dictionary=dictionary, max_seq_len=max_seq_len, - letter_case=letter_case) + letter_case=letter_case, + pad_with=pad_with) assert isinstance(ignore_char, (int, str)) assert isinstance(reduction, str) assert reduction in ['none', 'mean', 'sum'] @@ -81,10 +91,13 @@ class CEModuleLoss(BaseRecogModuleLoss): 'end': self.dictionary.end_idx, 'unknown': self.dictionary.unknown_idx, } - # TODO add char2id in Dictionary + ignore_index = mapping_table.get( - ignore_char, self.dictionary._char2idx.get(ignore_char, None)) - if ignore_index is None: + ignore_char, + self.dictionary.char2idx(ignore_char, strict=False)) + if ignore_index is None or (ignore_index + == self.dictionary.unknown_idx + and ignore_char != 'unknown'): warnings.warn( f'{ignore_char} does not exist in the dictionary', UserWarning) diff --git a/mmocr/models/textrecog/postprocessors/base_textrecog_postprocessor.py b/mmocr/models/textrecog/postprocessors/base_textrecog_postprocessor.py index a3a10353..ba4ee53a 100644 --- a/mmocr/models/textrecog/postprocessors/base_textrecog_postprocessor.py +++ b/mmocr/models/textrecog/postprocessors/base_textrecog_postprocessor.py @@ -52,10 +52,11 @@ class BaseTextRecogPostprocessor: raise TypeError('ignore_chars must be list of str') ignore_indexes = list() for ignore_char in ignore_chars: - # TODO add char2id in Dictionary index = mapping_table.get( - ignore_char, self.dictionary._char2idx.get(ignore_char, None)) - if index is None: + ignore_char, + self.dictionary.char2idx(ignore_char, strict=False)) + if index is None or (index == self.dictionary.unknown_idx + and ignore_char != 'unknown'): warnings.warn( f'{ignore_char} does not exist in the dictionary', UserWarning) diff --git a/tests/test_models/test_textrecog/test_dictionary/test_dictionary.py b/tests/test_models/test_textrecog/test_dictionary/test_dictionary.py index 4628d495..03827d20 100644 --- a/tests/test_models/test_textrecog/test_dictionary/test_dictionary.py +++ b/tests/test_models/test_textrecog/test_dictionary/test_dictionary.py @@ -100,6 +100,28 @@ class TestDictionary(TestCase): dict_gen = Dictionary(dict_file=dict_file) assert dict_gen.contain_uppercase is True + def test_char2idx(self): + with tempfile.TemporaryDirectory() as tmp_dir: + + # create dummy data + dict_file = osp.join(tmp_dir, 'fake_chars.txt') + create_dummy_dict_file(dict_file) + dict_gen = Dictionary(dict_file=dict_file, with_unknown=False) + self.assertEqual(dict_gen.char2idx('0'), 0) + + dict_gen = Dictionary(dict_file=dict_file, with_unknown=True) + self.assertEqual(dict_gen.char2idx('H'), dict_gen.unknown_idx) + + dict_gen = Dictionary( + dict_file=dict_file, with_unknown=True, unknown_token=None) + self.assertEqual(dict_gen.char2idx('H'), None) + + # Test strict + dict_gen = Dictionary(dict_file=dict_file, with_unknown=False) + with self.assertRaises(Exception): + dict_gen.char2idx('H', strict=True) + self.assertEqual(dict_gen.char2idx('H', strict=False), None) + def test_str2idx(self): with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/tests/test_models/test_textrecog/test_module_losses/test_ce_module_loss.py b/tests/test_models/test_textrecog/test_module_losses/test_ce_module_loss.py index 99d3f1e3..183131d7 100644 --- a/tests/test_models/test_textrecog/test_module_losses/test_ce_module_loss.py +++ b/tests/test_models/test_textrecog/test_module_losses/test_ce_module_loss.py @@ -21,9 +21,10 @@ class TestCEModuleLoss(TestCase): self.gt = [data_sample1, data_sample2, data_sample3] def test_init(self): + dict_file = 'dicts/lower_english_digits.txt' dict_cfg = dict( type='Dictionary', - dict_file='dicts/lower_english_digits.txt', + dict_file=dict_file, with_start=True, with_end=True, same_start_end=True, @@ -47,6 +48,26 @@ class TestCEModuleLoss(TestCase): # with self.assertRaises(ValueError): with self.assertWarns(UserWarning): ce_loss = CEModuleLoss(dict_cfg, ignore_char='ignore') + with self.assertWarns(UserWarning): + ce_loss = CEModuleLoss( + dict( + type='Dictionary', dict_file=dict_file, with_unknown=True), + ignore_char='M', + pad_with='none') + with self.assertWarns(UserWarning): + ce_loss = CEModuleLoss( + dict( + type='Dictionary', dict_file=dict_file, + with_unknown=False), + ignore_char='M', + pad_with='none') + with self.assertWarns(UserWarning): + ce_loss = CEModuleLoss( + dict( + type='Dictionary', dict_file=dict_file, + with_unknown=False), + ignore_char='unknown', + pad_with='none') ce_loss = CEModuleLoss(dict_cfg, ignore_char='1') self.assertEqual(ce_loss.ignore_index, 1) diff --git a/tests/test_models/test_textrecog/test_postprocessors/test_base_textrecog_postprocessor.py b/tests/test_models/test_textrecog/test_postprocessors/test_base_textrecog_postprocessor.py index 352c6bd6..aad78dd5 100644 --- a/tests/test_models/test_textrecog/test_postprocessors/test_base_textrecog_postprocessor.py +++ b/tests/test_models/test_textrecog/test_postprocessors/test_base_textrecog_postprocessor.py @@ -46,6 +46,65 @@ class TestBaseTextRecogPostprocessor(TestCase): base_postprocessor = BaseTextRecogPostprocessor( dict_cfg, ignore_chars=['M']) + base_postprocessor = BaseTextRecogPostprocessor( + dict_cfg, ignore_chars=['1', '2', '3']) + # test dictionary is invalid type + dict_cfg = ['tmp'] + with self.assertRaisesRegex( + TypeError, ('The type of dictionary should be `Dictionary`' + ' or dict, ' + f'but got {type(dict_cfg)}')): + base_postprocessor = BaseTextRecogPostprocessor(dict_cfg) + # test diction cfg with with_unknown=False + dict_cfg = dict( + type='Dictionary', + dict_file=dict_file, + with_start=True, + with_end=True, + same_start_end=False, + with_padding=True, + with_unknown=False) + base_postprocessor = BaseTextRecogPostprocessor( + dict_cfg, ignore_chars=['1', '2', '3']) + + self.assertListEqual(base_postprocessor.ignore_indexes, [1, 2, 3]) + + # test ignore_chars + with self.assertRaisesRegex(TypeError, + 'ignore_chars must be list of str'): + base_postprocessor = BaseTextRecogPostprocessor( + dict_cfg, ignore_chars=[1, 2, 3]) + + with self.assertWarnsRegex(Warning, + 'M does not exist in the dictionary'): + base_postprocessor = BaseTextRecogPostprocessor( + dict_cfg, ignore_chars=['M']) + + with self.assertWarnsRegex(Warning, + 'M does not exist in the dictionary'): + base_postprocessor = BaseTextRecogPostprocessor( + dict( + type='Dictionary', + dict_file=dict_file, + with_unknown=True, + unkown_token=None), + ignore_chars=['M']) + + with self.assertWarnsRegex(Warning, + 'M does not exist in the dictionary'): + base_postprocessor = BaseTextRecogPostprocessor( + dict( + type='Dictionary', dict_file=dict_file, with_unknown=True), + ignore_chars=['M']) + + with self.assertWarnsRegex(Warning, + 'unknown does not exist in the dictionary'): + base_postprocessor = BaseTextRecogPostprocessor( + dict( + type='Dictionary', dict_file=dict_file, + with_unknown=False), + ignore_chars=['unknown']) + base_postprocessor = BaseTextRecogPostprocessor( dict_cfg, ignore_chars=['1', '2', '3']) # test dictionary is invalid type