[TODO] Add char2idx

This commit is contained in:
jiangqing.vendor 2022-07-14 09:19:20 +00:00 committed by gaotongxiao
parent dc84187311
commit 3734527d38
6 changed files with 149 additions and 9 deletions

View File

@ -100,6 +100,30 @@ class Dictionary:
"""
return self._contain_uppercase
def char2idx(self, char: str, strict: bool = True) -> int:
"""Convert a character to an index via ``Dictionary.dict``.
Args:
char (str): The character to convert to index.
strict (bool): The flag to control whether to raise an exception
when the character is not in the dictionary. Defaults to True.
Return:
int: The index of the character.
"""
char_idx = self._char2idx.get(char, None)
if char_idx is None:
if self.with_unknown:
return self.unknown_idx
elif not strict:
return None
else:
raise Exception(f'Chararcter: {char} not in dict,'
' please check gt_label and use'
' custom dict file,'
' or set "with_unknown=True"')
return char_idx
def str2idx(self, string: str) -> List:
"""Convert a string to a list of indexes via ``Dictionary.dict``.
@ -111,7 +135,7 @@ class Dictionary:
"""
idx = list()
for s in string:
char_idx = self._char2idx.get(s, self.unknown_idx)
char_idx = self.char2idx(s)
if char_idx is None:
if self.with_unknown:
continue

View File

@ -28,6 +28,14 @@ class CEModuleLoss(BaseRecogModuleLoss):
- lower: Convert gt texts into lowercase characters.
Usually, it only works for English characters. Defaults to
'unchanged'.
pad_with (str): The padding strategy for ``gt_text.padded_indexes``.
Defaults to 'auto'. Options are:
- 'auto': Use dictionary.padding_idx to pad gt texts, or
dictionary.end_idx if dictionary.padding_idx
is None.
- 'padding': Always use dictionary.padding_idx to pad gt texts.
- 'end': Always use dictionary.end_idx to pad gt texts.
- 'none': Do not pad gt texts.
ignore_char (int or str): Specifies a target value that is
ignored and does not contribute to the input gradient.
ignore_char can be int or str. If int, it is the index of
@ -54,6 +62,7 @@ class CEModuleLoss(BaseRecogModuleLoss):
dictionary: Union[Dict, Dictionary],
max_seq_len: int = 40,
letter_case: str = 'unchanged',
pad_with: str = 'auto',
ignore_char: Union[int, str] = 'padding',
flatten: bool = False,
reduction: str = 'none',
@ -61,7 +70,8 @@ class CEModuleLoss(BaseRecogModuleLoss):
super().__init__(
dictionary=dictionary,
max_seq_len=max_seq_len,
letter_case=letter_case)
letter_case=letter_case,
pad_with=pad_with)
assert isinstance(ignore_char, (int, str))
assert isinstance(reduction, str)
assert reduction in ['none', 'mean', 'sum']
@ -81,10 +91,13 @@ class CEModuleLoss(BaseRecogModuleLoss):
'end': self.dictionary.end_idx,
'unknown': self.dictionary.unknown_idx,
}
# TODO add char2id in Dictionary
ignore_index = mapping_table.get(
ignore_char, self.dictionary._char2idx.get(ignore_char, None))
if ignore_index is None:
ignore_char,
self.dictionary.char2idx(ignore_char, strict=False))
if ignore_index is None or (ignore_index
== self.dictionary.unknown_idx
and ignore_char != 'unknown'):
warnings.warn(
f'{ignore_char} does not exist in the dictionary',
UserWarning)

View File

@ -52,10 +52,11 @@ class BaseTextRecogPostprocessor:
raise TypeError('ignore_chars must be list of str')
ignore_indexes = list()
for ignore_char in ignore_chars:
# TODO add char2id in Dictionary
index = mapping_table.get(
ignore_char, self.dictionary._char2idx.get(ignore_char, None))
if index is None:
ignore_char,
self.dictionary.char2idx(ignore_char, strict=False))
if index is None or (index == self.dictionary.unknown_idx
and ignore_char != 'unknown'):
warnings.warn(
f'{ignore_char} does not exist in the dictionary',
UserWarning)

View File

@ -100,6 +100,28 @@ class TestDictionary(TestCase):
dict_gen = Dictionary(dict_file=dict_file)
assert dict_gen.contain_uppercase is True
def test_char2idx(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# create dummy data
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
create_dummy_dict_file(dict_file)
dict_gen = Dictionary(dict_file=dict_file, with_unknown=False)
self.assertEqual(dict_gen.char2idx('0'), 0)
dict_gen = Dictionary(dict_file=dict_file, with_unknown=True)
self.assertEqual(dict_gen.char2idx('H'), dict_gen.unknown_idx)
dict_gen = Dictionary(
dict_file=dict_file, with_unknown=True, unknown_token=None)
self.assertEqual(dict_gen.char2idx('H'), None)
# Test strict
dict_gen = Dictionary(dict_file=dict_file, with_unknown=False)
with self.assertRaises(Exception):
dict_gen.char2idx('H', strict=True)
self.assertEqual(dict_gen.char2idx('H', strict=False), None)
def test_str2idx(self):
with tempfile.TemporaryDirectory() as tmp_dir:

View File

@ -21,9 +21,10 @@ class TestCEModuleLoss(TestCase):
self.gt = [data_sample1, data_sample2, data_sample3]
def test_init(self):
dict_file = 'dicts/lower_english_digits.txt'
dict_cfg = dict(
type='Dictionary',
dict_file='dicts/lower_english_digits.txt',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
@ -47,6 +48,26 @@ class TestCEModuleLoss(TestCase):
# with self.assertRaises(ValueError):
with self.assertWarns(UserWarning):
ce_loss = CEModuleLoss(dict_cfg, ignore_char='ignore')
with self.assertWarns(UserWarning):
ce_loss = CEModuleLoss(
dict(
type='Dictionary', dict_file=dict_file, with_unknown=True),
ignore_char='M',
pad_with='none')
with self.assertWarns(UserWarning):
ce_loss = CEModuleLoss(
dict(
type='Dictionary', dict_file=dict_file,
with_unknown=False),
ignore_char='M',
pad_with='none')
with self.assertWarns(UserWarning):
ce_loss = CEModuleLoss(
dict(
type='Dictionary', dict_file=dict_file,
with_unknown=False),
ignore_char='unknown',
pad_with='none')
ce_loss = CEModuleLoss(dict_cfg, ignore_char='1')
self.assertEqual(ce_loss.ignore_index, 1)

View File

@ -46,6 +46,65 @@ class TestBaseTextRecogPostprocessor(TestCase):
base_postprocessor = BaseTextRecogPostprocessor(
dict_cfg, ignore_chars=['M'])
base_postprocessor = BaseTextRecogPostprocessor(
dict_cfg, ignore_chars=['1', '2', '3'])
# test dictionary is invalid type
dict_cfg = ['tmp']
with self.assertRaisesRegex(
TypeError, ('The type of dictionary should be `Dictionary`'
' or dict, '
f'but got {type(dict_cfg)}')):
base_postprocessor = BaseTextRecogPostprocessor(dict_cfg)
# test diction cfg with with_unknown=False
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=False)
base_postprocessor = BaseTextRecogPostprocessor(
dict_cfg, ignore_chars=['1', '2', '3'])
self.assertListEqual(base_postprocessor.ignore_indexes, [1, 2, 3])
# test ignore_chars
with self.assertRaisesRegex(TypeError,
'ignore_chars must be list of str'):
base_postprocessor = BaseTextRecogPostprocessor(
dict_cfg, ignore_chars=[1, 2, 3])
with self.assertWarnsRegex(Warning,
'M does not exist in the dictionary'):
base_postprocessor = BaseTextRecogPostprocessor(
dict_cfg, ignore_chars=['M'])
with self.assertWarnsRegex(Warning,
'M does not exist in the dictionary'):
base_postprocessor = BaseTextRecogPostprocessor(
dict(
type='Dictionary',
dict_file=dict_file,
with_unknown=True,
unkown_token=None),
ignore_chars=['M'])
with self.assertWarnsRegex(Warning,
'M does not exist in the dictionary'):
base_postprocessor = BaseTextRecogPostprocessor(
dict(
type='Dictionary', dict_file=dict_file, with_unknown=True),
ignore_chars=['M'])
with self.assertWarnsRegex(Warning,
'unknown does not exist in the dictionary'):
base_postprocessor = BaseTextRecogPostprocessor(
dict(
type='Dictionary', dict_file=dict_file,
with_unknown=False),
ignore_chars=['unknown'])
base_postprocessor = BaseTextRecogPostprocessor(
dict_cfg, ignore_chars=['1', '2', '3'])
# test dictionary is invalid type