mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix squeezenet, vgg and shufflenet
This commit is contained in:
parent
251e47c170
commit
9ecc07df10
@ -35,5 +35,8 @@ from .alexnet import AlexNet
|
||||
from .inception_v4 import InceptionV4
|
||||
from .xception_deeplab import Xception41_deeplab, Xception65_deeplab, Xception71_deeplab
|
||||
from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl
|
||||
from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish
|
||||
from .squeezenet import SqueezeNet1_0, SqueezeNet1_1
|
||||
from .vgg import VGG11, VGG13, VGG16, VGG19
|
||||
|
||||
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0
|
||||
|
@ -18,15 +18,17 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
|
||||
from paddle.fluid.initializer import MSRA
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
|
||||
from paddle.nn.initializer import MSRA
|
||||
import math
|
||||
|
||||
__all__ = [
|
||||
"ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33", "ShuffleNetV2_x0_5",
|
||||
"ShuffleNetV2_x1_0", "ShuffleNetV2_x1_5", "ShuffleNetV2_x2_0",
|
||||
"ShuffleNetV2", "ShuffleNetV2_x1_5", "ShuffleNetV2_x2_0",
|
||||
"ShuffleNetV2_swish"
|
||||
]
|
||||
|
||||
@ -37,17 +39,16 @@ def channel_shuffle(x, groups):
|
||||
channels_per_group = num_channels // groups
|
||||
|
||||
# reshape
|
||||
x = fluid.layers.reshape(
|
||||
x = paddle.reshape(
|
||||
x=x, shape=[batchsize, groups, channels_per_group, height, width])
|
||||
|
||||
x = fluid.layers.transpose(x=x, perm=[0, 2, 1, 3, 4])
|
||||
x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4])
|
||||
# flatten
|
||||
x = fluid.layers.reshape(
|
||||
x=x, shape=[batchsize, num_channels, height, width])
|
||||
x = paddle.reshape(x=x, shape=[batchsize, num_channels, height, width])
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNLayer(fluid.dygraph.Layer):
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
filter_size,
|
||||
@ -58,24 +59,21 @@ class ConvBNLayer(fluid.dygraph.Layer):
|
||||
num_groups=1,
|
||||
if_act=True,
|
||||
act='relu',
|
||||
name=None,
|
||||
use_cudnn=True):
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self._if_act = if_act
|
||||
assert act in ['relu', 'swish'], \
|
||||
"supported act are {} but your act is {}".format(
|
||||
['relu', 'swish'], act)
|
||||
self._act = act
|
||||
self._conv = Conv2D(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
self._conv = Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
act=None,
|
||||
use_cudnn=use_cudnn,
|
||||
param_attr=ParamAttr(
|
||||
weight_attr=ParamAttr(
|
||||
initializer=MSRA(), name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
|
||||
@ -90,12 +88,11 @@ class ConvBNLayer(fluid.dygraph.Layer):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
if self._if_act:
|
||||
y = fluid.layers.relu(
|
||||
y) if self._act == 'relu' else fluid.layers.swish(y)
|
||||
y = F.relu(y) if self._act == 'relu' else F.swish(y)
|
||||
return y
|
||||
|
||||
|
||||
class InvertedResidualUnit(fluid.dygraph.Layer):
|
||||
class InvertedResidualUnit(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
@ -130,7 +127,6 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
|
||||
num_groups=oup_inc,
|
||||
if_act=False,
|
||||
act=act,
|
||||
use_cudnn=False,
|
||||
name='stage_' + name + '_conv2')
|
||||
self._conv_linear = ConvBNLayer(
|
||||
num_channels=oup_inc,
|
||||
@ -153,7 +149,6 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
|
||||
num_groups=inp,
|
||||
if_act=False,
|
||||
act=act,
|
||||
use_cudnn=False,
|
||||
name='stage_' + name + '_conv4')
|
||||
self._conv_linear_1 = ConvBNLayer(
|
||||
num_channels=inp,
|
||||
@ -185,7 +180,6 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
|
||||
num_groups=oup_inc,
|
||||
if_act=False,
|
||||
act=act,
|
||||
use_cudnn=False,
|
||||
name='stage_' + name + '_conv2')
|
||||
self._conv_linear_2 = ConvBNLayer(
|
||||
num_channels=oup_inc,
|
||||
@ -200,14 +194,14 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.benchmodel == 1:
|
||||
x1, x2 = fluid.layers.split(
|
||||
x1, x2 = paddle.split(
|
||||
inputs,
|
||||
num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2],
|
||||
dim=1)
|
||||
axis=1)
|
||||
x2 = self._conv_pw(x2)
|
||||
x2 = self._conv_dw(x2)
|
||||
x2 = self._conv_linear(x2)
|
||||
out = fluid.layers.concat([x1, x2], axis=1)
|
||||
out = paddle.concat([x1, x2], axis=1)
|
||||
else:
|
||||
x1 = self._conv_dw_1(inputs)
|
||||
x1 = self._conv_linear_1(x1)
|
||||
@ -215,12 +209,12 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
|
||||
x2 = self._conv_pw_2(inputs)
|
||||
x2 = self._conv_dw_2(x2)
|
||||
x2 = self._conv_linear_2(x2)
|
||||
out = fluid.layers.concat([x1, x2], axis=1)
|
||||
out = paddle.concat([x1, x2], axis=1)
|
||||
|
||||
return channel_shuffle(out, 2)
|
||||
|
||||
|
||||
class ShuffleNet(fluid.dygraph.Layer):
|
||||
class ShuffleNet(nn.Layer):
|
||||
def __init__(self, class_dim=1000, scale=1.0, act='relu'):
|
||||
super(ShuffleNet, self).__init__()
|
||||
self.scale = scale
|
||||
@ -252,8 +246,7 @@ class ShuffleNet(fluid.dygraph.Layer):
|
||||
if_act=True,
|
||||
act=act,
|
||||
name='stage1_conv')
|
||||
self._max_pool = Pool2D(
|
||||
pool_type='max', pool_size=3, pool_stride=2, pool_padding=1)
|
||||
self._max_pool = MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# 2. bottleneck sequences
|
||||
self._block_list = []
|
||||
@ -298,13 +291,13 @@ class ShuffleNet(fluid.dygraph.Layer):
|
||||
name='conv5')
|
||||
|
||||
# 4. pool
|
||||
self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
|
||||
self._pool2d_avg = AdaptiveAvgPool2d(1)
|
||||
self._out_c = stage_out_channels[-1]
|
||||
# 5. fc
|
||||
self._fc = Linear(
|
||||
stage_out_channels[-1],
|
||||
class_dim,
|
||||
param_attr=ParamAttr(name='fc6_weights'),
|
||||
weight_attr=ParamAttr(name='fc6_weights'),
|
||||
bias_attr=ParamAttr(name='fc6_offset'))
|
||||
|
||||
def forward(self, inputs):
|
||||
@ -314,7 +307,7 @@ class ShuffleNet(fluid.dygraph.Layer):
|
||||
y = inv(y)
|
||||
y = self._last_conv(y)
|
||||
y = self._pool2d_avg(y)
|
||||
y = fluid.layers.reshape(y, shape=[-1, self._out_c])
|
||||
y = paddle.reshape(y, shape=[-1, self._out_c])
|
||||
y = self._fc(y)
|
||||
return y
|
||||
|
||||
|
@ -1,73 +1,75 @@
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
|
||||
|
||||
__all__ = ["SqueezeNet1_0", "SqueezeNet1_1"]
|
||||
|
||||
class MakeFireConv(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
filter_size,
|
||||
padding=0,
|
||||
name=None):
|
||||
super(MakeFireConv, self).__init__()
|
||||
self._conv = Conv2D(input_channels,
|
||||
output_channels,
|
||||
filter_size,
|
||||
padding=padding,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
return self._conv(inputs)
|
||||
|
||||
class MakeFire(fluid.dygraph.Layer):
|
||||
class MakeFireConv(nn.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
squeeze_channels,
|
||||
expand1x1_channels,
|
||||
expand3x3_channels,
|
||||
name=None):
|
||||
input_channels,
|
||||
output_channels,
|
||||
filter_size,
|
||||
padding=0,
|
||||
name=None):
|
||||
super(MakeFireConv, self).__init__()
|
||||
self._conv = Conv2d(
|
||||
input_channels,
|
||||
output_channels,
|
||||
filter_size,
|
||||
padding=padding,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_offset"))
|
||||
|
||||
def forward(self, x):
|
||||
x = self._conv(x)
|
||||
x = F.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class MakeFire(nn.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
squeeze_channels,
|
||||
expand1x1_channels,
|
||||
expand3x3_channels,
|
||||
name=None):
|
||||
super(MakeFire, self).__init__()
|
||||
self._conv = MakeFireConv(input_channels,
|
||||
squeeze_channels,
|
||||
1,
|
||||
name=name + "_squeeze1x1")
|
||||
self._conv_path1 = MakeFireConv(squeeze_channels,
|
||||
expand1x1_channels,
|
||||
1,
|
||||
name=name + "_expand1x1")
|
||||
self._conv_path2 = MakeFireConv(squeeze_channels,
|
||||
expand3x3_channels,
|
||||
3,
|
||||
padding=1,
|
||||
name=name + "_expand3x3")
|
||||
self._conv = MakeFireConv(
|
||||
input_channels, squeeze_channels, 1, name=name + "_squeeze1x1")
|
||||
self._conv_path1 = MakeFireConv(
|
||||
squeeze_channels, expand1x1_channels, 1, name=name + "_expand1x1")
|
||||
self._conv_path2 = MakeFireConv(
|
||||
squeeze_channels,
|
||||
expand3x3_channels,
|
||||
3,
|
||||
padding=1,
|
||||
name=name + "_expand3x3")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv(inputs)
|
||||
x1 = self._conv_path1(x)
|
||||
x2 = self._conv_path2(x)
|
||||
return fluid.layers.concat([x1, x2], axis=1)
|
||||
return paddle.concat([x1, x2], axis=1)
|
||||
|
||||
class SqueezeNet(fluid.dygraph.Layer):
|
||||
|
||||
class SqueezeNet(nn.Layer):
|
||||
def __init__(self, version, class_dim=1000):
|
||||
super(SqueezeNet, self).__init__()
|
||||
self.version = version
|
||||
|
||||
if self.version == "1.0":
|
||||
self._conv = Conv2D(3,
|
||||
96,
|
||||
7,
|
||||
stride=2,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name="conv1_weights"),
|
||||
bias_attr=ParamAttr(name="conv1_offset"))
|
||||
self._pool = Pool2D(pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_type="max")
|
||||
self._conv = Conv2d(
|
||||
3,
|
||||
96,
|
||||
7,
|
||||
stride=2,
|
||||
weight_attr=ParamAttr(name="conv1_weights"),
|
||||
bias_attr=ParamAttr(name="conv1_offset"))
|
||||
self._pool = MaxPool2d(kernel_size=3, stride=2, padding=0)
|
||||
self._conv1 = MakeFire(96, 16, 64, 64, name="fire2")
|
||||
self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
|
||||
self._conv3 = MakeFire(128, 32, 128, 128, name="fire4")
|
||||
@ -79,17 +81,15 @@ class SqueezeNet(fluid.dygraph.Layer):
|
||||
|
||||
self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
|
||||
else:
|
||||
self._conv = Conv2D(3,
|
||||
64,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name="conv1_weights"),
|
||||
bias_attr=ParamAttr(name="conv1_offset"))
|
||||
self._pool = Pool2D(pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_type="max")
|
||||
self._conv = Conv2d(
|
||||
3,
|
||||
64,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(name="conv1_weights"),
|
||||
bias_attr=ParamAttr(name="conv1_offset"))
|
||||
self._pool = MaxPool2d(kernel_size=3, stride=2, padding=0)
|
||||
self._conv1 = MakeFire(64, 16, 64, 64, name="fire2")
|
||||
self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
|
||||
|
||||
@ -102,19 +102,19 @@ class SqueezeNet(fluid.dygraph.Layer):
|
||||
self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
|
||||
|
||||
self._drop = Dropout(p=0.5)
|
||||
self._conv9 = Conv2D(512,
|
||||
class_dim,
|
||||
1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name="conv10_weights"),
|
||||
bias_attr=ParamAttr(name="conv10_offset"))
|
||||
self._avg_pool = Pool2D(pool_type="avg",
|
||||
global_pooling=True)
|
||||
self._conv9 = Conv2d(
|
||||
512,
|
||||
class_dim,
|
||||
1,
|
||||
weight_attr=ParamAttr(name="conv10_weights"),
|
||||
bias_attr=ParamAttr(name="conv10_offset"))
|
||||
self._avg_pool = AdaptiveAvgPool2d(1)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv(inputs)
|
||||
x = F.relu(x)
|
||||
x = self._pool(x)
|
||||
if self.version=="1.0":
|
||||
if self.version == "1.0":
|
||||
x = self._conv1(x)
|
||||
x = self._conv2(x)
|
||||
x = self._conv3(x)
|
||||
@ -138,14 +138,17 @@ class SqueezeNet(fluid.dygraph.Layer):
|
||||
x = self._conv8(x)
|
||||
x = self._drop(x)
|
||||
x = self._conv9(x)
|
||||
x = F.relu(x)
|
||||
x = self._avg_pool(x)
|
||||
x = fluid.layers.squeeze(x, axes=[2,3])
|
||||
x = paddle.squeeze(x, axis=[2, 3])
|
||||
return x
|
||||
|
||||
|
||||
def SqueezeNet1_0(**args):
|
||||
model = SqueezeNet(version="1.0", **args)
|
||||
return model
|
||||
return model
|
||||
|
||||
|
||||
def SqueezeNet1_1(**args):
|
||||
model = SqueezeNet(version="1.1", **args)
|
||||
return model
|
||||
return model
|
||||
|
@ -1,80 +1,86 @@
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
|
||||
|
||||
__all__ = ["VGG11", "VGG13", "VGG16", "VGG19"]
|
||||
|
||||
class ConvBlock(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
groups,
|
||||
name=None):
|
||||
|
||||
class ConvBlock(nn.Layer):
|
||||
def __init__(self, input_channels, output_channels, groups, name=None):
|
||||
super(ConvBlock, self).__init__()
|
||||
|
||||
self.groups = groups
|
||||
self._conv_1 = Conv2D(num_channels=input_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=name + "1_weights"),
|
||||
bias_attr=False)
|
||||
self._conv_1 = Conv2d(
|
||||
in_channels=input_channels,
|
||||
out_channels=output_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(name=name + "1_weights"),
|
||||
bias_attr=False)
|
||||
if groups == 2 or groups == 3 or groups == 4:
|
||||
self._conv_2 = Conv2D(num_channels=output_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=name + "2_weights"),
|
||||
bias_attr=False)
|
||||
self._conv_2 = Conv2d(
|
||||
in_channels=output_channels,
|
||||
out_channels=output_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(name=name + "2_weights"),
|
||||
bias_attr=False)
|
||||
if groups == 3 or groups == 4:
|
||||
self._conv_3 = Conv2D(num_channels=output_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=name + "3_weights"),
|
||||
bias_attr=False)
|
||||
self._conv_3 = Conv2d(
|
||||
in_channels=output_channels,
|
||||
out_channels=output_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(name=name + "3_weights"),
|
||||
bias_attr=False)
|
||||
if groups == 4:
|
||||
self._conv_4 = Conv2D(num_channels=output_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=name + "4_weights"),
|
||||
bias_attr=False)
|
||||
self._pool = Pool2D(pool_size=2,
|
||||
pool_type="max",
|
||||
pool_stride=2)
|
||||
self._conv_4 = Conv2d(
|
||||
in_channels=output_channels,
|
||||
out_channels=output_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(name=name + "4_weights"),
|
||||
bias_attr=False)
|
||||
|
||||
self._pool = MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv_1(inputs)
|
||||
x = F.relu(x)
|
||||
if self.groups == 2 or self.groups == 3 or self.groups == 4:
|
||||
x = self._conv_2(x)
|
||||
if self.groups == 3 or self.groups == 4 :
|
||||
x = F.relu(x)
|
||||
if self.groups == 3 or self.groups == 4:
|
||||
x = self._conv_3(x)
|
||||
x = F.relu(x)
|
||||
if self.groups == 4:
|
||||
x = self._conv_4(x)
|
||||
x = F.relu(x)
|
||||
x = self._pool(x)
|
||||
return x
|
||||
|
||||
class VGGNet(fluid.dygraph.Layer):
|
||||
|
||||
class VGGNet(nn.Layer):
|
||||
def __init__(self, layers=11, class_dim=1000):
|
||||
super(VGGNet, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
self.vgg_configure = {11: [1, 1, 2, 2, 2],
|
||||
13: [2, 2, 2, 2, 2],
|
||||
16: [2, 2, 3, 3, 3],
|
||||
19: [2, 2, 4, 4, 4]}
|
||||
self.vgg_configure = {
|
||||
11: [1, 1, 2, 2, 2],
|
||||
13: [2, 2, 2, 2, 2],
|
||||
16: [2, 2, 3, 3, 3],
|
||||
19: [2, 2, 4, 4, 4]
|
||||
}
|
||||
assert self.layers in self.vgg_configure.keys(), \
|
||||
"supported layers are {} but input layer is {}".format(vgg_configure.keys(), layers)
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
vgg_configure.keys(), layers)
|
||||
self.groups = self.vgg_configure[self.layers]
|
||||
|
||||
self._conv_block_1 = ConvBlock(3, 64, self.groups[0], name="conv1_")
|
||||
@ -83,21 +89,22 @@ class VGGNet(fluid.dygraph.Layer):
|
||||
self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")
|
||||
self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")
|
||||
|
||||
self._drop = fluid.dygraph.Dropout(p=0.5)
|
||||
self._fc1 = Linear(input_dim=7*7*512,
|
||||
output_dim=4096,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name="fc6_weights"),
|
||||
bias_attr=ParamAttr(name="fc6_offset"))
|
||||
self._fc2 = Linear(input_dim=4096,
|
||||
output_dim=4096,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name="fc7_weights"),
|
||||
bias_attr=ParamAttr(name="fc7_offset"))
|
||||
self._out = Linear(input_dim=4096,
|
||||
output_dim=class_dim,
|
||||
param_attr=ParamAttr(name="fc8_weights"),
|
||||
bias_attr=ParamAttr(name="fc8_offset"))
|
||||
self._drop = Dropout(p=0.5)
|
||||
self._fc1 = Linear(
|
||||
7 * 7 * 512,
|
||||
4096,
|
||||
weight_attr=ParamAttr(name="fc6_weights"),
|
||||
bias_attr=ParamAttr(name="fc6_offset"))
|
||||
self._fc2 = Linear(
|
||||
4096,
|
||||
4096,
|
||||
weight_attr=ParamAttr(name="fc7_weights"),
|
||||
bias_attr=ParamAttr(name="fc7_offset"))
|
||||
self._out = Linear(
|
||||
4096,
|
||||
class_dim,
|
||||
weight_attr=ParamAttr(name="fc8_weights"),
|
||||
bias_attr=ParamAttr(name="fc8_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv_block_1(inputs)
|
||||
@ -106,26 +113,32 @@ class VGGNet(fluid.dygraph.Layer):
|
||||
x = self._conv_block_4(x)
|
||||
x = self._conv_block_5(x)
|
||||
|
||||
x = fluid.layers.reshape(x, [0,-1])
|
||||
x = paddle.reshape(x, [0, -1])
|
||||
x = self._fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self._drop(x)
|
||||
x = self._fc2(x)
|
||||
x = F.relu(x)
|
||||
x = self._drop(x)
|
||||
x = self._out(x)
|
||||
return x
|
||||
|
||||
|
||||
def VGG11(**args):
|
||||
model = VGGNet(layers=11, **args)
|
||||
return model
|
||||
return model
|
||||
|
||||
|
||||
def VGG13(**args):
|
||||
model = VGGNet(layers=13, **args)
|
||||
return model
|
||||
|
||||
|
||||
def VGG16(**args):
|
||||
model = VGGNet(layers=16, **args)
|
||||
return model
|
||||
return model
|
||||
|
||||
|
||||
def VGG19(**args):
|
||||
model = VGGNet(layers=19, **args)
|
||||
return model
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user