mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Merge pull request #220 from wqz960/PaddleClas_resnest
add ResNeSt, config and docs
This commit is contained in:
commit
47fbfe7d37
78
configs/ResNeSt/ResNeSt50.yaml
Normal file
78
configs/ResNeSt/ResNeSt50.yaml
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
mode: 'train'
|
||||||
|
ARCHITECTURE:
|
||||||
|
name: 'ResNeSt50'
|
||||||
|
|
||||||
|
pretrained_model: ""
|
||||||
|
model_save_dir: "./output/"
|
||||||
|
classes_num: 1000
|
||||||
|
total_images: 1281167
|
||||||
|
save_interval: 1
|
||||||
|
validate: True
|
||||||
|
valid_interval: 1
|
||||||
|
epochs: 300
|
||||||
|
topk: 5
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
|
||||||
|
use_mix: True
|
||||||
|
ls_epsilon: 0.1
|
||||||
|
|
||||||
|
LEARNING_RATE:
|
||||||
|
function: 'CosineWarmup'
|
||||||
|
params:
|
||||||
|
lr: 0.1
|
||||||
|
|
||||||
|
OPTIMIZER:
|
||||||
|
function: 'Momentum'
|
||||||
|
params:
|
||||||
|
momentum: 0.9
|
||||||
|
regularizer:
|
||||||
|
function: 'L2'
|
||||||
|
factor: 0.000070
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
batch_size: 256
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/ILSVRC2012/train_list.txt"
|
||||||
|
data_dir: "./dataset/ILSVRC2012/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- AutoAugment:
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
mix:
|
||||||
|
- CutmixOperator:
|
||||||
|
alpha: 0.2
|
||||||
|
|
||||||
|
VALID:
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/ILSVRC2012/val_list.txt"
|
||||||
|
data_dir: "./dataset/ILSVRC2012/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
9
docs/en/models/ResNeSt_RegNet_en.md
Normal file
9
docs/en/models/ResNeSt_RegNet_en.md
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
## Overview
|
||||||
|
|
||||||
|
The ResNeSt series was proposed in 2020. The original resnet network structure has been improved by introducing K groups and adding an attention module similar to SEBlock in different groups, the accuracy is greater than that of the basic model ResNet, but the parameter amount and flops are almost the same as the basic ResNet.
|
||||||
|
|
||||||
|
## Accuracy, FLOPs and Parameters
|
||||||
|
|
||||||
|
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Parameters<br>(M) |
|
||||||
|
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
|
||||||
|
| ResNeSt50 | 0.8102 | 0.9542| 0.8113 | -|5.39 | 27.5 |
|
14
docs/zh_CN/models/ResNeSt_RegNet.md
Normal file
14
docs/zh_CN/models/ResNeSt_RegNet.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# ResNeSt以及RegNet网络
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
ResNeSt系列模型是在2020年提出的,在原有的resnet网络结构上做了改进,通过引入K个Group和在不同Group中加入类似于SEBlock的attention模块,使得精度相比于基础模型ResNet有了大幅度的提高,且参数量和flops与基础的ResNet基本保持一致。
|
||||||
|
|
||||||
|
|
||||||
|
## 精度、FLOPS和参数量
|
||||||
|
|
||||||
|
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPS<br>(G) | Parameters<br>(M) |
|
||||||
|
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
|
||||||
|
| ResNeSt50 | 0.8102 | 0.9542| 0.8113 | -|5.39 | 27.5 |
|
||||||
|
|
||||||
|
|
@ -47,6 +47,7 @@ from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44
|
|||||||
from .darts_gs import DARTS_GS_6M, DARTS_GS_4M
|
from .darts_gs import DARTS_GS_6M, DARTS_GS_4M
|
||||||
from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet
|
from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet
|
||||||
from .ghostnet import GhostNet_x0_5, GhostNet_x1_0, GhostNet_x1_3
|
from .ghostnet import GhostNet_x0_5, GhostNet_x1_0, GhostNet_x1_3
|
||||||
|
from .resnest import ResNeSt50, ResNeSt101, ResNeSt200, ResNeSt269, ResNeSt50_fast_1s1x64d, ResNeSt50_fast_2s1x64d, ResNeSt50_fast_4s1x64d, ResNeSt50_fast_1s2x40d, ResNeSt50_fast_2s2x40d, ResNeSt50_fast_2s2x40d, ResNeSt50_fast_4s2x40d, ResNeSt50_fast_1s4x24d
|
||||||
|
|
||||||
# distillation model
|
# distillation model
|
||||||
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
|
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
|
||||||
|
648
ppcls/modeling/architectures/resnest.py
Normal file
648
ppcls/modeling/architectures/resnest.py
Normal file
@ -0,0 +1,648 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.fluid.initializer import MSRA, ConstantInitializer
|
||||||
|
from paddle.fluid.param_attr import ParamAttr
|
||||||
|
from paddle.fluid.regularizer import L2DecayRegularizer
|
||||||
|
import math
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'ResNeSt50', 'ResNeSt101', 'ResNeSt200', 'ResNeSt269',
|
||||||
|
'ResNeSt50_fast_1s1x64d', 'ResNeSt50_fast_2s1x64d',
|
||||||
|
'ResNeSt50_fast_4s1x64d', 'ResNeSt50_fast_1s2x40d',
|
||||||
|
'ResNeSt50_fast_2s2x40d', 'ResNeSt50_fast_2s2x40d',
|
||||||
|
'ResNeSt50_fast_4s2x40d', 'ResNeSt50_fast_1s4x24d'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ResNeSt():
|
||||||
|
def __init__(self,
|
||||||
|
layers,
|
||||||
|
radix=1,
|
||||||
|
groups=1,
|
||||||
|
bottleneck_width=64,
|
||||||
|
dilated=False,
|
||||||
|
dilation=1,
|
||||||
|
deep_stem=False,
|
||||||
|
stem_width=64,
|
||||||
|
avg_down=False,
|
||||||
|
rectify_avg=False,
|
||||||
|
avd=False,
|
||||||
|
avd_first=False,
|
||||||
|
final_drop=0.0,
|
||||||
|
last_gamma=False,
|
||||||
|
bn_decay=0.0):
|
||||||
|
self.cardinality = groups
|
||||||
|
self.bottleneck_width = bottleneck_width
|
||||||
|
# ResNet-D params
|
||||||
|
self.inplanes = stem_width * 2 if deep_stem else 64
|
||||||
|
self.avg_down = avg_down
|
||||||
|
self.last_gamma = last_gamma
|
||||||
|
# ResNeSt params
|
||||||
|
self.radix = radix
|
||||||
|
self.avd = avd
|
||||||
|
self.avd_first = avd_first
|
||||||
|
|
||||||
|
self.deep_stem = deep_stem
|
||||||
|
self.stem_width = stem_width
|
||||||
|
self.layers = layers
|
||||||
|
self.final_drop = final_drop
|
||||||
|
self.dilated = dilated
|
||||||
|
self.dilation = dilation
|
||||||
|
self.bn_decay = bn_decay
|
||||||
|
|
||||||
|
self.rectify_avg = rectify_avg
|
||||||
|
|
||||||
|
def net(self, input, class_dim=1000):
|
||||||
|
if self.deep_stem:
|
||||||
|
x = self.conv_bn_layer(
|
||||||
|
x=input,
|
||||||
|
num_filters=self.stem_width,
|
||||||
|
filters_size=3,
|
||||||
|
stride=2,
|
||||||
|
groups=1,
|
||||||
|
act="relu",
|
||||||
|
name="conv1")
|
||||||
|
x = self.conv_bn_layer(
|
||||||
|
x=x,
|
||||||
|
num_filters=self.stem_width,
|
||||||
|
filters_size=3,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
act="relu",
|
||||||
|
name="conv2")
|
||||||
|
x = self.conv_bn_layer(
|
||||||
|
x=x,
|
||||||
|
num_filters=self.stem_width * 2,
|
||||||
|
filters_size=3,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
act="relu",
|
||||||
|
name="conv3")
|
||||||
|
else:
|
||||||
|
x = self.conv_bn_layer(
|
||||||
|
x=input,
|
||||||
|
num_filters=64,
|
||||||
|
filters_size=7,
|
||||||
|
stride=2,
|
||||||
|
act="relu",
|
||||||
|
name="conv1")
|
||||||
|
|
||||||
|
x = fluid.layers.pool2d(
|
||||||
|
input=x,
|
||||||
|
pool_size=3,
|
||||||
|
pool_type="max",
|
||||||
|
pool_stride=2,
|
||||||
|
pool_padding=1)
|
||||||
|
|
||||||
|
x = self.resnest_layer(
|
||||||
|
x=x,
|
||||||
|
planes=64,
|
||||||
|
blocks=self.layers[0],
|
||||||
|
is_first=False,
|
||||||
|
name="layer1")
|
||||||
|
x = self.resnest_layer(
|
||||||
|
x=x,
|
||||||
|
planes=128,
|
||||||
|
blocks=self.layers[1],
|
||||||
|
stride=2,
|
||||||
|
name="layer2")
|
||||||
|
if self.dilated or self.dilation == 4:
|
||||||
|
x = self.resnest_layer(
|
||||||
|
x=x,
|
||||||
|
planes=256,
|
||||||
|
blocks=self.layers[2],
|
||||||
|
stride=1,
|
||||||
|
dilation=2,
|
||||||
|
name="layer3")
|
||||||
|
x = self.resnest_layer(
|
||||||
|
x=x,
|
||||||
|
planes=512,
|
||||||
|
blocks=self.layers[3],
|
||||||
|
stride=1,
|
||||||
|
dilation=4,
|
||||||
|
name="layer4")
|
||||||
|
elif self.dilation == 2:
|
||||||
|
x = self.resnest_layer(
|
||||||
|
x=x,
|
||||||
|
planes=256,
|
||||||
|
blocks=self.layers[2],
|
||||||
|
stride=2,
|
||||||
|
dilation=1,
|
||||||
|
name="layer3")
|
||||||
|
x = self.resnest_layer(
|
||||||
|
x=x,
|
||||||
|
planes=512,
|
||||||
|
blocks=self.layers[3],
|
||||||
|
stride=1,
|
||||||
|
dilation=2,
|
||||||
|
name="layer4")
|
||||||
|
else:
|
||||||
|
x = self.resnest_layer(
|
||||||
|
x=x,
|
||||||
|
planes=256,
|
||||||
|
blocks=self.layers[2],
|
||||||
|
stride=2,
|
||||||
|
name="layer3")
|
||||||
|
x = self.resnest_layer(
|
||||||
|
x=x,
|
||||||
|
planes=512,
|
||||||
|
blocks=self.layers[3],
|
||||||
|
stride=2,
|
||||||
|
name="layer4")
|
||||||
|
x = fluid.layers.pool2d(
|
||||||
|
input=x, pool_type="avg", global_pooling=True)
|
||||||
|
x = fluid.layers.dropout(
|
||||||
|
x=x, dropout_prob=self.final_drop)
|
||||||
|
stdv = 1.0 / math.sqrt(x.shape[1] * 1.0)
|
||||||
|
x = fluid.layers.fc(
|
||||||
|
input=x,
|
||||||
|
size=class_dim,
|
||||||
|
param_attr=ParamAttr(
|
||||||
|
name="fc_weights",
|
||||||
|
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||||
|
bias_attr=ParamAttr(name="fc_offset"))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def conv_bn_layer(self,
|
||||||
|
x,
|
||||||
|
num_filters,
|
||||||
|
filters_size,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
act=None,
|
||||||
|
name=None):
|
||||||
|
x = fluid.layers.conv2d(
|
||||||
|
input=x,
|
||||||
|
num_filters=num_filters,
|
||||||
|
filter_size=filters_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=(filters_size - 1) // 2,
|
||||||
|
groups=groups,
|
||||||
|
act=None,
|
||||||
|
param_attr=ParamAttr(
|
||||||
|
initializer=MSRA(), name=name + "_weight"),
|
||||||
|
bias_attr=False)
|
||||||
|
x = fluid.layers.batch_norm(
|
||||||
|
input=x,
|
||||||
|
act=act,
|
||||||
|
param_attr=ParamAttr(
|
||||||
|
name=name + "_scale",
|
||||||
|
regularizer=L2DecayRegularizer(
|
||||||
|
regularization_coeff=self.bn_decay)),
|
||||||
|
bias_attr=ParamAttr(
|
||||||
|
name=name + "_offset",
|
||||||
|
regularizer=L2DecayRegularizer(
|
||||||
|
regularization_coeff=self.bn_decay)),
|
||||||
|
moving_mean_name=name + "_mean",
|
||||||
|
moving_variance_name=name + "_variance")
|
||||||
|
return x
|
||||||
|
|
||||||
|
def rsoftmax(self, x, radix, cardinality):
|
||||||
|
batch, r, h, w = x.shape
|
||||||
|
if radix > 1:
|
||||||
|
x = fluid.layers.reshape(
|
||||||
|
x=x,
|
||||||
|
shape=[
|
||||||
|
0, cardinality, radix, int(r * h * w / cardinality / radix)
|
||||||
|
])
|
||||||
|
x = fluid.layers.transpose(x=x, perm=[0, 2, 1, 3])
|
||||||
|
x = fluid.layers.softmax(input=x, axis=1)
|
||||||
|
x = fluid.layers.reshape(x=x, shape=[0, r * h * w])
|
||||||
|
else:
|
||||||
|
x = fluid.layers.sigmoid(x=x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def splat_conv(self,
|
||||||
|
x,
|
||||||
|
in_channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
rectify_avg=False,
|
||||||
|
name=None):
|
||||||
|
x = self.conv_bn_layer(
|
||||||
|
x=x,
|
||||||
|
num_filters=channels * radix,
|
||||||
|
filters_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
groups=groups * radix,
|
||||||
|
act="relu",
|
||||||
|
name=name + "_splat1")
|
||||||
|
|
||||||
|
batch, rchannel = x.shape[:2]
|
||||||
|
if radix > 1:
|
||||||
|
splited = fluid.layers.split(input=x, num_or_sections=radix, dim=1)
|
||||||
|
gap = fluid.layers.sum(x=splited)
|
||||||
|
else:
|
||||||
|
gap = x
|
||||||
|
gap = fluid.layers.pool2d(
|
||||||
|
input=gap, pool_type="avg", global_pooling=True)
|
||||||
|
inter_channels = int(max(in_channels * radix // reduction_factor, 32))
|
||||||
|
gap = self.conv_bn_layer(
|
||||||
|
x=gap,
|
||||||
|
num_filters=inter_channels,
|
||||||
|
filters_size=1,
|
||||||
|
groups=groups,
|
||||||
|
act="relu",
|
||||||
|
name=name + "_splat2")
|
||||||
|
|
||||||
|
atten = fluid.layers.conv2d(
|
||||||
|
input=gap,
|
||||||
|
num_filters=channels * radix,
|
||||||
|
filter_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
groups=groups,
|
||||||
|
act=None,
|
||||||
|
param_attr=ParamAttr(
|
||||||
|
name=name + "_splat_weights", initializer=MSRA()),
|
||||||
|
bias_attr=False)
|
||||||
|
atten = self.rsoftmax(
|
||||||
|
x=atten, radix=radix, cardinality=groups)
|
||||||
|
atten = fluid.layers.reshape(x=atten, shape=[-1, atten.shape[1], 1, 1])
|
||||||
|
|
||||||
|
if radix > 1:
|
||||||
|
attens = fluid.layers.split(
|
||||||
|
input=atten, num_or_sections=radix, dim=1)
|
||||||
|
out = fluid.layers.sum([
|
||||||
|
fluid.layers.elementwise_mul(
|
||||||
|
x=att, y=split) for (att, split) in zip(attens, splited)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
out = fluid.layers.elementwise_mul(atten, x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def bottleneck(self,
|
||||||
|
x,
|
||||||
|
inplanes,
|
||||||
|
planes,
|
||||||
|
stride=1,
|
||||||
|
radix=1,
|
||||||
|
cardinality=1,
|
||||||
|
bottleneck_width=64,
|
||||||
|
avd=False,
|
||||||
|
avd_first=False,
|
||||||
|
dilation=1,
|
||||||
|
is_first=False,
|
||||||
|
rectify_avg=False,
|
||||||
|
last_gamma=False,
|
||||||
|
name=None):
|
||||||
|
|
||||||
|
short = x
|
||||||
|
|
||||||
|
group_width = int(planes * (bottleneck_width / 64.)) * cardinality
|
||||||
|
x = self.conv_bn_layer(
|
||||||
|
x=x,
|
||||||
|
num_filters=group_width,
|
||||||
|
filters_size=1,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
act="relu",
|
||||||
|
name=name + "_conv1")
|
||||||
|
if avd and avd_first and (stride > 1 or is_first):
|
||||||
|
x = fluid.layers.pool2d(
|
||||||
|
input=x,
|
||||||
|
pool_size=3,
|
||||||
|
pool_type="avg",
|
||||||
|
pool_stride=stride,
|
||||||
|
pool_padding=1)
|
||||||
|
if radix >= 1:
|
||||||
|
x = self.splat_conv(
|
||||||
|
x=x,
|
||||||
|
in_channels=group_width,
|
||||||
|
channels=group_width,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=dilation,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=cardinality,
|
||||||
|
bias=False,
|
||||||
|
radix=radix,
|
||||||
|
rectify_avg=rectify_avg,
|
||||||
|
name=name + "_splatconv")
|
||||||
|
else:
|
||||||
|
x = self.conv_bn_layer(
|
||||||
|
x=x,
|
||||||
|
num_filters=group_width,
|
||||||
|
filters_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=dilation,
|
||||||
|
dilation=dialtion,
|
||||||
|
groups=cardinality,
|
||||||
|
act="relu",
|
||||||
|
name=name + "_conv2")
|
||||||
|
|
||||||
|
if avd and avd_first == False and (stride > 1 or is_first):
|
||||||
|
x = fluid.layers.pool2d(
|
||||||
|
input=x,
|
||||||
|
pool_size=3,
|
||||||
|
pool_type="avg",
|
||||||
|
pool_stride=stride,
|
||||||
|
pool_padding=1)
|
||||||
|
x = self.conv_bn_layer(
|
||||||
|
x=x,
|
||||||
|
num_filters=planes * 4,
|
||||||
|
filters_size=1,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
act=None,
|
||||||
|
name=name + "_conv3")
|
||||||
|
|
||||||
|
if stride != 1 or self.inplanes != planes * 4:
|
||||||
|
if self.avg_down:
|
||||||
|
if dilation == 1:
|
||||||
|
short = fluid.layers.pool2d(
|
||||||
|
input=short,
|
||||||
|
pool_size=stride,
|
||||||
|
pool_type="avg",
|
||||||
|
pool_stride=stride,
|
||||||
|
ceil_mode=True)
|
||||||
|
else:
|
||||||
|
short = fluid.layers.pool2d(
|
||||||
|
input=short,
|
||||||
|
pool_size=1,
|
||||||
|
pool_type="avg",
|
||||||
|
pool_stride=1,
|
||||||
|
ceil_mode=True)
|
||||||
|
short = fluid.layers.conv2d(
|
||||||
|
input=short,
|
||||||
|
num_filters=planes * 4,
|
||||||
|
filter_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
groups=1,
|
||||||
|
act=None,
|
||||||
|
param_attr=ParamAttr(
|
||||||
|
name=name + "_weights", initializer=MSRA()),
|
||||||
|
bias_attr=False)
|
||||||
|
else:
|
||||||
|
short = fluid.layers.conv2d(
|
||||||
|
input=short,
|
||||||
|
num_filters=planes * 4,
|
||||||
|
filter_size=1,
|
||||||
|
stride=stride,
|
||||||
|
param_attr=ParamAttr(
|
||||||
|
name=name + "_shortcut_weights", initializer=MSRA()),
|
||||||
|
bias_attr=False)
|
||||||
|
|
||||||
|
short = fluid.layers.batch_norm(
|
||||||
|
input=short,
|
||||||
|
act=None,
|
||||||
|
param_attr=ParamAttr(
|
||||||
|
name=name + "_shortcut_scale",
|
||||||
|
regularizer=L2DecayRegularizer(
|
||||||
|
regularization_coeff=self.bn_decay)),
|
||||||
|
bias_attr=ParamAttr(
|
||||||
|
name=name + "_shortcut_offset",
|
||||||
|
regularizer=L2DecayRegularizer(
|
||||||
|
regularization_coeff=self.bn_decay)),
|
||||||
|
moving_mean_name=name + "_shortcut_mean",
|
||||||
|
moving_variance_name=name + "_shortcut_variance")
|
||||||
|
|
||||||
|
return fluid.layers.elementwise_add(x=short, y=x, act="relu")
|
||||||
|
|
||||||
|
def resnest_layer(self,
|
||||||
|
x,
|
||||||
|
planes,
|
||||||
|
blocks,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
is_first=True,
|
||||||
|
name=None):
|
||||||
|
if dilation == 1 or dilation == 2:
|
||||||
|
x = self.bottleneck(
|
||||||
|
x=x,
|
||||||
|
inplanes=self.inplanes,
|
||||||
|
planes=planes,
|
||||||
|
stride=stride,
|
||||||
|
radix=self.radix,
|
||||||
|
cardinality=self.cardinality,
|
||||||
|
bottleneck_width=self.bottleneck_width,
|
||||||
|
avd=self.avd,
|
||||||
|
avd_first=self.avd_first,
|
||||||
|
dilation=1,
|
||||||
|
is_first=is_first,
|
||||||
|
rectify_avg=self.rectify_avg,
|
||||||
|
last_gamma=self.last_gamma,
|
||||||
|
name=name + "_bottleneck_0")
|
||||||
|
elif dilation == 4:
|
||||||
|
x = self.bottleneck(
|
||||||
|
x=x,
|
||||||
|
inplanes=self.inplanes,
|
||||||
|
planes=planes,
|
||||||
|
stride=stride,
|
||||||
|
radix=self.radix,
|
||||||
|
cardinality=self.cardinality,
|
||||||
|
bottleneck_width=self.bottleneck_width,
|
||||||
|
avd=self.avd,
|
||||||
|
avd_first=self.avd_first,
|
||||||
|
dilation=2,
|
||||||
|
is_first=is_first,
|
||||||
|
rectify_avg=self.rectify_avg,
|
||||||
|
last_gamma=self.last_gamma,
|
||||||
|
name=name + "_bottleneck_0")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("=>unknown dilation size")
|
||||||
|
|
||||||
|
self.inplanes = planes * 4
|
||||||
|
for i in range(1, blocks):
|
||||||
|
name = name + "_bottleneck_" + str(i)
|
||||||
|
x = self.bottleneck(
|
||||||
|
x=x,
|
||||||
|
inplanes=self.inplanes,
|
||||||
|
planes=planes,
|
||||||
|
radix=self.radix,
|
||||||
|
cardinality=self.cardinality,
|
||||||
|
bottleneck_width=self.bottleneck_width,
|
||||||
|
avd=self.avd,
|
||||||
|
avd_first=self.avd_first,
|
||||||
|
dilation=dilation,
|
||||||
|
rectify_avg=self.rectify_avg,
|
||||||
|
last_gamma=self.last_gamma,
|
||||||
|
name=name)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def ResNeSt50(**args):
|
||||||
|
model = ResNeSt(
|
||||||
|
layers=[3, 4, 6, 3],
|
||||||
|
radix=2,
|
||||||
|
groups=1,
|
||||||
|
bottleneck_width=64,
|
||||||
|
deep_stem=True,
|
||||||
|
stem_width=32,
|
||||||
|
avg_down=True,
|
||||||
|
avd=True,
|
||||||
|
avd_first=False,
|
||||||
|
final_drop=0.0,
|
||||||
|
**args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNeSt101(**args):
|
||||||
|
model = ResNeSt(
|
||||||
|
layers=[3, 4, 23, 3],
|
||||||
|
radix=2,
|
||||||
|
groups=1,
|
||||||
|
bottleneck_width=64,
|
||||||
|
deep_stem=True,
|
||||||
|
stem_width=64,
|
||||||
|
avg_down=True,
|
||||||
|
avd=True,
|
||||||
|
avd_first=False,
|
||||||
|
final_drop=0.0,
|
||||||
|
**args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNeSt200(**args):
|
||||||
|
model = ResNeSt(
|
||||||
|
layers=[3, 24, 36, 3],
|
||||||
|
radix=2,
|
||||||
|
groups=1,
|
||||||
|
bottleneck_width=64,
|
||||||
|
deep_stem=True,
|
||||||
|
stem_width=64,
|
||||||
|
avg_down=True,
|
||||||
|
avd=True,
|
||||||
|
avd_first=False,
|
||||||
|
final_drop=0.2,
|
||||||
|
**args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNeSt269(**args):
|
||||||
|
model = ResNeSt(
|
||||||
|
layers=[3, 30, 48, 8],
|
||||||
|
radix=2,
|
||||||
|
groups=1,
|
||||||
|
bottleneck_width=64,
|
||||||
|
deep_stem=True,
|
||||||
|
stem_width=64,
|
||||||
|
avg_down=True,
|
||||||
|
avd=True,
|
||||||
|
avd_first=False,
|
||||||
|
final_drop=0.2,
|
||||||
|
**args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNeSt50_fast_1s1x64d(**args):
|
||||||
|
model = ResNeSt(
|
||||||
|
layers=[3, 4, 6, 3],
|
||||||
|
radix=1,
|
||||||
|
groups=1,
|
||||||
|
bottleneck_width=64,
|
||||||
|
deep_stem=True,
|
||||||
|
stem_width=32,
|
||||||
|
avg_down=True,
|
||||||
|
avd=True,
|
||||||
|
avd_first=True,
|
||||||
|
final_drop=0.0,
|
||||||
|
**args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNeSt50_fast_2s1x64d(**args):
|
||||||
|
model = ResNeSt(
|
||||||
|
layers=[3, 4, 6, 3],
|
||||||
|
radix=2,
|
||||||
|
groups=1,
|
||||||
|
bottleneck_width=64,
|
||||||
|
deep_stem=True,
|
||||||
|
stem_width=32,
|
||||||
|
avg_down=True,
|
||||||
|
avd=True,
|
||||||
|
avd_first=True,
|
||||||
|
final_drop=0.0,
|
||||||
|
**args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNeSt50_fast_4s1x64d(**args):
|
||||||
|
model = ResNeSt(
|
||||||
|
layers=[3, 4, 6, 3],
|
||||||
|
radix=2,
|
||||||
|
groups=1,
|
||||||
|
bottleneck_width=64,
|
||||||
|
deep_stem=True,
|
||||||
|
stem_width=32,
|
||||||
|
avg_down=True,
|
||||||
|
avd=True,
|
||||||
|
avd_first=True,
|
||||||
|
final_drop=0.0,
|
||||||
|
**args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNeSt50_fast_1s2x40d(**args):
|
||||||
|
model = ResNeSt(
|
||||||
|
layers=[3, 4, 6, 3],
|
||||||
|
radix=1,
|
||||||
|
groups=2,
|
||||||
|
bottleneck_width=40,
|
||||||
|
deep_stem=True,
|
||||||
|
stem_width=32,
|
||||||
|
avg_down=True,
|
||||||
|
avd=True,
|
||||||
|
avd_first=True,
|
||||||
|
final_drop=0.0,
|
||||||
|
**args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNeSt50_fast_2s2x40d(**args):
|
||||||
|
model = ResNeSt(
|
||||||
|
layers=[3, 4, 6, 3],
|
||||||
|
radix=2,
|
||||||
|
groups=2,
|
||||||
|
bottleneck_width=40,
|
||||||
|
deep_stem=True,
|
||||||
|
stem_width=32,
|
||||||
|
avg_down=True,
|
||||||
|
avd=True,
|
||||||
|
avd_first=True,
|
||||||
|
final_drop=0.0,
|
||||||
|
**args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNeSt50_fast_4s2x40d(**args):
|
||||||
|
model = ResNeSt(
|
||||||
|
layers=[3, 4, 6, 3],
|
||||||
|
radix=4,
|
||||||
|
groups=2,
|
||||||
|
bottleneck_width=40,
|
||||||
|
deep_stem=True,
|
||||||
|
stem_width=32,
|
||||||
|
avg_down=True,
|
||||||
|
avd=True,
|
||||||
|
avd_first=True,
|
||||||
|
final_drop=0.0,
|
||||||
|
**args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNeSt50_fast_1s4x24d(**args):
|
||||||
|
model = ResNeSt(
|
||||||
|
layers=[3, 4, 6, 3],
|
||||||
|
radix=1,
|
||||||
|
groups=4,
|
||||||
|
bottleneck_width=24,
|
||||||
|
deep_stem=True,
|
||||||
|
stem_width=32,
|
||||||
|
avg_down=True,
|
||||||
|
avd=True,
|
||||||
|
avd_first=True,
|
||||||
|
final_drop=0.0,
|
||||||
|
**args)
|
||||||
|
return model
|
Loading…
x
Reference in New Issue
Block a user