424 lines
14 KiB
Python
424 lines
14 KiB
Python
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# reference: https://arxiv.org/abs/1706.05587
|
|
|
|
import paddle
|
|
from paddle import ParamAttr
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
|
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
|
|
|
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
|
|
|
MODEL_URLS = {
|
|
"Xception41_deeplab":
|
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/Xception41_deeplab_pretrained.pdparams",
|
|
"Xception65_deeplab":
|
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/Xception65_deeplab_pretrained.pdparams"
|
|
}
|
|
|
|
__all__ = list(MODEL_URLS.keys())
|
|
|
|
|
|
def check_data(data, number):
|
|
if type(data) == int:
|
|
return [data] * number
|
|
assert len(data) == number
|
|
return data
|
|
|
|
|
|
def check_stride(s, os):
|
|
if s <= os:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def check_points(count, points):
|
|
if points is None:
|
|
return False
|
|
else:
|
|
if isinstance(points, list):
|
|
return (True if count in points else False)
|
|
else:
|
|
return (True if count == points else False)
|
|
|
|
|
|
def gen_bottleneck_params(backbone='xception_65'):
|
|
if backbone == 'xception_65':
|
|
bottleneck_params = {
|
|
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
|
|
"middle_flow": (16, 1, 728),
|
|
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
|
}
|
|
elif backbone == 'xception_41':
|
|
bottleneck_params = {
|
|
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
|
|
"middle_flow": (8, 1, 728),
|
|
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
|
}
|
|
elif backbone == 'xception_71':
|
|
bottleneck_params = {
|
|
"entry_flow": (5, [2, 1, 2, 1, 2], [128, 256, 256, 728, 728]),
|
|
"middle_flow": (16, 1, 728),
|
|
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
|
|
}
|
|
else:
|
|
raise Exception(
|
|
"xception backbont only support xception_41/xception_65/xception_71"
|
|
)
|
|
return bottleneck_params
|
|
|
|
|
|
class ConvBNLayer(nn.Layer):
|
|
def __init__(self,
|
|
input_channels,
|
|
output_channels,
|
|
filter_size,
|
|
stride=1,
|
|
padding=0,
|
|
act=None,
|
|
name=None):
|
|
super(ConvBNLayer, self).__init__()
|
|
|
|
self._conv = Conv2D(
|
|
in_channels=input_channels,
|
|
out_channels=output_channels,
|
|
kernel_size=filter_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
weight_attr=ParamAttr(name=name + "/weights"),
|
|
bias_attr=False)
|
|
self._bn = BatchNorm(
|
|
num_channels=output_channels,
|
|
act=act,
|
|
epsilon=1e-3,
|
|
momentum=0.99,
|
|
param_attr=ParamAttr(name=name + "/BatchNorm/gamma"),
|
|
bias_attr=ParamAttr(name=name + "/BatchNorm/beta"),
|
|
moving_mean_name=name + "/BatchNorm/moving_mean",
|
|
moving_variance_name=name + "/BatchNorm/moving_variance")
|
|
|
|
def forward(self, inputs):
|
|
return self._bn(self._conv(inputs))
|
|
|
|
|
|
class Seperate_Conv(nn.Layer):
|
|
def __init__(self,
|
|
input_channels,
|
|
output_channels,
|
|
stride,
|
|
filter,
|
|
dilation=1,
|
|
act=None,
|
|
name=None):
|
|
super(Seperate_Conv, self).__init__()
|
|
|
|
self._conv1 = Conv2D(
|
|
in_channels=input_channels,
|
|
out_channels=input_channels,
|
|
kernel_size=filter,
|
|
stride=stride,
|
|
groups=input_channels,
|
|
padding=(filter) // 2 * dilation,
|
|
dilation=dilation,
|
|
weight_attr=ParamAttr(name=name + "/depthwise/weights"),
|
|
bias_attr=False)
|
|
self._bn1 = BatchNorm(
|
|
input_channels,
|
|
act=act,
|
|
epsilon=1e-3,
|
|
momentum=0.99,
|
|
param_attr=ParamAttr(name=name + "/depthwise/BatchNorm/gamma"),
|
|
bias_attr=ParamAttr(name=name + "/depthwise/BatchNorm/beta"),
|
|
moving_mean_name=name + "/depthwise/BatchNorm/moving_mean",
|
|
moving_variance_name=name + "/depthwise/BatchNorm/moving_variance")
|
|
self._conv2 = Conv2D(
|
|
input_channels,
|
|
output_channels,
|
|
1,
|
|
stride=1,
|
|
groups=1,
|
|
padding=0,
|
|
weight_attr=ParamAttr(name=name + "/pointwise/weights"),
|
|
bias_attr=False)
|
|
self._bn2 = BatchNorm(
|
|
output_channels,
|
|
act=act,
|
|
epsilon=1e-3,
|
|
momentum=0.99,
|
|
param_attr=ParamAttr(name=name + "/pointwise/BatchNorm/gamma"),
|
|
bias_attr=ParamAttr(name=name + "/pointwise/BatchNorm/beta"),
|
|
moving_mean_name=name + "/pointwise/BatchNorm/moving_mean",
|
|
moving_variance_name=name + "/pointwise/BatchNorm/moving_variance")
|
|
|
|
def forward(self, inputs):
|
|
x = self._conv1(inputs)
|
|
x = self._bn1(x)
|
|
x = self._conv2(x)
|
|
x = self._bn2(x)
|
|
return x
|
|
|
|
|
|
class Xception_Block(nn.Layer):
|
|
def __init__(self,
|
|
input_channels,
|
|
output_channels,
|
|
strides=1,
|
|
filter_size=3,
|
|
dilation=1,
|
|
skip_conv=True,
|
|
has_skip=True,
|
|
activation_fn_in_separable_conv=False,
|
|
name=None):
|
|
super(Xception_Block, self).__init__()
|
|
|
|
repeat_number = 3
|
|
output_channels = check_data(output_channels, repeat_number)
|
|
filter_size = check_data(filter_size, repeat_number)
|
|
strides = check_data(strides, repeat_number)
|
|
|
|
self.has_skip = has_skip
|
|
self.skip_conv = skip_conv
|
|
self.activation_fn_in_separable_conv = activation_fn_in_separable_conv
|
|
if not activation_fn_in_separable_conv:
|
|
self._conv1 = Seperate_Conv(
|
|
input_channels,
|
|
output_channels[0],
|
|
stride=strides[0],
|
|
filter=filter_size[0],
|
|
dilation=dilation,
|
|
name=name + "/separable_conv1")
|
|
self._conv2 = Seperate_Conv(
|
|
output_channels[0],
|
|
output_channels[1],
|
|
stride=strides[1],
|
|
filter=filter_size[1],
|
|
dilation=dilation,
|
|
name=name + "/separable_conv2")
|
|
self._conv3 = Seperate_Conv(
|
|
output_channels[1],
|
|
output_channels[2],
|
|
stride=strides[2],
|
|
filter=filter_size[2],
|
|
dilation=dilation,
|
|
name=name + "/separable_conv3")
|
|
else:
|
|
self._conv1 = Seperate_Conv(
|
|
input_channels,
|
|
output_channels[0],
|
|
stride=strides[0],
|
|
filter=filter_size[0],
|
|
act="relu",
|
|
dilation=dilation,
|
|
name=name + "/separable_conv1")
|
|
self._conv2 = Seperate_Conv(
|
|
output_channels[0],
|
|
output_channels[1],
|
|
stride=strides[1],
|
|
filter=filter_size[1],
|
|
act="relu",
|
|
dilation=dilation,
|
|
name=name + "/separable_conv2")
|
|
self._conv3 = Seperate_Conv(
|
|
output_channels[1],
|
|
output_channels[2],
|
|
stride=strides[2],
|
|
filter=filter_size[2],
|
|
act="relu",
|
|
dilation=dilation,
|
|
name=name + "/separable_conv3")
|
|
|
|
if has_skip and skip_conv:
|
|
self._short = ConvBNLayer(
|
|
input_channels,
|
|
output_channels[-1],
|
|
1,
|
|
stride=strides[-1],
|
|
padding=0,
|
|
name=name + "/shortcut")
|
|
|
|
def forward(self, inputs):
|
|
if not self.activation_fn_in_separable_conv:
|
|
x = F.relu(inputs)
|
|
x = self._conv1(x)
|
|
x = F.relu(x)
|
|
x = self._conv2(x)
|
|
x = F.relu(x)
|
|
x = self._conv3(x)
|
|
else:
|
|
x = self._conv1(inputs)
|
|
x = self._conv2(x)
|
|
x = self._conv3(x)
|
|
if self.has_skip:
|
|
if self.skip_conv:
|
|
skip = self._short(inputs)
|
|
else:
|
|
skip = inputs
|
|
return paddle.add(x, skip)
|
|
else:
|
|
return x
|
|
|
|
|
|
class XceptionDeeplab(nn.Layer):
|
|
def __init__(self, backbone, class_num=1000):
|
|
super(XceptionDeeplab, self).__init__()
|
|
|
|
bottleneck_params = gen_bottleneck_params(backbone)
|
|
self.backbone = backbone
|
|
|
|
self._conv1 = ConvBNLayer(
|
|
3,
|
|
32,
|
|
3,
|
|
stride=2,
|
|
padding=1,
|
|
act="relu",
|
|
name=self.backbone + "/entry_flow/conv1")
|
|
self._conv2 = ConvBNLayer(
|
|
32,
|
|
64,
|
|
3,
|
|
stride=1,
|
|
padding=1,
|
|
act="relu",
|
|
name=self.backbone + "/entry_flow/conv2")
|
|
|
|
self.block_num = bottleneck_params["entry_flow"][0]
|
|
self.strides = bottleneck_params["entry_flow"][1]
|
|
self.chns = bottleneck_params["entry_flow"][2]
|
|
self.strides = check_data(self.strides, self.block_num)
|
|
self.chns = check_data(self.chns, self.block_num)
|
|
|
|
self.entry_flow = []
|
|
self.middle_flow = []
|
|
|
|
self.stride = 2
|
|
self.output_stride = 32
|
|
s = self.stride
|
|
|
|
for i in range(self.block_num):
|
|
stride = self.strides[i] if check_stride(s * self.strides[i],
|
|
self.output_stride) else 1
|
|
xception_block = self.add_sublayer(
|
|
self.backbone + "/entry_flow/block" + str(i + 1),
|
|
Xception_Block(
|
|
input_channels=64 if i == 0 else self.chns[i - 1],
|
|
output_channels=self.chns[i],
|
|
strides=[1, 1, self.stride],
|
|
name=self.backbone + "/entry_flow/block" + str(i + 1)))
|
|
self.entry_flow.append(xception_block)
|
|
s = s * stride
|
|
self.stride = s
|
|
|
|
self.block_num = bottleneck_params["middle_flow"][0]
|
|
self.strides = bottleneck_params["middle_flow"][1]
|
|
self.chns = bottleneck_params["middle_flow"][2]
|
|
self.strides = check_data(self.strides, self.block_num)
|
|
self.chns = check_data(self.chns, self.block_num)
|
|
s = self.stride
|
|
|
|
for i in range(self.block_num):
|
|
stride = self.strides[i] if check_stride(s * self.strides[i],
|
|
self.output_stride) else 1
|
|
xception_block = self.add_sublayer(
|
|
self.backbone + "/middle_flow/block" + str(i + 1),
|
|
Xception_Block(
|
|
input_channels=728,
|
|
output_channels=728,
|
|
strides=[1, 1, self.strides[i]],
|
|
skip_conv=False,
|
|
name=self.backbone + "/middle_flow/block" + str(i + 1)))
|
|
self.middle_flow.append(xception_block)
|
|
s = s * stride
|
|
self.stride = s
|
|
|
|
self.block_num = bottleneck_params["exit_flow"][0]
|
|
self.strides = bottleneck_params["exit_flow"][1]
|
|
self.chns = bottleneck_params["exit_flow"][2]
|
|
self.strides = check_data(self.strides, self.block_num)
|
|
self.chns = check_data(self.chns, self.block_num)
|
|
s = self.stride
|
|
stride = self.strides[0] if check_stride(s * self.strides[0],
|
|
self.output_stride) else 1
|
|
self._exit_flow_1 = Xception_Block(
|
|
728,
|
|
self.chns[0], [1, 1, stride],
|
|
name=self.backbone + "/exit_flow/block1")
|
|
s = s * stride
|
|
stride = self.strides[1] if check_stride(s * self.strides[1],
|
|
self.output_stride) else 1
|
|
self._exit_flow_2 = Xception_Block(
|
|
self.chns[0][-1],
|
|
self.chns[1], [1, 1, stride],
|
|
dilation=2,
|
|
has_skip=False,
|
|
activation_fn_in_separable_conv=True,
|
|
name=self.backbone + "/exit_flow/block2")
|
|
s = s * stride
|
|
|
|
self.stride = s
|
|
|
|
self._drop = Dropout(p=0.5, mode="downscale_in_infer")
|
|
self._pool = AdaptiveAvgPool2D(1)
|
|
self._fc = Linear(
|
|
self.chns[1][-1],
|
|
class_num,
|
|
weight_attr=ParamAttr(name="fc_weights"),
|
|
bias_attr=ParamAttr(name="fc_bias"))
|
|
|
|
def forward(self, inputs):
|
|
x = self._conv1(inputs)
|
|
x = self._conv2(x)
|
|
for ef in self.entry_flow:
|
|
x = ef(x)
|
|
for mf in self.middle_flow:
|
|
x = mf(x)
|
|
x = self._exit_flow_1(x)
|
|
x = self._exit_flow_2(x)
|
|
x = self._drop(x)
|
|
x = self._pool(x)
|
|
x = paddle.squeeze(x, axis=[2, 3])
|
|
x = self._fc(x)
|
|
return x
|
|
|
|
|
|
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
|
|
if pretrained is False:
|
|
pass
|
|
elif pretrained is True:
|
|
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
|
elif isinstance(pretrained, str):
|
|
load_dygraph_pretrain(model, pretrained)
|
|
else:
|
|
raise RuntimeError(
|
|
"pretrained type is not available. Please use `string` or `boolean` type."
|
|
)
|
|
|
|
|
|
def Xception41_deeplab(pretrained=False, use_ssld=False, **kwargs):
|
|
model = XceptionDeeplab('xception_41', **kwargs)
|
|
_load_pretrained(
|
|
pretrained, model, MODEL_URLS["Xception41_deeplab"], use_ssld=use_ssld)
|
|
return model
|
|
|
|
|
|
def Xception65_deeplab(pretrained=False, use_ssld=False, **kwargs):
|
|
model = XceptionDeeplab("xception_65", **kwargs)
|
|
_load_pretrained(
|
|
pretrained, model, MODEL_URLS["Xception65_deeplab"], use_ssld=use_ssld)
|
|
return model
|