modified pr
parent
4a3b874a36
commit
cb370419ec
|
@ -1217,7 +1217,7 @@ class ABINetLabelEncode(BaseRecLabelEncode):
|
|||
dict_character = ['</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
class SPINAttnLabelEncode(BaseRecLabelEncode):
|
||||
class SPINAttnLabelEncode(AttnLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
@ -12,6 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/sequence_heads/att_head.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
|
|
@ -669,7 +669,86 @@ class ABINetLabelDecode(NRTRLabelDecode):
|
|||
return dict_character
|
||||
|
||||
|
||||
class SPINAttnLabelDecode(BaseRecLabelDecode):
|
||||
# class SPINAttnLabelDecode(BaseRecLabelDecode):
|
||||
# """ Convert between text-label and text-index """
|
||||
|
||||
# def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
# **kwargs):
|
||||
# super(SPINAttnLabelDecode, self).__init__(character_dict_path,
|
||||
# use_space_char)
|
||||
|
||||
# def add_special_char(self, dict_character):
|
||||
# self.beg_str = "sos"
|
||||
# self.end_str = "eos"
|
||||
# dict_character = dict_character
|
||||
# dict_character = [self.beg_str] + [self.end_str] + dict_character
|
||||
# return dict_character
|
||||
|
||||
# def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
# """ convert text-index into text-label. """
|
||||
# result_list = []
|
||||
# ignored_tokens = self.get_ignored_tokens()
|
||||
# [beg_idx, end_idx] = self.get_ignored_tokens()
|
||||
# batch_size = len(text_index)
|
||||
# for batch_idx in range(batch_size):
|
||||
# char_list = []
|
||||
# conf_list = []
|
||||
# for idx in range(len(text_index[batch_idx])):
|
||||
# if text_index[batch_idx][idx] == int(beg_idx):
|
||||
# continue
|
||||
# if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||
# break
|
||||
# if is_remove_duplicate:
|
||||
# # only for predict
|
||||
# if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
# batch_idx][idx]:
|
||||
# continue
|
||||
# char_list.append(self.character[int(text_index[batch_idx][
|
||||
# idx])])
|
||||
# if text_prob is not None:
|
||||
# conf_list.append(text_prob[batch_idx][idx])
|
||||
# else:
|
||||
# conf_list.append(1)
|
||||
# text = ''.join(char_list)
|
||||
# result_list.append((text.lower(), np.mean(conf_list).tolist()))
|
||||
# return result_list
|
||||
|
||||
# def __call__(self, preds, label=None, *args, **kwargs):
|
||||
# """
|
||||
# text = self.decode(text)
|
||||
# if label is None:
|
||||
# return text
|
||||
# else:
|
||||
# label = self.decode(label, is_remove_duplicate=False)
|
||||
# return text, label
|
||||
# """
|
||||
# if isinstance(preds, paddle.Tensor):
|
||||
# preds = preds.numpy()
|
||||
|
||||
# preds_idx = preds.argmax(axis=2)
|
||||
# preds_prob = preds.max(axis=2)
|
||||
# text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
# if label is None:
|
||||
# return text
|
||||
# label = self.decode(label, is_remove_duplicate=False)
|
||||
# return text, label
|
||||
|
||||
# def get_ignored_tokens(self):
|
||||
# beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
# end_idx = self.get_beg_end_flag_idx("end")
|
||||
# return [beg_idx, end_idx]
|
||||
|
||||
# def get_beg_end_flag_idx(self, beg_or_end):
|
||||
# if beg_or_end == "beg":
|
||||
# idx = np.array(self.dict[self.beg_str])
|
||||
# elif beg_or_end == "end":
|
||||
# idx = np.array(self.dict[self.end_str])
|
||||
# else:
|
||||
# assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||
# % beg_or_end
|
||||
# return idx
|
||||
|
||||
class SPINAttnLabelDecode(AttnLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
|
@ -682,68 +761,4 @@ class SPINAttnLabelDecode(BaseRecLabelDecode):
|
|||
self.end_str = "eos"
|
||||
dict_character = dict_character
|
||||
dict_character = [self.beg_str] + [self.end_str] + dict_character
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
[beg_idx, end_idx] = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] == int(beg_idx):
|
||||
continue
|
||||
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||
break
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text.lower(), np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
"""
|
||||
text = self.decode(text)
|
||||
if label is None:
|
||||
return text
|
||||
else:
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
"""
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
return dict_character
|
|
@ -91,7 +91,7 @@ def export_single_model(model,
|
|||
]
|
||||
# print([None, 3, 32, 128])
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "NRTR" or arch_config["algorithm"] == "SPIN":
|
||||
elif arch_config["algorithm"] in ["NRTR", "SPIN"]:
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 1, 32, 100], dtype="float32"),
|
||||
|
|
Loading…
Reference in New Issue