mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
fix code style
This commit is contained in:
parent
8123688a09
commit
ae09ef607f
@ -166,21 +166,21 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|||||||
use_space_char=True,
|
use_space_char=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(NRTRLabelDecode, self).__init__(character_dict_path,
|
super(NRTRLabelDecode, self).__init__(character_dict_path,
|
||||||
character_type, use_space_char)
|
character_type, use_space_char)
|
||||||
|
|
||||||
def __call__(self, preds, label=None, *args, **kwargs):
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
if preds.dtype == paddle.int64:
|
if preds.dtype == paddle.int64:
|
||||||
if isinstance(preds, paddle.Tensor):
|
if isinstance(preds, paddle.Tensor):
|
||||||
preds = preds.numpy()
|
preds = preds.numpy()
|
||||||
if preds[0][0]==2:
|
if preds[0][0] == 2:
|
||||||
preds_idx = preds[:,1:]
|
preds_idx = preds[:, 1:]
|
||||||
else:
|
else:
|
||||||
preds_idx = preds
|
preds_idx = preds
|
||||||
|
|
||||||
text = self.decode(preds_idx)
|
text = self.decode(preds_idx)
|
||||||
if label is None:
|
if label is None:
|
||||||
return text
|
return text
|
||||||
label = self.decode(label[:,1:])
|
label = self.decode(label[:, 1:])
|
||||||
else:
|
else:
|
||||||
if isinstance(preds, paddle.Tensor):
|
if isinstance(preds, paddle.Tensor):
|
||||||
preds = preds.numpy()
|
preds = preds.numpy()
|
||||||
@ -189,13 +189,13 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|||||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
if label is None:
|
if label is None:
|
||||||
return text
|
return text
|
||||||
label = self.decode(label[:,1:])
|
label = self.decode(label[:, 1:])
|
||||||
return text, label
|
return text, label
|
||||||
|
|
||||||
def add_special_char(self, dict_character):
|
def add_special_char(self, dict_character):
|
||||||
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
|
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
""" convert text-index into text-label. """
|
""" convert text-index into text-label. """
|
||||||
result_list = []
|
result_list = []
|
||||||
@ -204,10 +204,11 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|||||||
char_list = []
|
char_list = []
|
||||||
conf_list = []
|
conf_list = []
|
||||||
for idx in range(len(text_index[batch_idx])):
|
for idx in range(len(text_index[batch_idx])):
|
||||||
if text_index[batch_idx][idx] == 3: # end
|
if text_index[batch_idx][idx] == 3: # end
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
except:
|
except:
|
||||||
continue
|
continue
|
||||||
if text_prob is not None:
|
if text_prob is not None:
|
||||||
@ -219,7 +220,6 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|||||||
return result_list
|
return result_list
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AttnLabelDecode(BaseRecLabelDecode):
|
class AttnLabelDecode(BaseRecLabelDecode):
|
||||||
""" Convert between text-label and text-index """
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
@ -257,7 +257,8 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
batch_idx][idx]:
|
batch_idx][idx]:
|
||||||
continue
|
continue
|
||||||
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
if text_prob is not None:
|
if text_prob is not None:
|
||||||
conf_list.append(text_prob[batch_idx][idx])
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
else:
|
else:
|
||||||
@ -387,10 +388,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
|||||||
class TableLabelDecode(object):
|
class TableLabelDecode(object):
|
||||||
""" """
|
""" """
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, character_dict_path, **kwargs):
|
||||||
character_dict_path,
|
list_character, list_elem = self.load_char_elem_dict(
|
||||||
**kwargs):
|
character_dict_path)
|
||||||
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
|
||||||
list_character = self.add_special_char(list_character)
|
list_character = self.add_special_char(list_character)
|
||||||
list_elem = self.add_special_char(list_elem)
|
list_elem = self.add_special_char(list_elem)
|
||||||
self.dict_character = {}
|
self.dict_character = {}
|
||||||
@ -409,7 +409,8 @@ class TableLabelDecode(object):
|
|||||||
list_elem = []
|
list_elem = []
|
||||||
with open(character_dict_path, "rb") as fin:
|
with open(character_dict_path, "rb") as fin:
|
||||||
lines = fin.readlines()
|
lines = fin.readlines()
|
||||||
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
|
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
|
||||||
|
"\t")
|
||||||
character_num = int(substr[0])
|
character_num = int(substr[0])
|
||||||
elem_num = int(substr[1])
|
elem_num = int(substr[1])
|
||||||
for cno in range(1, 1 + character_num):
|
for cno in range(1, 1 + character_num):
|
||||||
@ -429,14 +430,14 @@ class TableLabelDecode(object):
|
|||||||
def __call__(self, preds):
|
def __call__(self, preds):
|
||||||
structure_probs = preds['structure_probs']
|
structure_probs = preds['structure_probs']
|
||||||
loc_preds = preds['loc_preds']
|
loc_preds = preds['loc_preds']
|
||||||
if isinstance(structure_probs,paddle.Tensor):
|
if isinstance(structure_probs, paddle.Tensor):
|
||||||
structure_probs = structure_probs.numpy()
|
structure_probs = structure_probs.numpy()
|
||||||
if isinstance(loc_preds,paddle.Tensor):
|
if isinstance(loc_preds, paddle.Tensor):
|
||||||
loc_preds = loc_preds.numpy()
|
loc_preds = loc_preds.numpy()
|
||||||
structure_idx = structure_probs.argmax(axis=2)
|
structure_idx = structure_probs.argmax(axis=2)
|
||||||
structure_probs = structure_probs.max(axis=2)
|
structure_probs = structure_probs.max(axis=2)
|
||||||
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
|
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
|
||||||
structure_probs, 'elem')
|
structure_idx, structure_probs, 'elem')
|
||||||
res_html_code_list = []
|
res_html_code_list = []
|
||||||
res_loc_list = []
|
res_loc_list = []
|
||||||
batch_num = len(structure_str)
|
batch_num = len(structure_str)
|
||||||
@ -451,8 +452,13 @@ class TableLabelDecode(object):
|
|||||||
res_loc = np.array(res_loc)
|
res_loc = np.array(res_loc)
|
||||||
res_html_code_list.append(res_html_code)
|
res_html_code_list.append(res_html_code)
|
||||||
res_loc_list.append(res_loc)
|
res_loc_list.append(res_loc)
|
||||||
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
|
return {
|
||||||
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
|
'res_html_code': res_html_code_list,
|
||||||
|
'res_loc': res_loc_list,
|
||||||
|
'res_score_list': result_score_list,
|
||||||
|
'res_elem_idx_list': result_elem_idx_list,
|
||||||
|
'structure_str_list': structure_str
|
||||||
|
}
|
||||||
|
|
||||||
def decode(self, text_index, structure_probs, char_or_elem):
|
def decode(self, text_index, structure_probs, char_or_elem):
|
||||||
"""convert text-label into text-index.
|
"""convert text-label into text-index.
|
||||||
@ -528,9 +534,9 @@ class SARLabelDecode(BaseRecLabelDecode):
|
|||||||
use_space_char=False,
|
use_space_char=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(SARLabelDecode, self).__init__(character_dict_path,
|
super(SARLabelDecode, self).__init__(character_dict_path,
|
||||||
character_type, use_space_char)
|
character_type, use_space_char)
|
||||||
|
|
||||||
self.rm_symbol = kwargs.get('rm_symbol', True)
|
self.rm_symbol = kwargs.get('rm_symbol', True)
|
||||||
|
|
||||||
def add_special_char(self, dict_character):
|
def add_special_char(self, dict_character):
|
||||||
beg_end_str = "<BOS/EOS>"
|
beg_end_str = "<BOS/EOS>"
|
||||||
@ -549,7 +555,7 @@ class SARLabelDecode(BaseRecLabelDecode):
|
|||||||
""" convert text-index into text-label. """
|
""" convert text-index into text-label. """
|
||||||
result_list = []
|
result_list = []
|
||||||
ignored_tokens = self.get_ignored_tokens()
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
|
||||||
batch_size = len(text_index)
|
batch_size = len(text_index)
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
char_list = []
|
char_list = []
|
||||||
@ -558,7 +564,7 @@ class SARLabelDecode(BaseRecLabelDecode):
|
|||||||
if text_index[batch_idx][idx] in ignored_tokens:
|
if text_index[batch_idx][idx] in ignored_tokens:
|
||||||
continue
|
continue
|
||||||
if int(text_index[batch_idx][idx]) == int(self.end_idx):
|
if int(text_index[batch_idx][idx]) == int(self.end_idx):
|
||||||
if text_prob is None and idx ==0:
|
if text_prob is None and idx == 0:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
@ -586,7 +592,7 @@ class SARLabelDecode(BaseRecLabelDecode):
|
|||||||
preds = preds.numpy()
|
preds = preds.numpy()
|
||||||
preds_idx = preds.argmax(axis=2)
|
preds_idx = preds.argmax(axis=2)
|
||||||
preds_prob = preds.max(axis=2)
|
preds_prob = preds.max(axis=2)
|
||||||
|
|
||||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
|
||||||
if label is None:
|
if label is None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user