fix bug for paddlepaddle3.0 (#13568)

This commit is contained in:
changdazhou 2024-08-01 22:50:44 +08:00 committed by GitHub
parent 6c9bae667b
commit 9c19e6dffe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -69,7 +69,7 @@ class SLALoss(nn.Layer):
def forward(self, predicts, batch):
structure_probs = predicts["structure_probs"]
structure_targets = batch[1].astype("int64")
max_len = batch[-2].max()
max_len = batch[-2].max().astype("int32")
structure_targets = structure_targets[:, 1 : max_len + 2]
structure_loss = self.loss_func(structure_probs, structure_targets)

View File

@ -357,7 +357,7 @@ class SLAHead(nn.Layer):
if self.training and targets is not None:
structure = targets[0]
max_len = targets[-2].max()
max_len = targets[-2].max().astype("int32")
for i in range(max_len + 1):
hidden, structure_step, loc_step = self._decode(
structure[:, i], fea, hidden