mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Refactor] split labelconverter to Dictionary
This commit is contained in:
parent
c47c5711c1
commit
c920edfb3a
@ -1,9 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import (backbones, convertors, decoders, encoders, fusers, heads,
|
||||
losses, necks, plugins, preprocessor, recognizer)
|
||||
from . import (backbones, convertors, decoders, dictionary, encoders, fusers,
|
||||
heads, losses, necks, plugins, preprocessor, recognizer)
|
||||
from .backbones import * # NOQA
|
||||
from .convertors import * # NOQA
|
||||
from .decoders import * # NOQA
|
||||
from .dictionary import * # NOQA
|
||||
from .encoders import * # NOQA
|
||||
from .fusers import * # NOQA
|
||||
from .heads import * # NOQA
|
||||
@ -17,4 +18,4 @@ __all__ = (
|
||||
backbones.__all__ + convertors.__all__ + decoders.__all__ +
|
||||
encoders.__all__ + heads.__all__ + losses.__all__ + necks.__all__ +
|
||||
preprocessor.__all__ + recognizer.__all__ + fusers.__all__ +
|
||||
plugins.__all__)
|
||||
plugins.__all__ + dictionary.__all__)
|
||||
|
5
mmocr/models/textrecog/dictionary/__init__.py
Normal file
5
mmocr/models/textrecog/dictionary/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from .dictionary import Dictionary
|
||||
|
||||
__all__ = ['Dictionary']
|
169
mmocr/models/textrecog/dictionary/dictionary.py
Normal file
169
mmocr/models/textrecog/dictionary/dictionary.py
Normal file
@ -0,0 +1,169 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import re
|
||||
from typing import List, Sequence
|
||||
|
||||
from mmocr.registry import TASK_UTILS
|
||||
from mmocr.utils import list_from_file
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class Dictionary:
|
||||
"""The class generates a dictionary for recognition. It pre-defines four
|
||||
special tokens: ``start_token``, ``end_token``, ``pad_token``, and
|
||||
``unknown_token``, which will be sequentially placed at the end of the
|
||||
dictionary when their corresponding flags are True.
|
||||
|
||||
Args:
|
||||
dict_file (str): The path of Character dict file which a single
|
||||
character must occupies a line.
|
||||
with_start (bool): The flag to control whether to include the start
|
||||
token. Defaults to False.
|
||||
with_end (bool): The flag to control whether to include the end token.
|
||||
Defaults to False.
|
||||
same_start_end (bool): The flag to control whether the start token and
|
||||
end token are the same. It only works when both ``with_start`` and
|
||||
``with_end`` are True. Defaults to False.
|
||||
with_padding (bool):The padding token may represent more than a
|
||||
padding. It can also represent tokens like the blank token in CTC
|
||||
or the background token in SegOCR. Defaults to False.
|
||||
with_unknown (bool): The flag to control whether to include the
|
||||
unknown token. Defaults to False.
|
||||
start_token (str): The start token as a string. Defaults to '<BOS>'.
|
||||
end_token (str): The end token as a string. Defaults to '<EOS>'.
|
||||
start_end_token (str): The start/end token as a string. if start and
|
||||
end is the same. Defaults to '<BOS/EOS>'.
|
||||
padding_token (str): The padding token as a string.
|
||||
Defaults to '<PAD>'.
|
||||
unknown_token (str): The unknown token as a string.
|
||||
Defaults to '<UKN>'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dict_file: str,
|
||||
with_start: bool = False,
|
||||
with_end: bool = False,
|
||||
same_start_end: bool = False,
|
||||
with_padding: bool = False,
|
||||
with_unknown: bool = False,
|
||||
start_token: str = '<BOS>',
|
||||
end_token: str = '<EOS>',
|
||||
start_end_token: str = '<BOS/EOS>',
|
||||
padding_token: str = '<PAD>',
|
||||
unknown_token: str = '<UKN>',
|
||||
**kwargs) -> None:
|
||||
self.with_start = with_start
|
||||
self.with_end = with_end
|
||||
self.same_start_end = same_start_end
|
||||
self.with_padding = with_padding
|
||||
self.with_unknown = with_unknown
|
||||
self.start_end_token = start_end_token
|
||||
self.start_token = start_token
|
||||
self.end_token = end_token
|
||||
self.padding_token = padding_token
|
||||
self.unknown_token = unknown_token
|
||||
|
||||
assert isinstance(dict_file, str)
|
||||
self._dict = []
|
||||
for line_num, line in enumerate(list_from_file(dict_file)):
|
||||
line = line.strip('\r\n')
|
||||
if len(line) > 1:
|
||||
raise ValueError('Expect each line has 0 or 1 character, '
|
||||
f'got {len(line)} characters '
|
||||
f'at line {line_num + 1}')
|
||||
if line != '':
|
||||
self._dict.append(line)
|
||||
|
||||
self._char2idx = {char: idx for idx, char in enumerate(self._dict)}
|
||||
self._contain_uppercase = len(re.findall('[A-Z]', ''.join(
|
||||
self.dict))) > 0
|
||||
|
||||
self._update_dict(**kwargs)
|
||||
assert len(set(self._dict)) == len(self._dict), \
|
||||
'Invalid dictionary: Has duplicated characters.'
|
||||
|
||||
@property
|
||||
def num_classes(self) -> int:
|
||||
"""int: Number of output classes. Special tokens are counted.
|
||||
"""
|
||||
return len(self._dict)
|
||||
|
||||
@property
|
||||
def dict(self) -> list:
|
||||
"""list: The list of all character to recognize, which Special tokens
|
||||
are counted."""
|
||||
return self._dict
|
||||
|
||||
@property
|
||||
def contain_uppercase(self) -> bool:
|
||||
"""bool: Whether all the English characters in dict file are in lowercase.
|
||||
"""
|
||||
return self._contain_uppercase
|
||||
|
||||
def str2idx(self, string: str) -> List:
|
||||
"""Convert a string to a list of indexes via ``Dictionary.dict``
|
||||
|
||||
Args:
|
||||
string (str): The string to convert to indexes.
|
||||
|
||||
Return:
|
||||
list: The list of indexes of the string.
|
||||
"""
|
||||
idx = list()
|
||||
for s in string:
|
||||
char_idx = self._char2idx.get(s, self.unknown_idx)
|
||||
if char_idx is None:
|
||||
raise Exception(f'Chararcter: {s} not in dict,'
|
||||
f' please check gt_label and use'
|
||||
f' custom dict file,'
|
||||
f' or set "with_unknown=True"')
|
||||
idx.append(char_idx)
|
||||
return idx
|
||||
|
||||
def idx2str(self, index: Sequence[int]) -> str:
|
||||
"""Convert a list of index to string.
|
||||
|
||||
Args:
|
||||
index (list[int]): The list of indexes to convert to string.
|
||||
|
||||
Return:
|
||||
str: The converted string.
|
||||
"""
|
||||
assert isinstance(index, (list, tuple))
|
||||
string = ''
|
||||
for i in index:
|
||||
string += self._dict[i]
|
||||
return string
|
||||
|
||||
def _update_dict(self, **kwargs):
|
||||
"""Update the dict with tokens according to parameters."""
|
||||
# BOS/EOS
|
||||
self.start_idx = None
|
||||
self.end_idx = None
|
||||
if self.with_start and self.with_end and self.same_start_end:
|
||||
self._dict.append(self.start_end_token)
|
||||
self.start_idx = len(self._dict) - 1
|
||||
self.end_idx = self.start_idx
|
||||
else:
|
||||
if self.with_start:
|
||||
self._dict.append(self.start_token)
|
||||
self.start_idx = len(self._dict) - 1
|
||||
if self.with_end:
|
||||
self._dict.append(self.end_token)
|
||||
self.end_idx = len(self._dict) - 1
|
||||
|
||||
# padding
|
||||
self.padding_idx = None
|
||||
if self.with_padding:
|
||||
self._dict.append(self.padding_token)
|
||||
self.padding_idx = len(self._dict) - 1
|
||||
|
||||
# unknown
|
||||
self.unknown_idx = None
|
||||
if self.with_unknown:
|
||||
self._dict.append(self.unknown_token)
|
||||
self.unknown_idx = len(self._dict) - 1
|
||||
|
||||
# update char2idx
|
||||
self._char2idx = {}
|
||||
for idx, char in enumerate(self._dict):
|
||||
self._char2idx[char] = idx
|
132
tests/test_models/test_dictionary/test_dictionary.py
Normal file
132
tests/test_models/test_dictionary/test_dictionary.py
Normal file
@ -0,0 +1,132 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
from mmocr.models.textrecog import Dictionary
|
||||
|
||||
|
||||
class TestDictionary(TestCase):
|
||||
|
||||
def _create_dummy_dict_file(
|
||||
self, dict_file,
|
||||
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
|
||||
with open(dict_file, 'w') as f:
|
||||
for char in chars:
|
||||
f.write(char + '\n')
|
||||
|
||||
def test_init(self):
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
# with start
|
||||
dict_gen = Dictionary(
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
self.assertEqual(dict_gen.num_classes, 40)
|
||||
self.assertListEqual(
|
||||
dict_gen.dict,
|
||||
list('0123456789abcdefghijklmnopqrstuvwxyz') +
|
||||
['<BOS>', '<EOS>', '<PAD>', '<UKN>'])
|
||||
dict_gen = Dictionary(
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=True,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
assert dict_gen.num_classes == 39
|
||||
assert dict_gen.dict == list('0123456789abcdefghijklmnopqrstuvwxyz'
|
||||
) + ['<BOS/EOS>', '<PAD>', '<UKN>']
|
||||
self.assertEqual(dict_gen.num_classes, 39)
|
||||
self.assertListEqual(
|
||||
dict_gen.dict,
|
||||
list('0123456789abcdefghijklmnopqrstuvwxyz') +
|
||||
['<BOS/EOS>', '<PAD>', '<UKN>'])
|
||||
dict_gen = Dictionary(
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True,
|
||||
start_token='<STA>',
|
||||
end_token='<END>',
|
||||
padding_token='<BLK>',
|
||||
unknown_token='<NO>')
|
||||
assert dict_gen.num_classes == 40
|
||||
assert dict_gen.dict[-4:] == ['<STA>', '<END>', '<BLK>', '<NO>']
|
||||
self.assertEqual(dict_gen.num_classes, 40)
|
||||
self.assertListEqual(dict_gen.dict[-4:],
|
||||
['<STA>', '<END>', '<BLK>', '<NO>'])
|
||||
dict_gen = Dictionary(
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=True,
|
||||
with_padding=True,
|
||||
with_unknown=True,
|
||||
start_end_token='<BE>')
|
||||
self.assertEqual(dict_gen.num_classes, 39)
|
||||
self.assertListEqual(dict_gen.dict[-3:], ['<BE>', '<PAD>', '<UKN>'])
|
||||
# test len(line) > 1
|
||||
self._create_dummy_dict_file(dict_file, chars=['12', '3', '4'])
|
||||
with self.assertRaises(ValueError):
|
||||
dict_gen = Dictionary(dict_file=dict_file)
|
||||
|
||||
# test duplicated dict
|
||||
self._create_dummy_dict_file(dict_file, chars=['1', '1', '2'])
|
||||
with self.assertRaises(AssertionError):
|
||||
dict_gen = Dictionary(dict_file=dict_file)
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
def test_num_classes(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_gen = Dictionary(dict_file=dict_file)
|
||||
assert dict_gen.num_classes == 36
|
||||
|
||||
def test_contain_uppercase(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_gen = Dictionary(dict_file=dict_file)
|
||||
assert dict_gen.contain_uppercase is False
|
||||
self._create_dummy_dict_file(dict_file, chars='abcdABCD')
|
||||
dict_gen = Dictionary(dict_file=dict_file)
|
||||
assert dict_gen.contain_uppercase is True
|
||||
|
||||
def test_str2idx(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_gen = Dictionary(dict_file=dict_file)
|
||||
self.assertEqual(dict_gen.str2idx('01234'), [0, 1, 2, 3, 4])
|
||||
with self.assertRaises(Exception):
|
||||
dict_gen.str2idx('H')
|
||||
|
||||
dict_gen = Dictionary(dict_file=dict_file, with_unknown=True)
|
||||
self.assertListEqual(dict_gen.str2idx('H'), [dict_gen.unknown_idx])
|
||||
|
||||
def test_idx2str(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_gen = Dictionary(dict_file=dict_file)
|
||||
self.assertEqual(dict_gen.idx2str([0, 1, 2, 3, 4]), '01234')
|
||||
with self.assertRaises(AssertionError):
|
||||
dict_gen.idx2str('01234')
|
Loading…
x
Reference in New Issue
Block a user