mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
fix bug for paddlepaddle3.0 (#13568)
This commit is contained in:
parent
6c9bae667b
commit
9c19e6dffe
@ -69,7 +69,7 @@ class SLALoss(nn.Layer):
|
|||||||
def forward(self, predicts, batch):
|
def forward(self, predicts, batch):
|
||||||
structure_probs = predicts["structure_probs"]
|
structure_probs = predicts["structure_probs"]
|
||||||
structure_targets = batch[1].astype("int64")
|
structure_targets = batch[1].astype("int64")
|
||||||
max_len = batch[-2].max()
|
max_len = batch[-2].max().astype("int32")
|
||||||
structure_targets = structure_targets[:, 1 : max_len + 2]
|
structure_targets = structure_targets[:, 1 : max_len + 2]
|
||||||
|
|
||||||
structure_loss = self.loss_func(structure_probs, structure_targets)
|
structure_loss = self.loss_func(structure_probs, structure_targets)
|
||||||
|
@ -357,7 +357,7 @@ class SLAHead(nn.Layer):
|
|||||||
|
|
||||||
if self.training and targets is not None:
|
if self.training and targets is not None:
|
||||||
structure = targets[0]
|
structure = targets[0]
|
||||||
max_len = targets[-2].max()
|
max_len = targets[-2].max().astype("int32")
|
||||||
for i in range(max_len + 1):
|
for i in range(max_len + 1):
|
||||||
hidden, structure_step, loc_step = self._decode(
|
hidden, structure_step, loc_step = self._decode(
|
||||||
structure[:, i], fea, hidden
|
structure[:, i], fea, hidden
|
||||||
|
Loading…
x
Reference in New Issue
Block a user