This commit is contained in:
LDOUBLEV 2022-04-26 17:29:38 +08:00
parent d222ff0309
commit 62a01f7e30
3 changed files with 7 additions and 7 deletions

View File

@ -32,7 +32,7 @@ NetWorks:
model_name: large model_name: large
disable_se: true disable_se: true
Neck: Neck:
name: CAPAN name: CAFPN
out_channels: 96 out_channels: 96
shortcut: True shortcut: True
Head: Head:
@ -48,7 +48,7 @@ NetWorks:
model_name: large model_name: large
disable_se: true disable_se: true
Neck: Neck:
name: CAPAN name: CAFPN
out_channels: 96 out_channels: 96
shortcut: True shortcut: True
Head: Head:

View File

@ -28,7 +28,7 @@ Architecture:
model_name: large model_name: large
disable_se: True disable_se: True
Neck: Neck:
name: CAPAN name: CAFPN
out_channels: 96 out_channels: 96
shortcut: True shortcut: True
Head: Head:

View File

@ -37,8 +37,8 @@ class Head(nn.Layer):
self.conv1 = nn.Conv2D( self.conv1 = nn.Conv2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels // 4, out_channels=in_channels // 4,
kernel_size=kernel_size[0], kernel_size=kernel_list[0],
padding=int(kernel_size[0] // 2), padding=int(kernel_list[0] // 2),
weight_attr=ParamAttr(), weight_attr=ParamAttr(),
bias_attr=False) bias_attr=False)
self.conv_bn1 = nn.BatchNorm( self.conv_bn1 = nn.BatchNorm(
@ -51,7 +51,7 @@ class Head(nn.Layer):
self.conv2 = nn.Conv2DTranspose( self.conv2 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
out_channels=in_channels // 4, out_channels=in_channels // 4,
kernel_size=kernel_size[1], kernel_size=kernel_list[1],
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
initializer=paddle.nn.initializer.KaimingUniform()), initializer=paddle.nn.initializer.KaimingUniform()),
@ -66,7 +66,7 @@ class Head(nn.Layer):
self.conv3 = nn.Conv2DTranspose( self.conv3 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
out_channels=1, out_channels=1,
kernel_size=kernel_size[2], kernel_size=kernel_list[2],
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
initializer=paddle.nn.initializer.KaimingUniform()), initializer=paddle.nn.initializer.KaimingUniform()),