Merge pull request #4462 from Topdu/release/2.3
charry pick nrtr_postprocess and modify data type to adaptivepull/4512/head
commit
4cbb0ee86f
|
@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer):
|
||||||
log_prb = F.log_softmax(pred, axis=1)
|
log_prb = F.log_softmax(pred, axis=1)
|
||||||
non_pad_mask = paddle.not_equal(
|
non_pad_mask = paddle.not_equal(
|
||||||
tgt, paddle.zeros(
|
tgt, paddle.zeros(
|
||||||
tgt.shape, dtype='int64'))
|
tgt.shape, dtype=tgt.dtype))
|
||||||
loss = -(one_hot * log_prb).sum(axis=1)
|
loss = -(one_hot * log_prb).sum(axis=1)
|
||||||
loss = loss.masked_select(non_pad_mask).mean()
|
loss = loss.masked_select(non_pad_mask).mean()
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -168,14 +168,6 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
||||||
character_type, use_space_char)
|
character_type, use_space_char)
|
||||||
|
|
||||||
def __call__(self, preds, label=None, *args, **kwargs):
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
if preds.dtype == paddle.int64:
|
|
||||||
if isinstance(preds, paddle.Tensor):
|
|
||||||
preds = preds.numpy()
|
|
||||||
if preds[0][0]==2:
|
|
||||||
preds_idx = preds[:,1:]
|
|
||||||
else:
|
|
||||||
preds_idx = preds
|
|
||||||
|
|
||||||
if len(preds) == 2:
|
if len(preds) == 2:
|
||||||
preds_id = preds[0]
|
preds_id = preds[0]
|
||||||
preds_prob = preds[1]
|
preds_prob = preds[1]
|
||||||
|
|
Loading…
Reference in New Issue