Merge pull request #183 from wqz960/PaddleClas-dy
add Inception, ResNeXt101_wsl, EfficientNet and other modelspull/218/head
commit
fe302aec12
|
@ -1,172 +1,103 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
|
||||
import math
|
||||
|
||||
__all__ = ['AlexNet']
|
||||
__all__ = ["AlexNet"]
|
||||
|
||||
class ConvPoolLayer(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
inputc_channels,
|
||||
output_channels,
|
||||
filter_size,
|
||||
stride,
|
||||
padding,
|
||||
stdv,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvPoolLayer, self).__init__()
|
||||
|
||||
self._conv = Conv2D(num_channels=inputc_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
param_attr=ParamAttr(name=name + "_weights",
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name=name + "_offset",
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
act=act)
|
||||
self._pool = Pool2D(pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_padding=0,
|
||||
pool_type="max")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv(inputs)
|
||||
x = self._pool(x)
|
||||
return x
|
||||
|
||||
|
||||
class AlexNet():
|
||||
def __init__(self):
|
||||
pass
|
||||
class AlexNetDY(fluid.dygraph.Layer):
|
||||
def __init__(self, class_dim=1000):
|
||||
super(AlexNetDY, self).__init__()
|
||||
|
||||
def net(self, input, class_dim=1000):
|
||||
stdv = 1.0 / math.sqrt(input.shape[1] * 11 * 11)
|
||||
layer_name = [
|
||||
"conv1", "conv2", "conv3", "conv4", "conv5", "fc6", "fc7", "fc8"
|
||||
]
|
||||
conv1 = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=64,
|
||||
filter_size=11,
|
||||
stride=4,
|
||||
padding=2,
|
||||
groups=1,
|
||||
act='relu',
|
||||
bias_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[0] + "_offset"),
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[0] + "_weights"))
|
||||
pool1 = fluid.layers.pool2d(
|
||||
input=conv1,
|
||||
pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_padding=0,
|
||||
pool_type='max')
|
||||
stdv = 1.0/math.sqrt(3*11*11)
|
||||
self._conv1 = ConvPoolLayer(
|
||||
3, 64, 11, 4, 2, stdv, act="relu", name="conv1")
|
||||
stdv = 1.0/math.sqrt(64*5*5)
|
||||
self._conv2 = ConvPoolLayer(
|
||||
64, 192, 5, 1, 2, stdv, act="relu", name="conv2")
|
||||
stdv = 1.0/math.sqrt(192*3*3)
|
||||
self._conv3 = Conv2D(192, 384, 3, stride=1, padding=1,
|
||||
param_attr=ParamAttr(name="conv3_weights", initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="conv3_offset", initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
act="relu")
|
||||
stdv = 1.0/math.sqrt(384*3*3)
|
||||
self._conv4 = Conv2D(384, 256, 3, stride=1, padding=1,
|
||||
param_attr=ParamAttr(name="conv4_weights", initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="conv4_offset", initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
act="relu")
|
||||
stdv = 1.0/math.sqrt(256*3*3)
|
||||
self._conv5 = ConvPoolLayer(
|
||||
256, 256, 3, 1, 1, stdv, act="relu", name="conv5")
|
||||
stdv = 1.0/math.sqrt(256*6*6)
|
||||
|
||||
stdv = 1.0 / math.sqrt(pool1.shape[1] * 5 * 5)
|
||||
conv2 = fluid.layers.conv2d(
|
||||
input=pool1,
|
||||
num_filters=192,
|
||||
filter_size=5,
|
||||
stride=1,
|
||||
padding=2,
|
||||
groups=1,
|
||||
act='relu',
|
||||
bias_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[1] + "_offset"),
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[1] + "_weights"))
|
||||
pool2 = fluid.layers.pool2d(
|
||||
input=conv2,
|
||||
pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_padding=0,
|
||||
pool_type='max')
|
||||
self._drop1 = Dropout(p=0.5)
|
||||
self._fc6 = Linear(input_dim=256*6*6,
|
||||
output_dim=4096,
|
||||
param_attr=ParamAttr(name="fc6_weights", initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="fc6_offset", initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
act="relu")
|
||||
|
||||
stdv = 1.0 / math.sqrt(pool2.shape[1] * 3 * 3)
|
||||
conv3 = fluid.layers.conv2d(
|
||||
input=pool2,
|
||||
num_filters=384,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
act='relu',
|
||||
bias_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[2] + "_offset"),
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[2] + "_weights"))
|
||||
self._drop2 = Dropout(p=0.5)
|
||||
self._fc7 = Linear(input_dim=4096,
|
||||
output_dim=4096,
|
||||
param_attr=ParamAttr(name="fc7_weights", initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="fc7_offset", initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
act="relu")
|
||||
self._fc8 = Linear(input_dim=4096,
|
||||
output_dim=class_dim,
|
||||
param_attr=ParamAttr(name="fc8_weights", initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="fc8_offset", initializer=fluid.initializer.Uniform(-stdv, stdv)))
|
||||
|
||||
stdv = 1.0 / math.sqrt(conv3.shape[1] * 3 * 3)
|
||||
conv4 = fluid.layers.conv2d(
|
||||
input=conv3,
|
||||
num_filters=256,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
act='relu',
|
||||
bias_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[3] + "_offset"),
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[3] + "_weights"))
|
||||
def forward(self, inputs):
|
||||
x = self._conv1(inputs)
|
||||
x = self._conv2(x)
|
||||
x = self._conv3(x)
|
||||
x = self._conv4(x)
|
||||
x = self._conv5(x)
|
||||
x = fluid.layers.flatten(x, axis=0)
|
||||
x = self._drop1(x)
|
||||
x = self._fc6(x)
|
||||
x = self._drop2(x)
|
||||
x = self._fc7(x)
|
||||
x = self._fc8(x)
|
||||
return x
|
||||
|
||||
stdv = 1.0 / math.sqrt(conv4.shape[1] * 3 * 3)
|
||||
conv5 = fluid.layers.conv2d(
|
||||
input=conv4,
|
||||
num_filters=256,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
act='relu',
|
||||
bias_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[4] + "_offset"),
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[4] + "_weights"))
|
||||
pool5 = fluid.layers.pool2d(
|
||||
input=conv5,
|
||||
pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_padding=0,
|
||||
pool_type='max')
|
||||
|
||||
drop6 = fluid.layers.dropout(x=pool5, dropout_prob=0.5)
|
||||
stdv = 1.0 / math.sqrt(drop6.shape[1] * drop6.shape[2] *
|
||||
drop6.shape[3] * 1.0)
|
||||
|
||||
fc6 = fluid.layers.fc(
|
||||
input=drop6,
|
||||
size=4096,
|
||||
act='relu',
|
||||
bias_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[5] + "_offset"),
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[5] + "_weights"))
|
||||
|
||||
drop7 = fluid.layers.dropout(x=fc6, dropout_prob=0.5)
|
||||
stdv = 1.0 / math.sqrt(drop7.shape[1] * 1.0)
|
||||
|
||||
fc7 = fluid.layers.fc(
|
||||
input=drop7,
|
||||
size=4096,
|
||||
act='relu',
|
||||
bias_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[6] + "_offset"),
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[6] + "_weights"))
|
||||
|
||||
stdv = 1.0 / math.sqrt(fc7.shape[1] * 1.0)
|
||||
out = fluid.layers.fc(
|
||||
input=fc7,
|
||||
size=class_dim,
|
||||
bias_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[7] + "_offset"),
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=layer_name[7] + "_weights"))
|
||||
return out
|
||||
def AlexNet(**args):
|
||||
model = AlexNetDY(**args)
|
||||
return model
|
||||
|
|
|
@ -1,79 +1,25 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
|
||||
import math
|
||||
|
||||
__all__ = ["DarkNet53"]
|
||||
|
||||
|
||||
class DarkNet53():
|
||||
def __init__(self):
|
||||
class ConvBNLayer(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
filter_size,
|
||||
stride,
|
||||
padding,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
pass
|
||||
|
||||
def net(self, input, class_dim=1000):
|
||||
DarkNet_cfg = {53: ([1, 2, 8, 8, 4], self.basicblock)}
|
||||
stages, block_func = DarkNet_cfg[53]
|
||||
stages = stages[0:5]
|
||||
conv1 = self.conv_bn_layer(
|
||||
input,
|
||||
ch_out=32,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
name="yolo_input")
|
||||
conv = self.downsample(
|
||||
conv1, ch_out=conv1.shape[1] * 2, name="yolo_input.downsample")
|
||||
|
||||
for i, stage in enumerate(stages):
|
||||
conv = self.layer_warp(
|
||||
block_func,
|
||||
conv,
|
||||
32 * (2**i),
|
||||
stage,
|
||||
name="stage.{}".format(i))
|
||||
if i < len(stages) - 1: # do not downsaple in the last stage
|
||||
conv = self.downsample(
|
||||
conv,
|
||||
ch_out=conv.shape[1] * 2,
|
||||
name="stage.{}.downsample".format(i))
|
||||
pool = fluid.layers.pool2d(
|
||||
input=conv, pool_type='avg', global_pooling=True)
|
||||
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
|
||||
out = fluid.layers.fc(
|
||||
input=pool,
|
||||
size=class_dim,
|
||||
param_attr=ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name='fc_weights'),
|
||||
bias_attr=ParamAttr(name='fc_offset'))
|
||||
return out
|
||||
|
||||
def conv_bn_layer(self,
|
||||
input,
|
||||
ch_out,
|
||||
filter_size,
|
||||
stride,
|
||||
padding,
|
||||
name=None):
|
||||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=ch_out,
|
||||
self._conv = Conv2D(
|
||||
num_channels=input_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
|
@ -82,39 +28,133 @@ class DarkNet53():
|
|||
bias_attr=False)
|
||||
|
||||
bn_name = name + ".bn"
|
||||
out = fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
act='relu',
|
||||
param_attr=ParamAttr(name=bn_name + '.scale'),
|
||||
bias_attr=ParamAttr(name=bn_name + '.offset'),
|
||||
moving_mean_name=bn_name + '.mean',
|
||||
moving_variance_name=bn_name + '.var')
|
||||
return out
|
||||
self._bn = BatchNorm(
|
||||
num_channels=output_channels,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=bn_name + ".scale"),
|
||||
bias_attr=ParamAttr(name=bn_name + ".offset"),
|
||||
moving_mean_name=bn_name + ".mean",
|
||||
moving_variance_name=bn_name + ".var")
|
||||
|
||||
def downsample(self,
|
||||
input,
|
||||
ch_out,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
name=None):
|
||||
return self.conv_bn_layer(
|
||||
input,
|
||||
ch_out=ch_out,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
name=name)
|
||||
def forward(self, inputs):
|
||||
x = self._conv(inputs)
|
||||
x = self._bn(x)
|
||||
return x
|
||||
|
||||
def basicblock(self, input, ch_out, name=None):
|
||||
conv1 = self.conv_bn_layer(input, ch_out, 1, 1, 0, name=name + ".0")
|
||||
conv2 = self.conv_bn_layer(
|
||||
conv1, ch_out * 2, 3, 1, 1, name=name + ".1")
|
||||
out = fluid.layers.elementwise_add(x=input, y=conv2, act=None)
|
||||
return out
|
||||
|
||||
def layer_warp(self, block_func, input, ch_out, count, name=None):
|
||||
res_out = block_func(input, ch_out, name='{}.0'.format(name))
|
||||
for j in range(1, count):
|
||||
res_out = block_func(res_out, ch_out, name='{}.{}'.format(name, j))
|
||||
return res_out
|
||||
class BasicBlock(fluid.dygraph.Layer):
|
||||
def __init__(self, input_channels, output_channels, name=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
|
||||
self._conv1 = ConvBNLayer(
|
||||
input_channels, output_channels, 1, 1, 0, name=name + ".0")
|
||||
self._conv2 = ConvBNLayer(
|
||||
output_channels, output_channels * 2, 3, 1, 1, name=name + ".1")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv1(inputs)
|
||||
x = self._conv2(x)
|
||||
return fluid.layers.elementwise_add(x=inputs, y=x)
|
||||
|
||||
|
||||
class DarkNet(fluid.dygraph.Layer):
|
||||
def __init__(self, class_dim=1000):
|
||||
super(DarkNet, self).__init__()
|
||||
|
||||
self.stages = [1, 2, 8, 8, 4]
|
||||
self._conv1 = ConvBNLayer(3, 32, 3, 1, 1, name="yolo_input")
|
||||
self._conv2 = ConvBNLayer(
|
||||
32, 64, 3, 2, 1, name="yolo_input.downsample")
|
||||
|
||||
self._basic_block_01 = BasicBlock(64, 32, name="stage.0.0")
|
||||
self._downsample_0 = ConvBNLayer(
|
||||
64, 128, 3, 2, 1, name="stage.0.downsample")
|
||||
|
||||
self._basic_block_11 = BasicBlock(128, 64, name="stage.1.0")
|
||||
self._basic_block_12 = BasicBlock(128, 64, name="stage.1.1")
|
||||
self._downsample_1 = ConvBNLayer(
|
||||
128, 256, 3, 2, 1, name="stage.1.downsample")
|
||||
|
||||
self._basic_block_21 = BasicBlock(256, 128, name="stage.2.0")
|
||||
self._basic_block_22 = BasicBlock(256, 128, name="stage.2.1")
|
||||
self._basic_block_23 = BasicBlock(256, 128, name="stage.2.2")
|
||||
self._basic_block_24 = BasicBlock(256, 128, name="stage.2.3")
|
||||
self._basic_block_25 = BasicBlock(256, 128, name="stage.2.4")
|
||||
self._basic_block_26 = BasicBlock(256, 128, name="stage.2.5")
|
||||
self._basic_block_27 = BasicBlock(256, 128, name="stage.2.6")
|
||||
self._basic_block_28 = BasicBlock(256, 128, name="stage.2.7")
|
||||
self._downsample_2 = ConvBNLayer(
|
||||
256, 512, 3, 2, 1, name="stage.2.downsample")
|
||||
|
||||
self._basic_block_31 = BasicBlock(512, 256, name="stage.3.0")
|
||||
self._basic_block_32 = BasicBlock(512, 256, name="stage.3.1")
|
||||
self._basic_block_33 = BasicBlock(512, 256, name="stage.3.2")
|
||||
self._basic_block_34 = BasicBlock(512, 256, name="stage.3.3")
|
||||
self._basic_block_35 = BasicBlock(512, 256, name="stage.3.4")
|
||||
self._basic_block_36 = BasicBlock(512, 256, name="stage.3.5")
|
||||
self._basic_block_37 = BasicBlock(512, 256, name="stage.3.6")
|
||||
self._basic_block_38 = BasicBlock(512, 256, name="stage.3.7")
|
||||
self._downsample_3 = ConvBNLayer(
|
||||
512, 1024, 3, 2, 1, name="stage.3.downsample")
|
||||
|
||||
self._basic_block_41 = BasicBlock(1024, 512, name="stage.4.0")
|
||||
self._basic_block_42 = BasicBlock(1024, 512, name="stage.4.1")
|
||||
self._basic_block_43 = BasicBlock(1024, 512, name="stage.4.2")
|
||||
self._basic_block_44 = BasicBlock(1024, 512, name="stage.4.3")
|
||||
|
||||
self._pool = Pool2D(pool_type="avg", global_pooling=True)
|
||||
|
||||
stdv = 1.0 / math.sqrt(1024.0)
|
||||
self._out = Linear(
|
||||
input_dim=1024,
|
||||
output_dim=class_dim,
|
||||
param_attr=ParamAttr(
|
||||
name="fc_weights",
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="fc_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv1(inputs)
|
||||
x = self._conv2(x)
|
||||
|
||||
x = self._basic_block_01(x)
|
||||
x = self._downsample_0(x)
|
||||
|
||||
x = self._basic_block_11(x)
|
||||
x = self._basic_block_12(x)
|
||||
x = self._downsample_1(x)
|
||||
|
||||
x = self._basic_block_21(x)
|
||||
x = self._basic_block_22(x)
|
||||
x = self._basic_block_23(x)
|
||||
x = self._basic_block_24(x)
|
||||
x = self._basic_block_25(x)
|
||||
x = self._basic_block_26(x)
|
||||
x = self._basic_block_27(x)
|
||||
x = self._basic_block_28(x)
|
||||
x = self._downsample_2(x)
|
||||
|
||||
x = self._basic_block_31(x)
|
||||
x = self._basic_block_32(x)
|
||||
x = self._basic_block_33(x)
|
||||
x = self._basic_block_34(x)
|
||||
x = self._basic_block_35(x)
|
||||
x = self._basic_block_36(x)
|
||||
x = self._basic_block_37(x)
|
||||
x = self._basic_block_38(x)
|
||||
x = self._downsample_3(x)
|
||||
|
||||
x = self._basic_block_41(x)
|
||||
x = self._basic_block_42(x)
|
||||
x = self._basic_block_43(x)
|
||||
x = self._basic_block_44(x)
|
||||
|
||||
x = self._pool(x)
|
||||
x = fluid.layers.squeeze(x, axes=[2, 3])
|
||||
x = self._out(x)
|
||||
return x
|
||||
|
||||
|
||||
def DarkNet53(**args):
|
||||
model = DarkNet(**args)
|
||||
return model
|
File diff suppressed because it is too large
Load Diff
|
@ -1,237 +1,208 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.layer_helper import LayerHelper
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
|
||||
import math
|
||||
|
||||
__all__ = ['GoogLeNet']
|
||||
__all__ = ['GoogLeNet_DY']
|
||||
|
||||
def xavier(channels, filter_size, name):
|
||||
stdv = (3.0 / (filter_size**2 * channels))**0.5
|
||||
param_attr = ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=name + "_weights")
|
||||
|
||||
return param_attr
|
||||
|
||||
|
||||
class GoogLeNet():
|
||||
def __init__(self):
|
||||
class ConvLayer(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvLayer, self).__init__()
|
||||
|
||||
pass
|
||||
|
||||
def conv_layer(self,
|
||||
input,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
channels = input.shape[1]
|
||||
stdv = (3.0 / (filter_size**2 * channels))**0.5
|
||||
param_attr = ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=name + "_weights")
|
||||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
self._conv = Conv2D(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=act,
|
||||
param_attr=param_attr,
|
||||
bias_attr=False,
|
||||
name=name)
|
||||
return conv
|
||||
|
||||
def xavier(self, channels, filter_size, name):
|
||||
stdv = (3.0 / (filter_size**2 * channels))**0.5
|
||||
param_attr = ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name=name + "_weights")
|
||||
|
||||
return param_attr
|
||||
|
||||
def inception(self,
|
||||
input,
|
||||
channels,
|
||||
filter1,
|
||||
filter3R,
|
||||
filter3,
|
||||
filter5R,
|
||||
filter5,
|
||||
proj,
|
||||
name=None):
|
||||
conv1 = self.conv_layer(
|
||||
input=input,
|
||||
num_filters=filter1,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="inception_" + name + "_1x1")
|
||||
conv3r = self.conv_layer(
|
||||
input=input,
|
||||
num_filters=filter3R,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="inception_" + name + "_3x3_reduce")
|
||||
conv3 = self.conv_layer(
|
||||
input=conv3r,
|
||||
num_filters=filter3,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="inception_" + name + "_3x3")
|
||||
conv5r = self.conv_layer(
|
||||
input=input,
|
||||
num_filters=filter5R,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="inception_" + name + "_5x5_reduce")
|
||||
conv5 = self.conv_layer(
|
||||
input=conv5r,
|
||||
num_filters=filter5,
|
||||
filter_size=5,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="inception_" + name + "_5x5")
|
||||
pool = fluid.layers.pool2d(
|
||||
input=input,
|
||||
pool_size=3,
|
||||
pool_stride=1,
|
||||
pool_padding=1,
|
||||
pool_type='max')
|
||||
convprj = fluid.layers.conv2d(
|
||||
input=pool,
|
||||
filter_size=1,
|
||||
num_filters=proj,
|
||||
stride=1,
|
||||
padding=0,
|
||||
name="inception_" + name + "_3x3_proj",
|
||||
param_attr=ParamAttr(
|
||||
name="inception_" + name + "_3x3_proj_weights"),
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
cat = fluid.layers.concat(input=[conv1, conv3, conv5, convprj], axis=1)
|
||||
cat = fluid.layers.relu(cat)
|
||||
return cat
|
||||
|
||||
def net(self, input, class_dim=1000):
|
||||
conv = self.conv_layer(
|
||||
input=input,
|
||||
num_filters=64,
|
||||
filter_size=7,
|
||||
stride=2,
|
||||
act=None,
|
||||
name="conv1")
|
||||
pool = fluid.layers.pool2d(
|
||||
input=conv, pool_size=3, pool_type='max', pool_stride=2)
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
return y
|
||||
|
||||
conv = self.conv_layer(
|
||||
input=pool,
|
||||
num_filters=64,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv2_1x1")
|
||||
conv = self.conv_layer(
|
||||
input=conv,
|
||||
num_filters=192,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv2_3x3")
|
||||
pool = fluid.layers.pool2d(
|
||||
input=conv, pool_size=3, pool_type='max', pool_stride=2)
|
||||
|
||||
ince3a = self.inception(pool, 192, 64, 96, 128, 16, 32, 32, "ince3a")
|
||||
ince3b = self.inception(ince3a, 256, 128, 128, 192, 32, 96, 64,
|
||||
"ince3b")
|
||||
pool3 = fluid.layers.pool2d(
|
||||
input=ince3b, pool_size=3, pool_type='max', pool_stride=2)
|
||||
class Inception(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
filter1,
|
||||
filter3R,
|
||||
filter3,
|
||||
filter5R,
|
||||
filter5,
|
||||
proj,
|
||||
name=None):
|
||||
super(Inception, self).__init__()
|
||||
|
||||
ince4a = self.inception(pool3, 480, 192, 96, 208, 16, 48, 64, "ince4a")
|
||||
ince4b = self.inception(ince4a, 512, 160, 112, 224, 24, 64, 64,
|
||||
"ince4b")
|
||||
ince4c = self.inception(ince4b, 512, 128, 128, 256, 24, 64, 64,
|
||||
"ince4c")
|
||||
ince4d = self.inception(ince4c, 512, 112, 144, 288, 32, 64, 64,
|
||||
"ince4d")
|
||||
ince4e = self.inception(ince4d, 528, 256, 160, 320, 32, 128, 128,
|
||||
"ince4e")
|
||||
pool4 = fluid.layers.pool2d(
|
||||
input=ince4e, pool_size=3, pool_type='max', pool_stride=2)
|
||||
self._conv1 = ConvLayer(
|
||||
input_channels, filter1, 1, name="inception_" + name + "_1x1")
|
||||
self._conv3r = ConvLayer(
|
||||
input_channels,
|
||||
filter3R,
|
||||
1,
|
||||
name="inception_" + name + "_3x3_reduce")
|
||||
self._conv3 = ConvLayer(
|
||||
filter3R, filter3, 3, name="inception_" + name + "_3x3")
|
||||
self._conv5r = ConvLayer(
|
||||
input_channels,
|
||||
filter5R,
|
||||
1,
|
||||
name="inception_" + name + "_5x5_reduce")
|
||||
self._conv5 = ConvLayer(
|
||||
filter5R, filter5, 5, name="inception_" + name + "_5x5")
|
||||
self._pool = Pool2D(
|
||||
pool_size=3, pool_type="max", pool_stride=1, pool_padding=1)
|
||||
self._convprj = ConvLayer(
|
||||
input_channels, proj, 1, name="inception_" + name + "_3x3_proj")
|
||||
|
||||
ince5a = self.inception(pool4, 832, 256, 160, 320, 32, 128, 128,
|
||||
"ince5a")
|
||||
ince5b = self.inception(ince5a, 832, 384, 192, 384, 48, 128, 128,
|
||||
"ince5b")
|
||||
pool5 = fluid.layers.pool2d(
|
||||
input=ince5b, pool_size=7, pool_type='avg', pool_stride=7)
|
||||
dropout = fluid.layers.dropout(x=pool5, dropout_prob=0.4)
|
||||
out = fluid.layers.fc(input=dropout,
|
||||
size=class_dim,
|
||||
act='softmax',
|
||||
param_attr=self.xavier(1024, 1, "out"),
|
||||
name="out",
|
||||
bias_attr=ParamAttr(name="out_offset"))
|
||||
def forward(self, inputs):
|
||||
conv1 = self._conv1(inputs)
|
||||
|
||||
pool_o1 = fluid.layers.pool2d(
|
||||
input=ince4a, pool_size=5, pool_type='avg', pool_stride=3)
|
||||
conv_o1 = self.conv_layer(
|
||||
input=pool_o1,
|
||||
num_filters=128,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_o1")
|
||||
fc_o1 = fluid.layers.fc(input=conv_o1,
|
||||
size=1024,
|
||||
act='relu',
|
||||
param_attr=self.xavier(2048, 1, "fc_o1"),
|
||||
name="fc_o1",
|
||||
bias_attr=ParamAttr(name="fc_o1_offset"))
|
||||
dropout_o1 = fluid.layers.dropout(x=fc_o1, dropout_prob=0.7)
|
||||
out1 = fluid.layers.fc(input=dropout_o1,
|
||||
size=class_dim,
|
||||
act='softmax',
|
||||
param_attr=self.xavier(1024, 1, "out1"),
|
||||
name="out1",
|
||||
bias_attr=ParamAttr(name="out1_offset"))
|
||||
conv3r = self._conv3r(inputs)
|
||||
conv3 = self._conv3(conv3r)
|
||||
|
||||
pool_o2 = fluid.layers.pool2d(
|
||||
input=ince4d, pool_size=5, pool_type='avg', pool_stride=3)
|
||||
conv_o2 = self.conv_layer(
|
||||
input=pool_o2,
|
||||
num_filters=128,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_o2")
|
||||
fc_o2 = fluid.layers.fc(input=conv_o2,
|
||||
size=1024,
|
||||
act='relu',
|
||||
param_attr=self.xavier(2048, 1, "fc_o2"),
|
||||
name="fc_o2",
|
||||
bias_attr=ParamAttr(name="fc_o2_offset"))
|
||||
dropout_o2 = fluid.layers.dropout(x=fc_o2, dropout_prob=0.7)
|
||||
out2 = fluid.layers.fc(input=dropout_o2,
|
||||
size=class_dim,
|
||||
act='softmax',
|
||||
param_attr=self.xavier(1024, 1, "out2"),
|
||||
name="out2",
|
||||
bias_attr=ParamAttr(name="out2_offset"))
|
||||
conv5r = self._conv5r(inputs)
|
||||
conv5 = self._conv5(conv5r)
|
||||
|
||||
# last fc layer is "out"
|
||||
pool = self._pool(inputs)
|
||||
convprj = self._convprj(pool)
|
||||
|
||||
cat = fluid.layers.concat([conv1, conv3, conv5, convprj], axis=1)
|
||||
layer_helper = LayerHelper(self.full_name(), act="relu")
|
||||
return layer_helper.append_activation(cat)
|
||||
|
||||
|
||||
class GoogleNetDY(fluid.dygraph.Layer):
|
||||
def __init__(self, class_dim=1000):
|
||||
super(GoogleNetDY, self).__init__()
|
||||
self._conv = ConvLayer(3, 64, 7, 2, name="conv1")
|
||||
self._pool = Pool2D(pool_size=3, pool_type="max", pool_stride=2)
|
||||
self._conv_1 = ConvLayer(64, 64, 1, name="conv2_1x1")
|
||||
self._conv_2 = ConvLayer(64, 192, 3, name="conv2_3x3")
|
||||
|
||||
self._ince3a = Inception(
|
||||
192, 192, 64, 96, 128, 16, 32, 32, name="ince3a")
|
||||
self._ince3b = Inception(
|
||||
256, 256, 128, 128, 192, 32, 96, 64, name="ince3b")
|
||||
|
||||
self._ince4a = Inception(
|
||||
480, 480, 192, 96, 208, 16, 48, 64, name="ince4a")
|
||||
self._ince4b = Inception(
|
||||
512, 512, 160, 112, 224, 24, 64, 64, name="ince4b")
|
||||
self._ince4c = Inception(
|
||||
512, 512, 128, 128, 256, 24, 64, 64, name="ince4c")
|
||||
self._ince4d = Inception(
|
||||
512, 512, 112, 144, 288, 32, 64, 64, name="ince4d")
|
||||
self._ince4e = Inception(
|
||||
528, 528, 256, 160, 320, 32, 128, 128, name="ince4e")
|
||||
|
||||
self._ince5a = Inception(
|
||||
832, 832, 256, 160, 320, 32, 128, 128, name="ince5a")
|
||||
self._ince5b = Inception(
|
||||
832, 832, 384, 192, 384, 48, 128, 128, name="ince5b")
|
||||
|
||||
self._pool_5 = Pool2D(pool_size=7, pool_type='avg', pool_stride=7)
|
||||
|
||||
self._drop = fluid.dygraph.Dropout(p=0.4)
|
||||
self._fc_out = Linear(
|
||||
1024,
|
||||
class_dim,
|
||||
param_attr=xavier(1024, 1, "out"),
|
||||
bias_attr=ParamAttr(name="out_offset"),
|
||||
act="softmax")
|
||||
self._pool_o1 = Pool2D(pool_size=5, pool_stride=3, pool_type="avg")
|
||||
self._conv_o1 = ConvLayer(512, 128, 1, name="conv_o1")
|
||||
self._fc_o1 = Linear(
|
||||
1152,
|
||||
1024,
|
||||
param_attr=xavier(2048, 1, "fc_o1"),
|
||||
bias_attr=ParamAttr(name="fc_o1_offset"),
|
||||
act="relu")
|
||||
self._drop_o1 = fluid.dygraph.Dropout(p=0.7)
|
||||
self._out1 = Linear(
|
||||
1024,
|
||||
class_dim,
|
||||
param_attr=xavier(1024, 1, "out1"),
|
||||
bias_attr=ParamAttr(name="out1_offset"),
|
||||
act="softmax")
|
||||
self._pool_o2 = Pool2D(pool_size=5, pool_stride=3, pool_type='avg')
|
||||
self._conv_o2 = ConvLayer(528, 128, 1, name="conv_o2")
|
||||
self._fc_o2 = Linear(
|
||||
1152,
|
||||
1024,
|
||||
param_attr=xavier(2048, 1, "fc_o2"),
|
||||
bias_attr=ParamAttr(name="fc_o2_offset"))
|
||||
self._drop_o2 = fluid.dygraph.Dropout(p=0.7)
|
||||
self._out2 = Linear(
|
||||
1024,
|
||||
class_dim,
|
||||
param_attr=xavier(1024, 1, "out2"),
|
||||
bias_attr=ParamAttr(name="out2_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv(inputs)
|
||||
x = self._pool(x)
|
||||
x = self._conv_1(x)
|
||||
x = self._conv_2(x)
|
||||
x = self._pool(x)
|
||||
|
||||
x = self._ince3a(x)
|
||||
x = self._ince3b(x)
|
||||
x = self._pool(x)
|
||||
|
||||
ince4a = self._ince4a(x)
|
||||
x = self._ince4b(ince4a)
|
||||
x = self._ince4c(x)
|
||||
ince4d = self._ince4d(x)
|
||||
x = self._ince4e(ince4d)
|
||||
x = self._pool(x)
|
||||
|
||||
x = self._ince5a(x)
|
||||
ince5b = self._ince5b(x)
|
||||
|
||||
x = self._pool_5(ince5b)
|
||||
x = self._drop(x)
|
||||
x = fluid.layers.squeeze(x, axes=[2, 3])
|
||||
out = self._fc_out(x)
|
||||
|
||||
x = self._pool_o1(ince4a)
|
||||
x = self._conv_o1(x)
|
||||
x = fluid.layers.flatten(x)
|
||||
x = self._fc_o1(x)
|
||||
x = self._drop_o1(x)
|
||||
out1 = self._out1(x)
|
||||
|
||||
x = self._pool_o2(ince4d)
|
||||
x = self._conv_o2(x)
|
||||
x = fluid.layers.flatten(x)
|
||||
x = self._fc_o2(x)
|
||||
x = self._drop_o2(x)
|
||||
out2 = self._out2(x)
|
||||
return [out, out1, out2]
|
||||
|
||||
|
||||
def GoogLeNet(**args):
|
||||
model = GoogleNetDY(**args)
|
||||
return model
|
|
@ -1,77 +1,25 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
|
||||
import math
|
||||
|
||||
__all__ = ['InceptionV4']
|
||||
__all__ = ["InceptionV4"]
|
||||
|
||||
class ConvBNLayer(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
act='relu',
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
class InceptionV4():
|
||||
def __init__(self):
|
||||
|
||||
pass
|
||||
|
||||
def net(self, input, class_dim=1000):
|
||||
x = self.inception_stem(input)
|
||||
|
||||
for i in range(4):
|
||||
x = self.inceptionA(x, name=str(i + 1))
|
||||
x = self.reductionA(x)
|
||||
|
||||
for i in range(7):
|
||||
x = self.inceptionB(x, name=str(i + 1))
|
||||
x = self.reductionB(x)
|
||||
|
||||
for i in range(3):
|
||||
x = self.inceptionC(x, name=str(i + 1))
|
||||
|
||||
pool = fluid.layers.pool2d(
|
||||
input=x, pool_type='avg', global_pooling=True)
|
||||
|
||||
drop = fluid.layers.dropout(x=pool, dropout_prob=0.2)
|
||||
|
||||
stdv = 1.0 / math.sqrt(drop.shape[1] * 1.0)
|
||||
out = fluid.layers.fc(
|
||||
input=drop,
|
||||
size=class_dim,
|
||||
param_attr=ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name="final_fc_weights"),
|
||||
bias_attr=ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name="final_fc_offset"))
|
||||
return out
|
||||
|
||||
def conv_bn_layer(self,
|
||||
data,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
act='relu',
|
||||
name=None):
|
||||
conv = fluid.layers.conv2d(
|
||||
input=data,
|
||||
self._conv = Conv2D(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
|
@ -79,276 +27,413 @@ class InceptionV4():
|
|||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False,
|
||||
name=name)
|
||||
bias_attr=False)
|
||||
bn_name = name + "_bn"
|
||||
return fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
self._batch_norm = BatchNorm(
|
||||
num_filters,
|
||||
act=act,
|
||||
name=bn_name,
|
||||
param_attr=ParamAttr(name=bn_name + "_scale"),
|
||||
bias_attr=ParamAttr(name=bn_name + "_offset"),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def inception_stem(self, data, name=None):
|
||||
conv = self.conv_bn_layer(
|
||||
data, 32, 3, stride=2, act='relu', name="conv1_3x3_s2")
|
||||
conv = self.conv_bn_layer(conv, 32, 3, act='relu', name="conv2_3x3_s1")
|
||||
conv = self.conv_bn_layer(
|
||||
conv, 64, 3, padding=1, act='relu', name="conv3_3x3_s1")
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
pool1 = fluid.layers.pool2d(
|
||||
input=conv, pool_size=3, pool_stride=2, pool_type='max')
|
||||
conv2 = self.conv_bn_layer(
|
||||
conv, 96, 3, stride=2, act='relu', name="inception_stem1_3x3_s2")
|
||||
concat = fluid.layers.concat([pool1, conv2], axis=1)
|
||||
|
||||
conv1 = self.conv_bn_layer(
|
||||
concat, 64, 1, act='relu', name="inception_stem2_3x3_reduce")
|
||||
conv1 = self.conv_bn_layer(
|
||||
conv1, 96, 3, act='relu', name="inception_stem2_3x3")
|
||||
|
||||
conv2 = self.conv_bn_layer(
|
||||
concat, 64, 1, act='relu', name="inception_stem2_1x7_reduce")
|
||||
conv2 = self.conv_bn_layer(
|
||||
conv2,
|
||||
class InceptionStem(fluid.dygraph.Layer):
|
||||
def __init__(self):
|
||||
super(InceptionStem, self).__init__()
|
||||
self._conv_1 = ConvBNLayer(
|
||||
3, 32, 3, stride=2, act="relu", name="conv1_3x3_s2")
|
||||
self._conv_2 = ConvBNLayer(32, 32, 3, act="relu", name="conv2_3x3_s1")
|
||||
self._conv_3 = ConvBNLayer(
|
||||
32, 64, 3, padding=1, act="relu", name="conv3_3x3_s1")
|
||||
self._pool = Pool2D(pool_size=3, pool_type="max", pool_stride=2)
|
||||
self._conv2 = ConvBNLayer(
|
||||
64, 96, 3, stride=2, act="relu", name="inception_stem1_3x3_s2")
|
||||
self._conv1_1 = ConvBNLayer(
|
||||
160, 64, 1, act="relu", name="inception_stem2_3x3_reduce")
|
||||
self._conv1_2 = ConvBNLayer(
|
||||
64, 96, 3, act="relu", name="inception_stem2_3x3")
|
||||
self._conv2_1 = ConvBNLayer(
|
||||
160, 64, 1, act="relu", name="inception_stem2_1x7_reduce")
|
||||
self._conv2_2 = ConvBNLayer(
|
||||
64,
|
||||
64, (7, 1),
|
||||
padding=(3, 0),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_stem2_1x7")
|
||||
conv2 = self.conv_bn_layer(
|
||||
conv2,
|
||||
self._conv2_3 = ConvBNLayer(
|
||||
64,
|
||||
64, (1, 7),
|
||||
padding=(0, 3),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_stem2_7x1")
|
||||
conv2 = self.conv_bn_layer(
|
||||
conv2, 96, 3, act='relu', name="inception_stem2_3x3_2")
|
||||
self._conv2_4 = ConvBNLayer(
|
||||
64, 96, 3, act="relu", name="inception_stem2_3x3_2")
|
||||
self._conv3 = ConvBNLayer(
|
||||
192, 192, 3, stride=2, act="relu", name="inception_stem3_3x3_s2")
|
||||
|
||||
def forward(self, inputs):
|
||||
conv = self._conv_1(inputs)
|
||||
conv = self._conv_2(conv)
|
||||
conv = self._conv_3(conv)
|
||||
|
||||
pool1 = self._pool(conv)
|
||||
conv2 = self._conv2(conv)
|
||||
concat = fluid.layers.concat([pool1, conv2], axis=1)
|
||||
|
||||
conv1 = self._conv1_1(concat)
|
||||
conv1 = self._conv1_2(conv1)
|
||||
|
||||
conv2 = self._conv2_1(concat)
|
||||
conv2 = self._conv2_2(conv2)
|
||||
conv2 = self._conv2_3(conv2)
|
||||
conv2 = self._conv2_4(conv2)
|
||||
|
||||
concat = fluid.layers.concat([conv1, conv2], axis=1)
|
||||
|
||||
conv1 = self.conv_bn_layer(
|
||||
concat,
|
||||
192,
|
||||
3,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="inception_stem3_3x3_s2")
|
||||
pool1 = fluid.layers.pool2d(
|
||||
input=concat, pool_size=3, pool_stride=2, pool_type='max')
|
||||
conv1 = self._conv3(concat)
|
||||
pool1 = self._pool(concat)
|
||||
|
||||
concat = fluid.layers.concat([conv1, pool1], axis=1)
|
||||
|
||||
return concat
|
||||
|
||||
def inceptionA(self, data, name=None):
|
||||
pool1 = fluid.layers.pool2d(
|
||||
input=data, pool_size=3, pool_padding=1, pool_type='avg')
|
||||
conv1 = self.conv_bn_layer(
|
||||
pool1, 96, 1, act='relu', name="inception_a" + name + "_1x1")
|
||||
|
||||
conv2 = self.conv_bn_layer(
|
||||
data, 96, 1, act='relu', name="inception_a" + name + "_1x1_2")
|
||||
|
||||
conv3 = self.conv_bn_layer(
|
||||
data, 64, 1, act='relu', name="inception_a" + name + "_3x3_reduce")
|
||||
conv3 = self.conv_bn_layer(
|
||||
conv3,
|
||||
class InceptionA(fluid.dygraph.Layer):
|
||||
def __init__(self, name):
|
||||
super(InceptionA, self).__init__()
|
||||
self._pool = Pool2D(pool_size=3, pool_type="avg", pool_padding=1)
|
||||
self._conv1 = ConvBNLayer(
|
||||
384, 96, 1, act="relu", name="inception_a" + name + "_1x1")
|
||||
self._conv2 = ConvBNLayer(
|
||||
384, 96, 1, act="relu", name="inception_a" + name + "_1x1_2")
|
||||
self._conv3_1 = ConvBNLayer(
|
||||
384, 64, 1, act="relu", name="inception_a" + name + "_3x3_reduce")
|
||||
self._conv3_2 = ConvBNLayer(
|
||||
64,
|
||||
96,
|
||||
3,
|
||||
padding=1,
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_a" + name + "_3x3")
|
||||
|
||||
conv4 = self.conv_bn_layer(
|
||||
data,
|
||||
self._conv4_1 = ConvBNLayer(
|
||||
384,
|
||||
64,
|
||||
1,
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_a" + name + "_3x3_2_reduce")
|
||||
conv4 = self.conv_bn_layer(
|
||||
conv4,
|
||||
self._conv4_2 = ConvBNLayer(
|
||||
64,
|
||||
96,
|
||||
3,
|
||||
padding=1,
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_a" + name + "_3x3_2")
|
||||
conv4 = self.conv_bn_layer(
|
||||
conv4,
|
||||
self._conv4_3 = ConvBNLayer(
|
||||
96,
|
||||
96,
|
||||
3,
|
||||
padding=1,
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_a" + name + "_3x3_3")
|
||||
|
||||
def forward(self, inputs):
|
||||
pool1 = self._pool(inputs)
|
||||
conv1 = self._conv1(pool1)
|
||||
|
||||
conv2 = self._conv2(inputs)
|
||||
|
||||
conv3 = self._conv3_1(inputs)
|
||||
conv3 = self._conv3_2(conv3)
|
||||
|
||||
conv4 = self._conv4_1(inputs)
|
||||
conv4 = self._conv4_2(conv4)
|
||||
conv4 = self._conv4_3(conv4)
|
||||
|
||||
concat = fluid.layers.concat([conv1, conv2, conv3, conv4], axis=1)
|
||||
|
||||
return concat
|
||||
|
||||
def reductionA(self, data, name=None):
|
||||
pool1 = fluid.layers.pool2d(
|
||||
input=data, pool_size=3, pool_stride=2, pool_type='max')
|
||||
|
||||
conv2 = self.conv_bn_layer(
|
||||
data, 384, 3, stride=2, act='relu', name="reduction_a_3x3")
|
||||
|
||||
conv3 = self.conv_bn_layer(
|
||||
data, 192, 1, act='relu', name="reduction_a_3x3_2_reduce")
|
||||
conv3 = self.conv_bn_layer(
|
||||
conv3, 224, 3, padding=1, act='relu', name="reduction_a_3x3_2")
|
||||
conv3 = self.conv_bn_layer(
|
||||
conv3, 256, 3, stride=2, act='relu', name="reduction_a_3x3_3")
|
||||
class ReductionA(fluid.dygraph.Layer):
|
||||
def __init__(self):
|
||||
super(ReductionA, self).__init__()
|
||||
self._pool = Pool2D(pool_size=3, pool_type="max", pool_stride=2)
|
||||
self._conv2 = ConvBNLayer(
|
||||
384, 384, 3, stride=2, act="relu", name="reduction_a_3x3")
|
||||
self._conv3_1 = ConvBNLayer(
|
||||
384, 192, 1, act="relu", name="reduction_a_3x3_2_reduce")
|
||||
self._conv3_2 = ConvBNLayer(
|
||||
192, 224, 3, padding=1, act="relu", name="reduction_a_3x3_2")
|
||||
self._conv3_3 = ConvBNLayer(
|
||||
224, 256, 3, stride=2, act="relu", name="reduction_a_3x3_3")
|
||||
|
||||
def forward(self, inputs):
|
||||
pool1 = self._pool(inputs)
|
||||
conv2 = self._conv2(inputs)
|
||||
conv3 = self._conv3_1(inputs)
|
||||
conv3 = self._conv3_2(conv3)
|
||||
conv3 = self._conv3_3(conv3)
|
||||
concat = fluid.layers.concat([pool1, conv2, conv3], axis=1)
|
||||
|
||||
return concat
|
||||
|
||||
def inceptionB(self, data, name=None):
|
||||
pool1 = fluid.layers.pool2d(
|
||||
input=data, pool_size=3, pool_padding=1, pool_type='avg')
|
||||
conv1 = self.conv_bn_layer(
|
||||
pool1, 128, 1, act='relu', name="inception_b" + name + "_1x1")
|
||||
|
||||
conv2 = self.conv_bn_layer(
|
||||
data, 384, 1, act='relu', name="inception_b" + name + "_1x1_2")
|
||||
|
||||
conv3 = self.conv_bn_layer(
|
||||
data,
|
||||
class InceptionB(fluid.dygraph.Layer):
|
||||
def __init__(self, name=None):
|
||||
super(InceptionB, self).__init__()
|
||||
self._pool = Pool2D(pool_size=3, pool_type="avg", pool_padding=1)
|
||||
self._conv1 = ConvBNLayer(
|
||||
1024, 128, 1, act="relu", name="inception_b" + name + "_1x1")
|
||||
self._conv2 = ConvBNLayer(
|
||||
1024, 384, 1, act="relu", name="inception_b" + name + "_1x1_2")
|
||||
self._conv3_1 = ConvBNLayer(
|
||||
1024,
|
||||
192,
|
||||
1,
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_b" + name + "_1x7_reduce")
|
||||
conv3 = self.conv_bn_layer(
|
||||
conv3,
|
||||
self._conv3_2 = ConvBNLayer(
|
||||
192,
|
||||
224, (1, 7),
|
||||
padding=(0, 3),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_b" + name + "_1x7")
|
||||
conv3 = self.conv_bn_layer(
|
||||
conv3,
|
||||
self._conv3_3 = ConvBNLayer(
|
||||
224,
|
||||
256, (7, 1),
|
||||
padding=(3, 0),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_b" + name + "_7x1")
|
||||
|
||||
conv4 = self.conv_bn_layer(
|
||||
data,
|
||||
self._conv4_1 = ConvBNLayer(
|
||||
1024,
|
||||
192,
|
||||
1,
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_b" + name + "_7x1_2_reduce")
|
||||
conv4 = self.conv_bn_layer(
|
||||
conv4,
|
||||
self._conv4_2 = ConvBNLayer(
|
||||
192,
|
||||
192, (1, 7),
|
||||
padding=(0, 3),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_b" + name + "_1x7_2")
|
||||
conv4 = self.conv_bn_layer(
|
||||
conv4,
|
||||
self._conv4_3 = ConvBNLayer(
|
||||
192,
|
||||
224, (7, 1),
|
||||
padding=(3, 0),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_b" + name + "_7x1_2")
|
||||
conv4 = self.conv_bn_layer(
|
||||
conv4,
|
||||
self._conv4_4 = ConvBNLayer(
|
||||
224,
|
||||
224, (1, 7),
|
||||
padding=(0, 3),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_b" + name + "_1x7_3")
|
||||
conv4 = self.conv_bn_layer(
|
||||
conv4,
|
||||
self._conv4_5 = ConvBNLayer(
|
||||
224,
|
||||
256, (7, 1),
|
||||
padding=(3, 0),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_b" + name + "_7x1_3")
|
||||
|
||||
concat = fluid.layers.concat([conv1, conv2, conv3, conv4], axis=1)
|
||||
def forward(self, inputs):
|
||||
pool1 = self._pool(inputs)
|
||||
conv1 = self._conv1(pool1)
|
||||
|
||||
conv2 = self._conv2(inputs)
|
||||
|
||||
conv3 = self._conv3_1(inputs)
|
||||
conv3 = self._conv3_2(conv3)
|
||||
conv3 = self._conv3_3(conv3)
|
||||
|
||||
conv4 = self._conv4_1(inputs)
|
||||
conv4 = self._conv4_2(conv4)
|
||||
conv4 = self._conv4_3(conv4)
|
||||
conv4 = self._conv4_4(conv4)
|
||||
conv4 = self._conv4_5(conv4)
|
||||
|
||||
concat = fluid.layers.concat([conv1, conv2, conv3, conv4], axis=1)
|
||||
return concat
|
||||
|
||||
def reductionB(self, data, name=None):
|
||||
pool1 = fluid.layers.pool2d(
|
||||
input=data, pool_size=3, pool_stride=2, pool_type='max')
|
||||
|
||||
conv2 = self.conv_bn_layer(
|
||||
data, 192, 1, act='relu', name="reduction_b_3x3_reduce")
|
||||
conv2 = self.conv_bn_layer(
|
||||
conv2, 192, 3, stride=2, act='relu', name="reduction_b_3x3")
|
||||
|
||||
conv3 = self.conv_bn_layer(
|
||||
data, 256, 1, act='relu', name="reduction_b_1x7_reduce")
|
||||
conv3 = self.conv_bn_layer(
|
||||
conv3,
|
||||
class ReductionB(fluid.dygraph.Layer):
|
||||
def __init__(self):
|
||||
super(ReductionB, self).__init__()
|
||||
self._pool = Pool2D(pool_size=3, pool_type="max", pool_stride=2)
|
||||
self._conv2_1 = ConvBNLayer(
|
||||
1024, 192, 1, act="relu", name="reduction_b_3x3_reduce")
|
||||
self._conv2_2 = ConvBNLayer(
|
||||
192, 192, 3, stride=2, act="relu", name="reduction_b_3x3")
|
||||
self._conv3_1 = ConvBNLayer(
|
||||
1024, 256, 1, act="relu", name="reduction_b_1x7_reduce")
|
||||
self._conv3_2 = ConvBNLayer(
|
||||
256,
|
||||
256, (1, 7),
|
||||
padding=(0, 3),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="reduction_b_1x7")
|
||||
conv3 = self.conv_bn_layer(
|
||||
conv3,
|
||||
self._conv3_3 = ConvBNLayer(
|
||||
256,
|
||||
320, (7, 1),
|
||||
padding=(3, 0),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="reduction_b_7x1")
|
||||
conv3 = self.conv_bn_layer(
|
||||
conv3, 320, 3, stride=2, act='relu', name="reduction_b_3x3_2")
|
||||
self._conv3_4 = ConvBNLayer(
|
||||
320, 320, 3, stride=2, act="relu", name="reduction_b_3x3_2")
|
||||
|
||||
def forward(self, inputs):
|
||||
pool1 = self._pool(inputs)
|
||||
|
||||
conv2 = self._conv2_1(inputs)
|
||||
conv2 = self._conv2_2(conv2)
|
||||
|
||||
conv3 = self._conv3_1(inputs)
|
||||
conv3 = self._conv3_2(conv3)
|
||||
conv3 = self._conv3_3(conv3)
|
||||
conv3 = self._conv3_4(conv3)
|
||||
|
||||
concat = fluid.layers.concat([pool1, conv2, conv3], axis=1)
|
||||
|
||||
return concat
|
||||
|
||||
def inceptionC(self, data, name=None):
|
||||
pool1 = fluid.layers.pool2d(
|
||||
input=data, pool_size=3, pool_padding=1, pool_type='avg')
|
||||
conv1 = self.conv_bn_layer(
|
||||
pool1, 256, 1, act='relu', name="inception_c" + name + "_1x1")
|
||||
|
||||
conv2 = self.conv_bn_layer(
|
||||
data, 256, 1, act='relu', name="inception_c" + name + "_1x1_2")
|
||||
|
||||
conv3 = self.conv_bn_layer(
|
||||
data, 384, 1, act='relu', name="inception_c" + name + "_1x1_3")
|
||||
conv3_1 = self.conv_bn_layer(
|
||||
conv3,
|
||||
class InceptionC(fluid.dygraph.Layer):
|
||||
def __init__(self, name=None):
|
||||
super(InceptionC, self).__init__()
|
||||
self._pool = Pool2D(pool_size=3, pool_type="avg", pool_padding=1)
|
||||
self._conv1 = ConvBNLayer(
|
||||
1536, 256, 1, act="relu", name="inception_c" + name + "_1x1")
|
||||
self._conv2 = ConvBNLayer(
|
||||
1536, 256, 1, act="relu", name="inception_c" + name + "_1x1_2")
|
||||
self._conv3_0 = ConvBNLayer(
|
||||
1536, 384, 1, act="relu", name="inception_c" + name + "_1x1_3")
|
||||
self._conv3_1 = ConvBNLayer(
|
||||
384,
|
||||
256, (1, 3),
|
||||
padding=(0, 1),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_c" + name + "_1x3")
|
||||
conv3_2 = self.conv_bn_layer(
|
||||
conv3,
|
||||
self._conv3_2 = ConvBNLayer(
|
||||
384,
|
||||
256, (3, 1),
|
||||
padding=(1, 0),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_c" + name + "_3x1")
|
||||
|
||||
conv4 = self.conv_bn_layer(
|
||||
data, 384, 1, act='relu', name="inception_c" + name + "_1x1_4")
|
||||
conv4 = self.conv_bn_layer(
|
||||
conv4,
|
||||
self._conv4_0 = ConvBNLayer(
|
||||
1536, 384, 1, act="relu", name="inception_c" + name + "_1x1_4")
|
||||
self._conv4_00 = ConvBNLayer(
|
||||
384,
|
||||
448, (1, 3),
|
||||
padding=(0, 1),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_c" + name + "_1x3_2")
|
||||
conv4 = self.conv_bn_layer(
|
||||
conv4,
|
||||
self._conv4_000 = ConvBNLayer(
|
||||
448,
|
||||
512, (3, 1),
|
||||
padding=(1, 0),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_c" + name + "_3x1_2")
|
||||
conv4_1 = self.conv_bn_layer(
|
||||
conv4,
|
||||
self._conv4_1 = ConvBNLayer(
|
||||
512,
|
||||
256, (1, 3),
|
||||
padding=(0, 1),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_c" + name + "_1x3_3")
|
||||
conv4_2 = self.conv_bn_layer(
|
||||
conv4,
|
||||
self._conv4_2 = ConvBNLayer(
|
||||
512,
|
||||
256, (3, 1),
|
||||
padding=(1, 0),
|
||||
act='relu',
|
||||
act="relu",
|
||||
name="inception_c" + name + "_3x1_3")
|
||||
|
||||
def forward(self, inputs):
|
||||
pool1 = self._pool(inputs)
|
||||
conv1 = self._conv1(pool1)
|
||||
|
||||
conv2 = self._conv2(inputs)
|
||||
|
||||
conv3 = self._conv3_0(inputs)
|
||||
conv3_1 = self._conv3_1(conv3)
|
||||
conv3_2 = self._conv3_2(conv3)
|
||||
|
||||
conv4 = self._conv4_0(inputs)
|
||||
conv4 = self._conv4_00(conv4)
|
||||
conv4 = self._conv4_000(conv4)
|
||||
conv4_1 = self._conv4_1(conv4)
|
||||
conv4_2 = self._conv4_2(conv4)
|
||||
|
||||
concat = fluid.layers.concat(
|
||||
[conv1, conv2, conv3_1, conv3_2, conv4_1, conv4_2], axis=1)
|
||||
|
||||
return concat
|
||||
|
||||
|
||||
class InceptionV4DY(fluid.dygraph.Layer):
|
||||
def __init__(self, class_dim=1000):
|
||||
super(InceptionV4DY, self).__init__()
|
||||
self._inception_stem = InceptionStem()
|
||||
|
||||
self._inceptionA_1 = InceptionA(name="1")
|
||||
self._inceptionA_2 = InceptionA(name="2")
|
||||
self._inceptionA_3 = InceptionA(name="3")
|
||||
self._inceptionA_4 = InceptionA(name="4")
|
||||
self._reductionA = ReductionA()
|
||||
|
||||
self._inceptionB_1 = InceptionB(name="1")
|
||||
self._inceptionB_2 = InceptionB(name="2")
|
||||
self._inceptionB_3 = InceptionB(name="3")
|
||||
self._inceptionB_4 = InceptionB(name="4")
|
||||
self._inceptionB_5 = InceptionB(name="5")
|
||||
self._inceptionB_6 = InceptionB(name="6")
|
||||
self._inceptionB_7 = InceptionB(name="7")
|
||||
self._reductionB = ReductionB()
|
||||
|
||||
self._inceptionC_1 = InceptionC(name="1")
|
||||
self._inceptionC_2 = InceptionC(name="2")
|
||||
self._inceptionC_3 = InceptionC(name="3")
|
||||
|
||||
self.avg_pool = Pool2D(pool_type='avg', global_pooling=True)
|
||||
self._drop = Dropout(p=0.2)
|
||||
stdv = 1.0 / math.sqrt(1536 * 1.0)
|
||||
self.out = Linear(
|
||||
1536,
|
||||
class_dim,
|
||||
param_attr=ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name="final_fc_weights"),
|
||||
bias_attr=ParamAttr(name="final_fc_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._inception_stem(inputs)
|
||||
|
||||
x = self._inceptionA_1(x)
|
||||
x = self._inceptionA_2(x)
|
||||
x = self._inceptionA_3(x)
|
||||
x = self._inceptionA_4(x)
|
||||
x = self._reductionA(x)
|
||||
|
||||
x = self._inceptionB_1(x)
|
||||
x = self._inceptionB_2(x)
|
||||
x = self._inceptionB_3(x)
|
||||
x = self._inceptionB_4(x)
|
||||
x = self._inceptionB_5(x)
|
||||
x = self._inceptionB_6(x)
|
||||
x = self._inceptionB_7(x)
|
||||
x = self._reductionB(x)
|
||||
|
||||
x = self._inceptionC_1(x)
|
||||
x = self._inceptionC_2(x)
|
||||
x = self._inceptionC_3(x)
|
||||
|
||||
x = self.avg_pool(x)
|
||||
x = fluid.layers.squeeze(x, axes=[2, 3])
|
||||
x = self._drop(x)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
def InceptionV4(**args):
|
||||
model = InceptionV4DY(**args)
|
||||
return model
|
|
@ -1,182 +1,246 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
|
||||
|
||||
__all__ = [
|
||||
"ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl", "ResNeXt101_32x32d_wsl",
|
||||
"ResNeXt101_32x48d_wsl", "Fix_ResNeXt101_32x48d_wsl"
|
||||
]
|
||||
__all__ = ["ResNeXt101_32x8d_wsl",
|
||||
"ResNeXt101_wsl_32x16d_wsl",
|
||||
"ResNeXt101_wsl_32x32d_wsl",
|
||||
"ResNeXt101_wsl_32x48d_wsl"]
|
||||
|
||||
class ConvBNLayer(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
if "downsample" in name:
|
||||
conv_name = name + ".0"
|
||||
else:
|
||||
conv_name = name
|
||||
self._conv = Conv2D(num_channels=input_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=(filter_size-1)//2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=conv_name + ".weight"),
|
||||
bias_attr=False)
|
||||
if "downsample" in name:
|
||||
bn_name = name[:9] + "downsample.1"
|
||||
else:
|
||||
if "conv1" == name:
|
||||
bn_name = "bn" + name[-1]
|
||||
else:
|
||||
bn_name = (name[:10] if name[7:9].isdigit() else name[:9]) + "bn" + name[-1]
|
||||
self._bn = BatchNorm(num_channels=output_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + ".weight"),
|
||||
bias_attr=ParamAttr(name=bn_name + ".bias"),
|
||||
moving_mean_name=bn_name + ".running_mean",
|
||||
moving_variance_name=bn_name + ".running_var")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv(inputs)
|
||||
x = self._bn(x)
|
||||
return x
|
||||
|
||||
class ShortCut(fluid.dygraph.Layer):
|
||||
def __init__(self, input_channels, output_channels, stride, name=None):
|
||||
super(ShortCut, self).__init__()
|
||||
|
||||
self.input_channels = input_channels
|
||||
self.output_channels = output_channels
|
||||
self.stride = stride
|
||||
if input_channels!=output_channels or stride!=1:
|
||||
self._conv = ConvBNLayer(
|
||||
input_channels, output_channels, filter_size=1, stride=stride, name=name)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.input_channels!= self.output_channels or self.stride!=1:
|
||||
return self._conv(inputs)
|
||||
return inputs
|
||||
|
||||
class BottleneckBlock(fluid.dygraph.Layer):
|
||||
def __init__(self, input_channels, output_channels, stride, cardinality, width, name):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self._conv0 = ConvBNLayer(
|
||||
input_channels, output_channels, filter_size=1, act="relu", name=name + ".conv1")
|
||||
self._conv1 = ConvBNLayer(
|
||||
output_channels, output_channels, filter_size=3, act="relu", stride=stride, groups=cardinality, name=name + ".conv2")
|
||||
self._conv2 = ConvBNLayer(
|
||||
output_channels, output_channels//(width//8), filter_size=1, act=None, name=name + ".conv3")
|
||||
self._short = ShortCut(
|
||||
input_channels, output_channels//(width//8), stride=stride, name=name + ".downsample")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv0(inputs)
|
||||
x = self._conv1(x)
|
||||
x = self._conv2(x)
|
||||
y = self._short(inputs)
|
||||
return fluid.layers.elementwise_add(x, y, act="relu")
|
||||
|
||||
class ResNeXt101WSL(fluid.dygraph.Layer):
|
||||
def __init__(self, layers=101, cardinality=32, width=48, class_dim=1000):
|
||||
super(ResNeXt101WSL, self).__init__()
|
||||
|
||||
self.class_dim = class_dim
|
||||
|
||||
class ResNeXt101_wsl():
|
||||
def __init__(self, layers=101, cardinality=32, width=48):
|
||||
self.layers = layers
|
||||
self.cardinality = cardinality
|
||||
self.width = width
|
||||
self.scale = width//8
|
||||
|
||||
def net(self, input, class_dim=1000):
|
||||
layers = self.layers
|
||||
cardinality = self.cardinality
|
||||
width = self.width
|
||||
self.depth = [3, 4, 23, 3]
|
||||
self.base_width = cardinality * width
|
||||
num_filters = [self.base_width*i for i in [1,2,4,8]] #[256, 512, 1024, 2048]
|
||||
self._conv_stem = ConvBNLayer(
|
||||
3, 64, 7, stride=2, act="relu", name="conv1")
|
||||
self._pool = Pool2D(pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_padding=1,
|
||||
pool_type="max")
|
||||
|
||||
depth = [3, 4, 23, 3]
|
||||
base_width = cardinality * width
|
||||
num_filters = [base_width * i for i in [1, 2, 4, 8]]
|
||||
self._conv1_0 = BottleneckBlock(
|
||||
64, num_filters[0], stride=1, cardinality=self.cardinality, width=self.width, name="layer1.0")
|
||||
self._conv1_1 = BottleneckBlock(
|
||||
num_filters[0]//(width//8), num_filters[0], stride=1, cardinality=self.cardinality, width=self.width, name="layer1.1")
|
||||
self._conv1_2 = BottleneckBlock(
|
||||
num_filters[0]//(width//8), num_filters[0], stride=1, cardinality=self.cardinality, width=self.width, name="layer1.2")
|
||||
|
||||
conv = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=64,
|
||||
filter_size=7,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1") #debug
|
||||
conv = fluid.layers.pool2d(
|
||||
input=conv,
|
||||
pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_padding=1,
|
||||
pool_type='max')
|
||||
self._conv2_0 = BottleneckBlock(
|
||||
num_filters[0]//(width//8), num_filters[1], stride=2, cardinality=self.cardinality, width=self.width, name="layer2.0")
|
||||
self._conv2_1 = BottleneckBlock(
|
||||
num_filters[1]//(width//8), num_filters[1], stride=1, cardinality=self.cardinality, width=self.width, name="layer2.1")
|
||||
self._conv2_2 = BottleneckBlock(
|
||||
num_filters[1]//(width//8), num_filters[1], stride=1, cardinality=self.cardinality, width=self.width, name="layer2.2")
|
||||
self._conv2_3 = BottleneckBlock(
|
||||
num_filters[1]//(width//8), num_filters[1], stride=1, cardinality=self.cardinality, width=self.width, name="layer2.3")
|
||||
|
||||
for block in range(len(depth)):
|
||||
for i in range(depth[block]):
|
||||
conv_name = 'layer' + str(block + 1) + "." + str(i)
|
||||
conv = self.bottleneck_block(
|
||||
input=conv,
|
||||
num_filters=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
cardinality=cardinality,
|
||||
name=conv_name)
|
||||
self._conv3_0 = BottleneckBlock(
|
||||
num_filters[1]//(width//8), num_filters[2], stride=2, cardinality=self.cardinality, width=self.width, name="layer3.0")
|
||||
self._conv3_1 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.1")
|
||||
self._conv3_2 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.2")
|
||||
self._conv3_3 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.3")
|
||||
self._conv3_4 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.4")
|
||||
self._conv3_5 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.5")
|
||||
self._conv3_6 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.6")
|
||||
self._conv3_7 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.7")
|
||||
self._conv3_8 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.8")
|
||||
self._conv3_9 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.9")
|
||||
self._conv3_10 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.10")
|
||||
self._conv3_11 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.11")
|
||||
self._conv3_12 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.12")
|
||||
self._conv3_13 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.13")
|
||||
self._conv3_14 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.14")
|
||||
self._conv3_15 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.15")
|
||||
self._conv3_16 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.16")
|
||||
self._conv3_17 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.17")
|
||||
self._conv3_18 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.18")
|
||||
self._conv3_19 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.19")
|
||||
self._conv3_20 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.20")
|
||||
self._conv3_21 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.21")
|
||||
self._conv3_22 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.22")
|
||||
|
||||
pool = fluid.layers.pool2d(
|
||||
input=conv, pool_type='avg', global_pooling=True)
|
||||
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
|
||||
out = fluid.layers.fc(
|
||||
input=pool,
|
||||
size=class_dim,
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv),
|
||||
name='fc.weight'),
|
||||
bias_attr=fluid.param_attr.ParamAttr(name='fc.bias'))
|
||||
return out
|
||||
self._conv4_0 = BottleneckBlock(
|
||||
num_filters[2]//(width//8), num_filters[3], stride=2, cardinality=self.cardinality, width=self.width, name="layer4.0")
|
||||
self._conv4_1 = BottleneckBlock(
|
||||
num_filters[3]//(width//8), num_filters[3], stride=1, cardinality=self.cardinality, width=self.width, name="layer4.1")
|
||||
self._conv4_2 = BottleneckBlock(
|
||||
num_filters[3]//(width//8), num_filters[3], stride=1, cardinality=self.cardinality, width=self.width, name="layer4.2")
|
||||
|
||||
def conv_bn_layer(self,
|
||||
input,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
if "downsample" in name:
|
||||
conv_name = name + '.0'
|
||||
else:
|
||||
conv_name = name
|
||||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=conv_name + ".weight"),
|
||||
bias_attr=False)
|
||||
if "downsample" in name:
|
||||
bn_name = name[:9] + 'downsample' + '.1'
|
||||
else:
|
||||
if "conv1" == name:
|
||||
bn_name = 'bn' + name[-1]
|
||||
else:
|
||||
bn_name = (name[:10] if name[7:9].isdigit() else name[:9]
|
||||
) + 'bn' + name[-1]
|
||||
return fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '.weight'),
|
||||
bias_attr=ParamAttr(bn_name + '.bias'),
|
||||
moving_mean_name=bn_name + '.running_mean',
|
||||
moving_variance_name=bn_name + '.running_var', )
|
||||
self._avg_pool = Pool2D(pool_type="avg", global_pooling=True)
|
||||
self._out = Linear(input_dim=num_filters[3]//(width//8),
|
||||
output_dim=class_dim,
|
||||
param_attr=ParamAttr(name="fc.weight"),
|
||||
bias_attr=ParamAttr(name="fc.bias"))
|
||||
|
||||
def shortcut(self, input, ch_out, stride, name):
|
||||
ch_in = input.shape[1]
|
||||
if ch_in != ch_out or stride != 1:
|
||||
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
|
||||
else:
|
||||
return input
|
||||
def forward(self, inputs):
|
||||
x = self._conv_stem(inputs)
|
||||
x = self._pool(x)
|
||||
|
||||
def bottleneck_block(self, input, num_filters, stride, cardinality, name):
|
||||
cardinality = self.cardinality
|
||||
width = self.width
|
||||
conv0 = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
act='relu',
|
||||
name=name + ".conv1")
|
||||
conv1 = self.conv_bn_layer(
|
||||
input=conv0,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
stride=stride,
|
||||
groups=cardinality,
|
||||
act='relu',
|
||||
name=name + ".conv2")
|
||||
conv2 = self.conv_bn_layer(
|
||||
input=conv1,
|
||||
num_filters=num_filters // (width // 8),
|
||||
filter_size=1,
|
||||
act=None,
|
||||
name=name + ".conv3")
|
||||
x = self._conv1_0(x)
|
||||
x = self._conv1_1(x)
|
||||
x = self._conv1_2(x)
|
||||
|
||||
short = self.shortcut(
|
||||
input,
|
||||
num_filters // (width // 8),
|
||||
stride,
|
||||
name=name + ".downsample")
|
||||
x = self._conv2_0(x)
|
||||
x = self._conv2_1(x)
|
||||
x = self._conv2_2(x)
|
||||
x = self._conv2_3(x)
|
||||
|
||||
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
|
||||
x = self._conv3_0(x)
|
||||
x = self._conv3_1(x)
|
||||
x = self._conv3_2(x)
|
||||
x = self._conv3_3(x)
|
||||
x = self._conv3_4(x)
|
||||
x = self._conv3_5(x)
|
||||
x = self._conv3_6(x)
|
||||
x = self._conv3_7(x)
|
||||
x = self._conv3_8(x)
|
||||
x = self._conv3_9(x)
|
||||
x = self._conv3_10(x)
|
||||
x = self._conv3_11(x)
|
||||
x = self._conv3_12(x)
|
||||
x = self._conv3_13(x)
|
||||
x = self._conv3_14(x)
|
||||
x = self._conv3_15(x)
|
||||
x = self._conv3_16(x)
|
||||
x = self._conv3_17(x)
|
||||
x = self._conv3_18(x)
|
||||
x = self._conv3_19(x)
|
||||
x = self._conv3_20(x)
|
||||
x = self._conv3_21(x)
|
||||
x = self._conv3_22(x)
|
||||
|
||||
x = self._conv4_0(x)
|
||||
x = self._conv4_1(x)
|
||||
x = self._conv4_2(x)
|
||||
|
||||
def ResNeXt101_32x8d_wsl():
|
||||
model = ResNeXt101_wsl(cardinality=32, width=8)
|
||||
x = self._avg_pool(x)
|
||||
x = fluid.layers.squeeze(x, axes=[2, 3])
|
||||
x = self._out(x)
|
||||
return x
|
||||
|
||||
def ResNeXt101_32x8d_wsl(**args):
|
||||
model = ResNeXt101WSL(cardinality=32, width=8, **args)
|
||||
return model
|
||||
|
||||
|
||||
def ResNeXt101_32x16d_wsl():
|
||||
model = ResNeXt101_wsl(cardinality=32, width=16)
|
||||
def ResNeXt101_32x16d_wsl(**args):
|
||||
model = ResNeXt101WSL(cardinality=32, width=16, **args)
|
||||
return model
|
||||
|
||||
|
||||
def ResNeXt101_32x32d_wsl():
|
||||
model = ResNeXt101_wsl(cardinality=32, width=32)
|
||||
def ResNeXt101_32x32d_wsl(**args):
|
||||
model = ResNeXt101WSL(cardinality=32, width=32, **args)
|
||||
return model
|
||||
|
||||
|
||||
def ResNeXt101_32x48d_wsl():
|
||||
model = ResNeXt101_wsl(cardinality=32, width=48)
|
||||
return model
|
||||
|
||||
|
||||
def Fix_ResNeXt101_32x48d_wsl():
|
||||
model = ResNeXt101_wsl(cardinality=32, width=48)
|
||||
def ResNeXt101_32x48d_wsl(**args):
|
||||
model = ResNeXt101WSL(cardinality=32, width=48, **args)
|
||||
return model
|
|
@ -1,133 +1,151 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
|
||||
|
||||
__all__ = ["SqueezeNet", "SqueezeNet1_0", "SqueezeNet1_1"]
|
||||
__all__ = ["SqueezeNet1_0", "SqueezeNet1_1"]
|
||||
|
||||
class MakeFireConv(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
filter_size,
|
||||
padding=0,
|
||||
name=None):
|
||||
super(MakeFireConv, self).__init__()
|
||||
self._conv = Conv2D(input_channels,
|
||||
output_channels,
|
||||
filter_size,
|
||||
padding=padding,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_offset"))
|
||||
|
||||
class SqueezeNet():
|
||||
def __init__(self, version='1.0'):
|
||||
def forward(self, inputs):
|
||||
return self._conv(inputs)
|
||||
|
||||
class MakeFire(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
squeeze_channels,
|
||||
expand1x1_channels,
|
||||
expand3x3_channels,
|
||||
name=None):
|
||||
super(MakeFire, self).__init__()
|
||||
self._conv = MakeFireConv(input_channels,
|
||||
squeeze_channels,
|
||||
1,
|
||||
name=name + "_squeeze1x1")
|
||||
self._conv_path1 = MakeFireConv(squeeze_channels,
|
||||
expand1x1_channels,
|
||||
1,
|
||||
name=name + "_expand1x1")
|
||||
self._conv_path2 = MakeFireConv(squeeze_channels,
|
||||
expand3x3_channels,
|
||||
3,
|
||||
padding=1,
|
||||
name=name + "_expand3x3")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv(inputs)
|
||||
x1 = self._conv_path1(x)
|
||||
x2 = self._conv_path2(x)
|
||||
return fluid.layers.concat([x1, x2], axis=1)
|
||||
|
||||
class SqueezeNet(fluid.dygraph.Layer):
|
||||
def __init__(self, version, class_dim=1000):
|
||||
super(SqueezeNet, self).__init__()
|
||||
self.version = version
|
||||
|
||||
def net(self, input, class_dim=1000):
|
||||
version = self.version
|
||||
assert version in ['1.0', '1.1'], \
|
||||
"supported version are {} but input version is {}".format(['1.0', '1.1'], version)
|
||||
if version == '1.0':
|
||||
conv = fluid.layers.conv2d(
|
||||
input,
|
||||
num_filters=96,
|
||||
filter_size=7,
|
||||
stride=2,
|
||||
act='relu',
|
||||
param_attr=fluid.param_attr.ParamAttr(name="conv1_weights"),
|
||||
bias_attr=ParamAttr(name='conv1_offset'))
|
||||
conv = fluid.layers.pool2d(
|
||||
conv, pool_size=3, pool_stride=2, pool_type='max')
|
||||
conv = self.make_fire(conv, 16, 64, 64, name='fire2')
|
||||
conv = self.make_fire(conv, 16, 64, 64, name='fire3')
|
||||
conv = self.make_fire(conv, 32, 128, 128, name='fire4')
|
||||
conv = fluid.layers.pool2d(
|
||||
conv, pool_size=3, pool_stride=2, pool_type='max')
|
||||
conv = self.make_fire(conv, 32, 128, 128, name='fire5')
|
||||
conv = self.make_fire(conv, 48, 192, 192, name='fire6')
|
||||
conv = self.make_fire(conv, 48, 192, 192, name='fire7')
|
||||
conv = self.make_fire(conv, 64, 256, 256, name='fire8')
|
||||
conv = fluid.layers.pool2d(
|
||||
conv, pool_size=3, pool_stride=2, pool_type='max')
|
||||
conv = self.make_fire(conv, 64, 256, 256, name='fire9')
|
||||
if self.version == "1.0":
|
||||
self._conv = Conv2D(3,
|
||||
96,
|
||||
7,
|
||||
stride=2,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name="conv1_weights"),
|
||||
bias_attr=ParamAttr(name="conv1_offset"))
|
||||
self._pool = Pool2D(pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_type="max")
|
||||
self._conv1 = MakeFire(96, 16, 64, 64, name="fire2")
|
||||
self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
|
||||
self._conv3 = MakeFire(128, 32, 128, 128, name="fire4")
|
||||
|
||||
self._conv4 = MakeFire(256, 32, 128, 128, name="fire5")
|
||||
self._conv5 = MakeFire(256, 48, 192, 192, name="fire6")
|
||||
self._conv6 = MakeFire(384, 48, 192, 192, name="fire7")
|
||||
self._conv7 = MakeFire(384, 64, 256, 256, name="fire8")
|
||||
|
||||
self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
|
||||
else:
|
||||
conv = fluid.layers.conv2d(
|
||||
input,
|
||||
num_filters=64,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act='relu',
|
||||
param_attr=fluid.param_attr.ParamAttr(name="conv1_weights"),
|
||||
bias_attr=ParamAttr(name='conv1_offset'))
|
||||
conv = fluid.layers.pool2d(
|
||||
conv, pool_size=3, pool_stride=2, pool_type='max')
|
||||
conv = self.make_fire(conv, 16, 64, 64, name='fire2')
|
||||
conv = self.make_fire(conv, 16, 64, 64, name='fire3')
|
||||
conv = fluid.layers.pool2d(
|
||||
conv, pool_size=3, pool_stride=2, pool_type='max')
|
||||
conv = self.make_fire(conv, 32, 128, 128, name='fire4')
|
||||
conv = self.make_fire(conv, 32, 128, 128, name='fire5')
|
||||
conv = fluid.layers.pool2d(
|
||||
conv, pool_size=3, pool_stride=2, pool_type='max')
|
||||
conv = self.make_fire(conv, 48, 192, 192, name='fire6')
|
||||
conv = self.make_fire(conv, 48, 192, 192, name='fire7')
|
||||
conv = self.make_fire(conv, 64, 256, 256, name='fire8')
|
||||
conv = self.make_fire(conv, 64, 256, 256, name='fire9')
|
||||
conv = fluid.layers.dropout(conv, dropout_prob=0.5)
|
||||
conv = fluid.layers.conv2d(
|
||||
conv,
|
||||
num_filters=class_dim,
|
||||
filter_size=1,
|
||||
act='relu',
|
||||
param_attr=fluid.param_attr.ParamAttr(name="conv10_weights"),
|
||||
bias_attr=ParamAttr(name='conv10_offset'))
|
||||
conv = fluid.layers.pool2d(conv, pool_type='avg', global_pooling=True)
|
||||
out = fluid.layers.flatten(conv)
|
||||
return out
|
||||
self._conv = Conv2D(3,
|
||||
64,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name="conv1_weights"),
|
||||
bias_attr=ParamAttr(name="conv1_offset"))
|
||||
self._pool = Pool2D(pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_type="max")
|
||||
self._conv1 = MakeFire(64, 16, 64, 64, name="fire2")
|
||||
self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
|
||||
|
||||
def make_fire_conv(self,
|
||||
input,
|
||||
num_filters,
|
||||
filter_size,
|
||||
padding=0,
|
||||
name=None):
|
||||
conv = fluid.layers.conv2d(
|
||||
input,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
padding=padding,
|
||||
act='relu',
|
||||
param_attr=fluid.param_attr.ParamAttr(name=name + "_weights"),
|
||||
bias_attr=ParamAttr(name=name + '_offset'))
|
||||
return conv
|
||||
self._conv3 = MakeFire(128, 32, 128, 128, name="fire4")
|
||||
self._conv4 = MakeFire(256, 32, 128, 128, name="fire5")
|
||||
|
||||
def make_fire(self,
|
||||
input,
|
||||
squeeze_channels,
|
||||
expand1x1_channels,
|
||||
expand3x3_channels,
|
||||
name=None):
|
||||
conv = self.make_fire_conv(
|
||||
input, squeeze_channels, 1, name=name + '_squeeze1x1')
|
||||
conv_path1 = self.make_fire_conv(
|
||||
conv, expand1x1_channels, 1, name=name + '_expand1x1')
|
||||
conv_path2 = self.make_fire_conv(
|
||||
conv, expand3x3_channels, 3, 1, name=name + '_expand3x3')
|
||||
out = fluid.layers.concat([conv_path1, conv_path2], axis=1)
|
||||
return out
|
||||
self._conv5 = MakeFire(256, 48, 192, 192, name="fire6")
|
||||
self._conv6 = MakeFire(384, 48, 192, 192, name="fire7")
|
||||
self._conv7 = MakeFire(384, 64, 256, 256, name="fire8")
|
||||
self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
|
||||
|
||||
self._drop = Dropout(p=0.5)
|
||||
self._conv9 = Conv2D(512,
|
||||
class_dim,
|
||||
1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name="conv10_weights"),
|
||||
bias_attr=ParamAttr(name="conv10_offset"))
|
||||
self._avg_pool = Pool2D(pool_type="avg",
|
||||
global_pooling=True)
|
||||
|
||||
def SqueezeNet1_0():
|
||||
model = SqueezeNet(version='1.0')
|
||||
def forward(self, inputs):
|
||||
x = self._conv(inputs)
|
||||
x = self._pool(x)
|
||||
if self.version=="1.0":
|
||||
x = self._conv1(x)
|
||||
x = self._conv2(x)
|
||||
x = self._conv3(x)
|
||||
x = self._pool(x)
|
||||
x = self._conv4(x)
|
||||
x = self._conv5(x)
|
||||
x = self._conv6(x)
|
||||
x = self._conv7(x)
|
||||
x = self._pool(x)
|
||||
x = self._conv8(x)
|
||||
else:
|
||||
x = self._conv1(x)
|
||||
x = self._conv2(x)
|
||||
x = self._pool(x)
|
||||
x = self._conv3(x)
|
||||
x = self._conv4(x)
|
||||
x = self._pool(x)
|
||||
x = self._conv5(x)
|
||||
x = self._conv6(x)
|
||||
x = self._conv7(x)
|
||||
x = self._conv8(x)
|
||||
x = self._drop(x)
|
||||
x = self._conv9(x)
|
||||
x = self._avg_pool(x)
|
||||
x = fluid.layers.squeeze(x, axes=[2,3])
|
||||
return x
|
||||
|
||||
def SqueezeNet1_0(**args):
|
||||
model = SqueezeNet(version="1.0", **args)
|
||||
return model
|
||||
|
||||
|
||||
def SqueezeNet1_1():
|
||||
model = SqueezeNet(version='1.1')
|
||||
def SqueezeNet1_1(**args):
|
||||
model = SqueezeNet(version="1.1", **args)
|
||||
return model
|
|
@ -1,108 +1,131 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
|
||||
|
||||
__all__ = ["VGGNet", "VGG11", "VGG13", "VGG16", "VGG19"]
|
||||
__all__ = ["VGG11", "VGG13", "VGG16", "VGG19"]
|
||||
|
||||
class ConvBlock(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
groups,
|
||||
name=None):
|
||||
super(ConvBlock, self).__init__()
|
||||
|
||||
self.groups = groups
|
||||
self._conv_1 = Conv2D(num_channels=input_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=name + "1_weights"),
|
||||
bias_attr=False)
|
||||
if groups == 2 or groups == 3 or groups == 4:
|
||||
self._conv_2 = Conv2D(num_channels=output_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=name + "2_weights"),
|
||||
bias_attr=False)
|
||||
if groups == 3 or groups == 4:
|
||||
self._conv_3 = Conv2D(num_channels=output_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=name + "3_weights"),
|
||||
bias_attr=False)
|
||||
if groups == 4:
|
||||
self._conv_4 = Conv2D(num_channels=output_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=name + "4_weights"),
|
||||
bias_attr=False)
|
||||
self._pool = Pool2D(pool_size=2,
|
||||
pool_type="max",
|
||||
pool_stride=2)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv_1(inputs)
|
||||
if self.groups == 2 or self.groups == 3 or self.groups == 4:
|
||||
x = self._conv_2(x)
|
||||
if self.groups == 3 or self.groups == 4 :
|
||||
x = self._conv_3(x)
|
||||
if self.groups == 4:
|
||||
x = self._conv_4(x)
|
||||
x = self._pool(x)
|
||||
return x
|
||||
|
||||
class VGGNet(fluid.dygraph.Layer):
|
||||
def __init__(self, layers=11, class_dim=1000):
|
||||
super(VGGNet, self).__init__()
|
||||
|
||||
class VGGNet():
|
||||
def __init__(self, layers=16):
|
||||
self.layers = layers
|
||||
self.vgg_configure = {11: [1, 1, 2, 2, 2],
|
||||
13: [2, 2, 2, 2, 2],
|
||||
16: [2, 2, 3, 3, 3],
|
||||
19: [2, 2, 4, 4, 4]}
|
||||
assert self.layers in self.vgg_configure.keys(), \
|
||||
"supported layers are {} but input layer is {}".format(vgg_configure.keys(), layers)
|
||||
self.groups = self.vgg_configure[self.layers]
|
||||
|
||||
def net(self, input, class_dim=1000):
|
||||
layers = self.layers
|
||||
vgg_spec = {
|
||||
11: ([1, 1, 2, 2, 2]),
|
||||
13: ([2, 2, 2, 2, 2]),
|
||||
16: ([2, 2, 3, 3, 3]),
|
||||
19: ([2, 2, 4, 4, 4])
|
||||
}
|
||||
assert layers in vgg_spec.keys(), \
|
||||
"supported layers are {} but input layer is {}".format(vgg_spec.keys(), layers)
|
||||
self._conv_block_1 = ConvBlock(3, 64, self.groups[0], name="conv1_")
|
||||
self._conv_block_2 = ConvBlock(64, 128, self.groups[1], name="conv2_")
|
||||
self._conv_block_3 = ConvBlock(128, 256, self.groups[2], name="conv3_")
|
||||
self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")
|
||||
self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")
|
||||
|
||||
nums = vgg_spec[layers]
|
||||
conv1 = self.conv_block(input, 64, nums[0], name="conv1_")
|
||||
conv2 = self.conv_block(conv1, 128, nums[1], name="conv2_")
|
||||
conv3 = self.conv_block(conv2, 256, nums[2], name="conv3_")
|
||||
conv4 = self.conv_block(conv3, 512, nums[3], name="conv4_")
|
||||
conv5 = self.conv_block(conv4, 512, nums[4], name="conv5_")
|
||||
self._drop = fluid.dygraph.Dropout(p=0.5)
|
||||
self._fc1 = Linear(input_dim=7*7*512,
|
||||
output_dim=4096,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name="fc6_weights"),
|
||||
bias_attr=ParamAttr(name="fc6_offset"))
|
||||
self._fc2 = Linear(input_dim=4096,
|
||||
output_dim=4096,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name="fc7_weights"),
|
||||
bias_attr=ParamAttr(name="fc7_offset"))
|
||||
self._out = Linear(input_dim=4096,
|
||||
output_dim=class_dim,
|
||||
param_attr=ParamAttr(name="fc8_weights"),
|
||||
bias_attr=ParamAttr(name="fc8_offset"))
|
||||
|
||||
fc_dim = 4096
|
||||
fc_name = ["fc6", "fc7", "fc8"]
|
||||
fc1 = fluid.layers.fc(
|
||||
input=conv5,
|
||||
size=fc_dim,
|
||||
act='relu',
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
name=fc_name[0] + "_weights"),
|
||||
bias_attr=fluid.param_attr.ParamAttr(name=fc_name[0] + "_offset"))
|
||||
fc1 = fluid.layers.dropout(x=fc1, dropout_prob=0.5)
|
||||
fc2 = fluid.layers.fc(
|
||||
input=fc1,
|
||||
size=fc_dim,
|
||||
act='relu',
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
name=fc_name[1] + "_weights"),
|
||||
bias_attr=fluid.param_attr.ParamAttr(name=fc_name[1] + "_offset"))
|
||||
fc2 = fluid.layers.dropout(x=fc2, dropout_prob=0.5)
|
||||
out = fluid.layers.fc(
|
||||
input=fc2,
|
||||
size=class_dim,
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
name=fc_name[2] + "_weights"),
|
||||
bias_attr=fluid.param_attr.ParamAttr(name=fc_name[2] + "_offset"))
|
||||
def forward(self, inputs):
|
||||
x = self._conv_block_1(inputs)
|
||||
x = self._conv_block_2(x)
|
||||
x = self._conv_block_3(x)
|
||||
x = self._conv_block_4(x)
|
||||
x = self._conv_block_5(x)
|
||||
|
||||
return out
|
||||
x = fluid.layers.flatten(x, axis=0)
|
||||
x = self._fc1(x)
|
||||
x = self._drop(x)
|
||||
x = self._fc2(x)
|
||||
x = self._drop(x)
|
||||
x = self._out(x)
|
||||
return x
|
||||
|
||||
def conv_block(self, input, num_filter, groups, name=None):
|
||||
conv = input
|
||||
for i in range(groups):
|
||||
conv = fluid.layers.conv2d(
|
||||
input=conv,
|
||||
num_filters=num_filter,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act='relu',
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
name=name + str(i + 1) + "_weights"),
|
||||
bias_attr=False)
|
||||
return fluid.layers.pool2d(
|
||||
input=conv, pool_size=2, pool_type='max', pool_stride=2)
|
||||
|
||||
|
||||
def VGG11():
|
||||
model = VGGNet(layers=11)
|
||||
def VGG11(**args):
|
||||
model = VGGNet(layers=11, **args)
|
||||
return model
|
||||
|
||||
|
||||
def VGG13():
|
||||
model = VGGNet(layers=13)
|
||||
def VGG13(**args):
|
||||
model = VGGNet(layers=13, **args)
|
||||
return model
|
||||
|
||||
|
||||
def VGG16():
|
||||
model = VGGNet(layers=16)
|
||||
def VGG16(**args):
|
||||
model = VGGNet(layers=16, **args)
|
||||
return model
|
||||
|
||||
|
||||
def VGG19():
|
||||
model = VGGNet(layers=19)
|
||||
def VGG19(**args):
|
||||
model = VGGNet(layers=19, **args)
|
||||
return model
|
|
@ -1,250 +1,26 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.layer_helper import LayerHelper
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
|
||||
import math
|
||||
|
||||
__all__ = ['Xception', 'Xception41', 'Xception65', 'Xception71']
|
||||
__all__ = ['Xception41', 'Xception65', 'Xception71']
|
||||
|
||||
|
||||
class Xception(object):
|
||||
"""Xception"""
|
||||
class ConvBNLayer(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
def __init__(self, entry_flow_block_num=3, middle_flow_block_num=8):
|
||||
self.entry_flow_block_num = entry_flow_block_num
|
||||
self.middle_flow_block_num = middle_flow_block_num
|
||||
return
|
||||
|
||||
def net(self, input, class_dim=1000):
|
||||
conv = self.entry_flow(input, self.entry_flow_block_num)
|
||||
conv = self.middle_flow(conv, self.middle_flow_block_num)
|
||||
conv = self.exit_flow(conv, class_dim)
|
||||
|
||||
return conv
|
||||
|
||||
def entry_flow(self, input, block_num=3):
|
||||
'''xception entry_flow'''
|
||||
name = "entry_flow"
|
||||
conv = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=32,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name=name + "_conv1")
|
||||
conv = self.conv_bn_layer(
|
||||
input=conv,
|
||||
num_filters=64,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name=name + "_conv2")
|
||||
|
||||
if block_num == 3:
|
||||
relu_first = [False, True, True]
|
||||
num_filters = [128, 256, 728]
|
||||
stride = [2, 2, 2]
|
||||
elif block_num == 5:
|
||||
relu_first = [False, True, True, True, True]
|
||||
num_filters = [128, 256, 256, 728, 728]
|
||||
stride = [2, 1, 2, 1, 2]
|
||||
else:
|
||||
sys.exit(-1)
|
||||
|
||||
for block in range(block_num):
|
||||
curr_name = "{}_{}".format(name, block)
|
||||
conv = self.entry_flow_bottleneck_block(
|
||||
conv,
|
||||
num_filters=num_filters[block],
|
||||
name=curr_name,
|
||||
stride=stride[block],
|
||||
relu_first=relu_first[block])
|
||||
|
||||
return conv
|
||||
|
||||
def entry_flow_bottleneck_block(self,
|
||||
input,
|
||||
num_filters,
|
||||
name,
|
||||
stride=2,
|
||||
relu_first=False):
|
||||
'''entry_flow_bottleneck_block'''
|
||||
short = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name + "_branch1_weights"),
|
||||
bias_attr=False)
|
||||
|
||||
conv0 = input
|
||||
if relu_first:
|
||||
conv0 = fluid.layers.relu(conv0)
|
||||
|
||||
conv1 = self.separable_conv(
|
||||
conv0, num_filters, stride=1, name=name + "_branch2a_weights")
|
||||
|
||||
conv2 = fluid.layers.relu(conv1)
|
||||
conv2 = self.separable_conv(
|
||||
conv2, num_filters, stride=1, name=name + "_branch2b_weights")
|
||||
|
||||
pool = fluid.layers.pool2d(
|
||||
input=conv2,
|
||||
pool_size=3,
|
||||
pool_stride=stride,
|
||||
pool_padding=1,
|
||||
pool_type='max')
|
||||
|
||||
return fluid.layers.elementwise_add(x=short, y=pool)
|
||||
|
||||
def middle_flow(self, input, block_num=8):
|
||||
'''xception middle_flow'''
|
||||
num_filters = 728
|
||||
conv = input
|
||||
for block in range(block_num):
|
||||
name = "middle_flow_{}".format(block)
|
||||
conv = self.middle_flow_bottleneck_block(conv, num_filters, name)
|
||||
|
||||
return conv
|
||||
|
||||
def middle_flow_bottleneck_block(self, input, num_filters, name):
|
||||
'''middle_flow_bottleneck_block'''
|
||||
conv0 = fluid.layers.relu(input)
|
||||
conv0 = self.separable_conv(
|
||||
conv0,
|
||||
num_filters=num_filters,
|
||||
stride=1,
|
||||
name=name + "_branch2a_weights")
|
||||
|
||||
conv1 = fluid.layers.relu(conv0)
|
||||
conv1 = self.separable_conv(
|
||||
conv1,
|
||||
num_filters=num_filters,
|
||||
stride=1,
|
||||
name=name + "_branch2b_weights")
|
||||
|
||||
conv2 = fluid.layers.relu(conv1)
|
||||
conv2 = self.separable_conv(
|
||||
conv2,
|
||||
num_filters=num_filters,
|
||||
stride=1,
|
||||
name=name + "_branch2c_weights")
|
||||
|
||||
return fluid.layers.elementwise_add(x=input, y=conv2)
|
||||
|
||||
def exit_flow(self, input, class_dim):
|
||||
'''xception exit flow'''
|
||||
name = "exit_flow"
|
||||
num_filters1 = 728
|
||||
num_filters2 = 1024
|
||||
conv0 = self.exit_flow_bottleneck_block(
|
||||
input, num_filters1, num_filters2, name=name + "_1")
|
||||
|
||||
conv1 = self.separable_conv(
|
||||
conv0, num_filters=1536, stride=1, name=name + "_2")
|
||||
conv1 = fluid.layers.relu(conv1)
|
||||
|
||||
conv2 = self.separable_conv(
|
||||
conv1, num_filters=2048, stride=1, name=name + "_3")
|
||||
conv2 = fluid.layers.relu(conv2)
|
||||
|
||||
pool = fluid.layers.pool2d(
|
||||
input=conv2, pool_type='avg', global_pooling=True)
|
||||
|
||||
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
|
||||
out = fluid.layers.fc(
|
||||
input=pool,
|
||||
size=class_dim,
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
name='fc_weights',
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=fluid.param_attr.ParamAttr(name='fc_offset'))
|
||||
|
||||
return out
|
||||
|
||||
def exit_flow_bottleneck_block(self, input, num_filters1, num_filters2,
|
||||
name):
|
||||
'''entry_flow_bottleneck_block'''
|
||||
short = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=num_filters2,
|
||||
filter_size=1,
|
||||
stride=2,
|
||||
padding=0,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name + "_branch1_weights"),
|
||||
bias_attr=False)
|
||||
|
||||
conv0 = fluid.layers.relu(input)
|
||||
conv1 = self.separable_conv(
|
||||
conv0, num_filters1, stride=1, name=name + "_branch2a_weights")
|
||||
|
||||
conv2 = fluid.layers.relu(conv1)
|
||||
conv2 = self.separable_conv(
|
||||
conv2, num_filters2, stride=1, name=name + "_branch2b_weights")
|
||||
|
||||
pool = fluid.layers.pool2d(
|
||||
input=conv2,
|
||||
pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_padding=1,
|
||||
pool_type='max')
|
||||
|
||||
return fluid.layers.elementwise_add(x=short, y=pool)
|
||||
|
||||
def separable_conv(self, input, num_filters, stride=1, name=None):
|
||||
"""separable_conv"""
|
||||
pointwise_conv = self.conv_bn_layer(
|
||||
input=input,
|
||||
filter_size=1,
|
||||
num_filters=num_filters,
|
||||
stride=1,
|
||||
name=name + "_sep")
|
||||
|
||||
depthwise_conv = self.conv_bn_layer(
|
||||
input=pointwise_conv,
|
||||
filter_size=3,
|
||||
num_filters=num_filters,
|
||||
stride=stride,
|
||||
groups=num_filters,
|
||||
use_cudnn=False,
|
||||
name=name + "_dw")
|
||||
|
||||
return depthwise_conv
|
||||
|
||||
def conv_bn_layer(self,
|
||||
input,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
use_cudnn=True,
|
||||
name=None):
|
||||
"""conv_bn_layer"""
|
||||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
self._conv = Conv2D(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
|
@ -252,30 +28,325 @@ class Xception(object):
|
|||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False,
|
||||
use_cudnn=use_cudnn)
|
||||
|
||||
bias_attr=False)
|
||||
bn_name = "bn_" + name
|
||||
|
||||
return fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
self._batch_norm = BatchNorm(
|
||||
num_filters,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
param_attr=ParamAttr(name=bn_name + "_scale"),
|
||||
bias_attr=ParamAttr(name=bn_name + "_offset"),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
def Xception41():
|
||||
model = Xception(entry_flow_block_num=3, middle_flow_block_num=8)
|
||||
|
||||
class SeparableConv(fluid.dygraph.Layer):
|
||||
def __init__(self, input_channels, output_channels, stride=1, name=None):
|
||||
super(SeparableConv, self).__init__()
|
||||
|
||||
self._pointwise_conv = ConvBNLayer(
|
||||
input_channels, output_channels, 1, name=name + "_sep")
|
||||
self._depthwise_conv = ConvBNLayer(
|
||||
output_channels,
|
||||
output_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
groups=output_channels,
|
||||
name=name + "_dw")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._pointwise_conv(inputs)
|
||||
x = self._depthwise_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class EntryFlowBottleneckBlock(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
stride=2,
|
||||
name=None,
|
||||
relu_first=False):
|
||||
super(EntryFlowBottleneckBlock, self).__init__()
|
||||
self.relu_first = relu_first
|
||||
|
||||
self._short = Conv2D(
|
||||
num_channels=input_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name + "_branch1_weights"),
|
||||
bias_attr=False)
|
||||
self._conv1 = SeparableConv(
|
||||
input_channels,
|
||||
output_channels,
|
||||
stride=1,
|
||||
name=name + "_branch2a_weights")
|
||||
self._conv2 = SeparableConv(
|
||||
output_channels,
|
||||
output_channels,
|
||||
stride=1,
|
||||
name=name + "_branch2b_weights")
|
||||
self._pool = Pool2D(
|
||||
pool_size=3, pool_stride=stride, pool_padding=1, pool_type="max")
|
||||
|
||||
def forward(self, inputs):
|
||||
conv0 = inputs
|
||||
short = self._short(inputs)
|
||||
layer_helper = LayerHelper(self.full_name(), act="relu")
|
||||
if self.relu_first:
|
||||
conv0 = layer_helper.append_activation(conv0)
|
||||
conv1 = self._conv1(conv0)
|
||||
conv2 = layer_helper.append_activation(conv1)
|
||||
conv2 = self._conv2(conv2)
|
||||
pool = self._pool(conv2)
|
||||
return fluid.layers.elementwise_add(x=short, y=pool)
|
||||
|
||||
|
||||
class EntryFlow(fluid.dygraph.Layer):
|
||||
def __init__(self, block_num=3):
|
||||
super(EntryFlow, self).__init__()
|
||||
|
||||
name = "entry_flow"
|
||||
self.block_num = block_num
|
||||
self._conv1 = ConvBNLayer(
|
||||
3, 32, 3, stride=2, act="relu", name=name + "_conv1")
|
||||
self._conv2 = ConvBNLayer(32, 64, 3, act="relu", name=name + "_conv2")
|
||||
if block_num == 3:
|
||||
self._conv_0 = EntryFlowBottleneckBlock(
|
||||
64, 128, stride=2, name=name + "_0", relu_first=False)
|
||||
self._conv_1 = EntryFlowBottleneckBlock(
|
||||
128, 256, stride=2, name=name + "_1", relu_first=True)
|
||||
self._conv_2 = EntryFlowBottleneckBlock(
|
||||
256, 728, stride=2, name=name + "_2", relu_first=True)
|
||||
elif block_num == 5:
|
||||
self._conv_0 = EntryFlowBottleneckBlock(
|
||||
64, 128, stride=2, name=name + "_0", relu_first=False)
|
||||
self._conv_1 = EntryFlowBottleneckBlock(
|
||||
128, 256, stride=1, name=name + "_1", relu_first=True)
|
||||
self._conv_2 = EntryFlowBottleneckBlock(
|
||||
256, 256, stride=2, name=name + "_2", relu_first=True)
|
||||
self._conv_3 = EntryFlowBottleneckBlock(
|
||||
256, 728, stride=1, name=name + "_3", relu_first=True)
|
||||
self._conv_4 = EntryFlowBottleneckBlock(
|
||||
728, 728, stride=2, name=name + "_4", relu_first=True)
|
||||
else:
|
||||
sys.exit(-1)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv1(inputs)
|
||||
x = self._conv2(x)
|
||||
|
||||
if self.block_num == 3:
|
||||
x = self._conv_0(x)
|
||||
x = self._conv_1(x)
|
||||
x = self._conv_2(x)
|
||||
elif self.block_num == 5:
|
||||
x = self._conv_0(x)
|
||||
x = self._conv_1(x)
|
||||
x = self._conv_2(x)
|
||||
x = self._conv_3(x)
|
||||
x = self._conv_4(x)
|
||||
return x
|
||||
|
||||
|
||||
class MiddleFlowBottleneckBlock(fluid.dygraph.Layer):
|
||||
def __init__(self, input_channels, output_channels, name):
|
||||
super(MiddleFlowBottleneckBlock, self).__init__()
|
||||
|
||||
self._conv_0 = SeparableConv(
|
||||
input_channels,
|
||||
output_channels,
|
||||
stride=1,
|
||||
name=name + "_branch2a_weights")
|
||||
self._conv_1 = SeparableConv(
|
||||
output_channels,
|
||||
output_channels,
|
||||
stride=1,
|
||||
name=name + "_branch2b_weights")
|
||||
self._conv_2 = SeparableConv(
|
||||
output_channels,
|
||||
output_channels,
|
||||
stride=1,
|
||||
name=name + "_branch2c_weights")
|
||||
|
||||
def forward(self, inputs):
|
||||
layer_helper = LayerHelper(self.full_name(), act="relu")
|
||||
conv0 = layer_helper.append_activation(inputs)
|
||||
conv0 = self._conv_0(conv0)
|
||||
conv1 = layer_helper.append_activation(conv0)
|
||||
conv1 = self._conv_1(conv1)
|
||||
conv2 = layer_helper.append_activation(conv1)
|
||||
conv2 = self._conv_2(conv2)
|
||||
return fluid.layers.elementwise_add(x=inputs, y=conv2)
|
||||
|
||||
|
||||
class MiddleFlow(fluid.dygraph.Layer):
|
||||
def __init__(self, block_num=8):
|
||||
super(MiddleFlow, self).__init__()
|
||||
|
||||
self.block_num = block_num
|
||||
self._conv_0 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_0")
|
||||
self._conv_1 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_1")
|
||||
self._conv_2 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_2")
|
||||
self._conv_3 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_3")
|
||||
self._conv_4 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_4")
|
||||
self._conv_5 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_5")
|
||||
self._conv_6 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_6")
|
||||
self._conv_7 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_7")
|
||||
if block_num == 16:
|
||||
self._conv_8 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_8")
|
||||
self._conv_9 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_9")
|
||||
self._conv_10 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_10")
|
||||
self._conv_11 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_11")
|
||||
self._conv_12 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_12")
|
||||
self._conv_13 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_13")
|
||||
self._conv_14 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_14")
|
||||
self._conv_15 = MiddleFlowBottleneckBlock(
|
||||
728, 728, name="middle_flow_15")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv_0(inputs)
|
||||
x = self._conv_1(x)
|
||||
x = self._conv_2(x)
|
||||
x = self._conv_3(x)
|
||||
x = self._conv_4(x)
|
||||
x = self._conv_5(x)
|
||||
x = self._conv_6(x)
|
||||
x = self._conv_7(x)
|
||||
if self.block_num == 16:
|
||||
x = self._conv_8(x)
|
||||
x = self._conv_9(x)
|
||||
x = self._conv_10(x)
|
||||
x = self._conv_11(x)
|
||||
x = self._conv_12(x)
|
||||
x = self._conv_13(x)
|
||||
x = self._conv_14(x)
|
||||
x = self._conv_15(x)
|
||||
return x
|
||||
|
||||
|
||||
class ExitFlowBottleneckBlock(fluid.dygraph.Layer):
|
||||
def __init__(self, input_channels, output_channels1, output_channels2,
|
||||
name):
|
||||
super(ExitFlowBottleneckBlock, self).__init__()
|
||||
|
||||
self._short = Conv2D(
|
||||
num_channels=input_channels,
|
||||
num_filters=output_channels2,
|
||||
filter_size=1,
|
||||
stride=2,
|
||||
padding=0,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name + "_branch1_weights"),
|
||||
bias_attr=False)
|
||||
self._conv_1 = SeparableConv(
|
||||
input_channels,
|
||||
output_channels1,
|
||||
stride=1,
|
||||
name=name + "_branch2a_weights")
|
||||
self._conv_2 = SeparableConv(
|
||||
output_channels1,
|
||||
output_channels2,
|
||||
stride=1,
|
||||
name=name + "_branch2b_weights")
|
||||
self._pool = Pool2D(
|
||||
pool_size=3, pool_stride=2, pool_padding=1, pool_type="max")
|
||||
|
||||
def forward(self, inputs):
|
||||
short = self._short(inputs)
|
||||
layer_helper = LayerHelper(self.full_name(), act="relu")
|
||||
conv0 = layer_helper.append_activation(inputs)
|
||||
conv1 = self._conv_1(conv0)
|
||||
conv2 = layer_helper.append_activation(conv1)
|
||||
conv2 = self._conv_2(conv2)
|
||||
pool = self._pool(conv2)
|
||||
return fluid.layers.elementwise_add(x=short, y=pool)
|
||||
|
||||
|
||||
class ExitFlow(fluid.dygraph.Layer):
|
||||
def __init__(self, class_dim):
|
||||
super(ExitFlow, self).__init__()
|
||||
|
||||
name = "exit_flow"
|
||||
|
||||
self._conv_0 = ExitFlowBottleneckBlock(
|
||||
728, 728, 1024, name=name + "_1")
|
||||
self._conv_1 = SeparableConv(1024, 1536, stride=1, name=name + "_2")
|
||||
self._conv_2 = SeparableConv(1536, 2048, stride=1, name=name + "_3")
|
||||
self._pool = Pool2D(pool_type="avg", global_pooling=True)
|
||||
stdv = 1.0 / math.sqrt(2048 * 1.0)
|
||||
self._out = Linear(
|
||||
2048,
|
||||
class_dim,
|
||||
param_attr=ParamAttr(
|
||||
name="fc_weights",
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="fc_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
layer_helper = LayerHelper(self.full_name(), act="relu")
|
||||
conv0 = self._conv_0(inputs)
|
||||
conv1 = self._conv_1(conv0)
|
||||
conv1 = layer_helper.append_activation(conv1)
|
||||
conv2 = self._conv_2(conv1)
|
||||
conv2 = layer_helper.append_activation(conv2)
|
||||
pool = self._pool(conv2)
|
||||
pool = fluid.layers.reshape(pool, [0, -1])
|
||||
out = self._out(pool)
|
||||
return out
|
||||
|
||||
|
||||
class Xception(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
entry_flow_block_num=3,
|
||||
middle_flow_block_num=8,
|
||||
class_dim=1000):
|
||||
super(Xception, self).__init__()
|
||||
self.entry_flow_block_num = entry_flow_block_num
|
||||
self.middle_flow_block_num = middle_flow_block_num
|
||||
self._entry_flow = EntryFlow(entry_flow_block_num)
|
||||
self._middle_flow = MiddleFlow(middle_flow_block_num)
|
||||
self._exit_flow = ExitFlow(class_dim)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._entry_flow(inputs)
|
||||
x = self._middle_flow(x)
|
||||
x = self._exit_flow(x)
|
||||
return x
|
||||
|
||||
|
||||
def Xception41(**args):
|
||||
model = Xception(entry_flow_block_num=3, middle_flow_block_num=8, **args)
|
||||
return model
|
||||
|
||||
|
||||
def Xception65():
|
||||
model = Xception(entry_flow_block_num=3, middle_flow_block_num=16)
|
||||
def Xception65(**args):
|
||||
model = Xception(entry_flow_block_num=3, middle_flow_block_num=16, **args)
|
||||
return model
|
||||
|
||||
|
||||
def Xception71():
|
||||
model = Xception(entry_flow_block_num=5, middle_flow_block_num=16)
|
||||
def Xception71(**args):
|
||||
model = Xception(entry_flow_block_num=5, middle_flow_block_num=16, **args)
|
||||
return model
|
|
@ -1,33 +1,10 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import paddle
|
||||
import math
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.layer_helper import LayerHelper
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
|
||||
|
||||
from .model_libs import scope, name_scope
|
||||
from .model_libs import bn, bn_relu, relu
|
||||
from .model_libs import conv
|
||||
from .model_libs import seperate_conv
|
||||
|
||||
__all__ = ['Xception41_deeplab', 'Xception65_deeplab', 'Xception71_deeplab']
|
||||
__all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"]
|
||||
|
||||
|
||||
def check_data(data, number):
|
||||
|
@ -54,267 +31,355 @@ def check_points(count, points):
|
|||
return (True if count == points else False)
|
||||
|
||||
|
||||
class Xception():
|
||||
def __init__(self, backbone="xception_65"):
|
||||
self.bottleneck_params = self.gen_bottleneck_params(backbone)
|
||||
def gen_bottleneck_params(backbone='xception_65'):
|
||||
if backbone == 'xception_65':
|
||||
bottleneck_params = {
|
||||
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
|
||||
"middle_flow": (16, 1, 728),
|
||||
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
||||
}
|
||||
elif backbone == 'xception_41':
|
||||
bottleneck_params = {
|
||||
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
|
||||
"middle_flow": (8, 1, 728),
|
||||
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
||||
}
|
||||
elif backbone == 'xception_71':
|
||||
bottleneck_params = {
|
||||
"entry_flow": (5, [2, 1, 2, 1, 2], [128, 256, 256, 728, 728]),
|
||||
"middle_flow": (16, 1, 728),
|
||||
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
||||
}
|
||||
else:
|
||||
raise Exception(
|
||||
"xception backbont only support xception_41/xception_65/xception_71"
|
||||
)
|
||||
return bottleneck_params
|
||||
|
||||
|
||||
class ConvBNLayer(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
filter_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self._conv = Conv2D(
|
||||
num_channels=input_channels,
|
||||
num_filters=output_channels,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
param_attr=ParamAttr(name=name + "/weights"),
|
||||
bias_attr=False)
|
||||
self._bn = BatchNorm(
|
||||
num_channels=output_channels,
|
||||
act=act,
|
||||
epsilon=1e-3,
|
||||
momentum=0.99,
|
||||
param_attr=ParamAttr(name=name + "/BatchNorm/gamma"),
|
||||
bias_attr=ParamAttr(name=name + "/BatchNorm/beta"),
|
||||
moving_mean_name=name + "/BatchNorm/moving_mean",
|
||||
moving_variance_name=name + "/BatchNorm/moving_variance")
|
||||
|
||||
def forward(self, inputs):
|
||||
return self._bn(self._conv(inputs))
|
||||
|
||||
|
||||
class Seperate_Conv(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
stride,
|
||||
filter,
|
||||
dilation=1,
|
||||
act=None,
|
||||
name=None):
|
||||
super(Seperate_Conv, self).__init__()
|
||||
|
||||
self._conv1 = Conv2D(
|
||||
num_channels=input_channels,
|
||||
num_filters=input_channels,
|
||||
filter_size=filter,
|
||||
stride=stride,
|
||||
groups=input_channels,
|
||||
padding=(filter) // 2 * dilation,
|
||||
dilation=dilation,
|
||||
param_attr=ParamAttr(name=name + "/depthwise/weights"),
|
||||
bias_attr=False)
|
||||
self._bn1 = BatchNorm(
|
||||
input_channels,
|
||||
act=act,
|
||||
epsilon=1e-3,
|
||||
momentum=0.99,
|
||||
param_attr=ParamAttr(name=name + "/depthwise/BatchNorm/gamma"),
|
||||
bias_attr=ParamAttr(name=name + "/depthwise/BatchNorm/beta"),
|
||||
moving_mean_name=name + "/depthwise/BatchNorm/moving_mean",
|
||||
moving_variance_name=name + "/depthwise/BatchNorm/moving_variance")
|
||||
self._conv2 = Conv2D(
|
||||
input_channels,
|
||||
output_channels,
|
||||
1,
|
||||
stride=1,
|
||||
groups=1,
|
||||
padding=0,
|
||||
param_attr=ParamAttr(name=name + "/pointwise/weights"),
|
||||
bias_attr=False)
|
||||
self._bn2 = BatchNorm(
|
||||
output_channels,
|
||||
act=act,
|
||||
epsilon=1e-3,
|
||||
momentum=0.99,
|
||||
param_attr=ParamAttr(name=name + "/pointwise/BatchNorm/gamma"),
|
||||
bias_attr=ParamAttr(name=name + "/pointwise/BatchNorm/beta"),
|
||||
moving_mean_name=name + "/pointwise/BatchNorm/moving_mean",
|
||||
moving_variance_name=name + "/pointwise/BatchNorm/moving_variance")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv1(inputs)
|
||||
x = self._bn1(x)
|
||||
x = self._conv2(x)
|
||||
x = self._bn2(x)
|
||||
return x
|
||||
|
||||
|
||||
class Xception_Block(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
strides=1,
|
||||
filter_size=3,
|
||||
dilation=1,
|
||||
skip_conv=True,
|
||||
has_skip=True,
|
||||
activation_fn_in_separable_conv=False,
|
||||
name=None):
|
||||
super(Xception_Block, self).__init__()
|
||||
|
||||
repeat_number = 3
|
||||
output_channels = check_data(output_channels, repeat_number)
|
||||
filter_size = check_data(filter_size, repeat_number)
|
||||
strides = check_data(strides, repeat_number)
|
||||
|
||||
self.has_skip = has_skip
|
||||
self.skip_conv = skip_conv
|
||||
self.activation_fn_in_separable_conv = activation_fn_in_separable_conv
|
||||
if not activation_fn_in_separable_conv:
|
||||
self._conv1 = Seperate_Conv(
|
||||
input_channels,
|
||||
output_channels[0],
|
||||
stride=strides[0],
|
||||
filter=filter_size[0],
|
||||
dilation=dilation,
|
||||
name=name + "/separable_conv1")
|
||||
self._conv2 = Seperate_Conv(
|
||||
output_channels[0],
|
||||
output_channels[1],
|
||||
stride=strides[1],
|
||||
filter=filter_size[1],
|
||||
dilation=dilation,
|
||||
name=name + "/separable_conv2")
|
||||
self._conv3 = Seperate_Conv(
|
||||
output_channels[1],
|
||||
output_channels[2],
|
||||
stride=strides[2],
|
||||
filter=filter_size[2],
|
||||
dilation=dilation,
|
||||
name=name + "/separable_conv3")
|
||||
else:
|
||||
self._conv1 = Seperate_Conv(
|
||||
input_channels,
|
||||
output_channels[0],
|
||||
stride=strides[0],
|
||||
filter=filter_size[0],
|
||||
act="relu",
|
||||
dilation=dilation,
|
||||
name=name + "/separable_conv1")
|
||||
self._conv2 = Seperate_Conv(
|
||||
output_channels[0],
|
||||
output_channels[1],
|
||||
stride=strides[1],
|
||||
filter=filter_size[1],
|
||||
act="relu",
|
||||
dilation=dilation,
|
||||
name=name + "/separable_conv2")
|
||||
self._conv3 = Seperate_Conv(
|
||||
output_channels[1],
|
||||
output_channels[2],
|
||||
stride=strides[2],
|
||||
filter=filter_size[2],
|
||||
act="relu",
|
||||
dilation=dilation,
|
||||
name=name + "/separable_conv3")
|
||||
|
||||
if has_skip and skip_conv:
|
||||
self._short = ConvBNLayer(
|
||||
input_channels,
|
||||
output_channels[-1],
|
||||
1,
|
||||
stride=strides[-1],
|
||||
padding=0,
|
||||
name=name + "/shortcut")
|
||||
|
||||
def forward(self, inputs):
|
||||
layer_helper = LayerHelper(self.full_name(), act='relu')
|
||||
if not self.activation_fn_in_separable_conv:
|
||||
x = layer_helper.append_activation(inputs)
|
||||
x = self._conv1(x)
|
||||
x = layer_helper.append_activation(x)
|
||||
x = self._conv2(x)
|
||||
x = layer_helper.append_activation(x)
|
||||
x = self._conv3(x)
|
||||
else:
|
||||
x = self._conv1(inputs)
|
||||
x = self._conv2(x)
|
||||
x = self._conv3(x)
|
||||
if self.has_skip is False:
|
||||
return x
|
||||
if self.skip_conv:
|
||||
skip = self._short(inputs)
|
||||
else:
|
||||
skip = inputs
|
||||
return fluid.layers.elementwise_add(x, skip)
|
||||
|
||||
|
||||
class XceptionDeeplab(fluid.dygraph.Layer):
|
||||
def __init__(self, backbone, class_dim=1000):
|
||||
super(XceptionDeeplab, self).__init__()
|
||||
|
||||
bottleneck_params = gen_bottleneck_params(backbone)
|
||||
self.backbone = backbone
|
||||
|
||||
def gen_bottleneck_params(self, backbone='xception_65'):
|
||||
if backbone == 'xception_65':
|
||||
bottleneck_params = {
|
||||
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
|
||||
"middle_flow": (16, 1, 728),
|
||||
"exit_flow":
|
||||
(2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
||||
}
|
||||
elif backbone == 'xception_41':
|
||||
bottleneck_params = {
|
||||
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
|
||||
"middle_flow": (8, 1, 728),
|
||||
"exit_flow":
|
||||
(2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
||||
}
|
||||
elif backbone == 'xception_71':
|
||||
bottleneck_params = {
|
||||
"entry_flow": (5, [2, 1, 2, 1, 2], [128, 256, 256, 728, 728]),
|
||||
"middle_flow": (16, 1, 728),
|
||||
"exit_flow":
|
||||
(2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
||||
}
|
||||
else:
|
||||
raise Exception(
|
||||
"xception backbont only support xception_41/xception_65/xception_71"
|
||||
)
|
||||
return bottleneck_params
|
||||
self._conv1 = ConvBNLayer(
|
||||
3,
|
||||
32,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act="relu",
|
||||
name=self.backbone + "/entry_flow/conv1")
|
||||
self._conv2 = ConvBNLayer(
|
||||
32,
|
||||
64,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act="relu",
|
||||
name=self.backbone + "/entry_flow/conv2")
|
||||
|
||||
self.block_num = bottleneck_params["entry_flow"][0]
|
||||
self.strides = bottleneck_params["entry_flow"][1]
|
||||
self.chns = bottleneck_params["entry_flow"][2]
|
||||
self.strides = check_data(self.strides, self.block_num)
|
||||
self.chns = check_data(self.chns, self.block_num)
|
||||
|
||||
self.entry_flow = []
|
||||
self.middle_flow = []
|
||||
|
||||
def net(self,
|
||||
input,
|
||||
output_stride=32,
|
||||
class_dim=1000,
|
||||
end_points=None,
|
||||
decode_points=None):
|
||||
self.stride = 2
|
||||
self.block_point = 0
|
||||
self.output_stride = output_stride
|
||||
self.decode_points = decode_points
|
||||
self.short_cuts = dict()
|
||||
with scope(self.backbone):
|
||||
# Entry flow
|
||||
data = self.entry_flow(input)
|
||||
if check_points(self.block_point, end_points):
|
||||
return data, self.short_cuts
|
||||
|
||||
# Middle flow
|
||||
data = self.middle_flow(data)
|
||||
if check_points(self.block_point, end_points):
|
||||
return data, self.short_cuts
|
||||
|
||||
# Exit flow
|
||||
data = self.exit_flow(data)
|
||||
if check_points(self.block_point, end_points):
|
||||
return data, self.short_cuts
|
||||
|
||||
data = fluid.layers.reduce_mean(data, [2, 3], keep_dim=True)
|
||||
data = fluid.layers.dropout(data, 0.5)
|
||||
stdv = 1.0 / math.sqrt(data.shape[1] * 1.0)
|
||||
with scope("logit"):
|
||||
out = fluid.layers.fc(
|
||||
input=data,
|
||||
size=class_dim,
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
name='fc_weights',
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=fluid.param_attr.ParamAttr(name='fc_bias'))
|
||||
|
||||
return out
|
||||
|
||||
def entry_flow(self, data):
|
||||
param_attr = fluid.ParamAttr(
|
||||
name=name_scope + 'weights',
|
||||
regularizer=None,
|
||||
initializer=fluid.initializer.TruncatedNormal(
|
||||
loc=0.0, scale=0.09))
|
||||
with scope("entry_flow"):
|
||||
with scope("conv1"):
|
||||
data = bn_relu(
|
||||
conv(
|
||||
data,
|
||||
32,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
param_attr=param_attr))
|
||||
with scope("conv2"):
|
||||
data = bn_relu(
|
||||
conv(
|
||||
data,
|
||||
64,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
param_attr=param_attr))
|
||||
|
||||
# get entry flow params
|
||||
block_num = self.bottleneck_params["entry_flow"][0]
|
||||
strides = self.bottleneck_params["entry_flow"][1]
|
||||
chns = self.bottleneck_params["entry_flow"][2]
|
||||
strides = check_data(strides, block_num)
|
||||
chns = check_data(chns, block_num)
|
||||
|
||||
# params to control your flow
|
||||
self.output_stride = 32
|
||||
s = self.stride
|
||||
block_point = self.block_point
|
||||
output_stride = self.output_stride
|
||||
with scope("entry_flow"):
|
||||
for i in range(block_num):
|
||||
block_point = block_point + 1
|
||||
with scope("block" + str(i + 1)):
|
||||
stride = strides[i] if check_stride(s * strides[i],
|
||||
output_stride) else 1
|
||||
data, short_cuts = self.xception_block(data, chns[i],
|
||||
[1, 1, stride])
|
||||
s = s * stride
|
||||
if check_points(block_point, self.decode_points):
|
||||
self.short_cuts[block_point] = short_cuts[1]
|
||||
|
||||
for i in range(self.block_num):
|
||||
stride = self.strides[i] if check_stride(s * self.strides[i],
|
||||
self.output_stride) else 1
|
||||
xception_block = self.add_sublayer(
|
||||
self.backbone + "/entry_flow/block" + str(i + 1),
|
||||
Xception_Block(
|
||||
input_channels=64 if i == 0 else self.chns[i - 1],
|
||||
output_channels=self.chns[i],
|
||||
strides=[1, 1, self.stride],
|
||||
name=self.backbone + "/entry_flow/block" + str(i + 1)))
|
||||
self.entry_flow.append(xception_block)
|
||||
s = s * stride
|
||||
self.stride = s
|
||||
|
||||
self.block_num = bottleneck_params["middle_flow"][0]
|
||||
self.strides = bottleneck_params["middle_flow"][1]
|
||||
self.chns = bottleneck_params["middle_flow"][2]
|
||||
self.strides = check_data(self.strides, self.block_num)
|
||||
self.chns = check_data(self.chns, self.block_num)
|
||||
s = self.stride
|
||||
|
||||
for i in range(self.block_num):
|
||||
stride = self.strides[i] if check_stride(s * self.strides[i],
|
||||
self.output_stride) else 1
|
||||
xception_block = self.add_sublayer(
|
||||
self.backbone + "/middle_flow/block" + str(i + 1),
|
||||
Xception_Block(
|
||||
input_channels=728,
|
||||
output_channels=728,
|
||||
strides=[1, 1, self.strides[i]],
|
||||
skip_conv=False,
|
||||
name=self.backbone + "/middle_flow/block" + str(i + 1)))
|
||||
self.middle_flow.append(xception_block)
|
||||
s = s * stride
|
||||
self.stride = s
|
||||
|
||||
self.block_num = bottleneck_params["exit_flow"][0]
|
||||
self.strides = bottleneck_params["exit_flow"][1]
|
||||
self.chns = bottleneck_params["exit_flow"][2]
|
||||
self.strides = check_data(self.strides, self.block_num)
|
||||
self.chns = check_data(self.chns, self.block_num)
|
||||
s = self.stride
|
||||
stride = self.strides[0] if check_stride(s * self.strides[0],
|
||||
self.output_stride) else 1
|
||||
self._exit_flow_1 = Xception_Block(
|
||||
728,
|
||||
self.chns[0], [1, 1, stride],
|
||||
name=self.backbone + "/exit_flow/block1")
|
||||
s = s * stride
|
||||
stride = self.strides[1] if check_stride(s * self.strides[1],
|
||||
self.output_stride) else 1
|
||||
self._exit_flow_2 = Xception_Block(
|
||||
self.chns[0][-1],
|
||||
self.chns[1], [1, 1, stride],
|
||||
dilation=2,
|
||||
has_skip=False,
|
||||
activation_fn_in_separable_conv=True,
|
||||
name=self.backbone + "/exit_flow/block2")
|
||||
s = s * stride
|
||||
|
||||
self.stride = s
|
||||
self.block_point = block_point
|
||||
return data
|
||||
|
||||
def middle_flow(self, data):
|
||||
block_num = self.bottleneck_params["middle_flow"][0]
|
||||
strides = self.bottleneck_params["middle_flow"][1]
|
||||
chns = self.bottleneck_params["middle_flow"][2]
|
||||
strides = check_data(strides, block_num)
|
||||
chns = check_data(chns, block_num)
|
||||
self._drop = Dropout(p=0.5)
|
||||
self._pool = Pool2D(pool_type="avg", global_pooling=True)
|
||||
self._fc = Linear(
|
||||
self.chns[1][-1],
|
||||
class_dim,
|
||||
param_attr=ParamAttr(name="fc_weights"),
|
||||
bias_attr=ParamAttr(name="fc_bias"))
|
||||
|
||||
# params to control your flow
|
||||
s = self.stride
|
||||
block_point = self.block_point
|
||||
output_stride = self.output_stride
|
||||
with scope("middle_flow"):
|
||||
for i in range(block_num):
|
||||
block_point = block_point + 1
|
||||
with scope("block" + str(i + 1)):
|
||||
stride = strides[i] if check_stride(s * strides[i],
|
||||
output_stride) else 1
|
||||
data, short_cuts = self.xception_block(
|
||||
data, chns[i], [1, 1, strides[i]], skip_conv=False)
|
||||
s = s * stride
|
||||
if check_points(block_point, self.decode_points):
|
||||
self.short_cuts[block_point] = short_cuts[1]
|
||||
|
||||
self.stride = s
|
||||
self.block_point = block_point
|
||||
return data
|
||||
|
||||
def exit_flow(self, data):
|
||||
block_num = self.bottleneck_params["exit_flow"][0]
|
||||
strides = self.bottleneck_params["exit_flow"][1]
|
||||
chns = self.bottleneck_params["exit_flow"][2]
|
||||
strides = check_data(strides, block_num)
|
||||
chns = check_data(chns, block_num)
|
||||
|
||||
assert (block_num == 2)
|
||||
# params to control your flow
|
||||
s = self.stride
|
||||
block_point = self.block_point
|
||||
output_stride = self.output_stride
|
||||
with scope("exit_flow"):
|
||||
with scope('block1'):
|
||||
block_point += 1
|
||||
stride = strides[0] if check_stride(s * strides[0],
|
||||
output_stride) else 1
|
||||
data, short_cuts = self.xception_block(data, chns[0],
|
||||
[1, 1, stride])
|
||||
s = s * stride
|
||||
if check_points(block_point, self.decode_points):
|
||||
self.short_cuts[block_point] = short_cuts[1]
|
||||
with scope('block2'):
|
||||
block_point += 1
|
||||
stride = strides[1] if check_stride(s * strides[1],
|
||||
output_stride) else 1
|
||||
data, short_cuts = self.xception_block(
|
||||
data,
|
||||
chns[1], [1, 1, stride],
|
||||
dilation=2,
|
||||
has_skip=False,
|
||||
activation_fn_in_separable_conv=True)
|
||||
s = s * stride
|
||||
if check_points(block_point, self.decode_points):
|
||||
self.short_cuts[block_point] = short_cuts[1]
|
||||
|
||||
self.stride = s
|
||||
self.block_point = block_point
|
||||
return data
|
||||
|
||||
def xception_block(self,
|
||||
input,
|
||||
channels,
|
||||
strides=1,
|
||||
filters=3,
|
||||
dilation=1,
|
||||
skip_conv=True,
|
||||
has_skip=True,
|
||||
activation_fn_in_separable_conv=False):
|
||||
repeat_number = 3
|
||||
channels = check_data(channels, repeat_number)
|
||||
filters = check_data(filters, repeat_number)
|
||||
strides = check_data(strides, repeat_number)
|
||||
data = input
|
||||
results = []
|
||||
for i in range(repeat_number):
|
||||
with scope('separable_conv' + str(i + 1)):
|
||||
if not activation_fn_in_separable_conv:
|
||||
data = relu(data)
|
||||
data = seperate_conv(
|
||||
data,
|
||||
channels[i],
|
||||
strides[i],
|
||||
filters[i],
|
||||
dilation=dilation)
|
||||
else:
|
||||
data = seperate_conv(
|
||||
data,
|
||||
channels[i],
|
||||
strides[i],
|
||||
filters[i],
|
||||
dilation=dilation,
|
||||
act=relu)
|
||||
results.append(data)
|
||||
if not has_skip:
|
||||
return data, results
|
||||
if skip_conv:
|
||||
param_attr = fluid.ParamAttr(
|
||||
name=name_scope + 'weights',
|
||||
regularizer=None,
|
||||
initializer=fluid.initializer.TruncatedNormal(
|
||||
loc=0.0, scale=0.09))
|
||||
with scope('shortcut'):
|
||||
skip = bn(
|
||||
conv(
|
||||
input,
|
||||
channels[-1],
|
||||
1,
|
||||
strides[-1],
|
||||
groups=1,
|
||||
padding=0,
|
||||
param_attr=param_attr))
|
||||
else:
|
||||
skip = input
|
||||
return data + skip, results
|
||||
def forward(self, inputs):
|
||||
x = self._conv1(inputs)
|
||||
x = self._conv2(x)
|
||||
for ef in self.entry_flow:
|
||||
x = ef(x)
|
||||
for mf in self.middle_flow:
|
||||
x = mf(x)
|
||||
x = self._exit_flow_1(x)
|
||||
x = self._exit_flow_2(x)
|
||||
x = self._drop(x)
|
||||
x = self._pool(x)
|
||||
x = fluid.layers.squeeze(x, axes=[2, 3])
|
||||
x = self._fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def Xception41_deeplab():
|
||||
model = Xception("xception_41")
|
||||
def Xception41_deeplab(**args):
|
||||
model = XceptionDeeplab('xception_41', **args)
|
||||
return model
|
||||
|
||||
|
||||
def Xception65_deeplab():
|
||||
model = Xception("xception_65")
|
||||
def Xception65_deeplab(**args):
|
||||
model = XceptionDeeplab("xception_65", **args)
|
||||
return model
|
||||
|
||||
|
||||
def Xception71_deeplab():
|
||||
model = Xception("xception_71")
|
||||
def Xception71_deeplab(**args):
|
||||
model = XceptionDeeplab("xception_71", **args)
|
||||
return model
|
Loading…
Reference in New Issue