unifying data types in the SLAHead (#13276)

pull/13308/head
Wang Xin 2024-07-06 22:01:06 +08:00 committed by GitHub
parent 153de46b67
commit 024da2a58c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -368,7 +368,7 @@ class SLAHead(nn.Layer):
loc_preds = loc_preds[:, : max_len + 1]
else:
structure_ids = paddle.zeros(
(batch_size, self.max_text_length + 1), dtype=paddle.int64
(batch_size, self.max_text_length + 1), dtype="int32"
)
pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
max_text_length = paddle.to_tensor(self.max_text_length)