charry pick nrtr_postprocess and modify data type to adaptive

pull/4462/head
Topdu 2021-10-27 14:24:29 +00:00
parent 1edadf199b
commit 345b1510ab
2 changed files with 1 additions and 9 deletions

View File

@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer):
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype='int32'))
tgt.shape, dtype=tgt.dtype))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:

View File

@ -168,14 +168,6 @@ class NRTRLabelDecode(BaseRecLabelDecode):
character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if preds.dtype == paddle.int64:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if preds[0][0]==2:
preds_idx = preds[:,1:]
else:
preds_idx = preds
if len(preds) == 2:
preds_id = preds[0]
preds_prob = preds[1]