Merge pull request #4462 from Topdu/release/2.3

charry pick nrtr_postprocess and modify data type to adaptive
pull/4512/head
xiaoting 2021-10-28 13:31:08 +08:00 committed by GitHub
commit 4cbb0ee86f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 1 additions and 9 deletions

View File

@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer):
log_prb = F.log_softmax(pred, axis=1) log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal( non_pad_mask = paddle.not_equal(
tgt, paddle.zeros( tgt, paddle.zeros(
tgt.shape, dtype='int64')) tgt.shape, dtype=tgt.dtype))
loss = -(one_hot * log_prb).sum(axis=1) loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean() loss = loss.masked_select(non_pad_mask).mean()
else: else:

View File

@ -168,14 +168,6 @@ class NRTRLabelDecode(BaseRecLabelDecode):
character_type, use_space_char) character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if preds.dtype == paddle.int64:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if preds[0][0]==2:
preds_idx = preds[:,1:]
else:
preds_idx = preds
if len(preds) == 2: if len(preds) == 2:
preds_id = preds[0] preds_id = preds[0]
preds_prob = preds[1] preds_prob = preds[1]