legendary models v0.1

pull/756/head
dongshuilong 2021-05-31 16:01:02 +08:00
parent 4d246c20e4
commit 2fa808517d
7 changed files with 1544 additions and 834 deletions

View File

@ -0,0 +1,6 @@
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNet18_vd, ResNet34_vd, ResNet50_vd, ResNet101_vd, ResNet152_vd
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W64_C
from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1
from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25
from .inception_v3 import InceptionV3
from .vgg import VGG11, VGG13, VGG16, VGG19

View File

@ -24,29 +24,40 @@ from paddle.nn.functional import upsample
from paddle.nn.initializer import Uniform
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer, Identity
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"HRNet_W18_C": "",
"HRNet_W30_C": "",
"HRNet_W32_C": "",
"HRNet_W40_C": "",
"HRNet_W44_C": "",
"HRNet_W48_C": "",
"HRNet_W60_C": "",
"HRNet_W64_C": "",
"SE_HRNet_W18_C": "",
"SE_HRNet_W30_C": "",
"SE_HRNet_W32_C": "",
"SE_HRNet_W40_C": "",
"SE_HRNet_W44_C": "",
"SE_HRNet_W48_C": "",
"SE_HRNet_W60_C": "",
"SE_HRNet_W64_C": "",
"HRNet_W18_C":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W18_C_pretrained.pdparams",
"HRNet_W30_C":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W30_C_pretrained.pdparams",
"HRNet_W32_C":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W32_C_pretrained.pdparams",
"HRNet_W40_C":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W40_C_pretrained.pdparams",
"HRNet_W44_C":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W44_C_pretrained.pdparams",
"HRNet_W48_C":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W48_C_pretrained.pdparams",
"HRNet_W64_C":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W64_C_pretrained.pdparams"
}
__all__ = list(MODEL_URLS.keys())
def _create_act(act):
if act == "hardswish":
return nn.Hardswish()
elif act == "relu":
return nn.ReLU()
elif act is None:
return Identity()
else:
raise RuntimeError(
"The activation function is not supported: {}".format(act))
class ConvBNLayer(TheseusLayer):
def __init__(self,
num_channels,
@ -55,7 +66,7 @@ class ConvBNLayer(TheseusLayer):
stride=1,
groups=1,
act="relu"):
super(ConvBNLayer, self).__init__()
super().__init__()
self.conv = nn.Conv2D(
in_channels=num_channels,
@ -65,10 +76,8 @@ class ConvBNLayer(TheseusLayer):
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=False)
self.bn = nn.BatchNorm(
num_filters,
act=None)
self.act = create_act(act)
self.bn = nn.BatchNorm(num_filters, act=None)
self.act = _create_act(act)
def forward(self, x):
x = self.conv(x)
@ -77,18 +86,6 @@ class ConvBNLayer(TheseusLayer):
return x
def create_act(act):
if act == 'hardswish':
return nn.Hardswish()
elif act == 'relu':
return nn.ReLU()
elif act is None:
return Identity()
else:
raise RuntimeError(
'The activation function is not supported: {}'.format(act))
class BottleneckBlock(TheseusLayer):
def __init__(self,
num_channels,
@ -96,7 +93,7 @@ class BottleneckBlock(TheseusLayer):
has_se,
stride=1,
downsample=False):
super(BottleneckBlock, self).__init__()
super().__init__()
self.has_se = has_se
self.downsample = downsample
@ -147,11 +144,8 @@ class BottleneckBlock(TheseusLayer):
class BasicBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
has_se=False):
super(BasicBlock, self).__init__()
def __init__(self, num_channels, num_filters, has_se=False):
super().__init__()
self.has_se = has_se
@ -190,9 +184,9 @@ class BasicBlock(nn.Layer):
class SELayer(TheseusLayer):
def __init__(self, num_channels, num_filters, reduction_ratio):
super(SELayer, self).__init__()
super().__init__()
self.pool2d_gap = nn.AdaptiveAvgPool2D(1)
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self._num_channels = num_channels
@ -201,8 +195,7 @@ class SELayer(TheseusLayer):
self.fc_squeeze = nn.Linear(
num_channels,
med_ch,
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv)))
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
self.relu = nn.ReLU()
stdv = 1.0 / math.sqrt(med_ch * 1.0)
self.fc_excitation = nn.Linear(
@ -213,7 +206,7 @@ class SELayer(TheseusLayer):
def forward(self, x, res_dict=None):
residual = x
x = self.pool2d_gap(x)
x = self.avg_pool(x)
x = paddle.squeeze(x, axis=[2, 3])
x = self.fc_squeeze(x)
x = self.relu(x)
@ -225,11 +218,8 @@ class SELayer(TheseusLayer):
class Stage(TheseusLayer):
def __init__(self,
num_modules,
num_filters,
has_se=False):
super(Stage, self).__init__()
def __init__(self, num_modules, num_filters, has_se=False):
super().__init__()
self._num_modules = num_modules
@ -237,8 +227,7 @@ class Stage(TheseusLayer):
for i in range(num_modules):
self.stage_func_list.append(
HighResolutionModule(
num_filters=num_filters,
has_se=has_se))
num_filters=num_filters, has_se=has_se))
def forward(self, x, res_dict=None):
x = x
@ -248,10 +237,8 @@ class Stage(TheseusLayer):
class HighResolutionModule(TheseusLayer):
def __init__(self,
num_filters,
has_se=False):
super(HighResolutionModule, self).__init__()
def __init__(self, num_filters, has_se=False):
super().__init__()
self.basic_block_list = nn.LayerList()
@ -261,11 +248,11 @@ class HighResolutionModule(TheseusLayer):
BasicBlock(
num_channels=num_filters[i],
num_filters=num_filters[i],
has_se=has_se) for j in range(4)]))
has_se=has_se) for j in range(4)
]))
self.fuse_func = FuseLayers(
in_channels=num_filters,
out_channels=num_filters)
in_channels=num_filters, out_channels=num_filters)
def forward(self, x, res_dict=None):
out = []
@ -279,10 +266,8 @@ class HighResolutionModule(TheseusLayer):
class FuseLayers(TheseusLayer):
def __init__(self,
in_channels,
out_channels):
super(FuseLayers, self).__init__()
def __init__(self, in_channels, out_channels):
super().__init__()
self._actual_ch = len(in_channels)
self._in_channels = in_channels
@ -352,7 +337,7 @@ class LastClsOut(TheseusLayer):
num_channel_list,
has_se,
num_filters_list=[32, 64, 128, 256]):
super(LastClsOut, self).__init__()
super().__init__()
self.func_list = nn.LayerList()
for idx in range(len(num_channel_list)):
@ -378,9 +363,12 @@ class HRNet(TheseusLayer):
width: int=18. Base channel number of HRNet.
has_se: bool=False. If 'True', add se module to HRNet.
class_num: int=1000. Output num of last fc layer.
Returns:
model: nn.Layer. Specific HRNet model depends on args.
"""
def __init__(self, width=18, has_se=False, class_num=1000):
super(HRNet, self).__init__()
super().__init__()
self.width = width
self.has_se = has_se
@ -388,21 +376,23 @@ class HRNet(TheseusLayer):
channels_2 = [self.width, self.width * 2]
channels_3 = [self.width, self.width * 2, self.width * 4]
channels_4 = [self.width, self.width * 2, self.width * 4, self.width * 8]
channels_4 = [
self.width, self.width * 2, self.width * 4, self.width * 8
]
self.conv_layer1_1 = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=3,
stride=2,
act='relu')
act="relu")
self.conv_layer1_2 = ConvBNLayer(
num_channels=64,
num_filters=64,
filter_size=3,
stride=2,
act='relu')
act="relu")
self.layer1 = nn.Sequential(*[
BottleneckBlock(
@ -410,48 +400,33 @@ class HRNet(TheseusLayer):
num_filters=64,
has_se=has_se,
stride=1,
downsample=True if i == 0 else False)
for i in range(4)
downsample=True if i == 0 else False) for i in range(4)
])
self.conv_tr1_1 = ConvBNLayer(
num_channels=256,
num_filters=width,
filter_size=3)
num_channels=256, num_filters=width, filter_size=3)
self.conv_tr1_2 = ConvBNLayer(
num_channels=256,
num_filters=width * 2,
filter_size=3,
stride=2
)
num_channels=256, num_filters=width * 2, filter_size=3, stride=2)
self.st2 = Stage(
num_modules=1,
num_filters=channels_2,
has_se=self.has_se)
num_modules=1, num_filters=channels_2, has_se=self.has_se)
self.conv_tr2 = ConvBNLayer(
num_channels=width * 2,
num_filters=width * 4,
filter_size=3,
stride=2
)
stride=2)
self.st3 = Stage(
num_modules=4,
num_filters=channels_3,
has_se=self.has_se)
num_modules=4, num_filters=channels_3, has_se=self.has_se)
self.conv_tr3 = ConvBNLayer(
num_channels=width * 4,
num_filters=width * 8,
filter_size=3,
stride=2
)
stride=2)
self.st4 = Stage(
num_modules=3,
num_filters=channels_4,
has_se=self.has_se)
num_modules=3, num_filters=channels_4, has_se=self.has_se)
# classification
num_filters_list = [32, 64, 128, 256]
@ -464,17 +439,14 @@ class HRNet(TheseusLayer):
self.cls_head_conv_list = nn.LayerList()
for idx in range(3):
self.cls_head_conv_list.append(
ConvBNLayer(
num_channels=num_filters_list[idx] * 4,
num_filters=last_num_filters[idx],
filter_size=3,
stride=2))
ConvBNLayer(
num_channels=num_filters_list[idx] * 4,
num_filters=last_num_filters[idx],
filter_size=3,
stride=2))
self.conv_last = ConvBNLayer(
num_channels=1024,
num_filters=2048,
filter_size=1,
stride=1)
num_channels=1024, num_filters=2048, filter_size=1, stride=1)
self.avg_pool = nn.AdaptiveAvgPool2D(1)
@ -516,81 +488,254 @@ class HRNet(TheseusLayer):
return y
def HRNet_W18_C(**args):
model = HRNet(width=18, **args)
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def HRNet_W18_C(pretrained=False, use_ssld=False, **kwargs):
"""
HRNet_W18_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W18_C` model depends on args.
"""
model = HRNet(width=18, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W18_C"], use_ssld)
return model
def HRNet_W30_C(**args):
model = HRNet(width=30, **args)
def HRNet_W30_C(pretrained=False, use_ssld=False, **kwargs):
"""
HRNet_W30_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W30_C` model depends on args.
"""
model = HRNet(width=30, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W30_C"], use_ssld)
return model
def HRNet_W32_C(**args):
model = HRNet(width=32, **args)
def HRNet_W32_C(pretrained=False, use_ssld=False, **kwargs):
"""
HRNet_W32_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W32_C` model depends on args.
"""
model = HRNet(width=32, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W32_C"], use_ssld)
return model
def HRNet_W40_C(**args):
model = HRNet(width=40, **args)
def HRNet_W40_C(pretrained=False, use_ssld=False, **kwargs):
"""
HRNet_W40_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W40_C` model depends on args.
"""
model = HRNet(width=40, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W40_C"], use_ssld)
return model
def HRNet_W44_C(**args):
model = HRNet(width=44, **args)
def HRNet_W44_C(pretrained=False, use_ssld=False, **kwargs):
"""
HRNet_W44_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W44_C` model depends on args.
"""
model = HRNet(width=44, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W44_C"], use_ssld)
return model
def HRNet_W48_C(**args):
model = HRNet(width=48, **args)
def HRNet_W48_C(pretrained=False, use_ssld=False, **kwargs):
"""
HRNet_W48_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W48_C` model depends on args.
"""
model = HRNet(width=48, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W48_C"], use_ssld)
return model
def HRNet_W60_C(**args):
model = HRNet(width=60, **args)
def HRNet_W60_C(pretrained=False, use_ssld=False, **kwargs):
"""
HRNet_W60_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W60_C` model depends on args.
"""
model = HRNet(width=60, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W60_C"], use_ssld)
return model
def HRNet_W64_C(**args):
model = HRNet(width=64, **args)
def HRNet_W64_C(pretrained=False, use_ssld=False, **kwargs):
"""
HRNet_W64_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W64_C` model depends on args.
"""
model = HRNet(width=64, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W64_C"], use_ssld)
return model
def SE_HRNet_W18_C(**args):
model = HRNet(width=18, has_se=True, **args)
def SE_HRNet_W18_C(pretrained=False, use_ssld=False, **kwargs):
"""
SE_HRNet_W18_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W18_C` model depends on args.
"""
model = HRNet(width=18, has_se=True, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W18_C"], use_ssld)
return model
def SE_HRNet_W30_C(**args):
model = HRNet(width=30, has_se=True, **args)
def SE_HRNet_W30_C(pretrained=False, use_ssld=False, **kwargs):
"""
SE_HRNet_W30_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W30_C` model depends on args.
"""
model = HRNet(width=30, has_se=True, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W30_C"], use_ssld)
return model
def SE_HRNet_W32_C(**args):
model = HRNet(width=32, has_se=True, **args)
def SE_HRNet_W32_C(pretrained=False, use_ssld=False, **kwargs):
"""
SE_HRNet_W32_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W32_C` model depends on args.
"""
model = HRNet(width=32, has_se=True, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W32_C"], use_ssld)
return model
def SE_HRNet_W40_C(**args):
model = HRNet(width=40, has_se=True, **args)
def SE_HRNet_W40_C(pretrained=False, use_ssld=False, **kwargs):
"""
SE_HRNet_W40_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W40_C` model depends on args.
"""
model = HRNet(width=40, has_se=True, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W40_C"], use_ssld)
return model
def SE_HRNet_W44_C(**args):
model = HRNet(width=44, has_se=True, **args)
def SE_HRNet_W44_C(pretrained=False, use_ssld=False, **kwargs):
"""
SE_HRNet_W44_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W44_C` model depends on args.
"""
model = HRNet(width=44, has_se=True, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W44_C"], use_ssld)
return model
def SE_HRNet_W48_C(**args):
model = HRNet(width=48, has_se=True, **args)
def SE_HRNet_W48_C(pretrained=False, use_ssld=False, **kwargs):
"""
SE_HRNet_W48_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W48_C` model depends on args.
"""
model = HRNet(width=48, has_se=True, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W48_C"], use_ssld)
return model
def SE_HRNet_W60_C(**args):
model = HRNet(width=60, has_se=True, **args)
def SE_HRNet_W60_C(pretrained=False, use_ssld=False, **kwargs):
"""
SE_HRNet_W60_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W60_C` model depends on args.
"""
model = HRNet(width=60, has_se=True, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W60_C"], use_ssld)
return model
def SE_HRNet_W64_C(**args):
model = HRNet(width=64, has_se=True, **args)
def SE_HRNet_W64_C(pretrained=False, use_ssld=False, **kwargs):
"""
SE_HRNet_W64_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W64_C` model depends on args.
"""
model = HRNet(width=64, has_se=True, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W64_C"], use_ssld)
return model

View File

@ -13,39 +13,37 @@
# limitations under the License.
from __future__ import absolute_import, division, print_function
import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
import math
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"InceptionV3": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/InceptionV3_pretrained.pdparams",
"InceptionV3":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/InceptionV3_pretrained.pdparams"
}
__all__ = MODEL_URLS.keys()
'''
InceptionV3 config: dict.
key: inception blocks of InceptionV3.
values: conv num in different blocks.
'''
NET_CONFIG = {
'inception_a':[[192, 256, 288], [32, 64, 64]],
'inception_b':[288],
'inception_c':[[768, 768, 768, 768], [128, 160, 160, 192]],
'inception_d':[768],
'inception_e':[1280,2048]
"inception_a": [[192, 256, 288], [32, 64, 64]],
"inception_b": [288],
"inception_c": [[768, 768, 768, 768], [128, 160, 160, 192]],
"inception_d": [768],
"inception_e": [1280, 2048]
}
class ConvBNLayer(TheseusLayer):
def __init__(self,
num_channels,
@ -55,7 +53,7 @@ class ConvBNLayer(TheseusLayer):
padding=0,
groups=1,
act="relu"):
super(ConvBNLayer, self).__init__()
super().__init__()
self.act = act
self.conv = Conv2D(
in_channels=num_channels,
@ -65,92 +63,100 @@ class ConvBNLayer(TheseusLayer):
padding=padding,
groups=groups,
bias_attr=False)
self.batch_norm = BatchNorm(
num_filters)
self.bn = BatchNorm(num_filters)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.batch_norm(x)
x = self.bn(x)
if self.act:
x = self.relu(x)
return x
class InceptionStem(TheseusLayer):
def __init__(self):
super(InceptionStem, self).__init__()
self.conv_1a_3x3 = ConvBNLayer(num_channels=3,
num_filters=32,
filter_size=3,
stride=2,
act="relu")
self.conv_2a_3x3 = ConvBNLayer(num_channels=32,
num_filters=32,
filter_size=3,
stride=1,
act="relu")
self.conv_2b_3x3 = ConvBNLayer(num_channels=32,
num_filters=64,
filter_size=3,
padding=1,
act="relu")
super().__init__()
self.conv_1a_3x3 = ConvBNLayer(
num_channels=3,
num_filters=32,
filter_size=3,
stride=2,
act="relu")
self.conv_2a_3x3 = ConvBNLayer(
num_channels=32,
num_filters=32,
filter_size=3,
stride=1,
act="relu")
self.conv_2b_3x3 = ConvBNLayer(
num_channels=32,
num_filters=64,
filter_size=3,
padding=1,
act="relu")
self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=0)
self.conv_3b_1x1 = ConvBNLayer(
num_channels=64, num_filters=80, filter_size=1, act="relu")
self.conv_4a_3x3 = ConvBNLayer(
num_channels=80, num_filters=192, filter_size=3, act="relu")
self.maxpool = MaxPool2D(kernel_size=3, stride=2, padding=0)
self.conv_3b_1x1 = ConvBNLayer(num_channels=64,
num_filters=80,
filter_size=1,
act="relu")
self.conv_4a_3x3 = ConvBNLayer(num_channels=80,
num_filters=192,
filter_size=3,
act="relu")
def forward(self, x):
x = self.conv_1a_3x3(x)
x = self.conv_2a_3x3(x)
x = self.conv_2b_3x3(x)
x = self.maxpool(x)
x = self.max_pool(x)
x = self.conv_3b_1x1(x)
x = self.conv_4a_3x3(x)
x = self.maxpool(x)
x = self.max_pool(x)
return x
class InceptionA(TheseusLayer):
def __init__(self, num_channels, pool_features):
super(InceptionA, self).__init__()
self.branch1x1 = ConvBNLayer(num_channels=num_channels,
num_filters=64,
filter_size=1,
act="relu")
self.branch5x5_1 = ConvBNLayer(num_channels=num_channels,
num_filters=48,
filter_size=1,
act="relu")
self.branch5x5_2 = ConvBNLayer(num_channels=48,
num_filters=64,
filter_size=5,
padding=2,
act="relu")
super().__init__()
self.branch1x1 = ConvBNLayer(
num_channels=num_channels,
num_filters=64,
filter_size=1,
act="relu")
self.branch5x5_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=48,
filter_size=1,
act="relu")
self.branch5x5_2 = ConvBNLayer(
num_channels=48,
num_filters=64,
filter_size=5,
padding=2,
act="relu")
self.branch3x3dbl_1 = ConvBNLayer(num_channels=num_channels,
num_filters=64,
filter_size=1,
act="relu")
self.branch3x3dbl_2 = ConvBNLayer(num_channels=64,
num_filters=96,
filter_size=3,
padding=1,
act="relu")
self.branch3x3dbl_3 = ConvBNLayer(num_channels=96,
num_filters=96,
filter_size=3,
padding=1,
act="relu")
self.branch_pool = AvgPool2D(kernel_size=3, stride=1, padding=1, exclusive=False)
self.branch_pool_conv = ConvBNLayer(num_channels=num_channels,
num_filters=pool_features,
filter_size=1,
act="relu")
self.branch3x3dbl_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=64,
filter_size=1,
act="relu")
self.branch3x3dbl_2 = ConvBNLayer(
num_channels=64,
num_filters=96,
filter_size=3,
padding=1,
act="relu")
self.branch3x3dbl_3 = ConvBNLayer(
num_channels=96,
num_filters=96,
filter_size=3,
padding=1,
act="relu")
self.branch_pool = AvgPool2D(
kernel_size=3, stride=1, padding=1, exclusive=False)
self.branch_pool_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=pool_features,
filter_size=1,
act="relu")
def forward(self, x):
branch1x1 = self.branch1x1(x)
@ -163,34 +169,39 @@ class InceptionA(TheseusLayer):
branch_pool = self.branch_pool(x)
branch_pool = self.branch_pool_conv(branch_pool)
x = paddle.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1)
x = paddle.concat(
[branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1)
return x
class InceptionB(TheseusLayer):
def __init__(self, num_channels):
super(InceptionB, self).__init__()
self.branch3x3 = ConvBNLayer(num_channels=num_channels,
num_filters=384,
filter_size=3,
stride=2,
act="relu")
self.branch3x3dbl_1 = ConvBNLayer(num_channels=num_channels,
num_filters=64,
filter_size=1,
act="relu")
self.branch3x3dbl_2 = ConvBNLayer(num_channels=64,
num_filters=96,
filter_size=3,
padding=1,
act="relu")
self.branch3x3dbl_3 = ConvBNLayer(num_channels=96,
num_filters=96,
filter_size=3,
stride=2,
act="relu")
super().__init__()
self.branch3x3 = ConvBNLayer(
num_channels=num_channels,
num_filters=384,
filter_size=3,
stride=2,
act="relu")
self.branch3x3dbl_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=64,
filter_size=1,
act="relu")
self.branch3x3dbl_2 = ConvBNLayer(
num_channels=64,
num_filters=96,
filter_size=3,
padding=1,
act="relu")
self.branch3x3dbl_3 = ConvBNLayer(
num_channels=96,
num_filters=96,
filter_size=3,
stride=2,
act="relu")
self.branch_pool = MaxPool2D(kernel_size=3, stride=2)
def forward(self, x):
branch3x3 = self.branch3x3(x)
@ -204,64 +215,75 @@ class InceptionB(TheseusLayer):
return x
class InceptionC(TheseusLayer):
def __init__(self, num_channels, channels_7x7):
super(InceptionC, self).__init__()
self.branch1x1 = ConvBNLayer(num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
super().__init__()
self.branch1x1 = ConvBNLayer(
num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
self.branch7x7_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=channels_7x7,
filter_size=1,
stride=1,
act="relu")
self.branch7x7_2 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(1, 7),
stride=1,
padding=(0, 3),
act="relu")
self.branch7x7_3 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=192,
filter_size=(7, 1),
stride=1,
padding=(3, 0),
act="relu")
self.branch7x7dbl_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=channels_7x7,
filter_size=1,
act="relu")
self.branch7x7dbl_2 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(7, 1),
padding=(3, 0),
act="relu")
self.branch7x7dbl_3 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(1, 7),
padding=(0, 3),
act="relu")
self.branch7x7dbl_4 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(7, 1),
padding=(3, 0),
act="relu")
self.branch7x7dbl_5 = ConvBNLayer(
num_channels=channels_7x7,
num_filters=192,
filter_size=(1, 7),
padding=(0, 3),
act="relu")
self.branch_pool = AvgPool2D(
kernel_size=3, stride=1, padding=1, exclusive=False)
self.branch_pool_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
self.branch7x7_1 = ConvBNLayer(num_channels=num_channels,
num_filters=channels_7x7,
filter_size=1,
stride=1,
act="relu")
self.branch7x7_2 = ConvBNLayer(num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(1, 7),
stride=1,
padding=(0, 3),
act="relu")
self.branch7x7_3 = ConvBNLayer(num_channels=channels_7x7,
num_filters=192,
filter_size=(7, 1),
stride=1,
padding=(3, 0),
act="relu")
self.branch7x7dbl_1 = ConvBNLayer(num_channels=num_channels,
num_filters=channels_7x7,
filter_size=1,
act="relu")
self.branch7x7dbl_2 = ConvBNLayer(num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(7, 1),
padding = (3, 0),
act="relu")
self.branch7x7dbl_3 = ConvBNLayer(num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(1, 7),
padding = (0, 3),
act="relu")
self.branch7x7dbl_4 = ConvBNLayer(num_channels=channels_7x7,
num_filters=channels_7x7,
filter_size=(7, 1),
padding = (3, 0),
act="relu")
self.branch7x7dbl_5 = ConvBNLayer(num_channels=channels_7x7,
num_filters=192,
filter_size=(1, 7),
padding = (0, 3),
act="relu")
self.branch_pool = AvgPool2D(kernel_size=3, stride=1, padding=1, exclusive=False)
self.branch_pool_conv = ConvBNLayer(num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
def forward(self, x):
branch1x1 = self.branch1x1(x)
@ -278,41 +300,49 @@ class InceptionC(TheseusLayer):
branch_pool = self.branch_pool(x)
branch_pool = self.branch_pool_conv(branch_pool)
x = paddle.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1)
x = paddle.concat(
[branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1)
return x
class InceptionD(TheseusLayer):
def __init__(self, num_channels):
super(InceptionD, self).__init__()
self.branch3x3_1 = ConvBNLayer(num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
self.branch3x3_2 = ConvBNLayer(num_channels=192,
num_filters=320,
filter_size=3,
stride=2,
act="relu")
self.branch7x7x3_1 = ConvBNLayer(num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
self.branch7x7x3_2 = ConvBNLayer(num_channels=192,
num_filters=192,
filter_size=(1, 7),
padding=(0, 3),
act="relu")
self.branch7x7x3_3 = ConvBNLayer(num_channels=192,
num_filters=192,
filter_size=(7, 1),
padding=(3, 0),
act="relu")
self.branch7x7x3_4 = ConvBNLayer(num_channels=192,
num_filters=192,
filter_size=3,
stride=2,
act="relu")
super().__init__()
self.branch3x3_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
self.branch3x3_2 = ConvBNLayer(
num_channels=192,
num_filters=320,
filter_size=3,
stride=2,
act="relu")
self.branch7x7x3_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
self.branch7x7x3_2 = ConvBNLayer(
num_channels=192,
num_filters=192,
filter_size=(1, 7),
padding=(0, 3),
act="relu")
self.branch7x7x3_3 = ConvBNLayer(
num_channels=192,
num_filters=192,
filter_size=(7, 1),
padding=(3, 0),
act="relu")
self.branch7x7x3_4 = ConvBNLayer(
num_channels=192,
num_filters=192,
filter_size=3,
stride=2,
act="relu")
self.branch_pool = MaxPool2D(kernel_size=3, stride=2)
def forward(self, x):
@ -325,56 +355,68 @@ class InceptionD(TheseusLayer):
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
branch_pool = self.branch_pool(x)
x = paddle.concat([branch3x3, branch7x7x3, branch_pool], axis=1)
return x
class InceptionE(TheseusLayer):
def __init__(self, num_channels):
super(InceptionE, self).__init__()
self.branch1x1 = ConvBNLayer(num_channels=num_channels,
num_filters=320,
filter_size=1,
act="relu")
self.branch3x3_1 = ConvBNLayer(num_channels=num_channels,
num_filters=384,
filter_size=1,
act="relu")
self.branch3x3_2a = ConvBNLayer(num_channels=384,
num_filters=384,
filter_size=(1, 3),
padding=(0, 1),
act="relu")
self.branch3x3_2b = ConvBNLayer(num_channels=384,
num_filters=384,
filter_size=(3, 1),
padding=(1, 0),
act="relu")
self.branch3x3dbl_1 = ConvBNLayer(num_channels=num_channels,
num_filters=448,
filter_size=1,
act="relu")
self.branch3x3dbl_2 = ConvBNLayer(num_channels=448,
num_filters=384,
filter_size=3,
padding=1,
act="relu")
self.branch3x3dbl_3a = ConvBNLayer(num_channels=384,
num_filters=384,
filter_size=(1, 3),
padding=(0, 1),
act="relu")
self.branch3x3dbl_3b = ConvBNLayer(num_channels=384,
num_filters=384,
filter_size=(3, 1),
padding=(1, 0),
act="relu")
self.branch_pool = AvgPool2D(kernel_size=3, stride=1, padding=1, exclusive=False)
self.branch_pool_conv = ConvBNLayer(num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
super().__init__()
self.branch1x1 = ConvBNLayer(
num_channels=num_channels,
num_filters=320,
filter_size=1,
act="relu")
self.branch3x3_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=384,
filter_size=1,
act="relu")
self.branch3x3_2a = ConvBNLayer(
num_channels=384,
num_filters=384,
filter_size=(1, 3),
padding=(0, 1),
act="relu")
self.branch3x3_2b = ConvBNLayer(
num_channels=384,
num_filters=384,
filter_size=(3, 1),
padding=(1, 0),
act="relu")
self.branch3x3dbl_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=448,
filter_size=1,
act="relu")
self.branch3x3dbl_2 = ConvBNLayer(
num_channels=448,
num_filters=384,
filter_size=3,
padding=1,
act="relu")
self.branch3x3dbl_3a = ConvBNLayer(
num_channels=384,
num_filters=384,
filter_size=(1, 3),
padding=(0, 1),
act="relu")
self.branch3x3dbl_3b = ConvBNLayer(
num_channels=384,
num_filters=384,
filter_size=(3, 1),
padding=(1, 0),
act="relu")
self.branch_pool = AvgPool2D(
kernel_size=3, stride=1, padding=1, exclusive=False)
self.branch_pool_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=192,
filter_size=1,
act="relu")
def forward(self, x):
branch1x1 = self.branch1x1(x)
@ -396,8 +438,9 @@ class InceptionE(TheseusLayer):
branch_pool = self.branch_pool(x)
branch_pool = self.branch_pool_conv(branch_pool)
x = paddle.concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1)
return x
x = paddle.concat(
[branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1)
return x
class Inception_V3(TheseusLayer):
@ -410,25 +453,21 @@ class Inception_V3(TheseusLayer):
Returns:
model: nn.Layer. Specific Inception_V3 model depends on args.
"""
def __init__(self,
config,
class_num=1000,
pretrained=False,
**kwargs):
super(Inception_V3, self).__init__()
self.inception_a_list = config['inception_a']
self.inception_c_list = config['inception_c']
self.inception_b_list = config['inception_b']
self.inception_d_list = config['inception_d']
self.inception_e_list = config ['inception_e']
self.pretrained = pretrained
def __init__(self, config, class_num=1000):
super().__init__()
self.inception_a_list = config["inception_a"]
self.inception_c_list = config["inception_c"]
self.inception_b_list = config["inception_b"]
self.inception_d_list = config["inception_d"]
self.inception_e_list = config["inception_e"]
self.inception_stem = InceptionStem()
self.inception_block_list = nn.LayerList()
for i in range(len(self.inception_a_list[0])):
inception_a = InceptionA(self.inception_a_list[0][i],
inception_a = InceptionA(self.inception_a_list[0][i],
self.inception_a_list[1][i])
self.inception_block_list.append(inception_a)
@ -437,7 +476,7 @@ class Inception_V3(TheseusLayer):
self.inception_block_list.append(inception_b)
for i in range(len(self.inception_c_list[0])):
inception_c = InceptionC(self.inception_c_list[0][i],
inception_c = InceptionC(self.inception_c_list[0][i],
self.inception_c_list[1][i])
self.inception_block_list.append(inception_c)
@ -448,21 +487,20 @@ class Inception_V3(TheseusLayer):
for i in range(len(self.inception_e_list)):
inception_e = InceptionE(self.inception_e_list[i])
self.inception_block_list.append(inception_e)
self.avg_pool = AdaptiveAvgPool2D(1)
self.dropout = Dropout(p=0.2, mode="downscale_in_infer")
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.fc = Linear(
2048,
class_num,
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv)),
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=ParamAttr())
def forward(self, x):
x = self.inception_stem(x)
for inception_block in self.inception_block_list:
x = inception_block(x)
x = inception_block(x)
x = self.avg_pool(x)
x = paddle.reshape(x, shape=[-1, 2048])
x = self.dropout(x)
@ -470,25 +508,29 @@ class Inception_V3(TheseusLayer):
return x
def InceptionV3(**kwargs):
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def InceptionV3(pretrained=False, use_ssld=False, **kwargs):
"""
InceptionV3
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
pretrained: bool=false or str. if `true` load pretrained parameters, `false` otherwise.
if str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `InceptionV3` model
"""
model = Inception_V3(NET_CONFIG, **kwargs)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["InceptionV3"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
_load_pretrained(pretrained, model, MODEL_URLS["InceptionV3"], use_ssld)
return model

View File

@ -14,8 +14,6 @@
from __future__ import absolute_import, division, print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
from paddle.nn import Conv2D, BatchNorm, Linear, ReLU, Flatten
@ -23,19 +21,22 @@ from paddle.nn import AdaptiveAvgPool2D
from paddle.nn.initializer import KaimingNormal
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain_from, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"MobileNetV1_x0_25": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_25_pretrained.pdparams",
"MobileNetV1_x0_5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_5_pretrained.pdparams",
"MobileNetV1_x0_75": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_75_pretrained.pdparams",
"MobileNetV1": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_pretrained.pdparams",
"MobileNetV1_x0_25":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_25_pretrained.pdparams",
"MobileNetV1_x0_5":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_5_pretrained.pdparams",
"MobileNetV1_x0_75":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_75_pretrained.pdparams",
"MobileNetV1":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_pretrained.pdparams"
}
__all__ = MODEL_URLS.keys()
class ConvBNLayer(TheseusLayer):
def __init__(self,
num_channels,
@ -44,7 +45,7 @@ class ConvBNLayer(TheseusLayer):
stride,
padding,
num_groups=1):
super(ConvBNLayer, self).__init__()
super().__init__()
self.conv = Conv2D(
in_channels=num_channels,
@ -55,9 +56,7 @@ class ConvBNLayer(TheseusLayer):
groups=num_groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self.bn = BatchNorm(num_filters)
self.relu = ReLU()
def forward(self, x):
@ -68,14 +67,9 @@ class ConvBNLayer(TheseusLayer):
class DepthwiseSeparable(TheseusLayer):
def __init__(self,
num_channels,
num_filters1,
num_filters2,
num_groups,
stride,
scale):
super(DepthwiseSeparable, self).__init__()
def __init__(self, num_channels, num_filters1, num_filters2, num_groups,
stride, scale):
super().__init__()
self.depthwise_conv = ConvBNLayer(
num_channels=num_channels,
@ -99,10 +93,18 @@ class DepthwiseSeparable(TheseusLayer):
class MobileNet(TheseusLayer):
def __init__(self, scale=1.0, class_num=1000, pretrained=False):
super(MobileNet, self).__init__()
"""
MobileNet
Args:
scale: float=1.0. The coefficient that controls the size of network parameters.
class_num: int=1000. The number of classes.
Returns:
model: nn.Layer. Specific MobileNet model depends on args.
"""
def __init__(self, scale=1.0, class_num=1000):
super().__init__()
self.scale = scale
self.pretrained = pretrained
self.conv = ConvBNLayer(
num_channels=3,
@ -110,30 +112,31 @@ class MobileNet(TheseusLayer):
num_filters=int(32 * scale),
stride=2,
padding=1)
#num_channels, num_filters1, num_filters2, num_groups, stride
self.cfg = [[int(32 * scale), 32, 64, 32, 1],
[int(64 * scale), 64, 128, 64, 2],
[int(128 * scale), 128, 128, 128, 1],
[int(128 * scale), 128, 256, 128, 2],
[int(256 * scale), 256, 256, 256, 1],
[int(256 * scale), 256, 512, 256, 2],
[int(512 * scale), 512, 512, 512, 1],
[int(512 * scale), 512, 512, 512, 1],
[int(512 * scale), 512, 512, 512, 1],
[int(512 * scale), 512, 512, 512, 1],
[int(512 * scale), 512, 512, 512, 1],
[int(512 * scale), 512, 1024, 512, 2],
self.cfg = [[int(32 * scale), 32, 64, 32, 1],
[int(64 * scale), 64, 128, 64, 2],
[int(128 * scale), 128, 128, 128, 1],
[int(128 * scale), 128, 256, 128, 2],
[int(256 * scale), 256, 256, 256, 1],
[int(256 * scale), 256, 512, 256, 2],
[int(512 * scale), 512, 512, 512, 1],
[int(512 * scale), 512, 512, 512, 1],
[int(512 * scale), 512, 512, 512, 1],
[int(512 * scale), 512, 512, 512, 1],
[int(512 * scale), 512, 512, 512, 1],
[int(512 * scale), 512, 1024, 512, 2],
[int(1024 * scale), 1024, 1024, 1024, 1]]
self.blocks = nn.Sequential(*[
DepthwiseSeparable(
num_channels=params[0],
num_filters1=params[1],
num_filters2=params[2],
num_groups=params[3],
stride=params[4],
scale=scale) for params in self.cfg])
DepthwiseSeparable(
num_channels=params[0],
num_filters1=params[1],
num_filters2=params[2],
num_groups=params[3],
stride=params[4],
scale=scale) for params in self.cfg
])
self.avg_pool = AdaptiveAvgPool2D(1)
self.flatten = Flatten(start_axis=1, stop_axis=-1)
@ -142,7 +145,7 @@ class MobileNet(TheseusLayer):
int(1024 * scale),
class_num,
weight_attr=ParamAttr(initializer=KaimingNormal()))
def forward(self, x):
x = self.conv(x)
x = self.blocks(x)
@ -152,91 +155,77 @@ class MobileNet(TheseusLayer):
return x
def MobileNetV1_x0_25(**args):
"""
MobileNetV1_x0_25
Args:
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
kwargs:
class_num: int=1000. Output dim of last fc layer.
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args.
"""
model = MobileNet(scale=0.25, **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1_x0_25"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
"pretrained type is not available. Please use `string` or `boolean` type."
)
def MobileNetV1_x0_25(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV1_x0_25
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args.
"""
model = MobileNet(scale=0.25, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1_x0_25"],
use_ssld)
return model
def MobileNetV1_x0_5(**args):
def MobileNetV1_x0_5(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV1_x0_5
Args:
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
kwargs:
class_num: int=1000. Output dim of last fc layer.
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args.
MobileNetV1_x0_5
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args.
"""
model = MobileNet(scale=0.5, **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1_x0_5"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = MobileNet(scale=0.5, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1_x0_5"],
use_ssld)
return model
def MobileNetV1_x0_75(**args):
def MobileNetV1_x0_75(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV1_x0_75
Args:
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
kwargs:
class_num: int=1000. Output dim of last fc layer.
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args.
MobileNetV1_x0_75
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args.
"""
model = MobileNet(scale=0.75, **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1_x0_75"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = MobileNet(scale=0.75, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1_x0_75"],
use_ssld)
return model
def MobileNetV1(**args):
def MobileNetV1(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV1
Args:
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
kwargs:
class_num: int=1000. Output dim of last fc layer.
Returns:
model: nn.Layer. Specific `MobileNetV1` model depends on args.
MobileNetV1
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV1` model depends on args.
"""
model = MobileNet(scale=1.0, **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = MobileNet(scale=1.0, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1"], use_ssld)
return model

View File

@ -0,0 +1,557 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import paddle
import paddle.nn as nn
from paddle import ParamAttr
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
from paddle.regularizer import L2Decay
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"MobileNetV3_small_x0_35":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_35_pretrained.pdparams",
"MobileNetV3_small_x0_5":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_5_pretrained.pdparams",
"MobileNetV3_small_x0_75":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_75_pretrained.pdparams",
"MobileNetV3_small_x1_0":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_0_pretrained.pdparams",
"MobileNetV3_small_x1_25":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_25_pretrained.pdparams",
"MobileNetV3_large_x0_35":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_35_pretrained.pdparams",
"MobileNetV3_large_x0_5":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_5_pretrained.pdparams",
"MobileNetV3_large_x0_75":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_75_pretrained.pdparams",
"MobileNetV3_large_x1_0":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x1_0_pretrained.pdparams",
"MobileNetV3_large_x1_25":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x1_25_pretrained.pdparams",
}
__all__ = MODEL_URLS.keys()
# "large", "small" is just for MobinetV3_large, MobileNetV3_small respectively.
# The type of "large" or "small" config is a list. Each element(list) represents a depthwise block, which is composed of k, exp, se, act, s.
# k: kernel_size
# exp: middle channel number in depthwise block
# c: output channel number in depthwise block
# se: whether to use SE block
# act: which activation to use
# s: stride in depthwise block
NET_CONFIG = {
"large": [
# k, exp, c, se, act, s
[3, 16, 16, False, "relu", 1],
[3, 64, 24, False, "relu", 2],
[3, 72, 24, False, "relu", 1],
[5, 72, 40, True, "relu", 2],
[5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1],
[3, 240, 80, False, "hardswish", 2],
[3, 200, 80, False, "hardswish", 1],
[3, 184, 80, False, "hardswish", 1],
[3, 184, 80, False, "hardswish", 1],
[3, 480, 112, True, "hardswish", 1],
[3, 672, 112, True, "hardswish", 1],
[5, 672, 160, True, "hardswish", 2],
[5, 960, 160, True, "hardswish", 1],
[5, 960, 160, True, "hardswish", 1],
],
"small": [
# k, exp, c, se, act, s
[3, 16, 16, True, "relu", 2],
[3, 72, 24, False, "relu", 2],
[3, 88, 24, False, "relu", 1],
[5, 96, 40, True, "hardswish", 2],
[5, 240, 40, True, "hardswish", 1],
[5, 240, 40, True, "hardswish", 1],
[5, 120, 48, True, "hardswish", 1],
[5, 144, 48, True, "hardswish", 1],
[5, 288, 96, True, "hardswish", 2],
[5, 576, 96, True, "hardswish", 1],
[5, 576, 96, True, "hardswish", 1],
]
}
# first conv output channel number in MobileNetV3
STEM_CONV_NUMBER = 16
# last second conv output channel for "small"
LAST_SECOND_CONV_SMALL = 576
# last second conv output channel for "large"
LAST_SECOND_CONV_LARGE = 960
# last conv output channel number for "large" and "small"
LAST_CONV = 1280
def _make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
def _create_act(act):
if act == "hardswish":
return nn.Hardswish()
elif act == "relu":
return nn.ReLU()
elif act is None:
return None
else:
raise RuntimeError(
"The activation function is not supported: {}".format(act))
class MobileNetV3(TheseusLayer):
"""
MobileNetV3
Args:
config: list. MobileNetV3 depthwise blocks config.
scale: float=1.0. The coefficient that controls the size of network parameters.
class_num: int=1000. The number of classes.
inplanes: int=16. The output channel number of first convolution layer.
class_squeeze: int=960. The output channel number of penultimate convolution layer.
class_expand: int=1280. The output channel number of last convolution layer.
dropout_prob: float=0.2. Probability of setting units to zero.
Returns:
model: nn.Layer. Specific MobileNetV3 model depends on args.
"""
def __init__(self,
config,
scale=1.0,
class_num=1000,
inplanes=STEM_CONV_NUMBER,
class_squeeze=LAST_SECOND_CONV_LARGE,
class_expand=LAST_CONV,
dropout_prob=0.2):
super().__init__()
self.cfg = config
self.scale = scale
self.inplanes = inplanes
self.class_squeeze = class_squeeze
self.class_expand = class_expand
self.class_num = class_num
self.conv = ConvBNLayer(
in_c=3,
out_c=_make_divisible(self.inplanes * self.scale),
filter_size=3,
stride=2,
padding=1,
num_groups=1,
if_act=True,
act="hardswish")
self.blocks = nn.Sequential(*[
ResidualUnit(
in_c=_make_divisible(self.inplanes * self.scale if i == 0 else
self.cfg[i - 1][2] * self.scale),
mid_c=_make_divisible(self.scale * exp),
out_c=_make_divisible(self.scale * c),
filter_size=k,
stride=s,
use_se=se,
act=act) for i, (k, exp, c, se, act, s) in enumerate(self.cfg)
])
self.last_second_conv = ConvBNLayer(
in_c=_make_divisible(self.cfg[-1][2] * self.scale),
out_c=_make_divisible(self.scale * self.class_squeeze),
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act="hardswish")
self.avg_pool = AdaptiveAvgPool2D(1)
self.last_conv = Conv2D(
in_channels=_make_divisible(self.scale * self.class_squeeze),
out_channels=self.class_expand,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.hardswish = nn.Hardswish()
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
self.fc = Linear(self.class_expand, class_num)
def forward(self, x):
x = self.conv(x)
x = self.blocks(x)
x = self.last_second_conv(x)
x = self.avg_pool(x)
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
x = self.flatten(x)
x = self.fc(x)
return x
class ConvBNLayer(TheseusLayer):
def __init__(self,
in_c,
out_c,
filter_size,
stride,
padding,
num_groups=1,
if_act=True,
act=None):
super().__init__()
self.conv = Conv2D(
in_channels=in_c,
out_channels=out_c,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
bias_attr=False)
self.bn = BatchNorm(
num_channels=out_c,
act=None,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.if_act = if_act
self.act = _create_act(act)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.if_act:
x = self.act(x)
return x
class ResidualUnit(TheseusLayer):
def __init__(self,
in_c,
mid_c,
out_c,
filter_size,
stride,
use_se,
act=None):
super().__init__()
self.if_shortcut = stride == 1 and in_c == out_c
self.if_se = use_se
self.expand_conv = ConvBNLayer(
in_c=in_c,
out_c=mid_c,
filter_size=1,
stride=1,
padding=0,
if_act=True,
act=act)
self.bottleneck_conv = ConvBNLayer(
in_c=mid_c,
out_c=mid_c,
filter_size=filter_size,
stride=stride,
padding=int((filter_size - 1) // 2),
num_groups=mid_c,
if_act=True,
act=act)
if self.if_se:
self.mid_se = SEModule(mid_c)
self.linear_conv = ConvBNLayer(
in_c=mid_c,
out_c=out_c,
filter_size=1,
stride=1,
padding=0,
if_act=False,
act=None)
def forward(self, x):
identity = x
x = self.expand_conv(x)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = paddle.add(identity, x)
return x
# nn.Hardsigmoid can't transfer "slope" and "offset" in nn.functional.hardsigmoid
class Hardsigmoid(TheseusLayer):
def __init__(self, slope=0.2, offset=0.5):
super().__init__()
self.slope = slope
self.offset = offset
def forward(self, x):
return nn.functional.hardsigmoid(
x, slope=self.slope, offset=self.offset)
class SEModule(TheseusLayer):
def __init__(self, channel, reduction=4):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0)
self.relu = nn.ReLU()
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0)
self.hardsigmoid = Hardsigmoid(slope=0.2, offset=0.5)
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.hardsigmoid(x)
return paddle.multiply(x=identity, y=x)
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def MobileNetV3_small_x0_35(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV3_small_x0_35
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x0_35` model depends on args.
"""
model = MobileNetV3(
config=NET_CONFIG["small"],
scale=0.35,
class_squeeze=LAST_SECOND_CONV_SMALL,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_35"],
use_ssld)
return model
def MobileNetV3_small_x0_5(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV3_small_x0_5
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x0_5` model depends on args.
"""
model = MobileNetV3(
config=NET_CONFIG["small"],
scale=0.5,
class_squeeze=LAST_SECOND_CONV_SMALL,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_5"],
use_ssld)
return model
def MobileNetV3_small_x0_75(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV3_small_x0_75
Args:
pretrained: bool=false or str. if `true` load pretrained parameters, `false` otherwise.
if str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x0_75` model depends on args.
"""
model = MobileNetV3(
config=NET_CONFIG["small"],
scale=0.75,
class_squeeze=LAST_SECOND_CONV_SMALL,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_75"],
use_ssld)
return model
def MobileNetV3_small_x1_0(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV3_small_x1_0
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x1_0` model depends on args.
"""
model = MobileNetV3(
config=NET_CONFIG["small"],
scale=1.0,
class_squeeze=LAST_SECOND_CONV_SMALL,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x1_0"],
use_ssld)
return model
def MobileNetV3_small_x1_25(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV3_small_x1_25
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x1_25` model depends on args.
"""
model = MobileNetV3(
config=NET_CONFIG["small"],
scale=1.25,
class_squeeze=LAST_SECOND_CONV_SMALL,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x1_25"],
use_ssld)
return model
def MobileNetV3_large_x0_35(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV3_large_x0_35
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x0_35` model depends on args.
"""
model = MobileNetV3(
config=NET_CONFIG["large"],
scale=0.35,
class_squeeze=LAST_SECOND_CONV_LARGE,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_35"],
use_ssld)
return model
def MobileNetV3_large_x0_5(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV3_large_x0_5
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x0_5` model depends on args.
"""
model = MobileNetV3(
config=NET_CONFIG["large"],
scale=0.5,
class_squeeze=LAST_SECOND_CONV_LARGE,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_5"],
use_ssld)
return model
def MobileNetV3_large_x0_75(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV3_large_x0_75
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x0_75` model depends on args.
"""
model = MobileNetV3(
config=NET_CONFIG["large"],
scale=0.75,
class_squeeze=LAST_SECOND_CONV_LARGE,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_75"],
use_ssld)
return model
def MobileNetV3_large_x1_0(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV3_large_x1_0
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x1_0` model depends on args.
"""
model = MobileNetV3(
config=NET_CONFIG["large"],
scale=1.0,
class_squeeze=LAST_SECOND_CONV_LARGE,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x1_0"],
use_ssld)
return model
def MobileNetV3_large_x1_25(pretrained=False, use_ssld=False, **kwargs):
"""
MobileNetV3_large_x1_25
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x1_25` model depends on args.
"""
model = MobileNetV3(
config=NET_CONFIG["large"],
scale=1.25,
class_squeeze=LAST_SECOND_CONV_LARGE,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x1_25"],
use_ssld)
return model

View File

@ -24,26 +24,34 @@ from paddle.nn.initializer import Uniform
import math
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain_from, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"ResNet18": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_pretrained.pdparams",
"ResNet18_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams",
"ResNet34": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_pretrained.pdparams",
"ResNet34_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_vd_pretrained.pdparams",
"ResNet50": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams",
"ResNet50_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams",
"ResNet101": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet101_pretrained.pdparams",
"ResNet101_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet101_vd_pretrained.pdparams",
"ResNet152": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet152_pretrained.pdparams",
"ResNet152_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet152_vd_pretrained.pdparams",
"ResNet200_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet200_vd_pretrained.pdparams",
"ResNet18":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams",
"ResNet18_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams",
"ResNet34":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_pretrained.pdparams",
"ResNet34_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_vd_pretrained.pdparams",
"ResNet50":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams",
"ResNet50_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_vd_pretrained.pdparams",
"ResNet101":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_pretrained.pdparams",
"ResNet101_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_pretrained.pdparams",
"ResNet152":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_pretrained.pdparams",
"ResNet152_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_vd_pretrained.pdparams",
"ResNet200_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet200_vd_pretrained.pdparams",
}
__all__ = MODEL_URLS.keys()
'''
ResNet config: dict.
key: depth of ResNet.
@ -55,17 +63,35 @@ ResNet config: dict.
'''
NET_CONFIG = {
"18": {
"block_type": "BasicBlock", "block_depth": [2, 2, 2, 2], "num_channels": [64, 64, 128, 256]},
"block_type": "BasicBlock",
"block_depth": [2, 2, 2, 2],
"num_channels": [64, 64, 128, 256]
},
"34": {
"block_type": "BasicBlock", "block_depth": [3, 4, 6, 3], "num_channels": [64, 64, 128, 256]},
"block_type": "BasicBlock",
"block_depth": [3, 4, 6, 3],
"num_channels": [64, 64, 128, 256]
},
"50": {
"block_type": "BottleneckBlock", "block_depth": [3, 4, 6, 3], "num_channels": [64, 256, 512, 1024]},
"block_type": "BottleneckBlock",
"block_depth": [3, 4, 6, 3],
"num_channels": [64, 256, 512, 1024]
},
"101": {
"block_type": "BottleneckBlock", "block_depth": [3, 4, 23, 3], "num_channels": [64, 256, 512, 1024]},
"block_type": "BottleneckBlock",
"block_depth": [3, 4, 23, 3],
"num_channels": [64, 256, 512, 1024]
},
"152": {
"block_type": "BottleneckBlock", "block_depth": [3, 8, 36, 3], "num_channels": [64, 256, 512, 1024]},
"block_type": "BottleneckBlock",
"block_depth": [3, 8, 36, 3],
"num_channels": [64, 256, 512, 1024]
},
"200": {
"block_type": "BottleneckBlock", "block_depth": [3, 12, 48, 3], "num_channels": [64, 256, 512, 1024]},
"block_type": "BottleneckBlock",
"block_depth": [3, 12, 48, 3],
"num_channels": [64, 256, 512, 1024]
},
}
@ -110,14 +136,14 @@ class ConvBNLayer(TheseusLayer):
class BottleneckBlock(TheseusLayer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
if_first=False,
lr_mult=1.0,
):
def __init__(
self,
num_channels,
num_filters,
stride,
shortcut=True,
if_first=False,
lr_mult=1.0, ):
super().__init__()
self.conv0 = ConvBNLayer(
@ -222,16 +248,15 @@ class ResNet(TheseusLayer):
version: str="vb". Different version of ResNet, version vd can perform better.
class_num: int=1000. The number of classes.
lr_mult_list: list. Control the learning rate of different stages.
pretrained: (True or False) or path of pretrained_model. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific ResNet model depends on args.
"""
def __init__(self,
config,
version="vb",
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
pretrained=False):
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
super().__init__()
self.cfg = config
@ -243,51 +268,46 @@ class ResNet(TheseusLayer):
self.block_type = self.cfg["block_type"]
self.num_channels = self.cfg["num_channels"]
self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
self.pretrained = pretrained
assert isinstance(self.lr_mult_list, (
list, tuple
)), "lr_mult_list should be in (list, tuple) but got {}".format(
type(self.lr_mult_list))
assert len(
self.lr_mult_list
) == 5, "lr_mult_list length should be 5 but got {}".format(
len(self.lr_mult_list))
assert len(self.lr_mult_list
) == 5, "lr_mult_list length should be 5 but got {}".format(
len(self.lr_mult_list))
self.stem_cfg = {
#num_channels, num_filters, filter_size, stride
"vb": [[3, 64, 7, 2]],
"vd": [[3, 32, 3, 2],
[32, 32, 3, 1],
[32, 64, 3, 1]]}
"vd": [[3, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
}
self.stem = nn.Sequential(*[
ConvBNLayer(
num_channels=in_c,
num_filters=out_c,
filter_size=k,
stride=s,
act="relu",
lr_mult=self.lr_mult_list[0])
num_channels=in_c,
num_filters=out_c,
filter_size=k,
stride=s,
act="relu",
lr_mult=self.lr_mult_list[0])
for in_c, out_c, k, s in self.stem_cfg[version]
])
self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
block_list = []
for block_idx in range(len(self.block_depth)):
shortcut = False
for i in range(self.block_depth[block_idx]):
block_list.append(
globals()[self.block_type](
num_channels=self.num_channels[block_idx]
if i == 0 else self.num_filters[block_idx] * self.channels_mult,
block_list.append(globals()[self.block_type](
num_channels=self.num_channels[block_idx] if i == 0 else
self.num_filters[block_idx] * self.channels_mult,
num_filters=self.num_filters[block_idx],
stride=2 if i == 0 and block_idx != 0 else 1,
shortcut=shortcut,
if_first=block_idx == i == 0 if version == "vd" else True,
lr_mult=self.lr_mult_list[block_idx + 1]))
shortcut = True
shortcut = True
self.blocks = nn.Sequential(*block_list)
self.avg_pool = AdaptiveAvgPool2D(1)
@ -297,8 +317,7 @@ class ResNet(TheseusLayer):
self.fc = Linear(
self.avg_pool_channels,
self.class_num,
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv)))
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
def forward(self, x):
x = self.stem(x)
@ -310,254 +329,179 @@ class ResNet(TheseusLayer):
return x
def ResNet18(**args):
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def ResNet18(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet18
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet18` model depends on args.
"""
model = ResNet(config=NET_CONFIG["18"], version="vb", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet18"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = ResNet(config=NET_CONFIG["18"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld)
return model
def ResNet18_vd(**args):
def ResNet18_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet18_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet18_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["18"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet18_vd"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = ResNet(config=NET_CONFIG["18"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18_vd"], use_ssld)
return model
def ResNet34(**args):
def ResNet34(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet34
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet34` model depends on args.
"""
model = ResNet(config=NET_CONFIG["34"], version="vb", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet34"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = ResNet(config=NET_CONFIG["34"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34"], use_ssld)
return model
def ResNet34_vd(**args):
def ResNet34_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet34_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet34_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["34"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet34_vd"], use_ssld=True)
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = ResNet(config=NET_CONFIG["34"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34_vd"], use_ssld)
return model
def ResNet50(**args):
def ResNet50(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet50
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet50` model depends on args.
"""
model = ResNet(config=NET_CONFIG["50"], version="vb", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet50"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = ResNet(config=NET_CONFIG["50"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
return model
def ResNet50_vd(**args):
def ResNet50_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet50_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet50_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["50"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet50_vd"], use_ssld=True)
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = ResNet(config=NET_CONFIG["50"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50_vd"], use_ssld)
return model
def ResNet101(**args):
def ResNet101(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet101
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool=False. Whether to load the pretrained model.
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet101` model depends on args.
"""
model = ResNet(config=NET_CONFIG["101"], version="vb", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet101"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = ResNet(config=NET_CONFIG["101"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101"], use_ssld)
return model
def ResNet101_vd(**args):
def ResNet101_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet101_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet101_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["101"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet101_vd"], use_ssld=True)
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = ResNet(config=NET_CONFIG["101"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101_vd"], use_ssld)
return model
def ResNet152(**args):
def ResNet152(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet152
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet152` model depends on args.
"""
model = ResNet(config=NET_CONFIG["152"], version="vb", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet152"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = ResNet(config=NET_CONFIG["152"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152"], use_ssld)
return model
def ResNet152_vd(**args):
def ResNet152_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet152_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet152_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["152"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet152_vd"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = ResNet(config=NET_CONFIG["152"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152_vd"], use_ssld)
return model
def ResNet200_vd(**args):
def ResNet200_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet200_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet200_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["200"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet200_vd"], use_ssld=True)
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
model = ResNet(config=NET_CONFIG["200"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet200_vd"], use_ssld)
return model

View File

@ -14,16 +14,24 @@
from __future__ import absolute_import, division, print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import MaxPool2D
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
__all__ = ["VGG11", "VGG13", "VGG16", "VGG19"]
MODEL_URLS = {
"VGG11":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG11_pretrained.pdparams",
"VGG13":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG13_pretrained.pdparams",
"VGG16":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG16_pretrained.pdparams",
"VGG19":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG19_pretrained.pdparams",
}
__all__ = MODEL_URLS.keys()
# VGG config
# key: VGG network depth
@ -36,68 +44,12 @@ NET_CONFIG = {
}
def VGG11(**args):
"""
VGG11
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
Returns:
model: nn.Layer. Specific `VGG11` model depends on args.
"""
model = VGGNet(config=NET_CONFIG[11], **args)
return model
def VGG13(**args):
"""
VGG13
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
Returns:
model: nn.Layer. Specific `VGG11` model depends on args.
"""
model = VGGNet(config=NET_CONFIG[13], **args)
return model
def VGG16(**args):
"""
VGG16
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
Returns:
model: nn.Layer. Specific `VGG11` model depends on args.
"""
model = VGGNet(config=NET_CONFIG[16], **args)
return model
def VGG19(**args):
"""
VGG19
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
Returns:
model: nn.Layer. Specific `VGG11` model depends on args.
"""
model = VGGNet(config=NET_CONFIG[19], **args)
return model
class ConvBlock(TheseusLayer):
def __init__(self, input_channels, output_channels, groups):
super(ConvBlock, self).__init__()
super().__init__()
self.groups = groups
self._conv_1 = Conv2D(
self.conv1 = Conv2D(
in_channels=input_channels,
out_channels=output_channels,
kernel_size=3,
@ -105,7 +57,7 @@ class ConvBlock(TheseusLayer):
padding=1,
bias_attr=False)
if groups == 2 or groups == 3 or groups == 4:
self._conv_2 = Conv2D(
self.conv2 = Conv2D(
in_channels=output_channels,
out_channels=output_channels,
kernel_size=3,
@ -113,7 +65,7 @@ class ConvBlock(TheseusLayer):
padding=1,
bias_attr=False)
if groups == 3 or groups == 4:
self._conv_3 = Conv2D(
self.conv3 = Conv2D(
in_channels=output_channels,
out_channels=output_channels,
kernel_size=3,
@ -121,7 +73,7 @@ class ConvBlock(TheseusLayer):
padding=1,
bias_attr=False)
if groups == 4:
self._conv_4 = Conv2D(
self.conv4 = Conv2D(
in_channels=output_channels,
out_channels=output_channels,
kernel_size=3,
@ -129,73 +81,148 @@ class ConvBlock(TheseusLayer):
padding=1,
bias_attr=False)
self._pool = MaxPool2D(kernel_size=2, stride=2, padding=0)
self._relu = nn.ReLU()
self.max_pool = MaxPool2D(kernel_size=2, stride=2, padding=0)
self.relu = nn.ReLU()
def forward(self, inputs):
x = self._conv_1(inputs)
x = self._relu(x)
x = self.conv1(inputs)
x = self.relu(x)
if self.groups == 2 or self.groups == 3 or self.groups == 4:
x = self._conv_2(x)
x = self._relu(x)
x = self.conv2(x)
x = self.relu(x)
if self.groups == 3 or self.groups == 4:
x = self._conv_3(x)
x = self._relu(x)
x = self.conv3(x)
x = self.relu(x)
if self.groups == 4:
x = self._conv_4(x)
x = self._relu(x)
x = self._pool(x)
x = self.conv4(x)
x = self.relu(x)
x = self.max_pool(x)
return x
class VGGNet(TheseusLayer):
def __init__(self,
config,
stop_grad_layers=0,
class_num=1000,
pretrained=False,
**args):
"""
VGGNet
Args:
config: list. VGGNet config.
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
class_num: int=1000. The number of classes.
Returns:
model: nn.Layer. Specific VGG model depends on args.
"""
def __init__(self, config, stop_grad_layers=0, class_num=1000):
super().__init__()
self.stop_grad_layers = stop_grad_layers
self._conv_block_1 = ConvBlock(3, 64, config[0])
self._conv_block_2 = ConvBlock(64, 128, config[1])
self._conv_block_3 = ConvBlock(128, 256, config[2])
self._conv_block_4 = ConvBlock(256, 512, config[3])
self._conv_block_5 = ConvBlock(512, 512, config[4])
self.conv_block_1 = ConvBlock(3, 64, config[0])
self.conv_block_2 = ConvBlock(64, 128, config[1])
self.conv_block_3 = ConvBlock(128, 256, config[2])
self.conv_block_4 = ConvBlock(256, 512, config[3])
self.conv_block_5 = ConvBlock(512, 512, config[4])
self._relu = nn.ReLU()
self._flatten = nn.Flatten(start_axis=1, stop_axis=-1)
self.relu = nn.ReLU()
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
for idx, block in enumerate([
self._conv_block_1, self._conv_block_2, self._conv_block_3,
self._conv_block_4, self._conv_block_5
self.conv_block_1, self.conv_block_2, self.conv_block_3,
self.conv_block_4, self.conv_block_5
]):
if self.stop_grad_layers >= idx + 1:
for param in block.parameters():
param.trainable = False
self._drop = Dropout(p=0.5, mode="downscale_in_infer")
self._fc1 = Linear(7 * 7 * 512, 4096)
self._fc2 = Linear(4096, 4096)
self._out = Linear(4096, class_num)
if pretrained is not None:
load_dygraph_pretrain(self, pretrained)
self.drop = Dropout(p=0.5, mode="downscale_in_infer")
self.fc1 = Linear(7 * 7 * 512, 4096)
self.fc2 = Linear(4096, 4096)
self.fc3 = Linear(4096, class_num)
def forward(self, inputs):
x = self._conv_block_1(inputs)
x = self._conv_block_2(x)
x = self._conv_block_3(x)
x = self._conv_block_4(x)
x = self._conv_block_5(x)
x = self._flatten(x)
x = self._fc1(x)
x = self._relu(x)
x = self._drop(x)
x = self._fc2(x)
x = self._relu(x)
x = self._drop(x)
x = self._out(x)
x = self.conv_block_1(inputs)
x = self.conv_block_2(x)
x = self.conv_block_3(x)
x = self.conv_block_4(x)
x = self.conv_block_5(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.drop(x)
x = self.fc2(x)
x = self.relu(x)
x = self.drop(x)
x = self.fc3(x)
return x
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def VGG11(pretrained=False, use_ssld=False, **kwargs):
"""
VGG11
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `VGG11` model depends on args.
"""
model = VGGNet(config=NET_CONFIG[11], **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["VGG11"], use_ssld)
return model
def VGG13(pretrained=False, use_ssld=False, **kwargs):
"""
VGG13
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `VGG13` model depends on args.
"""
model = VGGNet(config=NET_CONFIG[13], **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["VGG13"], use_ssld)
return model
def VGG16(pretrained=False, use_ssld=False, **kwargs):
"""
VGG16
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `VGG16` model depends on args.
"""
model = VGGNet(config=NET_CONFIG[16], **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["VGG16"], use_ssld)
return model
def VGG19(pretrained=False, use_ssld=False, **kwargs):
"""
VGG19
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `VGG19` model depends on args.
"""
model = VGGNet(config=NET_CONFIG[19], **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["VGG19"], use_ssld)
return model