commit
f96797e106
|
@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer):
|
||||||
log_prb = F.log_softmax(pred, axis=1)
|
log_prb = F.log_softmax(pred, axis=1)
|
||||||
non_pad_mask = paddle.not_equal(
|
non_pad_mask = paddle.not_equal(
|
||||||
tgt, paddle.zeros(
|
tgt, paddle.zeros(
|
||||||
tgt.shape, dtype='int64'))
|
tgt.shape, dtype=tgt.dtype))
|
||||||
loss = -(one_hot * log_prb).sum(axis=1)
|
loss = -(one_hot * log_prb).sum(axis=1)
|
||||||
loss = loss.masked_select(non_pad_mask).mean()
|
loss = loss.masked_select(non_pad_mask).mean()
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue