unifying data types in the SLAHead (#13276)
parent
153de46b67
commit
024da2a58c
|
@ -368,7 +368,7 @@ class SLAHead(nn.Layer):
|
|||
loc_preds = loc_preds[:, : max_len + 1]
|
||||
else:
|
||||
structure_ids = paddle.zeros(
|
||||
(batch_size, self.max_text_length + 1), dtype=paddle.int64
|
||||
(batch_size, self.max_text_length + 1), dtype="int32"
|
||||
)
|
||||
pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
max_text_length = paddle.to_tensor(self.max_text_length)
|
||||
|
|
Loading…
Reference in New Issue