commit
81c9a6d3ea
configs/CSPNet
ppcls
modeling/architectures
tools
|
@ -0,0 +1,76 @@
|
|||
mode: 'train'
|
||||
ARCHITECTURE:
|
||||
name: 'CSPResNet50_leaky'
|
||||
|
||||
pretrained_model: ""
|
||||
model_save_dir: "./output/"
|
||||
classes_num: 1000
|
||||
total_images: 1281167
|
||||
save_interval: 1
|
||||
validate: True
|
||||
valid_interval: 1
|
||||
epochs: 120
|
||||
topk: 5
|
||||
image_shape: [3, 256, 256]
|
||||
|
||||
use_mix: False
|
||||
ls_epsilon: -1
|
||||
|
||||
LEARNING_RATE:
|
||||
function: 'Piecewise'
|
||||
params:
|
||||
lr: 0.1
|
||||
decay_epochs: [30, 60, 90]
|
||||
gamma: 0.1
|
||||
|
||||
OPTIMIZER:
|
||||
function: 'Momentum'
|
||||
params:
|
||||
momentum: 0.9
|
||||
regularizer:
|
||||
function: 'L2'
|
||||
factor: 0.000100
|
||||
|
||||
TRAIN:
|
||||
batch_size: 256
|
||||
num_workers: 4
|
||||
file_list: "./dataset/ILSVRC2012/train_list.txt"
|
||||
data_dir: "./dataset/ILSVRC2012/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 256
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
|
||||
VALID:
|
||||
batch_size: 64
|
||||
num_workers: 4
|
||||
file_list: "./dataset/ILSVRC2012/val_list.txt"
|
||||
data_dir: "./dataset/ILSVRC2012/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 256
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
|
@ -1,16 +1,16 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .alexnet import AlexNet
|
||||
from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x1_0, MobileNetV1_x0_75, MobileNetV1
|
||||
|
@ -45,3 +45,5 @@ from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet
|
|||
|
||||
# distillation model
|
||||
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
|
||||
|
||||
from .csp_resnet import CSPResNet50_leaky
|
|
@ -0,0 +1,258 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
|
||||
__all__ = [
|
||||
"CSPResNet50_leaky", "CSPResNet50_mish", "CSPResNet101_leaky",
|
||||
"CSPResNet101_mish"
|
||||
]
|
||||
|
||||
|
||||
class CSPResNet():
|
||||
def __init__(self, layers=50, act="leaky_relu"):
|
||||
self.layers = layers
|
||||
self.act = act
|
||||
|
||||
def net(self, input, class_dim=1000, data_format="NCHW"):
|
||||
layers = self.layers
|
||||
supported_layers = [50, 101]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
supported_layers, layers)
|
||||
|
||||
if layers == 50:
|
||||
depth = [3, 3, 5, 2]
|
||||
elif layers == 101:
|
||||
depth = [3, 3, 22, 2]
|
||||
|
||||
num_filters = [64, 128, 256, 512]
|
||||
|
||||
conv = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=64,
|
||||
filter_size=7,
|
||||
stride=2,
|
||||
act=self.act,
|
||||
name="conv1",
|
||||
data_format=data_format)
|
||||
conv = fluid.layers.pool2d(
|
||||
input=conv,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
pool_padding=0,
|
||||
pool_type='max',
|
||||
data_format=data_format)
|
||||
|
||||
for block in range(len(depth)):
|
||||
conv_name = "res" + str(block + 2) + chr(97)
|
||||
if block != 0:
|
||||
conv = self.conv_bn_layer(
|
||||
input=conv,
|
||||
num_filters=num_filters[block],
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act=self.act,
|
||||
name=conv_name + "_downsample",
|
||||
data_format=data_format)
|
||||
|
||||
# split
|
||||
left = conv
|
||||
right = conv
|
||||
if block == 0:
|
||||
ch = num_filters[block]
|
||||
else:
|
||||
ch = num_filters[block] * 2
|
||||
right = self.conv_bn_layer(
|
||||
input=right,
|
||||
num_filters=ch,
|
||||
filter_size=1,
|
||||
act=self.act,
|
||||
name=conv_name + "_right_first_route",
|
||||
data_format=data_format)
|
||||
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
|
||||
right = self.bottleneck_block(
|
||||
input=right,
|
||||
num_filters=num_filters[block],
|
||||
stride=1,
|
||||
name=conv_name,
|
||||
data_format=data_format)
|
||||
|
||||
# route
|
||||
left = self.conv_bn_layer(
|
||||
input=left,
|
||||
num_filters=num_filters[block] * 2,
|
||||
filter_size=1,
|
||||
act=self.act,
|
||||
name=conv_name + "_left_route",
|
||||
data_format=data_format)
|
||||
right = self.conv_bn_layer(
|
||||
input=right,
|
||||
num_filters=num_filters[block] * 2,
|
||||
filter_size=1,
|
||||
act=self.act,
|
||||
name=conv_name + "_right_route",
|
||||
data_format=data_format)
|
||||
conv = fluid.layers.concat([left, right], axis=1)
|
||||
|
||||
conv = self.conv_bn_layer(
|
||||
input=conv,
|
||||
num_filters=num_filters[block] * 2,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
act=self.act,
|
||||
name=conv_name + "_merged_transition",
|
||||
data_format=data_format)
|
||||
|
||||
pool = fluid.layers.pool2d(
|
||||
input=conv,
|
||||
pool_type='avg',
|
||||
global_pooling=True,
|
||||
data_format=data_format)
|
||||
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
|
||||
out = fluid.layers.fc(
|
||||
input=pool,
|
||||
size=class_dim,
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
name="fc_0.w_0",
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="fc_0.b_0"))
|
||||
return out
|
||||
|
||||
def conv_bn_layer(self,
|
||||
input,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None,
|
||||
data_format='NCHW'):
|
||||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False,
|
||||
name=name + '.conv2d.output.1',
|
||||
data_format=data_format)
|
||||
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
bn = fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
act=None,
|
||||
name=bn_name + '.output.1',
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance',
|
||||
data_layout=data_format)
|
||||
if act == "relu":
|
||||
bn = fluid.layers.relu(bn)
|
||||
elif act == "leaky_relu":
|
||||
bn = fluid.layers.leaky_relu(bn)
|
||||
elif act == "mish":
|
||||
bn = self._mish(bn)
|
||||
return bn
|
||||
|
||||
def _mish(self, input):
|
||||
return input * fluid.layers.tanh(self._softplus(input))
|
||||
|
||||
def _softplus(self, input):
|
||||
expf = fluid.layers.exp(fluid.layers.clip(input, -200, 50))
|
||||
return fluid.layers.log(1 + expf)
|
||||
|
||||
def shortcut(self, input, ch_out, stride, is_first, name, data_format):
|
||||
if data_format == 'NCHW':
|
||||
ch_in = input.shape[1]
|
||||
else:
|
||||
ch_in = input.shape[-1]
|
||||
if ch_in != ch_out or stride != 1 or is_first is True:
|
||||
return self.conv_bn_layer(
|
||||
input, ch_out, 1, stride, name=name, data_format=data_format)
|
||||
else:
|
||||
return input
|
||||
|
||||
def bottleneck_block(self, input, num_filters, stride, name, data_format):
|
||||
conv0 = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
act="leaky_relu",
|
||||
name=name + "_branch2a",
|
||||
data_format=data_format)
|
||||
conv1 = self.conv_bn_layer(
|
||||
input=conv0,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
stride=stride,
|
||||
act="leaky_relu",
|
||||
name=name + "_branch2b",
|
||||
data_format=data_format)
|
||||
conv2 = self.conv_bn_layer(
|
||||
input=conv1,
|
||||
num_filters=num_filters * 2,
|
||||
filter_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c",
|
||||
data_format=data_format)
|
||||
|
||||
short = self.shortcut(
|
||||
input,
|
||||
num_filters * 2,
|
||||
stride,
|
||||
is_first=False,
|
||||
name=name + "_branch1",
|
||||
data_format=data_format)
|
||||
|
||||
ret = short + conv2
|
||||
ret = fluid.layers.leaky_relu(ret, alpha=0.1)
|
||||
return ret
|
||||
|
||||
|
||||
def CSPResNet50_leaky():
|
||||
model = CSPResNet(layers=50, act="leaky_relu")
|
||||
return model
|
||||
|
||||
|
||||
def CSPResNet50_mish():
|
||||
model = CSPResNet(layers=50, act="mish")
|
||||
return model
|
||||
|
||||
|
||||
def CSPResNet101_leaky():
|
||||
model = CSPResNet(layers=101, act="leaky_relu")
|
||||
return model
|
||||
|
||||
|
||||
def CSPResNet101_mish():
|
||||
model = CSPResNet(layers=101, act="mish")
|
||||
return model
|
|
@ -58,9 +58,9 @@ class RetryError(Exception):
|
|||
super(RetryError, self).__init__(message)
|
||||
|
||||
|
||||
def _get_url(architecture):
|
||||
def _get_url(architecture, postfix="tar"):
|
||||
prefix = "https://paddle-imagenet-models-name.bj.bcebos.com/"
|
||||
fname = architecture + "_pretrained.tar"
|
||||
fname = architecture + "_pretrained." + postfix
|
||||
return prefix + fname
|
||||
|
||||
|
||||
|
@ -193,13 +193,13 @@ def list_models():
|
|||
return
|
||||
|
||||
|
||||
def get(architecture, path, decompress=True):
|
||||
def get(architecture, path, decompress=True, postfix="tar"):
|
||||
"""
|
||||
Get the pretrained model.
|
||||
"""
|
||||
_check_pretrained_name(architecture)
|
||||
url = _get_url(architecture)
|
||||
url = _get_url(architecture, postfix=postfix)
|
||||
fname = _download(url, path)
|
||||
if decompress:
|
||||
if postfix == "tar" and decompress:
|
||||
_decompress(fname)
|
||||
logger.info("download {} finished ".format(fname))
|
||||
|
|
|
@ -116,3 +116,4 @@ VGG16
|
|||
VGG19
|
||||
DarkNet53_ImageNet1k
|
||||
ResNet50_ACNet_deploy
|
||||
CSPResNet50_leaky
|
||||
|
|
|
@ -24,6 +24,7 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-a', '--architecture', type=str, default='ResNet50')
|
||||
parser.add_argument('-p', '--path', type=str, default='./pretrained/')
|
||||
parser.add_argument('--postfix', type=str, default="tar")
|
||||
parser.add_argument('-d', '--decompress', type=str2bool, default=True)
|
||||
parser.add_argument('-l', '--list', type=str2bool, default=False)
|
||||
|
||||
|
@ -36,7 +37,8 @@ def main():
|
|||
if args.list:
|
||||
model_zoo.list_models()
|
||||
else:
|
||||
model_zoo.get(args.architecture, args.path, args.decompress)
|
||||
model_zoo.get(args.architecture, args.path, args.decompress,
|
||||
args.postfix)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue