fix nan in ppocrv4 for benchmark (#14072)

* fix nan in ppocrv4 for benchmark

* fix config
pull/14079/head
wangna11BD 2024-10-23 11:55:43 +08:00 committed by GitHub
parent 8327f79b86
commit 661cda1289
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 3 deletions

View File

@ -33,6 +33,7 @@ Architecture:
Head:
name: DBHead
k: 50
fix_nan: True
Loss:
name: DBLoss

View File

@ -33,6 +33,7 @@ Architecture:
name: PFHeadLocal
k: 50
mode: "large"
fix_nan: True
Loss:

View File

@ -32,7 +32,7 @@ def get_bias_attr(k):
class Head(nn.Layer):
def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs):
def __init__(self, in_channels, kernel_list=[3, 2, 2], fix_nan=False, **kwargs):
super(Head, self).__init__()
self.conv1 = nn.Conv2D(
@ -73,14 +73,16 @@ class Head(nn.Layer):
bias_attr=get_bias_attr(in_channels // 4),
)
self.fix_nan = fix_nan
def forward(self, x, return_f=False):
x = self.conv1(x)
x = self.conv_bn1(x)
if self.training:
if self.fix_nan and self.training:
x = paddle.where(paddle.isnan(x), paddle.zeros_like(x), x)
x = self.conv2(x)
x = self.conv_bn2(x)
if self.training:
if self.fix_nan and self.training:
x = paddle.where(paddle.isnan(x), paddle.zeros_like(x), x)
if return_f is True:
f = x