mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
fix bug for paddlepaddle3.0 (#13568)
This commit is contained in:
parent
6c9bae667b
commit
9c19e6dffe
@ -69,7 +69,7 @@ class SLALoss(nn.Layer):
|
||||
def forward(self, predicts, batch):
|
||||
structure_probs = predicts["structure_probs"]
|
||||
structure_targets = batch[1].astype("int64")
|
||||
max_len = batch[-2].max()
|
||||
max_len = batch[-2].max().astype("int32")
|
||||
structure_targets = structure_targets[:, 1 : max_len + 2]
|
||||
|
||||
structure_loss = self.loss_func(structure_probs, structure_targets)
|
||||
|
@ -357,7 +357,7 @@ class SLAHead(nn.Layer):
|
||||
|
||||
if self.training and targets is not None:
|
||||
structure = targets[0]
|
||||
max_len = targets[-2].max()
|
||||
max_len = targets[-2].max().astype("int32")
|
||||
for i in range(max_len + 1):
|
||||
hidden, structure_step, loc_step = self._decode(
|
||||
structure[:, i], fea, hidden
|
||||
|
Loading…
x
Reference in New Issue
Block a user