mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Refactor] split labelconverter to Dictionary
This commit is contained in:
parent
c47c5711c1
commit
c920edfb3a
@ -1,9 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from . import (backbones, convertors, decoders, encoders, fusers, heads,
|
from . import (backbones, convertors, decoders, dictionary, encoders, fusers,
|
||||||
losses, necks, plugins, preprocessor, recognizer)
|
heads, losses, necks, plugins, preprocessor, recognizer)
|
||||||
from .backbones import * # NOQA
|
from .backbones import * # NOQA
|
||||||
from .convertors import * # NOQA
|
from .convertors import * # NOQA
|
||||||
from .decoders import * # NOQA
|
from .decoders import * # NOQA
|
||||||
|
from .dictionary import * # NOQA
|
||||||
from .encoders import * # NOQA
|
from .encoders import * # NOQA
|
||||||
from .fusers import * # NOQA
|
from .fusers import * # NOQA
|
||||||
from .heads import * # NOQA
|
from .heads import * # NOQA
|
||||||
@ -17,4 +18,4 @@ __all__ = (
|
|||||||
backbones.__all__ + convertors.__all__ + decoders.__all__ +
|
backbones.__all__ + convertors.__all__ + decoders.__all__ +
|
||||||
encoders.__all__ + heads.__all__ + losses.__all__ + necks.__all__ +
|
encoders.__all__ + heads.__all__ + losses.__all__ + necks.__all__ +
|
||||||
preprocessor.__all__ + recognizer.__all__ + fusers.__all__ +
|
preprocessor.__all__ + recognizer.__all__ + fusers.__all__ +
|
||||||
plugins.__all__)
|
plugins.__all__ + dictionary.__all__)
|
||||||
|
5
mmocr/models/textrecog/dictionary/__init__.py
Normal file
5
mmocr/models/textrecog/dictionary/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
|
from .dictionary import Dictionary
|
||||||
|
|
||||||
|
__all__ = ['Dictionary']
|
169
mmocr/models/textrecog/dictionary/dictionary.py
Normal file
169
mmocr/models/textrecog/dictionary/dictionary.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import re
|
||||||
|
from typing import List, Sequence
|
||||||
|
|
||||||
|
from mmocr.registry import TASK_UTILS
|
||||||
|
from mmocr.utils import list_from_file
|
||||||
|
|
||||||
|
|
||||||
|
@TASK_UTILS.register_module()
|
||||||
|
class Dictionary:
|
||||||
|
"""The class generates a dictionary for recognition. It pre-defines four
|
||||||
|
special tokens: ``start_token``, ``end_token``, ``pad_token``, and
|
||||||
|
``unknown_token``, which will be sequentially placed at the end of the
|
||||||
|
dictionary when their corresponding flags are True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dict_file (str): The path of Character dict file which a single
|
||||||
|
character must occupies a line.
|
||||||
|
with_start (bool): The flag to control whether to include the start
|
||||||
|
token. Defaults to False.
|
||||||
|
with_end (bool): The flag to control whether to include the end token.
|
||||||
|
Defaults to False.
|
||||||
|
same_start_end (bool): The flag to control whether the start token and
|
||||||
|
end token are the same. It only works when both ``with_start`` and
|
||||||
|
``with_end`` are True. Defaults to False.
|
||||||
|
with_padding (bool):The padding token may represent more than a
|
||||||
|
padding. It can also represent tokens like the blank token in CTC
|
||||||
|
or the background token in SegOCR. Defaults to False.
|
||||||
|
with_unknown (bool): The flag to control whether to include the
|
||||||
|
unknown token. Defaults to False.
|
||||||
|
start_token (str): The start token as a string. Defaults to '<BOS>'.
|
||||||
|
end_token (str): The end token as a string. Defaults to '<EOS>'.
|
||||||
|
start_end_token (str): The start/end token as a string. if start and
|
||||||
|
end is the same. Defaults to '<BOS/EOS>'.
|
||||||
|
padding_token (str): The padding token as a string.
|
||||||
|
Defaults to '<PAD>'.
|
||||||
|
unknown_token (str): The unknown token as a string.
|
||||||
|
Defaults to '<UKN>'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dict_file: str,
|
||||||
|
with_start: bool = False,
|
||||||
|
with_end: bool = False,
|
||||||
|
same_start_end: bool = False,
|
||||||
|
with_padding: bool = False,
|
||||||
|
with_unknown: bool = False,
|
||||||
|
start_token: str = '<BOS>',
|
||||||
|
end_token: str = '<EOS>',
|
||||||
|
start_end_token: str = '<BOS/EOS>',
|
||||||
|
padding_token: str = '<PAD>',
|
||||||
|
unknown_token: str = '<UKN>',
|
||||||
|
**kwargs) -> None:
|
||||||
|
self.with_start = with_start
|
||||||
|
self.with_end = with_end
|
||||||
|
self.same_start_end = same_start_end
|
||||||
|
self.with_padding = with_padding
|
||||||
|
self.with_unknown = with_unknown
|
||||||
|
self.start_end_token = start_end_token
|
||||||
|
self.start_token = start_token
|
||||||
|
self.end_token = end_token
|
||||||
|
self.padding_token = padding_token
|
||||||
|
self.unknown_token = unknown_token
|
||||||
|
|
||||||
|
assert isinstance(dict_file, str)
|
||||||
|
self._dict = []
|
||||||
|
for line_num, line in enumerate(list_from_file(dict_file)):
|
||||||
|
line = line.strip('\r\n')
|
||||||
|
if len(line) > 1:
|
||||||
|
raise ValueError('Expect each line has 0 or 1 character, '
|
||||||
|
f'got {len(line)} characters '
|
||||||
|
f'at line {line_num + 1}')
|
||||||
|
if line != '':
|
||||||
|
self._dict.append(line)
|
||||||
|
|
||||||
|
self._char2idx = {char: idx for idx, char in enumerate(self._dict)}
|
||||||
|
self._contain_uppercase = len(re.findall('[A-Z]', ''.join(
|
||||||
|
self.dict))) > 0
|
||||||
|
|
||||||
|
self._update_dict(**kwargs)
|
||||||
|
assert len(set(self._dict)) == len(self._dict), \
|
||||||
|
'Invalid dictionary: Has duplicated characters.'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self) -> int:
|
||||||
|
"""int: Number of output classes. Special tokens are counted.
|
||||||
|
"""
|
||||||
|
return len(self._dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dict(self) -> list:
|
||||||
|
"""list: The list of all character to recognize, which Special tokens
|
||||||
|
are counted."""
|
||||||
|
return self._dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def contain_uppercase(self) -> bool:
|
||||||
|
"""bool: Whether all the English characters in dict file are in lowercase.
|
||||||
|
"""
|
||||||
|
return self._contain_uppercase
|
||||||
|
|
||||||
|
def str2idx(self, string: str) -> List:
|
||||||
|
"""Convert a string to a list of indexes via ``Dictionary.dict``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
string (str): The string to convert to indexes.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
list: The list of indexes of the string.
|
||||||
|
"""
|
||||||
|
idx = list()
|
||||||
|
for s in string:
|
||||||
|
char_idx = self._char2idx.get(s, self.unknown_idx)
|
||||||
|
if char_idx is None:
|
||||||
|
raise Exception(f'Chararcter: {s} not in dict,'
|
||||||
|
f' please check gt_label and use'
|
||||||
|
f' custom dict file,'
|
||||||
|
f' or set "with_unknown=True"')
|
||||||
|
idx.append(char_idx)
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def idx2str(self, index: Sequence[int]) -> str:
|
||||||
|
"""Convert a list of index to string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index (list[int]): The list of indexes to convert to string.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
str: The converted string.
|
||||||
|
"""
|
||||||
|
assert isinstance(index, (list, tuple))
|
||||||
|
string = ''
|
||||||
|
for i in index:
|
||||||
|
string += self._dict[i]
|
||||||
|
return string
|
||||||
|
|
||||||
|
def _update_dict(self, **kwargs):
|
||||||
|
"""Update the dict with tokens according to parameters."""
|
||||||
|
# BOS/EOS
|
||||||
|
self.start_idx = None
|
||||||
|
self.end_idx = None
|
||||||
|
if self.with_start and self.with_end and self.same_start_end:
|
||||||
|
self._dict.append(self.start_end_token)
|
||||||
|
self.start_idx = len(self._dict) - 1
|
||||||
|
self.end_idx = self.start_idx
|
||||||
|
else:
|
||||||
|
if self.with_start:
|
||||||
|
self._dict.append(self.start_token)
|
||||||
|
self.start_idx = len(self._dict) - 1
|
||||||
|
if self.with_end:
|
||||||
|
self._dict.append(self.end_token)
|
||||||
|
self.end_idx = len(self._dict) - 1
|
||||||
|
|
||||||
|
# padding
|
||||||
|
self.padding_idx = None
|
||||||
|
if self.with_padding:
|
||||||
|
self._dict.append(self.padding_token)
|
||||||
|
self.padding_idx = len(self._dict) - 1
|
||||||
|
|
||||||
|
# unknown
|
||||||
|
self.unknown_idx = None
|
||||||
|
if self.with_unknown:
|
||||||
|
self._dict.append(self.unknown_token)
|
||||||
|
self.unknown_idx = len(self._dict) - 1
|
||||||
|
|
||||||
|
# update char2idx
|
||||||
|
self._char2idx = {}
|
||||||
|
for idx, char in enumerate(self._dict):
|
||||||
|
self._char2idx[char] = idx
|
132
tests/test_models/test_dictionary/test_dictionary.py
Normal file
132
tests/test_models/test_dictionary/test_dictionary.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
from mmocr.models.textrecog import Dictionary
|
||||||
|
|
||||||
|
|
||||||
|
class TestDictionary(TestCase):
|
||||||
|
|
||||||
|
def _create_dummy_dict_file(
|
||||||
|
self, dict_file,
|
||||||
|
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
|
||||||
|
with open(dict_file, 'w') as f:
|
||||||
|
for char in chars:
|
||||||
|
f.write(char + '\n')
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
tmp_dir = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
# create dummy data
|
||||||
|
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||||
|
self._create_dummy_dict_file(dict_file)
|
||||||
|
# with start
|
||||||
|
dict_gen = Dictionary(
|
||||||
|
dict_file=dict_file,
|
||||||
|
with_start=True,
|
||||||
|
with_end=True,
|
||||||
|
same_start_end=False,
|
||||||
|
with_padding=True,
|
||||||
|
with_unknown=True)
|
||||||
|
self.assertEqual(dict_gen.num_classes, 40)
|
||||||
|
self.assertListEqual(
|
||||||
|
dict_gen.dict,
|
||||||
|
list('0123456789abcdefghijklmnopqrstuvwxyz') +
|
||||||
|
['<BOS>', '<EOS>', '<PAD>', '<UKN>'])
|
||||||
|
dict_gen = Dictionary(
|
||||||
|
dict_file=dict_file,
|
||||||
|
with_start=True,
|
||||||
|
with_end=True,
|
||||||
|
same_start_end=True,
|
||||||
|
with_padding=True,
|
||||||
|
with_unknown=True)
|
||||||
|
assert dict_gen.num_classes == 39
|
||||||
|
assert dict_gen.dict == list('0123456789abcdefghijklmnopqrstuvwxyz'
|
||||||
|
) + ['<BOS/EOS>', '<PAD>', '<UKN>']
|
||||||
|
self.assertEqual(dict_gen.num_classes, 39)
|
||||||
|
self.assertListEqual(
|
||||||
|
dict_gen.dict,
|
||||||
|
list('0123456789abcdefghijklmnopqrstuvwxyz') +
|
||||||
|
['<BOS/EOS>', '<PAD>', '<UKN>'])
|
||||||
|
dict_gen = Dictionary(
|
||||||
|
dict_file=dict_file,
|
||||||
|
with_start=True,
|
||||||
|
with_end=True,
|
||||||
|
same_start_end=False,
|
||||||
|
with_padding=True,
|
||||||
|
with_unknown=True,
|
||||||
|
start_token='<STA>',
|
||||||
|
end_token='<END>',
|
||||||
|
padding_token='<BLK>',
|
||||||
|
unknown_token='<NO>')
|
||||||
|
assert dict_gen.num_classes == 40
|
||||||
|
assert dict_gen.dict[-4:] == ['<STA>', '<END>', '<BLK>', '<NO>']
|
||||||
|
self.assertEqual(dict_gen.num_classes, 40)
|
||||||
|
self.assertListEqual(dict_gen.dict[-4:],
|
||||||
|
['<STA>', '<END>', '<BLK>', '<NO>'])
|
||||||
|
dict_gen = Dictionary(
|
||||||
|
dict_file=dict_file,
|
||||||
|
with_start=True,
|
||||||
|
with_end=True,
|
||||||
|
same_start_end=True,
|
||||||
|
with_padding=True,
|
||||||
|
with_unknown=True,
|
||||||
|
start_end_token='<BE>')
|
||||||
|
self.assertEqual(dict_gen.num_classes, 39)
|
||||||
|
self.assertListEqual(dict_gen.dict[-3:], ['<BE>', '<PAD>', '<UKN>'])
|
||||||
|
# test len(line) > 1
|
||||||
|
self._create_dummy_dict_file(dict_file, chars=['12', '3', '4'])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
dict_gen = Dictionary(dict_file=dict_file)
|
||||||
|
|
||||||
|
# test duplicated dict
|
||||||
|
self._create_dummy_dict_file(dict_file, chars=['1', '1', '2'])
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
dict_gen = Dictionary(dict_file=dict_file)
|
||||||
|
|
||||||
|
tmp_dir.cleanup()
|
||||||
|
|
||||||
|
def test_num_classes(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# create dummy data
|
||||||
|
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
|
||||||
|
self._create_dummy_dict_file(dict_file)
|
||||||
|
dict_gen = Dictionary(dict_file=dict_file)
|
||||||
|
assert dict_gen.num_classes == 36
|
||||||
|
|
||||||
|
def test_contain_uppercase(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# create dummy data
|
||||||
|
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
|
||||||
|
self._create_dummy_dict_file(dict_file)
|
||||||
|
dict_gen = Dictionary(dict_file=dict_file)
|
||||||
|
assert dict_gen.contain_uppercase is False
|
||||||
|
self._create_dummy_dict_file(dict_file, chars='abcdABCD')
|
||||||
|
dict_gen = Dictionary(dict_file=dict_file)
|
||||||
|
assert dict_gen.contain_uppercase is True
|
||||||
|
|
||||||
|
def test_str2idx(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
|
||||||
|
# create dummy data
|
||||||
|
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
|
||||||
|
self._create_dummy_dict_file(dict_file)
|
||||||
|
dict_gen = Dictionary(dict_file=dict_file)
|
||||||
|
self.assertEqual(dict_gen.str2idx('01234'), [0, 1, 2, 3, 4])
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
dict_gen.str2idx('H')
|
||||||
|
|
||||||
|
dict_gen = Dictionary(dict_file=dict_file, with_unknown=True)
|
||||||
|
self.assertListEqual(dict_gen.str2idx('H'), [dict_gen.unknown_idx])
|
||||||
|
|
||||||
|
def test_idx2str(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
|
||||||
|
# create dummy data
|
||||||
|
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
|
||||||
|
self._create_dummy_dict_file(dict_file)
|
||||||
|
dict_gen = Dictionary(dict_file=dict_file)
|
||||||
|
self.assertEqual(dict_gen.idx2str([0, 1, 2, 3, 4]), '01234')
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
dict_gen.idx2str('01234')
|
Loading…
x
Reference in New Issue
Block a user