fix bug for rec_postprocess.py (#11389)

Co-authored-by: xueyadong <xueyadong@baidu.com>
pull/11397/head
Xue Yadong 2023-12-19 11:06:25 +08:00 committed by GitHub
parent d3e362a3a0
commit c708180ce9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 4 deletions

View File

@ -1135,15 +1135,19 @@ class VLLabelDecode(BaseRecLabelDecode):
net_out = paddle.to_tensor(net_out, dtype='float32')
net_out = F.softmax(net_out, axis=1)
for i in range(0, length.shape[0]):
preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
) + length[i])].topk(1)[1][:, 0].tolist()
if i == 0:
start_idx = 0
end_idx = int(length[i])
else:
start_idx = int(length[:i].sum())
end_idx = int(length[:i].sum() + length[i])
preds_idx = net_out[start_idx:end_idx].topk(1)[1][:, 0].tolist()
preds_text = ''.join([
self.character[idx - 1]
if idx > 0 and idx <= len(self.character) else ''
for idx in preds_idx
])
preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
) + length[i])].topk(1)[0][:, 0]
preds_prob = net_out[start_idx:end_idx].topk(1)[0][:, 0]
preds_prob = paddle.exp(
paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
text.append((preds_text, float(preds_prob)))