[Fix] Fix dictionary docstr and remove unncessary kwargs (#1276)

* [Fix] Fix dictionary docstr and remove unncessary kwargs

* fix

* fix
pull/1278/head
Tong Gao 2022-08-11 11:14:17 +08:00 committed by GitHub
parent 97f6c1d5d6
commit 7b25b62c21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 28 deletions

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import List, Optional, Sequence
from typing import List, Sequence
from mmocr.registry import TASK_UTILS
from mmocr.utils import list_from_file
@ -50,8 +49,7 @@ class Dictionary:
end_token: str = '<EOS>',
start_end_token: str = '<BOS/EOS>',
padding_token: str = '<PAD>',
unknown_token: Optional[str] = '<UKN>',
**kwargs) -> None:
unknown_token: str = '<UKN>') -> None:
self.with_start = with_start
self.with_end = with_end
self.same_start_end = same_start_end
@ -75,10 +73,8 @@ class Dictionary:
self._dict.append(line)
self._char2idx = {char: idx for idx, char in enumerate(self._dict)}
self._contain_uppercase = len(re.findall('[A-Z]', ''.join(
self.dict))) > 0
self._update_dict(**kwargs)
self._update_dict()
assert len(set(self._dict)) == len(self._dict), \
'Invalid dictionary: Has duplicated characters.'
@ -90,16 +86,10 @@ class Dictionary:
@property
def dict(self) -> list:
"""list: The list of all character to recognize, which Special tokens
are counted."""
"""list: Returns a list of characters to recognize, where special
tokens are counted."""
return self._dict
@property
def contain_uppercase(self) -> bool:
"""bool: Whether all the English characters in dict file are in lowercase.
"""
return self._contain_uppercase
def char2idx(self, char: str, strict: bool = True) -> int:
"""Convert a character to an index via ``Dictionary.dict``.
@ -158,10 +148,12 @@ class Dictionary:
assert isinstance(index, (list, tuple))
string = ''
for i in index:
assert i < len(self._dict), f'Index: {i} out of range! Index ' \
f'must be less than {len(self._dict)}'
string += self._dict[i]
return string
def _update_dict(self, **kwargs):
def _update_dict(self):
"""Update the dict with tokens according to parameters."""
# BOS/EOS
self.start_idx = None

View File

@ -40,6 +40,7 @@ class TestSDMGRHead(TestCase):
self.assertEqual(edge_cls.shape, torch.Size([4, 2]))
# When input image is None
del (dict_cfg['type'])
head = SDMGRHead(dictionary=Dictionary(**dict_cfg))
node_cls, edge_cls = head(None, [data_sample])
self.assertEqual(node_cls.shape, torch.Size([2, 26]))

View File

@ -89,17 +89,6 @@ class TestDictionary(TestCase):
dict_gen = Dictionary(dict_file=dict_file)
assert dict_gen.num_classes == 36
def test_contain_uppercase(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# create dummy data
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
create_dummy_dict_file(dict_file)
dict_gen = Dictionary(dict_file=dict_file)
assert dict_gen.contain_uppercase is False
create_dummy_dict_file(dict_file, chars='abcdABCD')
dict_gen = Dictionary(dict_file=dict_file)
assert dict_gen.contain_uppercase is True
def test_char2idx(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -150,3 +139,5 @@ class TestDictionary(TestCase):
self.assertEqual(dict_gen.idx2str([0, 1, 2, 3, 4]), '01234')
with self.assertRaises(AssertionError):
dict_gen.idx2str('01234')
with self.assertRaises(AssertionError):
dict_gen.idx2str([40])

View File

@ -87,7 +87,7 @@ class TestBaseTextRecogPostprocessor(TestCase):
type='Dictionary',
dict_file=dict_file,
with_unknown=True,
unkown_token=None),
unknown_token=None),
ignore_chars=['M'])
with self.assertWarnsRegex(Warning,