fix bug for rec_postprocess.py (#11389)
Co-authored-by: xueyadong <xueyadong@baidu.com>pull/11397/head
parent
d3e362a3a0
commit
c708180ce9
|
@ -1135,15 +1135,19 @@ class VLLabelDecode(BaseRecLabelDecode):
|
|||
net_out = paddle.to_tensor(net_out, dtype='float32')
|
||||
net_out = F.softmax(net_out, axis=1)
|
||||
for i in range(0, length.shape[0]):
|
||||
preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
|
||||
) + length[i])].topk(1)[1][:, 0].tolist()
|
||||
if i == 0:
|
||||
start_idx = 0
|
||||
end_idx = int(length[i])
|
||||
else:
|
||||
start_idx = int(length[:i].sum())
|
||||
end_idx = int(length[:i].sum() + length[i])
|
||||
preds_idx = net_out[start_idx:end_idx].topk(1)[1][:, 0].tolist()
|
||||
preds_text = ''.join([
|
||||
self.character[idx - 1]
|
||||
if idx > 0 and idx <= len(self.character) else ''
|
||||
for idx in preds_idx
|
||||
])
|
||||
preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
|
||||
) + length[i])].topk(1)[0][:, 0]
|
||||
preds_prob = net_out[start_idx:end_idx].topk(1)[0][:, 0]
|
||||
preds_prob = paddle.exp(
|
||||
paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
|
||||
text.append((preds_text, float(preds_prob)))
|
||||
|
|
Loading…
Reference in New Issue