321 lines
11 KiB
Python
321 lines
11 KiB
Python
|
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||
|
#
|
||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
#you may not use this file except in compliance with the License.
|
||
|
#You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
#Unless required by applicable law or agreed to in writing, software
|
||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
#See the License for the specific language governing permissions and
|
||
|
#limitations under the License.
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import contextlib
|
||
|
import paddle
|
||
|
import math
|
||
|
|
||
|
import paddle.fluid as fluid
|
||
|
|
||
|
from .model_libs import scope, name_scope
|
||
|
from .model_libs import bn, bn_relu, relu
|
||
|
from .model_libs import conv
|
||
|
from .model_libs import seperate_conv
|
||
|
|
||
|
__all__ = ['Xception41_deeplab', 'Xception65_deeplab', 'Xception71_deeplab']
|
||
|
|
||
|
|
||
|
def check_data(data, number):
|
||
|
if type(data) == int:
|
||
|
return [data] * number
|
||
|
assert len(data) == number
|
||
|
return data
|
||
|
|
||
|
|
||
|
def check_stride(s, os):
|
||
|
if s <= os:
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
|
||
|
def check_points(count, points):
|
||
|
if points is None:
|
||
|
return False
|
||
|
else:
|
||
|
if isinstance(points, list):
|
||
|
return (True if count in points else False)
|
||
|
else:
|
||
|
return (True if count == points else False)
|
||
|
|
||
|
|
||
|
class Xception():
|
||
|
def __init__(self, backbone="xception_65"):
|
||
|
self.bottleneck_params = self.gen_bottleneck_params(backbone)
|
||
|
self.backbone = backbone
|
||
|
|
||
|
def gen_bottleneck_params(self, backbone='xception_65'):
|
||
|
if backbone == 'xception_65':
|
||
|
bottleneck_params = {
|
||
|
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
|
||
|
"middle_flow": (16, 1, 728),
|
||
|
"exit_flow":
|
||
|
(2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
||
|
}
|
||
|
elif backbone == 'xception_41':
|
||
|
bottleneck_params = {
|
||
|
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
|
||
|
"middle_flow": (8, 1, 728),
|
||
|
"exit_flow":
|
||
|
(2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
||
|
}
|
||
|
elif backbone == 'xception_71':
|
||
|
bottleneck_params = {
|
||
|
"entry_flow": (5, [2, 1, 2, 1, 2], [128, 256, 256, 728, 728]),
|
||
|
"middle_flow": (16, 1, 728),
|
||
|
"exit_flow":
|
||
|
(2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
||
|
}
|
||
|
else:
|
||
|
raise Exception(
|
||
|
"xception backbont only support xception_41/xception_65/xception_71"
|
||
|
)
|
||
|
return bottleneck_params
|
||
|
|
||
|
def net(self,
|
||
|
input,
|
||
|
output_stride=32,
|
||
|
class_dim=1000,
|
||
|
end_points=None,
|
||
|
decode_points=None):
|
||
|
self.stride = 2
|
||
|
self.block_point = 0
|
||
|
self.output_stride = output_stride
|
||
|
self.decode_points = decode_points
|
||
|
self.short_cuts = dict()
|
||
|
with scope(self.backbone):
|
||
|
# Entry flow
|
||
|
data = self.entry_flow(input)
|
||
|
if check_points(self.block_point, end_points):
|
||
|
return data, self.short_cuts
|
||
|
|
||
|
# Middle flow
|
||
|
data = self.middle_flow(data)
|
||
|
if check_points(self.block_point, end_points):
|
||
|
return data, self.short_cuts
|
||
|
|
||
|
# Exit flow
|
||
|
data = self.exit_flow(data)
|
||
|
if check_points(self.block_point, end_points):
|
||
|
return data, self.short_cuts
|
||
|
|
||
|
data = fluid.layers.reduce_mean(data, [2, 3], keep_dim=True)
|
||
|
data = fluid.layers.dropout(data, 0.5)
|
||
|
stdv = 1.0 / math.sqrt(data.shape[1] * 1.0)
|
||
|
with scope("logit"):
|
||
|
out = fluid.layers.fc(
|
||
|
input=data,
|
||
|
size=class_dim,
|
||
|
param_attr=fluid.param_attr.ParamAttr(
|
||
|
name='fc_weights',
|
||
|
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||
|
bias_attr=fluid.param_attr.ParamAttr(name='fc_bias'))
|
||
|
|
||
|
return out
|
||
|
|
||
|
def entry_flow(self, data):
|
||
|
param_attr = fluid.ParamAttr(
|
||
|
name=name_scope + 'weights',
|
||
|
regularizer=None,
|
||
|
initializer=fluid.initializer.TruncatedNormal(
|
||
|
loc=0.0, scale=0.09))
|
||
|
with scope("entry_flow"):
|
||
|
with scope("conv1"):
|
||
|
data = bn_relu(
|
||
|
conv(
|
||
|
data,
|
||
|
32,
|
||
|
3,
|
||
|
stride=2,
|
||
|
padding=1,
|
||
|
param_attr=param_attr))
|
||
|
with scope("conv2"):
|
||
|
data = bn_relu(
|
||
|
conv(
|
||
|
data,
|
||
|
64,
|
||
|
3,
|
||
|
stride=1,
|
||
|
padding=1,
|
||
|
param_attr=param_attr))
|
||
|
|
||
|
# get entry flow params
|
||
|
block_num = self.bottleneck_params["entry_flow"][0]
|
||
|
strides = self.bottleneck_params["entry_flow"][1]
|
||
|
chns = self.bottleneck_params["entry_flow"][2]
|
||
|
strides = check_data(strides, block_num)
|
||
|
chns = check_data(chns, block_num)
|
||
|
|
||
|
# params to control your flow
|
||
|
s = self.stride
|
||
|
block_point = self.block_point
|
||
|
output_stride = self.output_stride
|
||
|
with scope("entry_flow"):
|
||
|
for i in range(block_num):
|
||
|
block_point = block_point + 1
|
||
|
with scope("block" + str(i + 1)):
|
||
|
stride = strides[i] if check_stride(s * strides[i],
|
||
|
output_stride) else 1
|
||
|
data, short_cuts = self.xception_block(data, chns[i],
|
||
|
[1, 1, stride])
|
||
|
s = s * stride
|
||
|
if check_points(block_point, self.decode_points):
|
||
|
self.short_cuts[block_point] = short_cuts[1]
|
||
|
|
||
|
self.stride = s
|
||
|
self.block_point = block_point
|
||
|
return data
|
||
|
|
||
|
def middle_flow(self, data):
|
||
|
block_num = self.bottleneck_params["middle_flow"][0]
|
||
|
strides = self.bottleneck_params["middle_flow"][1]
|
||
|
chns = self.bottleneck_params["middle_flow"][2]
|
||
|
strides = check_data(strides, block_num)
|
||
|
chns = check_data(chns, block_num)
|
||
|
|
||
|
# params to control your flow
|
||
|
s = self.stride
|
||
|
block_point = self.block_point
|
||
|
output_stride = self.output_stride
|
||
|
with scope("middle_flow"):
|
||
|
for i in range(block_num):
|
||
|
block_point = block_point + 1
|
||
|
with scope("block" + str(i + 1)):
|
||
|
stride = strides[i] if check_stride(s * strides[i],
|
||
|
output_stride) else 1
|
||
|
data, short_cuts = self.xception_block(
|
||
|
data, chns[i], [1, 1, strides[i]], skip_conv=False)
|
||
|
s = s * stride
|
||
|
if check_points(block_point, self.decode_points):
|
||
|
self.short_cuts[block_point] = short_cuts[1]
|
||
|
|
||
|
self.stride = s
|
||
|
self.block_point = block_point
|
||
|
return data
|
||
|
|
||
|
def exit_flow(self, data):
|
||
|
block_num = self.bottleneck_params["exit_flow"][0]
|
||
|
strides = self.bottleneck_params["exit_flow"][1]
|
||
|
chns = self.bottleneck_params["exit_flow"][2]
|
||
|
strides = check_data(strides, block_num)
|
||
|
chns = check_data(chns, block_num)
|
||
|
|
||
|
assert (block_num == 2)
|
||
|
# params to control your flow
|
||
|
s = self.stride
|
||
|
block_point = self.block_point
|
||
|
output_stride = self.output_stride
|
||
|
with scope("exit_flow"):
|
||
|
with scope('block1'):
|
||
|
block_point += 1
|
||
|
stride = strides[0] if check_stride(s * strides[0],
|
||
|
output_stride) else 1
|
||
|
data, short_cuts = self.xception_block(data, chns[0],
|
||
|
[1, 1, stride])
|
||
|
s = s * stride
|
||
|
if check_points(block_point, self.decode_points):
|
||
|
self.short_cuts[block_point] = short_cuts[1]
|
||
|
with scope('block2'):
|
||
|
block_point += 1
|
||
|
stride = strides[1] if check_stride(s * strides[1],
|
||
|
output_stride) else 1
|
||
|
data, short_cuts = self.xception_block(
|
||
|
data,
|
||
|
chns[1], [1, 1, stride],
|
||
|
dilation=2,
|
||
|
has_skip=False,
|
||
|
activation_fn_in_separable_conv=True)
|
||
|
s = s * stride
|
||
|
if check_points(block_point, self.decode_points):
|
||
|
self.short_cuts[block_point] = short_cuts[1]
|
||
|
|
||
|
self.stride = s
|
||
|
self.block_point = block_point
|
||
|
return data
|
||
|
|
||
|
def xception_block(self,
|
||
|
input,
|
||
|
channels,
|
||
|
strides=1,
|
||
|
filters=3,
|
||
|
dilation=1,
|
||
|
skip_conv=True,
|
||
|
has_skip=True,
|
||
|
activation_fn_in_separable_conv=False):
|
||
|
repeat_number = 3
|
||
|
channels = check_data(channels, repeat_number)
|
||
|
filters = check_data(filters, repeat_number)
|
||
|
strides = check_data(strides, repeat_number)
|
||
|
data = input
|
||
|
results = []
|
||
|
for i in range(repeat_number):
|
||
|
with scope('separable_conv' + str(i + 1)):
|
||
|
if not activation_fn_in_separable_conv:
|
||
|
data = relu(data)
|
||
|
data = seperate_conv(
|
||
|
data,
|
||
|
channels[i],
|
||
|
strides[i],
|
||
|
filters[i],
|
||
|
dilation=dilation)
|
||
|
else:
|
||
|
data = seperate_conv(
|
||
|
data,
|
||
|
channels[i],
|
||
|
strides[i],
|
||
|
filters[i],
|
||
|
dilation=dilation,
|
||
|
act=relu)
|
||
|
results.append(data)
|
||
|
if not has_skip:
|
||
|
return data, results
|
||
|
if skip_conv:
|
||
|
param_attr = fluid.ParamAttr(
|
||
|
name=name_scope + 'weights',
|
||
|
regularizer=None,
|
||
|
initializer=fluid.initializer.TruncatedNormal(
|
||
|
loc=0.0, scale=0.09))
|
||
|
with scope('shortcut'):
|
||
|
skip = bn(
|
||
|
conv(
|
||
|
input,
|
||
|
channels[-1],
|
||
|
1,
|
||
|
strides[-1],
|
||
|
groups=1,
|
||
|
padding=0,
|
||
|
param_attr=param_attr))
|
||
|
else:
|
||
|
skip = input
|
||
|
return data + skip, results
|
||
|
|
||
|
|
||
|
def Xception41_deeplab():
|
||
|
model = Xception("xception_41")
|
||
|
return model
|
||
|
|
||
|
|
||
|
def Xception65_deeplab():
|
||
|
model = Xception("xception_65")
|
||
|
return model
|
||
|
|
||
|
|
||
|
def Xception71_deeplab():
|
||
|
model = Xception("xception_71")
|
||
|
return model
|