mirror of https://github.com/open-mmlab/mmocr.git
[Fix] Fix dictionary docstr and remove unncessary kwargs (#1276)
* [Fix] Fix dictionary docstr and remove unncessary kwargs * fix * fixpull/1278/head
parent
97f6c1d5d6
commit
7b25b62c21
|
@ -1,6 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import re
|
||||
from typing import List, Optional, Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
from mmocr.registry import TASK_UTILS
|
||||
from mmocr.utils import list_from_file
|
||||
|
@ -50,8 +49,7 @@ class Dictionary:
|
|||
end_token: str = '<EOS>',
|
||||
start_end_token: str = '<BOS/EOS>',
|
||||
padding_token: str = '<PAD>',
|
||||
unknown_token: Optional[str] = '<UKN>',
|
||||
**kwargs) -> None:
|
||||
unknown_token: str = '<UKN>') -> None:
|
||||
self.with_start = with_start
|
||||
self.with_end = with_end
|
||||
self.same_start_end = same_start_end
|
||||
|
@ -75,10 +73,8 @@ class Dictionary:
|
|||
self._dict.append(line)
|
||||
|
||||
self._char2idx = {char: idx for idx, char in enumerate(self._dict)}
|
||||
self._contain_uppercase = len(re.findall('[A-Z]', ''.join(
|
||||
self.dict))) > 0
|
||||
|
||||
self._update_dict(**kwargs)
|
||||
self._update_dict()
|
||||
assert len(set(self._dict)) == len(self._dict), \
|
||||
'Invalid dictionary: Has duplicated characters.'
|
||||
|
||||
|
@ -90,16 +86,10 @@ class Dictionary:
|
|||
|
||||
@property
|
||||
def dict(self) -> list:
|
||||
"""list: The list of all character to recognize, which Special tokens
|
||||
are counted."""
|
||||
"""list: Returns a list of characters to recognize, where special
|
||||
tokens are counted."""
|
||||
return self._dict
|
||||
|
||||
@property
|
||||
def contain_uppercase(self) -> bool:
|
||||
"""bool: Whether all the English characters in dict file are in lowercase.
|
||||
"""
|
||||
return self._contain_uppercase
|
||||
|
||||
def char2idx(self, char: str, strict: bool = True) -> int:
|
||||
"""Convert a character to an index via ``Dictionary.dict``.
|
||||
|
||||
|
@ -158,10 +148,12 @@ class Dictionary:
|
|||
assert isinstance(index, (list, tuple))
|
||||
string = ''
|
||||
for i in index:
|
||||
assert i < len(self._dict), f'Index: {i} out of range! Index ' \
|
||||
f'must be less than {len(self._dict)}'
|
||||
string += self._dict[i]
|
||||
return string
|
||||
|
||||
def _update_dict(self, **kwargs):
|
||||
def _update_dict(self):
|
||||
"""Update the dict with tokens according to parameters."""
|
||||
# BOS/EOS
|
||||
self.start_idx = None
|
||||
|
|
|
@ -40,6 +40,7 @@ class TestSDMGRHead(TestCase):
|
|||
self.assertEqual(edge_cls.shape, torch.Size([4, 2]))
|
||||
|
||||
# When input image is None
|
||||
del (dict_cfg['type'])
|
||||
head = SDMGRHead(dictionary=Dictionary(**dict_cfg))
|
||||
node_cls, edge_cls = head(None, [data_sample])
|
||||
self.assertEqual(node_cls.shape, torch.Size([2, 26]))
|
||||
|
|
|
@ -89,17 +89,6 @@ class TestDictionary(TestCase):
|
|||
dict_gen = Dictionary(dict_file=dict_file)
|
||||
assert dict_gen.num_classes == 36
|
||||
|
||||
def test_contain_uppercase(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
|
||||
create_dummy_dict_file(dict_file)
|
||||
dict_gen = Dictionary(dict_file=dict_file)
|
||||
assert dict_gen.contain_uppercase is False
|
||||
create_dummy_dict_file(dict_file, chars='abcdABCD')
|
||||
dict_gen = Dictionary(dict_file=dict_file)
|
||||
assert dict_gen.contain_uppercase is True
|
||||
|
||||
def test_char2idx(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
|
@ -150,3 +139,5 @@ class TestDictionary(TestCase):
|
|||
self.assertEqual(dict_gen.idx2str([0, 1, 2, 3, 4]), '01234')
|
||||
with self.assertRaises(AssertionError):
|
||||
dict_gen.idx2str('01234')
|
||||
with self.assertRaises(AssertionError):
|
||||
dict_gen.idx2str([40])
|
||||
|
|
|
@ -87,7 +87,7 @@ class TestBaseTextRecogPostprocessor(TestCase):
|
|||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_unknown=True,
|
||||
unkown_token=None),
|
||||
unknown_token=None),
|
||||
ignore_chars=['M'])
|
||||
|
||||
with self.assertWarnsRegex(Warning,
|
||||
|
|
Loading…
Reference in New Issue