dbg ghostnet
parent
7c9e695f8c
commit
7064136058
|
@ -27,6 +27,7 @@ from .hrnet import HRNet_W18_C
|
|||
from .efficientnet import EfficientNetB0
|
||||
from .resnest import ResNeSt50_fast_1s1x64d, ResNeSt50
|
||||
from .googlenet import GoogLeNet
|
||||
from .ghostnet import GhostNet_x0_5, GhostNet_x1_0, GhostNet_x1_3
|
||||
from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1
|
||||
from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0
|
||||
from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25
|
||||
|
|
|
@ -20,7 +20,6 @@ import paddle.nn.functional as F
|
|||
from paddle.nn import Conv2d, BatchNorm, AdaptiveAvgPool2d, Linear
|
||||
from paddle.fluid.regularizer import L2DecayRegularizer
|
||||
from paddle.nn.initializer import Uniform
|
||||
from paddle import fluid
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
|
@ -42,9 +41,12 @@ class ConvBNLayer(nn.Layer):
|
|||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
weight_attr=ParamAttr(initializer=nn.initializer.MSRA(), name=name + "_weights"),
|
||||
bias_attr=False
|
||||
)
|
||||
bn_name = name + "_bn"
|
||||
|
||||
# In the old version, moving_variance_name was name + "_variance"
|
||||
self._batch_norm = BatchNorm(
|
||||
num_filters,
|
||||
act=act,
|
||||
|
@ -104,7 +106,7 @@ class SEBlock(nn.Layer):
|
|||
squeeze = self.squeeze(pool)
|
||||
squeeze = F.relu(squeeze)
|
||||
excitation = self.excitation(squeeze)
|
||||
excitation = F.sigmoid(excitation)
|
||||
excitation = paddle.fluid.layers.clip(x=excitation, min=0, max=1)
|
||||
excitation = paddle.reshape(
|
||||
excitation,
|
||||
shape=[-1, self._num_channels, 1, 1]
|
||||
|
@ -138,7 +140,7 @@ class GhostModule(nn.Layer):
|
|||
name=name + "_primary_conv"
|
||||
)
|
||||
self.cheap_operation = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_channels=init_channels,
|
||||
num_filters=new_channels,
|
||||
filter_size=dw_size,
|
||||
stride=1,
|
||||
|
@ -186,7 +188,7 @@ class GhostBottleneck(nn.Layer):
|
|||
stride=stride,
|
||||
groups=hidden_dim,
|
||||
act=None,
|
||||
name=name+"_depthwise"
|
||||
name=name+"_depthwise" # In the old version, name was name + "_depthwise_depthwise"
|
||||
)
|
||||
if use_se:
|
||||
self.se_block = SEBlock(
|
||||
|
@ -194,7 +196,7 @@ class GhostBottleneck(nn.Layer):
|
|||
name=name + "_se"
|
||||
)
|
||||
self.ghost_module_2 = GhostModule(
|
||||
num_channels=num_channels,
|
||||
num_channels=hidden_dim,
|
||||
output_channels=output_channels,
|
||||
kernel_size=1,
|
||||
relu=False,
|
||||
|
@ -208,7 +210,7 @@ class GhostBottleneck(nn.Layer):
|
|||
stride=stride,
|
||||
groups=num_channels,
|
||||
act=None,
|
||||
name=name + "_shotcut_depthwise"
|
||||
name=name + "_shortcut_depthwise" # In the old version, name was name + "_shortcut_depthwise_depthwise"
|
||||
)
|
||||
self.shortcut_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
|
@ -217,11 +219,11 @@ class GhostBottleneck(nn.Layer):
|
|||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=name + "_shotcut_conv"
|
||||
name=name + "_shortcut_conv"
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.ghost_module(inputs)
|
||||
x = self.ghost_module_1(inputs)
|
||||
if self._stride == 2:
|
||||
x = self.depthwise_conv(x)
|
||||
if self._use_se:
|
||||
|
@ -275,14 +277,17 @@ class GhostNet(nn.Layer):
|
|||
num_channels = output_channels
|
||||
output_channels = int(self._make_divisible(c * self.scale, 4))
|
||||
hidden_dim = int(self._make_divisible(exp_size, self.scale, 4))
|
||||
ghost_bottleneck = GhostBottleneck(
|
||||
num_channels=num_channels,
|
||||
hidden_dim=hidden_dim,
|
||||
output_channels=output_channels,
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
use_se=use_se,
|
||||
name="_ghostbottleneck" + str(idx)
|
||||
ghost_bottleneck = self.add_sublayer(
|
||||
name="_ghostbottleneck_" + str(idx),
|
||||
sublayer=GhostBottleneck(
|
||||
num_channels=num_channels,
|
||||
hidden_dim=hidden_dim,
|
||||
output_channels=output_channels,
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
use_se=use_se,
|
||||
name="_ghostbottleneck_" + str(idx)
|
||||
)
|
||||
)
|
||||
self.ghost_bottleneck_list.append(ghost_bottleneck)
|
||||
idx += 1
|
||||
|
@ -300,24 +305,26 @@ class GhostNet(nn.Layer):
|
|||
)
|
||||
self.pool2d_gap = AdaptiveAvgPool2d(1)
|
||||
num_channels = output_channels
|
||||
output_channels = 1280
|
||||
self._num_channels = num_channels
|
||||
self._fc0_output_channels = 1280
|
||||
self.fc_0 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=output_channels,
|
||||
num_filters=self._fc0_output_channels,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
act="relu",
|
||||
name="fc_0"
|
||||
)
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
stdv = 1.0 / math.sqrt(output_channels * 1.0)
|
||||
stdv = 1.0 / math.sqrt(self._fc0_output_channels * 1.0)
|
||||
self.fc_1 = Linear(
|
||||
output_channels,
|
||||
self._fc0_output_channels,
|
||||
class_dim,
|
||||
param_attr=ParamAttr(
|
||||
weight_attr=ParamAttr(
|
||||
name="fc_1_weights",
|
||||
initializer=Uniform(-stdv, stdv)
|
||||
)
|
||||
),
|
||||
bias_attr=ParamAttr(name="fc_1_offset")
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
@ -328,6 +335,7 @@ class GhostNet(nn.Layer):
|
|||
x = self.pool2d_gap(x)
|
||||
x = self.fc_0(x)
|
||||
x = self.dropout(x)
|
||||
x = paddle.reshape(x, shape=[-1, self._fc0_output_channels])
|
||||
x = self.fc_1(x)
|
||||
return x
|
||||
|
||||
|
@ -345,3 +353,18 @@ class GhostNet(nn.Layer):
|
|||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def GhostNet_x0_5():
|
||||
model = GhostNet(scale=0.5)
|
||||
return model
|
||||
|
||||
|
||||
def GhostNet_x1_0():
|
||||
model = GhostNet(scale=1.0)
|
||||
return model
|
||||
|
||||
|
||||
def GhostNet_x1_3():
|
||||
model = GhostNet(scale=1.3)
|
||||
return model
|
Loading…
Reference in New Issue