modify int64 to int32 for windows train

pull/4463/head
topduke 2021-10-27 15:53:46 +08:00 committed by GitHub
parent 1008376c3f
commit b995249965
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer):
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype='int64'))
tgt.shape, dtype='int32'))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else: