mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[TODO] Add char2idx
This commit is contained in:
parent
dc84187311
commit
3734527d38
@ -100,6 +100,30 @@ class Dictionary:
|
||||
"""
|
||||
return self._contain_uppercase
|
||||
|
||||
def char2idx(self, char: str, strict: bool = True) -> int:
|
||||
"""Convert a character to an index via ``Dictionary.dict``.
|
||||
|
||||
Args:
|
||||
char (str): The character to convert to index.
|
||||
strict (bool): The flag to control whether to raise an exception
|
||||
when the character is not in the dictionary. Defaults to True.
|
||||
|
||||
Return:
|
||||
int: The index of the character.
|
||||
"""
|
||||
char_idx = self._char2idx.get(char, None)
|
||||
if char_idx is None:
|
||||
if self.with_unknown:
|
||||
return self.unknown_idx
|
||||
elif not strict:
|
||||
return None
|
||||
else:
|
||||
raise Exception(f'Chararcter: {char} not in dict,'
|
||||
' please check gt_label and use'
|
||||
' custom dict file,'
|
||||
' or set "with_unknown=True"')
|
||||
return char_idx
|
||||
|
||||
def str2idx(self, string: str) -> List:
|
||||
"""Convert a string to a list of indexes via ``Dictionary.dict``.
|
||||
|
||||
@ -111,7 +135,7 @@ class Dictionary:
|
||||
"""
|
||||
idx = list()
|
||||
for s in string:
|
||||
char_idx = self._char2idx.get(s, self.unknown_idx)
|
||||
char_idx = self.char2idx(s)
|
||||
if char_idx is None:
|
||||
if self.with_unknown:
|
||||
continue
|
||||
|
@ -28,6 +28,14 @@ class CEModuleLoss(BaseRecogModuleLoss):
|
||||
- lower: Convert gt texts into lowercase characters.
|
||||
Usually, it only works for English characters. Defaults to
|
||||
'unchanged'.
|
||||
pad_with (str): The padding strategy for ``gt_text.padded_indexes``.
|
||||
Defaults to 'auto'. Options are:
|
||||
- 'auto': Use dictionary.padding_idx to pad gt texts, or
|
||||
dictionary.end_idx if dictionary.padding_idx
|
||||
is None.
|
||||
- 'padding': Always use dictionary.padding_idx to pad gt texts.
|
||||
- 'end': Always use dictionary.end_idx to pad gt texts.
|
||||
- 'none': Do not pad gt texts.
|
||||
ignore_char (int or str): Specifies a target value that is
|
||||
ignored and does not contribute to the input gradient.
|
||||
ignore_char can be int or str. If int, it is the index of
|
||||
@ -54,6 +62,7 @@ class CEModuleLoss(BaseRecogModuleLoss):
|
||||
dictionary: Union[Dict, Dictionary],
|
||||
max_seq_len: int = 40,
|
||||
letter_case: str = 'unchanged',
|
||||
pad_with: str = 'auto',
|
||||
ignore_char: Union[int, str] = 'padding',
|
||||
flatten: bool = False,
|
||||
reduction: str = 'none',
|
||||
@ -61,7 +70,8 @@ class CEModuleLoss(BaseRecogModuleLoss):
|
||||
super().__init__(
|
||||
dictionary=dictionary,
|
||||
max_seq_len=max_seq_len,
|
||||
letter_case=letter_case)
|
||||
letter_case=letter_case,
|
||||
pad_with=pad_with)
|
||||
assert isinstance(ignore_char, (int, str))
|
||||
assert isinstance(reduction, str)
|
||||
assert reduction in ['none', 'mean', 'sum']
|
||||
@ -81,10 +91,13 @@ class CEModuleLoss(BaseRecogModuleLoss):
|
||||
'end': self.dictionary.end_idx,
|
||||
'unknown': self.dictionary.unknown_idx,
|
||||
}
|
||||
# TODO add char2id in Dictionary
|
||||
|
||||
ignore_index = mapping_table.get(
|
||||
ignore_char, self.dictionary._char2idx.get(ignore_char, None))
|
||||
if ignore_index is None:
|
||||
ignore_char,
|
||||
self.dictionary.char2idx(ignore_char, strict=False))
|
||||
if ignore_index is None or (ignore_index
|
||||
== self.dictionary.unknown_idx
|
||||
and ignore_char != 'unknown'):
|
||||
warnings.warn(
|
||||
f'{ignore_char} does not exist in the dictionary',
|
||||
UserWarning)
|
||||
|
@ -52,10 +52,11 @@ class BaseTextRecogPostprocessor:
|
||||
raise TypeError('ignore_chars must be list of str')
|
||||
ignore_indexes = list()
|
||||
for ignore_char in ignore_chars:
|
||||
# TODO add char2id in Dictionary
|
||||
index = mapping_table.get(
|
||||
ignore_char, self.dictionary._char2idx.get(ignore_char, None))
|
||||
if index is None:
|
||||
ignore_char,
|
||||
self.dictionary.char2idx(ignore_char, strict=False))
|
||||
if index is None or (index == self.dictionary.unknown_idx
|
||||
and ignore_char != 'unknown'):
|
||||
warnings.warn(
|
||||
f'{ignore_char} does not exist in the dictionary',
|
||||
UserWarning)
|
||||
|
@ -100,6 +100,28 @@ class TestDictionary(TestCase):
|
||||
dict_gen = Dictionary(dict_file=dict_file)
|
||||
assert dict_gen.contain_uppercase is True
|
||||
|
||||
def test_char2idx(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
|
||||
create_dummy_dict_file(dict_file)
|
||||
dict_gen = Dictionary(dict_file=dict_file, with_unknown=False)
|
||||
self.assertEqual(dict_gen.char2idx('0'), 0)
|
||||
|
||||
dict_gen = Dictionary(dict_file=dict_file, with_unknown=True)
|
||||
self.assertEqual(dict_gen.char2idx('H'), dict_gen.unknown_idx)
|
||||
|
||||
dict_gen = Dictionary(
|
||||
dict_file=dict_file, with_unknown=True, unknown_token=None)
|
||||
self.assertEqual(dict_gen.char2idx('H'), None)
|
||||
|
||||
# Test strict
|
||||
dict_gen = Dictionary(dict_file=dict_file, with_unknown=False)
|
||||
with self.assertRaises(Exception):
|
||||
dict_gen.char2idx('H', strict=True)
|
||||
self.assertEqual(dict_gen.char2idx('H', strict=False), None)
|
||||
|
||||
def test_str2idx(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
|
@ -21,9 +21,10 @@ class TestCEModuleLoss(TestCase):
|
||||
self.gt = [data_sample1, data_sample2, data_sample3]
|
||||
|
||||
def test_init(self):
|
||||
dict_file = 'dicts/lower_english_digits.txt'
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file='dicts/lower_english_digits.txt',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=True,
|
||||
@ -47,6 +48,26 @@ class TestCEModuleLoss(TestCase):
|
||||
# with self.assertRaises(ValueError):
|
||||
with self.assertWarns(UserWarning):
|
||||
ce_loss = CEModuleLoss(dict_cfg, ignore_char='ignore')
|
||||
with self.assertWarns(UserWarning):
|
||||
ce_loss = CEModuleLoss(
|
||||
dict(
|
||||
type='Dictionary', dict_file=dict_file, with_unknown=True),
|
||||
ignore_char='M',
|
||||
pad_with='none')
|
||||
with self.assertWarns(UserWarning):
|
||||
ce_loss = CEModuleLoss(
|
||||
dict(
|
||||
type='Dictionary', dict_file=dict_file,
|
||||
with_unknown=False),
|
||||
ignore_char='M',
|
||||
pad_with='none')
|
||||
with self.assertWarns(UserWarning):
|
||||
ce_loss = CEModuleLoss(
|
||||
dict(
|
||||
type='Dictionary', dict_file=dict_file,
|
||||
with_unknown=False),
|
||||
ignore_char='unknown',
|
||||
pad_with='none')
|
||||
ce_loss = CEModuleLoss(dict_cfg, ignore_char='1')
|
||||
self.assertEqual(ce_loss.ignore_index, 1)
|
||||
|
||||
|
@ -46,6 +46,65 @@ class TestBaseTextRecogPostprocessor(TestCase):
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict_cfg, ignore_chars=['M'])
|
||||
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict_cfg, ignore_chars=['1', '2', '3'])
|
||||
# test dictionary is invalid type
|
||||
dict_cfg = ['tmp']
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, ('The type of dictionary should be `Dictionary`'
|
||||
' or dict, '
|
||||
f'but got {type(dict_cfg)}')):
|
||||
base_postprocessor = BaseTextRecogPostprocessor(dict_cfg)
|
||||
# test diction cfg with with_unknown=False
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=False)
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict_cfg, ignore_chars=['1', '2', '3'])
|
||||
|
||||
self.assertListEqual(base_postprocessor.ignore_indexes, [1, 2, 3])
|
||||
|
||||
# test ignore_chars
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'ignore_chars must be list of str'):
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict_cfg, ignore_chars=[1, 2, 3])
|
||||
|
||||
with self.assertWarnsRegex(Warning,
|
||||
'M does not exist in the dictionary'):
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict_cfg, ignore_chars=['M'])
|
||||
|
||||
with self.assertWarnsRegex(Warning,
|
||||
'M does not exist in the dictionary'):
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_unknown=True,
|
||||
unkown_token=None),
|
||||
ignore_chars=['M'])
|
||||
|
||||
with self.assertWarnsRegex(Warning,
|
||||
'M does not exist in the dictionary'):
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict(
|
||||
type='Dictionary', dict_file=dict_file, with_unknown=True),
|
||||
ignore_chars=['M'])
|
||||
|
||||
with self.assertWarnsRegex(Warning,
|
||||
'unknown does not exist in the dictionary'):
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict(
|
||||
type='Dictionary', dict_file=dict_file,
|
||||
with_unknown=False),
|
||||
ignore_chars=['unknown'])
|
||||
|
||||
base_postprocessor = BaseTextRecogPostprocessor(
|
||||
dict_cfg, ignore_chars=['1', '2', '3'])
|
||||
# test dictionary is invalid type
|
||||
|
Loading…
x
Reference in New Issue
Block a user