charry pick nrtr_postprocess and modify data type to adaptive
parent
1edadf199b
commit
345b1510ab
|
@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer):
|
|||
log_prb = F.log_softmax(pred, axis=1)
|
||||
non_pad_mask = paddle.not_equal(
|
||||
tgt, paddle.zeros(
|
||||
tgt.shape, dtype='int32'))
|
||||
tgt.shape, dtype=tgt.dtype))
|
||||
loss = -(one_hot * log_prb).sum(axis=1)
|
||||
loss = loss.masked_select(non_pad_mask).mean()
|
||||
else:
|
||||
|
|
|
@ -168,14 +168,6 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|||
character_type, use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if preds.dtype == paddle.int64:
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
if preds[0][0]==2:
|
||||
preds_idx = preds[:,1:]
|
||||
else:
|
||||
preds_idx = preds
|
||||
|
||||
if len(preds) == 2:
|
||||
preds_id = preds[0]
|
||||
preds_prob = preds[1]
|
||||
|
|
Loading…
Reference in New Issue