mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
* [Feature] Add RepeatAugSampler * initial commit * spts inference done * merge repeat_aug (bug in multi-node?) * fix inference * train done * rm readme * Revert "merge repeat_aug (bug in multi-node?)" This reverts commit 393506a97cbe6d75ad1f28611ea10eba6b8fa4b3. * Revert "[Feature] Add RepeatAugSampler" This reverts commit 2089b02b4844157670033766f257b5d1bca452ce. * remove utils * readme & conversion script * update readme * fix * optimize * rename cfg & del compose * fix * fix
139 lines
5.5 KiB
Python
139 lines
5.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Sequence
|
|
|
|
from mmocr.models.common.dictionary import Dictionary
|
|
from mmocr.registry import TASK_UTILS
|
|
|
|
|
|
@TASK_UTILS.register_module()
|
|
class SPTSDictionary(Dictionary):
|
|
"""The class generates a dictionary for recognition. It pre-defines four
|
|
special tokens: ``start_token``, ``end_token``, ``pad_token``, and
|
|
``unknown_token``, which will be sequentially placed at the end of the
|
|
dictionary when their corresponding flags are True.
|
|
|
|
Args:
|
|
dict_file (str): The path of Character dict file which a single
|
|
character must occupies a line.
|
|
with_start (bool): The flag to control whether to include the start
|
|
token. Defaults to False.
|
|
with_end (bool): The flag to control whether to include the end token.
|
|
Defaults to False.
|
|
with_seq end (bool): The flag to control whether to include the
|
|
sequence end token. Defaults to False.
|
|
same_start_end (bool): The flag to control whether the start token and
|
|
end token are the same. It only works when both ``with_start`` and
|
|
``with_end`` are True. Defaults to False.
|
|
with_padding (bool):The padding token may represent more than a
|
|
padding. It can also represent tokens like the blank token in CTC
|
|
or the background token in SegOCR. Defaults to False.
|
|
with_unknown (bool): The flag to control whether to include the
|
|
unknown token. Defaults to False.
|
|
start_token (str): The start token as a string. Defaults to '<BOS>'.
|
|
end_token (str): The end token as a string. Defaults to '<EOS>'.
|
|
seq_end_token (str): The sequence end token as a string. Defaults to
|
|
'<SEQEOS>'.
|
|
start_end_token (str): The start/end token as a string. if start and
|
|
end is the same. Defaults to '<BOS/EOS>'.
|
|
padding_token (str): The padding token as a string.
|
|
Defaults to '<PAD>'.
|
|
unknown_token (str, optional): The unknown token as a string. If it's
|
|
set to None and ``with_unknown`` is True, the unknown token will be
|
|
skipped when converting string to index. Defaults to '<UKN>'.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dict_file: str,
|
|
with_start: bool = False,
|
|
with_end: bool = False,
|
|
with_seq_end: bool = False,
|
|
same_start_end: bool = False,
|
|
with_padding: bool = False,
|
|
with_unknown: bool = False,
|
|
start_token: str = '<BOS>',
|
|
end_token: str = '<EOS>',
|
|
seq_end_token: str = '<SEQEOS>',
|
|
start_end_token: str = '<BOS/EOS>',
|
|
padding_token: str = '<PAD>',
|
|
unknown_token: str = '<UKN>',
|
|
) -> None:
|
|
self.with_seq_end = with_seq_end
|
|
self.seq_end_token = seq_end_token
|
|
|
|
super().__init__(
|
|
dict_file=dict_file,
|
|
with_start=with_start,
|
|
with_end=with_end,
|
|
same_start_end=same_start_end,
|
|
with_padding=with_padding,
|
|
with_unknown=with_unknown,
|
|
start_token=start_token,
|
|
end_token=end_token,
|
|
start_end_token=start_end_token,
|
|
padding_token=padding_token,
|
|
unknown_token=unknown_token)
|
|
|
|
def _update_dict(self):
|
|
"""Update the dict with tokens according to parameters."""
|
|
# BOS/EOS
|
|
self.start_idx = None
|
|
self.end_idx = None
|
|
# unknown
|
|
self.unknown_idx = None
|
|
# TODO: Check if this line in Dictionary is correct and
|
|
# work as expected
|
|
# if self.with_unknown and self.unknown_token is not None:
|
|
if self.with_unknown:
|
|
self._dict.append(self.unknown_token)
|
|
self.unknown_idx = len(self._dict) - 1
|
|
|
|
if self.with_start and self.with_end and self.same_start_end:
|
|
self._dict.append(self.start_end_token)
|
|
self.start_idx = len(self._dict) - 1
|
|
self.end_idx = self.start_idx
|
|
if self.with_seq_end:
|
|
self._dict.append(self.seq_end_token)
|
|
self.seq_end_idx = len(self.dict) - 1
|
|
else:
|
|
if self.with_end:
|
|
self._dict.append(self.end_token)
|
|
self.end_idx = len(self._dict) - 1
|
|
if self.with_seq_end:
|
|
self._dict.append(self.seq_end_token)
|
|
self.seq_end_idx = len(self.dict) - 1
|
|
if self.with_start:
|
|
self._dict.append(self.start_token)
|
|
self.start_idx = len(self._dict) - 1
|
|
|
|
# padding
|
|
self.padding_idx = None
|
|
if self.with_padding:
|
|
self._dict.append(self.padding_token)
|
|
self.padding_idx = len(self._dict) - 1
|
|
|
|
# update char2idx
|
|
self._char2idx = {}
|
|
for idx, char in enumerate(self._dict):
|
|
self._char2idx[char] = idx
|
|
|
|
def idx2str(self, index: Sequence[int]) -> str:
|
|
"""Convert a list of index to string.
|
|
|
|
Args:
|
|
index (list[int]): The list of indexes to convert to string.
|
|
|
|
Return:
|
|
str: The converted string.
|
|
"""
|
|
assert isinstance(index, (list, tuple))
|
|
string = ''
|
|
for i in index:
|
|
assert i < len(self._dict), f'Index: {i} out of range! Index ' \
|
|
f'must be less than {len(self._dict)}'
|
|
# TODO: find its difference from ignore_chars
|
|
# in TextRecogPostprocessor
|
|
if self._dict[i] is not None:
|
|
string += self._dict[i]
|
|
return string
|