[Refactor] split labelconverter to Dictionary

This commit is contained in:
liukuikun 2022-05-12 12:01:03 +00:00 committed by gaotongxiao
parent c47c5711c1
commit c920edfb3a
4 changed files with 310 additions and 3 deletions

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import (backbones, convertors, decoders, encoders, fusers, heads,
losses, necks, plugins, preprocessor, recognizer)
from . import (backbones, convertors, decoders, dictionary, encoders, fusers,
heads, losses, necks, plugins, preprocessor, recognizer)
from .backbones import * # NOQA
from .convertors import * # NOQA
from .decoders import * # NOQA
from .dictionary import * # NOQA
from .encoders import * # NOQA
from .fusers import * # NOQA
from .heads import * # NOQA
@ -17,4 +18,4 @@ __all__ = (
backbones.__all__ + convertors.__all__ + decoders.__all__ +
encoders.__all__ + heads.__all__ + losses.__all__ + necks.__all__ +
preprocessor.__all__ + recognizer.__all__ + fusers.__all__ +
plugins.__all__)
plugins.__all__ + dictionary.__all__)

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dictionary import Dictionary
__all__ = ['Dictionary']

View File

@ -0,0 +1,169 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import List, Sequence
from mmocr.registry import TASK_UTILS
from mmocr.utils import list_from_file
@TASK_UTILS.register_module()
class Dictionary:
"""The class generates a dictionary for recognition. It pre-defines four
special tokens: ``start_token``, ``end_token``, ``pad_token``, and
``unknown_token``, which will be sequentially placed at the end of the
dictionary when their corresponding flags are True.
Args:
dict_file (str): The path of Character dict file which a single
character must occupies a line.
with_start (bool): The flag to control whether to include the start
token. Defaults to False.
with_end (bool): The flag to control whether to include the end token.
Defaults to False.
same_start_end (bool): The flag to control whether the start token and
end token are the same. It only works when both ``with_start`` and
``with_end`` are True. Defaults to False.
with_padding (bool):The padding token may represent more than a
padding. It can also represent tokens like the blank token in CTC
or the background token in SegOCR. Defaults to False.
with_unknown (bool): The flag to control whether to include the
unknown token. Defaults to False.
start_token (str): The start token as a string. Defaults to '<BOS>'.
end_token (str): The end token as a string. Defaults to '<EOS>'.
start_end_token (str): The start/end token as a string. if start and
end is the same. Defaults to '<BOS/EOS>'.
padding_token (str): The padding token as a string.
Defaults to '<PAD>'.
unknown_token (str): The unknown token as a string.
Defaults to '<UKN>'.
"""
def __init__(self,
dict_file: str,
with_start: bool = False,
with_end: bool = False,
same_start_end: bool = False,
with_padding: bool = False,
with_unknown: bool = False,
start_token: str = '<BOS>',
end_token: str = '<EOS>',
start_end_token: str = '<BOS/EOS>',
padding_token: str = '<PAD>',
unknown_token: str = '<UKN>',
**kwargs) -> None:
self.with_start = with_start
self.with_end = with_end
self.same_start_end = same_start_end
self.with_padding = with_padding
self.with_unknown = with_unknown
self.start_end_token = start_end_token
self.start_token = start_token
self.end_token = end_token
self.padding_token = padding_token
self.unknown_token = unknown_token
assert isinstance(dict_file, str)
self._dict = []
for line_num, line in enumerate(list_from_file(dict_file)):
line = line.strip('\r\n')
if len(line) > 1:
raise ValueError('Expect each line has 0 or 1 character, '
f'got {len(line)} characters '
f'at line {line_num + 1}')
if line != '':
self._dict.append(line)
self._char2idx = {char: idx for idx, char in enumerate(self._dict)}
self._contain_uppercase = len(re.findall('[A-Z]', ''.join(
self.dict))) > 0
self._update_dict(**kwargs)
assert len(set(self._dict)) == len(self._dict), \
'Invalid dictionary: Has duplicated characters.'
@property
def num_classes(self) -> int:
"""int: Number of output classes. Special tokens are counted.
"""
return len(self._dict)
@property
def dict(self) -> list:
"""list: The list of all character to recognize, which Special tokens
are counted."""
return self._dict
@property
def contain_uppercase(self) -> bool:
"""bool: Whether all the English characters in dict file are in lowercase.
"""
return self._contain_uppercase
def str2idx(self, string: str) -> List:
"""Convert a string to a list of indexes via ``Dictionary.dict``
Args:
string (str): The string to convert to indexes.
Return:
list: The list of indexes of the string.
"""
idx = list()
for s in string:
char_idx = self._char2idx.get(s, self.unknown_idx)
if char_idx is None:
raise Exception(f'Chararcter: {s} not in dict,'
f' please check gt_label and use'
f' custom dict file,'
f' or set "with_unknown=True"')
idx.append(char_idx)
return idx
def idx2str(self, index: Sequence[int]) -> str:
"""Convert a list of index to string.
Args:
index (list[int]): The list of indexes to convert to string.
Return:
str: The converted string.
"""
assert isinstance(index, (list, tuple))
string = ''
for i in index:
string += self._dict[i]
return string
def _update_dict(self, **kwargs):
"""Update the dict with tokens according to parameters."""
# BOS/EOS
self.start_idx = None
self.end_idx = None
if self.with_start and self.with_end and self.same_start_end:
self._dict.append(self.start_end_token)
self.start_idx = len(self._dict) - 1
self.end_idx = self.start_idx
else:
if self.with_start:
self._dict.append(self.start_token)
self.start_idx = len(self._dict) - 1
if self.with_end:
self._dict.append(self.end_token)
self.end_idx = len(self._dict) - 1
# padding
self.padding_idx = None
if self.with_padding:
self._dict.append(self.padding_token)
self.padding_idx = len(self._dict) - 1
# unknown
self.unknown_idx = None
if self.with_unknown:
self._dict.append(self.unknown_token)
self.unknown_idx = len(self._dict) - 1
# update char2idx
self._char2idx = {}
for idx, char in enumerate(self._dict):
self._char2idx[char] = idx

View File

@ -0,0 +1,132 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
from mmocr.models.textrecog import Dictionary
class TestDictionary(TestCase):
def _create_dummy_dict_file(
self, dict_file,
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
with open(dict_file, 'w') as f:
for char in chars:
f.write(char + '\n')
def test_init(self):
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
# with start
dict_gen = Dictionary(
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True)
self.assertEqual(dict_gen.num_classes, 40)
self.assertListEqual(
dict_gen.dict,
list('0123456789abcdefghijklmnopqrstuvwxyz') +
['<BOS>', '<EOS>', '<PAD>', '<UKN>'])
dict_gen = Dictionary(
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True)
assert dict_gen.num_classes == 39
assert dict_gen.dict == list('0123456789abcdefghijklmnopqrstuvwxyz'
) + ['<BOS/EOS>', '<PAD>', '<UKN>']
self.assertEqual(dict_gen.num_classes, 39)
self.assertListEqual(
dict_gen.dict,
list('0123456789abcdefghijklmnopqrstuvwxyz') +
['<BOS/EOS>', '<PAD>', '<UKN>'])
dict_gen = Dictionary(
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True,
start_token='<STA>',
end_token='<END>',
padding_token='<BLK>',
unknown_token='<NO>')
assert dict_gen.num_classes == 40
assert dict_gen.dict[-4:] == ['<STA>', '<END>', '<BLK>', '<NO>']
self.assertEqual(dict_gen.num_classes, 40)
self.assertListEqual(dict_gen.dict[-4:],
['<STA>', '<END>', '<BLK>', '<NO>'])
dict_gen = Dictionary(
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True,
start_end_token='<BE>')
self.assertEqual(dict_gen.num_classes, 39)
self.assertListEqual(dict_gen.dict[-3:], ['<BE>', '<PAD>', '<UKN>'])
# test len(line) > 1
self._create_dummy_dict_file(dict_file, chars=['12', '3', '4'])
with self.assertRaises(ValueError):
dict_gen = Dictionary(dict_file=dict_file)
# test duplicated dict
self._create_dummy_dict_file(dict_file, chars=['1', '1', '2'])
with self.assertRaises(AssertionError):
dict_gen = Dictionary(dict_file=dict_file)
tmp_dir.cleanup()
def test_num_classes(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# create dummy data
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_gen = Dictionary(dict_file=dict_file)
assert dict_gen.num_classes == 36
def test_contain_uppercase(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# create dummy data
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_gen = Dictionary(dict_file=dict_file)
assert dict_gen.contain_uppercase is False
self._create_dummy_dict_file(dict_file, chars='abcdABCD')
dict_gen = Dictionary(dict_file=dict_file)
assert dict_gen.contain_uppercase is True
def test_str2idx(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# create dummy data
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_gen = Dictionary(dict_file=dict_file)
self.assertEqual(dict_gen.str2idx('01234'), [0, 1, 2, 3, 4])
with self.assertRaises(Exception):
dict_gen.str2idx('H')
dict_gen = Dictionary(dict_file=dict_file, with_unknown=True)
self.assertListEqual(dict_gen.str2idx('H'), [dict_gen.unknown_idx])
def test_idx2str(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# create dummy data
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_gen = Dictionary(dict_file=dict_file)
self.assertEqual(dict_gen.idx2str([0, 1, 2, 3, 4]), '01234')
with self.assertRaises(AssertionError):
dict_gen.idx2str('01234')