legendary models v0.1
parent
4d246c20e4
commit
2fa808517d
|
@ -0,0 +1,6 @@
|
|||
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNet18_vd, ResNet34_vd, ResNet50_vd, ResNet101_vd, ResNet152_vd
|
||||
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W64_C
|
||||
from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1
|
||||
from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25
|
||||
from .inception_v3 import InceptionV3
|
||||
from .vgg import VGG11, VGG13, VGG16, VGG19
|
|
@ -24,29 +24,40 @@ from paddle.nn.functional import upsample
|
|||
from paddle.nn.initializer import Uniform
|
||||
|
||||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer, Identity
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
MODEL_URLS = {
|
||||
"HRNet_W18_C": "",
|
||||
"HRNet_W30_C": "",
|
||||
"HRNet_W32_C": "",
|
||||
"HRNet_W40_C": "",
|
||||
"HRNet_W44_C": "",
|
||||
"HRNet_W48_C": "",
|
||||
"HRNet_W60_C": "",
|
||||
"HRNet_W64_C": "",
|
||||
"SE_HRNet_W18_C": "",
|
||||
"SE_HRNet_W30_C": "",
|
||||
"SE_HRNet_W32_C": "",
|
||||
"SE_HRNet_W40_C": "",
|
||||
"SE_HRNet_W44_C": "",
|
||||
"SE_HRNet_W48_C": "",
|
||||
"SE_HRNet_W60_C": "",
|
||||
"SE_HRNet_W64_C": "",
|
||||
"HRNet_W18_C":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W18_C_pretrained.pdparams",
|
||||
"HRNet_W30_C":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W30_C_pretrained.pdparams",
|
||||
"HRNet_W32_C":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W32_C_pretrained.pdparams",
|
||||
"HRNet_W40_C":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W40_C_pretrained.pdparams",
|
||||
"HRNet_W44_C":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W44_C_pretrained.pdparams",
|
||||
"HRNet_W48_C":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W48_C_pretrained.pdparams",
|
||||
"HRNet_W64_C":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W64_C_pretrained.pdparams"
|
||||
}
|
||||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
|
||||
def _create_act(act):
|
||||
if act == "hardswish":
|
||||
return nn.Hardswish()
|
||||
elif act == "relu":
|
||||
return nn.ReLU()
|
||||
elif act is None:
|
||||
return Identity()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"The activation function is not supported: {}".format(act))
|
||||
|
||||
|
||||
class ConvBNLayer(TheseusLayer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
|
@ -55,7 +66,7 @@ class ConvBNLayer(TheseusLayer):
|
|||
stride=1,
|
||||
groups=1,
|
||||
act="relu"):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=num_channels,
|
||||
|
@ -65,10 +76,8 @@ class ConvBNLayer(TheseusLayer):
|
|||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
bias_attr=False)
|
||||
self.bn = nn.BatchNorm(
|
||||
num_filters,
|
||||
act=None)
|
||||
self.act = create_act(act)
|
||||
self.bn = nn.BatchNorm(num_filters, act=None)
|
||||
self.act = _create_act(act)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
|
@ -77,18 +86,6 @@ class ConvBNLayer(TheseusLayer):
|
|||
return x
|
||||
|
||||
|
||||
def create_act(act):
|
||||
if act == 'hardswish':
|
||||
return nn.Hardswish()
|
||||
elif act == 'relu':
|
||||
return nn.ReLU()
|
||||
elif act is None:
|
||||
return Identity()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'The activation function is not supported: {}'.format(act))
|
||||
|
||||
|
||||
class BottleneckBlock(TheseusLayer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
|
@ -96,7 +93,7 @@ class BottleneckBlock(TheseusLayer):
|
|||
has_se,
|
||||
stride=1,
|
||||
downsample=False):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.has_se = has_se
|
||||
self.downsample = downsample
|
||||
|
@ -147,11 +144,8 @@ class BottleneckBlock(TheseusLayer):
|
|||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
has_se=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
def __init__(self, num_channels, num_filters, has_se=False):
|
||||
super().__init__()
|
||||
|
||||
self.has_se = has_se
|
||||
|
||||
|
@ -190,9 +184,9 @@ class BasicBlock(nn.Layer):
|
|||
|
||||
class SELayer(TheseusLayer):
|
||||
def __init__(self, num_channels, num_filters, reduction_ratio):
|
||||
super(SELayer, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.pool2d_gap = nn.AdaptiveAvgPool2D(1)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||
|
||||
self._num_channels = num_channels
|
||||
|
||||
|
@ -201,8 +195,7 @@ class SELayer(TheseusLayer):
|
|||
self.fc_squeeze = nn.Linear(
|
||||
num_channels,
|
||||
med_ch,
|
||||
weight_attr=ParamAttr(
|
||||
initializer=Uniform(-stdv, stdv)))
|
||||
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
|
||||
self.relu = nn.ReLU()
|
||||
stdv = 1.0 / math.sqrt(med_ch * 1.0)
|
||||
self.fc_excitation = nn.Linear(
|
||||
|
@ -213,7 +206,7 @@ class SELayer(TheseusLayer):
|
|||
|
||||
def forward(self, x, res_dict=None):
|
||||
residual = x
|
||||
x = self.pool2d_gap(x)
|
||||
x = self.avg_pool(x)
|
||||
x = paddle.squeeze(x, axis=[2, 3])
|
||||
x = self.fc_squeeze(x)
|
||||
x = self.relu(x)
|
||||
|
@ -225,11 +218,8 @@ class SELayer(TheseusLayer):
|
|||
|
||||
|
||||
class Stage(TheseusLayer):
|
||||
def __init__(self,
|
||||
num_modules,
|
||||
num_filters,
|
||||
has_se=False):
|
||||
super(Stage, self).__init__()
|
||||
def __init__(self, num_modules, num_filters, has_se=False):
|
||||
super().__init__()
|
||||
|
||||
self._num_modules = num_modules
|
||||
|
||||
|
@ -237,8 +227,7 @@ class Stage(TheseusLayer):
|
|||
for i in range(num_modules):
|
||||
self.stage_func_list.append(
|
||||
HighResolutionModule(
|
||||
num_filters=num_filters,
|
||||
has_se=has_se))
|
||||
num_filters=num_filters, has_se=has_se))
|
||||
|
||||
def forward(self, x, res_dict=None):
|
||||
x = x
|
||||
|
@ -248,10 +237,8 @@ class Stage(TheseusLayer):
|
|||
|
||||
|
||||
class HighResolutionModule(TheseusLayer):
|
||||
def __init__(self,
|
||||
num_filters,
|
||||
has_se=False):
|
||||
super(HighResolutionModule, self).__init__()
|
||||
def __init__(self, num_filters, has_se=False):
|
||||
super().__init__()
|
||||
|
||||
self.basic_block_list = nn.LayerList()
|
||||
|
||||
|
@ -261,11 +248,11 @@ class HighResolutionModule(TheseusLayer):
|
|||
BasicBlock(
|
||||
num_channels=num_filters[i],
|
||||
num_filters=num_filters[i],
|
||||
has_se=has_se) for j in range(4)]))
|
||||
has_se=has_se) for j in range(4)
|
||||
]))
|
||||
|
||||
self.fuse_func = FuseLayers(
|
||||
in_channels=num_filters,
|
||||
out_channels=num_filters)
|
||||
in_channels=num_filters, out_channels=num_filters)
|
||||
|
||||
def forward(self, x, res_dict=None):
|
||||
out = []
|
||||
|
@ -279,10 +266,8 @@ class HighResolutionModule(TheseusLayer):
|
|||
|
||||
|
||||
class FuseLayers(TheseusLayer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels):
|
||||
super(FuseLayers, self).__init__()
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
self._actual_ch = len(in_channels)
|
||||
self._in_channels = in_channels
|
||||
|
@ -352,7 +337,7 @@ class LastClsOut(TheseusLayer):
|
|||
num_channel_list,
|
||||
has_se,
|
||||
num_filters_list=[32, 64, 128, 256]):
|
||||
super(LastClsOut, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.func_list = nn.LayerList()
|
||||
for idx in range(len(num_channel_list)):
|
||||
|
@ -378,9 +363,12 @@ class HRNet(TheseusLayer):
|
|||
width: int=18. Base channel number of HRNet.
|
||||
has_se: bool=False. If 'True', add se module to HRNet.
|
||||
class_num: int=1000. Output num of last fc layer.
|
||||
Returns:
|
||||
model: nn.Layer. Specific HRNet model depends on args.
|
||||
"""
|
||||
|
||||
def __init__(self, width=18, has_se=False, class_num=1000):
|
||||
super(HRNet, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.width = width
|
||||
self.has_se = has_se
|
||||
|
@ -388,21 +376,23 @@ class HRNet(TheseusLayer):
|
|||
|
||||
channels_2 = [self.width, self.width * 2]
|
||||
channels_3 = [self.width, self.width * 2, self.width * 4]
|
||||
channels_4 = [self.width, self.width * 2, self.width * 4, self.width * 8]
|
||||
channels_4 = [
|
||||
self.width, self.width * 2, self.width * 4, self.width * 8
|
||||
]
|
||||
|
||||
self.conv_layer1_1 = ConvBNLayer(
|
||||
num_channels=3,
|
||||
num_filters=64,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act='relu')
|
||||
act="relu")
|
||||
|
||||
self.conv_layer1_2 = ConvBNLayer(
|
||||
num_channels=64,
|
||||
num_filters=64,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act='relu')
|
||||
act="relu")
|
||||
|
||||
self.layer1 = nn.Sequential(*[
|
||||
BottleneckBlock(
|
||||
|
@ -410,48 +400,33 @@ class HRNet(TheseusLayer):
|
|||
num_filters=64,
|
||||
has_se=has_se,
|
||||
stride=1,
|
||||
downsample=True if i == 0 else False)
|
||||
for i in range(4)
|
||||
downsample=True if i == 0 else False) for i in range(4)
|
||||
])
|
||||
|
||||
self.conv_tr1_1 = ConvBNLayer(
|
||||
num_channels=256,
|
||||
num_filters=width,
|
||||
filter_size=3)
|
||||
num_channels=256, num_filters=width, filter_size=3)
|
||||
self.conv_tr1_2 = ConvBNLayer(
|
||||
num_channels=256,
|
||||
num_filters=width * 2,
|
||||
filter_size=3,
|
||||
stride=2
|
||||
)
|
||||
num_channels=256, num_filters=width * 2, filter_size=3, stride=2)
|
||||
|
||||
self.st2 = Stage(
|
||||
num_modules=1,
|
||||
num_filters=channels_2,
|
||||
has_se=self.has_se)
|
||||
num_modules=1, num_filters=channels_2, has_se=self.has_se)
|
||||
|
||||
self.conv_tr2 = ConvBNLayer(
|
||||
num_channels=width * 2,
|
||||
num_filters=width * 4,
|
||||
filter_size=3,
|
||||
stride=2
|
||||
)
|
||||
stride=2)
|
||||
self.st3 = Stage(
|
||||
num_modules=4,
|
||||
num_filters=channels_3,
|
||||
has_se=self.has_se)
|
||||
num_modules=4, num_filters=channels_3, has_se=self.has_se)
|
||||
|
||||
self.conv_tr3 = ConvBNLayer(
|
||||
num_channels=width * 4,
|
||||
num_filters=width * 8,
|
||||
filter_size=3,
|
||||
stride=2
|
||||
)
|
||||
stride=2)
|
||||
|
||||
self.st4 = Stage(
|
||||
num_modules=3,
|
||||
num_filters=channels_4,
|
||||
has_se=self.has_se)
|
||||
num_modules=3, num_filters=channels_4, has_se=self.has_se)
|
||||
|
||||
# classification
|
||||
num_filters_list = [32, 64, 128, 256]
|
||||
|
@ -464,17 +439,14 @@ class HRNet(TheseusLayer):
|
|||
self.cls_head_conv_list = nn.LayerList()
|
||||
for idx in range(3):
|
||||
self.cls_head_conv_list.append(
|
||||
ConvBNLayer(
|
||||
num_channels=num_filters_list[idx] * 4,
|
||||
num_filters=last_num_filters[idx],
|
||||
filter_size=3,
|
||||
stride=2))
|
||||
ConvBNLayer(
|
||||
num_channels=num_filters_list[idx] * 4,
|
||||
num_filters=last_num_filters[idx],
|
||||
filter_size=3,
|
||||
stride=2))
|
||||
|
||||
self.conv_last = ConvBNLayer(
|
||||
num_channels=1024,
|
||||
num_filters=2048,
|
||||
filter_size=1,
|
||||
stride=1)
|
||||
num_channels=1024, num_filters=2048, filter_size=1, stride=1)
|
||||
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||
|
||||
|
@ -516,81 +488,254 @@ class HRNet(TheseusLayer):
|
|||
return y
|
||||
|
||||
|
||||
def HRNet_W18_C(**args):
|
||||
model = HRNet(width=18, **args)
|
||||
def _load_pretrained(pretrained, model, model_url, use_ssld):
|
||||
if pretrained is False:
|
||||
pass
|
||||
elif pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||
elif isinstance(pretrained, str):
|
||||
load_dygraph_pretrain(model, pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||
)
|
||||
|
||||
|
||||
def HRNet_W18_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
HRNet_W18_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `HRNet_W18_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=18, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W18_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def HRNet_W30_C(**args):
|
||||
model = HRNet(width=30, **args)
|
||||
def HRNet_W30_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
HRNet_W30_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `HRNet_W30_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=30, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W30_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def HRNet_W32_C(**args):
|
||||
model = HRNet(width=32, **args)
|
||||
def HRNet_W32_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
HRNet_W32_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `HRNet_W32_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=32, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W32_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def HRNet_W40_C(**args):
|
||||
model = HRNet(width=40, **args)
|
||||
def HRNet_W40_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
HRNet_W40_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `HRNet_W40_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=40, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W40_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def HRNet_W44_C(**args):
|
||||
model = HRNet(width=44, **args)
|
||||
def HRNet_W44_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
HRNet_W44_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `HRNet_W44_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=44, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W44_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def HRNet_W48_C(**args):
|
||||
model = HRNet(width=48, **args)
|
||||
def HRNet_W48_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
HRNet_W48_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `HRNet_W48_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=48, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W48_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def HRNet_W60_C(**args):
|
||||
model = HRNet(width=60, **args)
|
||||
def HRNet_W60_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
HRNet_W60_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `HRNet_W60_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=60, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W60_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def HRNet_W64_C(**args):
|
||||
model = HRNet(width=64, **args)
|
||||
def HRNet_W64_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
HRNet_W64_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `HRNet_W64_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=64, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["HRNet_W64_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def SE_HRNet_W18_C(**args):
|
||||
model = HRNet(width=18, has_se=True, **args)
|
||||
def SE_HRNet_W18_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
SE_HRNet_W18_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `SE_HRNet_W18_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=18, has_se=True, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W18_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def SE_HRNet_W30_C(**args):
|
||||
model = HRNet(width=30, has_se=True, **args)
|
||||
def SE_HRNet_W30_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
SE_HRNet_W30_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `SE_HRNet_W30_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=30, has_se=True, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W30_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def SE_HRNet_W32_C(**args):
|
||||
model = HRNet(width=32, has_se=True, **args)
|
||||
def SE_HRNet_W32_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
SE_HRNet_W32_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `SE_HRNet_W32_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=32, has_se=True, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W32_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def SE_HRNet_W40_C(**args):
|
||||
model = HRNet(width=40, has_se=True, **args)
|
||||
def SE_HRNet_W40_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
SE_HRNet_W40_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `SE_HRNet_W40_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=40, has_se=True, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W40_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def SE_HRNet_W44_C(**args):
|
||||
model = HRNet(width=44, has_se=True, **args)
|
||||
def SE_HRNet_W44_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
SE_HRNet_W44_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `SE_HRNet_W44_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=44, has_se=True, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W44_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def SE_HRNet_W48_C(**args):
|
||||
model = HRNet(width=48, has_se=True, **args)
|
||||
def SE_HRNet_W48_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
SE_HRNet_W48_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `SE_HRNet_W48_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=48, has_se=True, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W48_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def SE_HRNet_W60_C(**args):
|
||||
model = HRNet(width=60, has_se=True, **args)
|
||||
def SE_HRNet_W60_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
SE_HRNet_W60_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `SE_HRNet_W60_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=60, has_se=True, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W60_C"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def SE_HRNet_W64_C(**args):
|
||||
model = HRNet(width=64, has_se=True, **args)
|
||||
def SE_HRNet_W64_C(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
SE_HRNet_W64_C
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `SE_HRNet_W64_C` model depends on args.
|
||||
"""
|
||||
model = HRNet(width=64, has_se=True, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["SE_HRNet_W64_C"], use_ssld)
|
||||
return model
|
||||
|
|
|
@ -13,39 +13,37 @@
|
|||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import Uniform
|
||||
import math
|
||||
|
||||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
|
||||
MODEL_URLS = {
|
||||
"InceptionV3": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/InceptionV3_pretrained.pdparams",
|
||||
"InceptionV3":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/InceptionV3_pretrained.pdparams"
|
||||
}
|
||||
|
||||
|
||||
__all__ = MODEL_URLS.keys()
|
||||
|
||||
'''
|
||||
InceptionV3 config: dict.
|
||||
key: inception blocks of InceptionV3.
|
||||
values: conv num in different blocks.
|
||||
'''
|
||||
NET_CONFIG = {
|
||||
'inception_a':[[192, 256, 288], [32, 64, 64]],
|
||||
'inception_b':[288],
|
||||
'inception_c':[[768, 768, 768, 768], [128, 160, 160, 192]],
|
||||
'inception_d':[768],
|
||||
'inception_e':[1280,2048]
|
||||
"inception_a": [[192, 256, 288], [32, 64, 64]],
|
||||
"inception_b": [288],
|
||||
"inception_c": [[768, 768, 768, 768], [128, 160, 160, 192]],
|
||||
"inception_d": [768],
|
||||
"inception_e": [1280, 2048]
|
||||
}
|
||||
|
||||
|
||||
class ConvBNLayer(TheseusLayer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
|
@ -55,7 +53,7 @@ class ConvBNLayer(TheseusLayer):
|
|||
padding=0,
|
||||
groups=1,
|
||||
act="relu"):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
super().__init__()
|
||||
self.act = act
|
||||
self.conv = Conv2D(
|
||||
in_channels=num_channels,
|
||||
|
@ -65,92 +63,100 @@ class ConvBNLayer(TheseusLayer):
|
|||
padding=padding,
|
||||
groups=groups,
|
||||
bias_attr=False)
|
||||
self.batch_norm = BatchNorm(
|
||||
num_filters)
|
||||
self.bn = BatchNorm(num_filters)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.batch_norm(x)
|
||||
x = self.bn(x)
|
||||
if self.act:
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class InceptionStem(TheseusLayer):
|
||||
def __init__(self):
|
||||
super(InceptionStem, self).__init__()
|
||||
self.conv_1a_3x3 = ConvBNLayer(num_channels=3,
|
||||
num_filters=32,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act="relu")
|
||||
self.conv_2a_3x3 = ConvBNLayer(num_channels=32,
|
||||
num_filters=32,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
act="relu")
|
||||
self.conv_2b_3x3 = ConvBNLayer(num_channels=32,
|
||||
num_filters=64,
|
||||
filter_size=3,
|
||||
padding=1,
|
||||
act="relu")
|
||||
super().__init__()
|
||||
self.conv_1a_3x3 = ConvBNLayer(
|
||||
num_channels=3,
|
||||
num_filters=32,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act="relu")
|
||||
self.conv_2a_3x3 = ConvBNLayer(
|
||||
num_channels=32,
|
||||
num_filters=32,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
act="relu")
|
||||
self.conv_2b_3x3 = ConvBNLayer(
|
||||
num_channels=32,
|
||||
num_filters=64,
|
||||
filter_size=3,
|
||||
padding=1,
|
||||
act="relu")
|
||||
|
||||
self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=0)
|
||||
self.conv_3b_1x1 = ConvBNLayer(
|
||||
num_channels=64, num_filters=80, filter_size=1, act="relu")
|
||||
self.conv_4a_3x3 = ConvBNLayer(
|
||||
num_channels=80, num_filters=192, filter_size=3, act="relu")
|
||||
|
||||
self.maxpool = MaxPool2D(kernel_size=3, stride=2, padding=0)
|
||||
self.conv_3b_1x1 = ConvBNLayer(num_channels=64,
|
||||
num_filters=80,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.conv_4a_3x3 = ConvBNLayer(num_channels=80,
|
||||
num_filters=192,
|
||||
filter_size=3,
|
||||
act="relu")
|
||||
def forward(self, x):
|
||||
x = self.conv_1a_3x3(x)
|
||||
x = self.conv_2a_3x3(x)
|
||||
x = self.conv_2b_3x3(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.max_pool(x)
|
||||
x = self.conv_3b_1x1(x)
|
||||
x = self.conv_4a_3x3(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.max_pool(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class InceptionA(TheseusLayer):
|
||||
def __init__(self, num_channels, pool_features):
|
||||
super(InceptionA, self).__init__()
|
||||
self.branch1x1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=64,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch5x5_1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=48,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch5x5_2 = ConvBNLayer(num_channels=48,
|
||||
num_filters=64,
|
||||
filter_size=5,
|
||||
padding=2,
|
||||
act="relu")
|
||||
super().__init__()
|
||||
self.branch1x1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=64,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch5x5_1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=48,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch5x5_2 = ConvBNLayer(
|
||||
num_channels=48,
|
||||
num_filters=64,
|
||||
filter_size=5,
|
||||
padding=2,
|
||||
act="relu")
|
||||
|
||||
self.branch3x3dbl_1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=64,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_2 = ConvBNLayer(num_channels=64,
|
||||
num_filters=96,
|
||||
filter_size=3,
|
||||
padding=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_3 = ConvBNLayer(num_channels=96,
|
||||
num_filters=96,
|
||||
filter_size=3,
|
||||
padding=1,
|
||||
act="relu")
|
||||
self.branch_pool = AvgPool2D(kernel_size=3, stride=1, padding=1, exclusive=False)
|
||||
self.branch_pool_conv = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=pool_features,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=64,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_2 = ConvBNLayer(
|
||||
num_channels=64,
|
||||
num_filters=96,
|
||||
filter_size=3,
|
||||
padding=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_3 = ConvBNLayer(
|
||||
num_channels=96,
|
||||
num_filters=96,
|
||||
filter_size=3,
|
||||
padding=1,
|
||||
act="relu")
|
||||
self.branch_pool = AvgPool2D(
|
||||
kernel_size=3, stride=1, padding=1, exclusive=False)
|
||||
self.branch_pool_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=pool_features,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
|
||||
def forward(self, x):
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
@ -163,34 +169,39 @@ class InceptionA(TheseusLayer):
|
|||
|
||||
branch_pool = self.branch_pool(x)
|
||||
branch_pool = self.branch_pool_conv(branch_pool)
|
||||
x = paddle.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1)
|
||||
x = paddle.concat(
|
||||
[branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class InceptionB(TheseusLayer):
|
||||
def __init__(self, num_channels):
|
||||
super(InceptionB, self).__init__()
|
||||
self.branch3x3 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=384,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act="relu")
|
||||
self.branch3x3dbl_1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=64,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_2 = ConvBNLayer(num_channels=64,
|
||||
num_filters=96,
|
||||
filter_size=3,
|
||||
padding=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_3 = ConvBNLayer(num_channels=96,
|
||||
num_filters=96,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act="relu")
|
||||
super().__init__()
|
||||
self.branch3x3 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=384,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act="relu")
|
||||
self.branch3x3dbl_1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=64,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_2 = ConvBNLayer(
|
||||
num_channels=64,
|
||||
num_filters=96,
|
||||
filter_size=3,
|
||||
padding=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_3 = ConvBNLayer(
|
||||
num_channels=96,
|
||||
num_filters=96,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act="relu")
|
||||
self.branch_pool = MaxPool2D(kernel_size=3, stride=2)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
branch3x3 = self.branch3x3(x)
|
||||
|
||||
|
@ -204,64 +215,75 @@ class InceptionB(TheseusLayer):
|
|||
|
||||
return x
|
||||
|
||||
|
||||
class InceptionC(TheseusLayer):
|
||||
def __init__(self, num_channels, channels_7x7):
|
||||
super(InceptionC, self).__init__()
|
||||
self.branch1x1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=192,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
super().__init__()
|
||||
self.branch1x1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=192,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
|
||||
self.branch7x7_1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
act="relu")
|
||||
self.branch7x7_2 = ConvBNLayer(
|
||||
num_channels=channels_7x7,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=(1, 7),
|
||||
stride=1,
|
||||
padding=(0, 3),
|
||||
act="relu")
|
||||
self.branch7x7_3 = ConvBNLayer(
|
||||
num_channels=channels_7x7,
|
||||
num_filters=192,
|
||||
filter_size=(7, 1),
|
||||
stride=1,
|
||||
padding=(3, 0),
|
||||
act="relu")
|
||||
|
||||
self.branch7x7dbl_1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch7x7dbl_2 = ConvBNLayer(
|
||||
num_channels=channels_7x7,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=(7, 1),
|
||||
padding=(3, 0),
|
||||
act="relu")
|
||||
self.branch7x7dbl_3 = ConvBNLayer(
|
||||
num_channels=channels_7x7,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=(1, 7),
|
||||
padding=(0, 3),
|
||||
act="relu")
|
||||
self.branch7x7dbl_4 = ConvBNLayer(
|
||||
num_channels=channels_7x7,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=(7, 1),
|
||||
padding=(3, 0),
|
||||
act="relu")
|
||||
self.branch7x7dbl_5 = ConvBNLayer(
|
||||
num_channels=channels_7x7,
|
||||
num_filters=192,
|
||||
filter_size=(1, 7),
|
||||
padding=(0, 3),
|
||||
act="relu")
|
||||
|
||||
self.branch_pool = AvgPool2D(
|
||||
kernel_size=3, stride=1, padding=1, exclusive=False)
|
||||
self.branch_pool_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=192,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
|
||||
self.branch7x7_1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
act="relu")
|
||||
self.branch7x7_2 = ConvBNLayer(num_channels=channels_7x7,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=(1, 7),
|
||||
stride=1,
|
||||
padding=(0, 3),
|
||||
act="relu")
|
||||
self.branch7x7_3 = ConvBNLayer(num_channels=channels_7x7,
|
||||
num_filters=192,
|
||||
filter_size=(7, 1),
|
||||
stride=1,
|
||||
padding=(3, 0),
|
||||
act="relu")
|
||||
|
||||
self.branch7x7dbl_1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch7x7dbl_2 = ConvBNLayer(num_channels=channels_7x7,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=(7, 1),
|
||||
padding = (3, 0),
|
||||
act="relu")
|
||||
self.branch7x7dbl_3 = ConvBNLayer(num_channels=channels_7x7,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=(1, 7),
|
||||
padding = (0, 3),
|
||||
act="relu")
|
||||
self.branch7x7dbl_4 = ConvBNLayer(num_channels=channels_7x7,
|
||||
num_filters=channels_7x7,
|
||||
filter_size=(7, 1),
|
||||
padding = (3, 0),
|
||||
act="relu")
|
||||
self.branch7x7dbl_5 = ConvBNLayer(num_channels=channels_7x7,
|
||||
num_filters=192,
|
||||
filter_size=(1, 7),
|
||||
padding = (0, 3),
|
||||
act="relu")
|
||||
|
||||
self.branch_pool = AvgPool2D(kernel_size=3, stride=1, padding=1, exclusive=False)
|
||||
self.branch_pool_conv = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=192,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
|
||||
def forward(self, x):
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
||||
|
@ -278,41 +300,49 @@ class InceptionC(TheseusLayer):
|
|||
branch_pool = self.branch_pool(x)
|
||||
branch_pool = self.branch_pool_conv(branch_pool)
|
||||
|
||||
x = paddle.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1)
|
||||
|
||||
x = paddle.concat(
|
||||
[branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class InceptionD(TheseusLayer):
|
||||
def __init__(self, num_channels):
|
||||
super(InceptionD, self).__init__()
|
||||
self.branch3x3_1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=192,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3_2 = ConvBNLayer(num_channels=192,
|
||||
num_filters=320,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act="relu")
|
||||
self.branch7x7x3_1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=192,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch7x7x3_2 = ConvBNLayer(num_channels=192,
|
||||
num_filters=192,
|
||||
filter_size=(1, 7),
|
||||
padding=(0, 3),
|
||||
act="relu")
|
||||
self.branch7x7x3_3 = ConvBNLayer(num_channels=192,
|
||||
num_filters=192,
|
||||
filter_size=(7, 1),
|
||||
padding=(3, 0),
|
||||
act="relu")
|
||||
self.branch7x7x3_4 = ConvBNLayer(num_channels=192,
|
||||
num_filters=192,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act="relu")
|
||||
super().__init__()
|
||||
self.branch3x3_1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=192,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3_2 = ConvBNLayer(
|
||||
num_channels=192,
|
||||
num_filters=320,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act="relu")
|
||||
self.branch7x7x3_1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=192,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch7x7x3_2 = ConvBNLayer(
|
||||
num_channels=192,
|
||||
num_filters=192,
|
||||
filter_size=(1, 7),
|
||||
padding=(0, 3),
|
||||
act="relu")
|
||||
self.branch7x7x3_3 = ConvBNLayer(
|
||||
num_channels=192,
|
||||
num_filters=192,
|
||||
filter_size=(7, 1),
|
||||
padding=(3, 0),
|
||||
act="relu")
|
||||
self.branch7x7x3_4 = ConvBNLayer(
|
||||
num_channels=192,
|
||||
num_filters=192,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act="relu")
|
||||
self.branch_pool = MaxPool2D(kernel_size=3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -325,56 +355,68 @@ class InceptionD(TheseusLayer):
|
|||
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
|
||||
|
||||
branch_pool = self.branch_pool(x)
|
||||
|
||||
|
||||
x = paddle.concat([branch3x3, branch7x7x3, branch_pool], axis=1)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class InceptionE(TheseusLayer):
|
||||
def __init__(self, num_channels):
|
||||
super(InceptionE, self).__init__()
|
||||
self.branch1x1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=320,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3_1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=384,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3_2a = ConvBNLayer(num_channels=384,
|
||||
num_filters=384,
|
||||
filter_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
act="relu")
|
||||
self.branch3x3_2b = ConvBNLayer(num_channels=384,
|
||||
num_filters=384,
|
||||
filter_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
act="relu")
|
||||
|
||||
self.branch3x3dbl_1 = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=448,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_2 = ConvBNLayer(num_channels=448,
|
||||
num_filters=384,
|
||||
filter_size=3,
|
||||
padding=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_3a = ConvBNLayer(num_channels=384,
|
||||
num_filters=384,
|
||||
filter_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
act="relu")
|
||||
self.branch3x3dbl_3b = ConvBNLayer(num_channels=384,
|
||||
num_filters=384,
|
||||
filter_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
act="relu")
|
||||
self.branch_pool = AvgPool2D(kernel_size=3, stride=1, padding=1, exclusive=False)
|
||||
self.branch_pool_conv = ConvBNLayer(num_channels=num_channels,
|
||||
num_filters=192,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
super().__init__()
|
||||
self.branch1x1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=320,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3_1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=384,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3_2a = ConvBNLayer(
|
||||
num_channels=384,
|
||||
num_filters=384,
|
||||
filter_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
act="relu")
|
||||
self.branch3x3_2b = ConvBNLayer(
|
||||
num_channels=384,
|
||||
num_filters=384,
|
||||
filter_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
act="relu")
|
||||
|
||||
self.branch3x3dbl_1 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=448,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_2 = ConvBNLayer(
|
||||
num_channels=448,
|
||||
num_filters=384,
|
||||
filter_size=3,
|
||||
padding=1,
|
||||
act="relu")
|
||||
self.branch3x3dbl_3a = ConvBNLayer(
|
||||
num_channels=384,
|
||||
num_filters=384,
|
||||
filter_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
act="relu")
|
||||
self.branch3x3dbl_3b = ConvBNLayer(
|
||||
num_channels=384,
|
||||
num_filters=384,
|
||||
filter_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
act="relu")
|
||||
self.branch_pool = AvgPool2D(
|
||||
kernel_size=3, stride=1, padding=1, exclusive=False)
|
||||
self.branch_pool_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=192,
|
||||
filter_size=1,
|
||||
act="relu")
|
||||
|
||||
def forward(self, x):
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
||||
|
@ -396,8 +438,9 @@ class InceptionE(TheseusLayer):
|
|||
branch_pool = self.branch_pool(x)
|
||||
branch_pool = self.branch_pool_conv(branch_pool)
|
||||
|
||||
x = paddle.concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1)
|
||||
return x
|
||||
x = paddle.concat(
|
||||
[branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1)
|
||||
return x
|
||||
|
||||
|
||||
class Inception_V3(TheseusLayer):
|
||||
|
@ -410,25 +453,21 @@ class Inception_V3(TheseusLayer):
|
|||
Returns:
|
||||
model: nn.Layer. Specific Inception_V3 model depends on args.
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
class_num=1000,
|
||||
pretrained=False,
|
||||
**kwargs):
|
||||
super(Inception_V3, self).__init__()
|
||||
|
||||
self.inception_a_list = config['inception_a']
|
||||
self.inception_c_list = config['inception_c']
|
||||
self.inception_b_list = config['inception_b']
|
||||
self.inception_d_list = config['inception_d']
|
||||
self.inception_e_list = config ['inception_e']
|
||||
self.pretrained = pretrained
|
||||
def __init__(self, config, class_num=1000):
|
||||
super().__init__()
|
||||
|
||||
self.inception_a_list = config["inception_a"]
|
||||
self.inception_c_list = config["inception_c"]
|
||||
self.inception_b_list = config["inception_b"]
|
||||
self.inception_d_list = config["inception_d"]
|
||||
self.inception_e_list = config["inception_e"]
|
||||
|
||||
self.inception_stem = InceptionStem()
|
||||
|
||||
self.inception_block_list = nn.LayerList()
|
||||
for i in range(len(self.inception_a_list[0])):
|
||||
inception_a = InceptionA(self.inception_a_list[0][i],
|
||||
inception_a = InceptionA(self.inception_a_list[0][i],
|
||||
self.inception_a_list[1][i])
|
||||
self.inception_block_list.append(inception_a)
|
||||
|
||||
|
@ -437,7 +476,7 @@ class Inception_V3(TheseusLayer):
|
|||
self.inception_block_list.append(inception_b)
|
||||
|
||||
for i in range(len(self.inception_c_list[0])):
|
||||
inception_c = InceptionC(self.inception_c_list[0][i],
|
||||
inception_c = InceptionC(self.inception_c_list[0][i],
|
||||
self.inception_c_list[1][i])
|
||||
self.inception_block_list.append(inception_c)
|
||||
|
||||
|
@ -448,21 +487,20 @@ class Inception_V3(TheseusLayer):
|
|||
for i in range(len(self.inception_e_list)):
|
||||
inception_e = InceptionE(self.inception_e_list[i])
|
||||
self.inception_block_list.append(inception_e)
|
||||
|
||||
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
self.dropout = Dropout(p=0.2, mode="downscale_in_infer")
|
||||
stdv = 1.0 / math.sqrt(2048 * 1.0)
|
||||
self.fc = Linear(
|
||||
2048,
|
||||
class_num,
|
||||
weight_attr=ParamAttr(
|
||||
initializer=Uniform(-stdv, stdv)),
|
||||
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr())
|
||||
|
||||
def forward(self, x):
|
||||
x = self.inception_stem(x)
|
||||
for inception_block in self.inception_block_list:
|
||||
x = inception_block(x)
|
||||
x = inception_block(x)
|
||||
x = self.avg_pool(x)
|
||||
x = paddle.reshape(x, shape=[-1, 2048])
|
||||
x = self.dropout(x)
|
||||
|
@ -470,25 +508,29 @@ class Inception_V3(TheseusLayer):
|
|||
return x
|
||||
|
||||
|
||||
def InceptionV3(**kwargs):
|
||||
def _load_pretrained(pretrained, model, model_url, use_ssld):
|
||||
if pretrained is False:
|
||||
pass
|
||||
elif pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||
elif isinstance(pretrained, str):
|
||||
load_dygraph_pretrain(model, pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||
)
|
||||
|
||||
|
||||
def InceptionV3(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
InceptionV3
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=false or str. if `true` load pretrained parameters, `false` otherwise.
|
||||
if str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `InceptionV3` model
|
||||
"""
|
||||
model = Inception_V3(NET_CONFIG, **kwargs)
|
||||
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["InceptionV3"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["InceptionV3"], use_ssld)
|
||||
return model
|
||||
|
||||
|
|
|
@ -14,8 +14,6 @@
|
|||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, ReLU, Flatten
|
||||
|
@ -23,19 +21,22 @@ from paddle.nn import AdaptiveAvgPool2D
|
|||
from paddle.nn.initializer import KaimingNormal
|
||||
|
||||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain_from, load_dygraph_pretrain_from_url
|
||||
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
MODEL_URLS = {
|
||||
"MobileNetV1_x0_25": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_25_pretrained.pdparams",
|
||||
"MobileNetV1_x0_5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_5_pretrained.pdparams",
|
||||
"MobileNetV1_x0_75": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_75_pretrained.pdparams",
|
||||
"MobileNetV1": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_pretrained.pdparams",
|
||||
"MobileNetV1_x0_25":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_25_pretrained.pdparams",
|
||||
"MobileNetV1_x0_5":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_5_pretrained.pdparams",
|
||||
"MobileNetV1_x0_75":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_75_pretrained.pdparams",
|
||||
"MobileNetV1":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_pretrained.pdparams"
|
||||
}
|
||||
|
||||
__all__ = MODEL_URLS.keys()
|
||||
|
||||
|
||||
|
||||
|
||||
class ConvBNLayer(TheseusLayer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
|
@ -44,7 +45,7 @@ class ConvBNLayer(TheseusLayer):
|
|||
stride,
|
||||
padding,
|
||||
num_groups=1):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.conv = Conv2D(
|
||||
in_channels=num_channels,
|
||||
|
@ -55,9 +56,7 @@ class ConvBNLayer(TheseusLayer):
|
|||
groups=num_groups,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = BatchNorm(num_filters)
|
||||
|
||||
self.relu = ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -68,14 +67,9 @@ class ConvBNLayer(TheseusLayer):
|
|||
|
||||
|
||||
class DepthwiseSeparable(TheseusLayer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters1,
|
||||
num_filters2,
|
||||
num_groups,
|
||||
stride,
|
||||
scale):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
def __init__(self, num_channels, num_filters1, num_filters2, num_groups,
|
||||
stride, scale):
|
||||
super().__init__()
|
||||
|
||||
self.depthwise_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
|
@ -99,10 +93,18 @@ class DepthwiseSeparable(TheseusLayer):
|
|||
|
||||
|
||||
class MobileNet(TheseusLayer):
|
||||
def __init__(self, scale=1.0, class_num=1000, pretrained=False):
|
||||
super(MobileNet, self).__init__()
|
||||
"""
|
||||
MobileNet
|
||||
Args:
|
||||
scale: float=1.0. The coefficient that controls the size of network parameters.
|
||||
class_num: int=1000. The number of classes.
|
||||
Returns:
|
||||
model: nn.Layer. Specific MobileNet model depends on args.
|
||||
"""
|
||||
|
||||
def __init__(self, scale=1.0, class_num=1000):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.pretrained = pretrained
|
||||
|
||||
self.conv = ConvBNLayer(
|
||||
num_channels=3,
|
||||
|
@ -110,30 +112,31 @@ class MobileNet(TheseusLayer):
|
|||
num_filters=int(32 * scale),
|
||||
stride=2,
|
||||
padding=1)
|
||||
|
||||
|
||||
#num_channels, num_filters1, num_filters2, num_groups, stride
|
||||
self.cfg = [[int(32 * scale), 32, 64, 32, 1],
|
||||
[int(64 * scale), 64, 128, 64, 2],
|
||||
[int(128 * scale), 128, 128, 128, 1],
|
||||
[int(128 * scale), 128, 256, 128, 2],
|
||||
[int(256 * scale), 256, 256, 256, 1],
|
||||
[int(256 * scale), 256, 512, 256, 2],
|
||||
[int(512 * scale), 512, 512, 512, 1],
|
||||
[int(512 * scale), 512, 512, 512, 1],
|
||||
[int(512 * scale), 512, 512, 512, 1],
|
||||
[int(512 * scale), 512, 512, 512, 1],
|
||||
[int(512 * scale), 512, 512, 512, 1],
|
||||
[int(512 * scale), 512, 1024, 512, 2],
|
||||
self.cfg = [[int(32 * scale), 32, 64, 32, 1],
|
||||
[int(64 * scale), 64, 128, 64, 2],
|
||||
[int(128 * scale), 128, 128, 128, 1],
|
||||
[int(128 * scale), 128, 256, 128, 2],
|
||||
[int(256 * scale), 256, 256, 256, 1],
|
||||
[int(256 * scale), 256, 512, 256, 2],
|
||||
[int(512 * scale), 512, 512, 512, 1],
|
||||
[int(512 * scale), 512, 512, 512, 1],
|
||||
[int(512 * scale), 512, 512, 512, 1],
|
||||
[int(512 * scale), 512, 512, 512, 1],
|
||||
[int(512 * scale), 512, 512, 512, 1],
|
||||
[int(512 * scale), 512, 1024, 512, 2],
|
||||
[int(1024 * scale), 1024, 1024, 1024, 1]]
|
||||
|
||||
|
||||
self.blocks = nn.Sequential(*[
|
||||
DepthwiseSeparable(
|
||||
num_channels=params[0],
|
||||
num_filters1=params[1],
|
||||
num_filters2=params[2],
|
||||
num_groups=params[3],
|
||||
stride=params[4],
|
||||
scale=scale) for params in self.cfg])
|
||||
DepthwiseSeparable(
|
||||
num_channels=params[0],
|
||||
num_filters1=params[1],
|
||||
num_filters2=params[2],
|
||||
num_groups=params[3],
|
||||
stride=params[4],
|
||||
scale=scale) for params in self.cfg
|
||||
])
|
||||
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
self.flatten = Flatten(start_axis=1, stop_axis=-1)
|
||||
|
@ -142,7 +145,7 @@ class MobileNet(TheseusLayer):
|
|||
int(1024 * scale),
|
||||
class_num,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()))
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.blocks(x)
|
||||
|
@ -152,91 +155,77 @@ class MobileNet(TheseusLayer):
|
|||
return x
|
||||
|
||||
|
||||
def MobileNetV1_x0_25(**args):
|
||||
"""
|
||||
MobileNetV1_x0_25
|
||||
Args:
|
||||
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args.
|
||||
"""
|
||||
model = MobileNet(scale=0.25, **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1_x0_25"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
def _load_pretrained(pretrained, model, model_url, use_ssld):
|
||||
if pretrained is False:
|
||||
pass
|
||||
elif pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||
elif isinstance(pretrained, str):
|
||||
load_dygraph_pretrain(model, pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||
)
|
||||
|
||||
|
||||
def MobileNetV1_x0_25(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV1_x0_25
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args.
|
||||
"""
|
||||
model = MobileNet(scale=0.25, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1_x0_25"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV1_x0_5(**args):
|
||||
def MobileNetV1_x0_5(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV1_x0_5
|
||||
Args:
|
||||
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args.
|
||||
MobileNetV1_x0_5
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args.
|
||||
"""
|
||||
model = MobileNet(scale=0.5, **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1_x0_5"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = MobileNet(scale=0.5, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1_x0_5"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV1_x0_75(**args):
|
||||
def MobileNetV1_x0_75(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV1_x0_75
|
||||
Args:
|
||||
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args.
|
||||
MobileNetV1_x0_75
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args.
|
||||
"""
|
||||
model = MobileNet(scale=0.75, **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1_x0_75"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = MobileNet(scale=0.75, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1_x0_75"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV1(**args):
|
||||
def MobileNetV1(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV1
|
||||
Args:
|
||||
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1` model depends on args.
|
||||
MobileNetV1
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1` model depends on args.
|
||||
"""
|
||||
model = MobileNet(scale=1.0, **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["MobileNetV1"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = MobileNet(scale=1.0, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV1"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,557 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle import ParamAttr
|
||||
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
|
||||
from paddle.regularizer import L2Decay
|
||||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
MODEL_URLS = {
|
||||
"MobileNetV3_small_x0_35":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_35_pretrained.pdparams",
|
||||
"MobileNetV3_small_x0_5":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_5_pretrained.pdparams",
|
||||
"MobileNetV3_small_x0_75":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_75_pretrained.pdparams",
|
||||
"MobileNetV3_small_x1_0":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_0_pretrained.pdparams",
|
||||
"MobileNetV3_small_x1_25":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_25_pretrained.pdparams",
|
||||
"MobileNetV3_large_x0_35":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_35_pretrained.pdparams",
|
||||
"MobileNetV3_large_x0_5":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_5_pretrained.pdparams",
|
||||
"MobileNetV3_large_x0_75":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_75_pretrained.pdparams",
|
||||
"MobileNetV3_large_x1_0":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x1_0_pretrained.pdparams",
|
||||
"MobileNetV3_large_x1_25":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x1_25_pretrained.pdparams",
|
||||
}
|
||||
|
||||
__all__ = MODEL_URLS.keys()
|
||||
|
||||
# "large", "small" is just for MobinetV3_large, MobileNetV3_small respectively.
|
||||
# The type of "large" or "small" config is a list. Each element(list) represents a depthwise block, which is composed of k, exp, se, act, s.
|
||||
# k: kernel_size
|
||||
# exp: middle channel number in depthwise block
|
||||
# c: output channel number in depthwise block
|
||||
# se: whether to use SE block
|
||||
# act: which activation to use
|
||||
# s: stride in depthwise block
|
||||
NET_CONFIG = {
|
||||
"large": [
|
||||
# k, exp, c, se, act, s
|
||||
[3, 16, 16, False, "relu", 1],
|
||||
[3, 64, 24, False, "relu", 2],
|
||||
[3, 72, 24, False, "relu", 1],
|
||||
[5, 72, 40, True, "relu", 2],
|
||||
[5, 120, 40, True, "relu", 1],
|
||||
[5, 120, 40, True, "relu", 1],
|
||||
[3, 240, 80, False, "hardswish", 2],
|
||||
[3, 200, 80, False, "hardswish", 1],
|
||||
[3, 184, 80, False, "hardswish", 1],
|
||||
[3, 184, 80, False, "hardswish", 1],
|
||||
[3, 480, 112, True, "hardswish", 1],
|
||||
[3, 672, 112, True, "hardswish", 1],
|
||||
[5, 672, 160, True, "hardswish", 2],
|
||||
[5, 960, 160, True, "hardswish", 1],
|
||||
[5, 960, 160, True, "hardswish", 1],
|
||||
],
|
||||
"small": [
|
||||
# k, exp, c, se, act, s
|
||||
[3, 16, 16, True, "relu", 2],
|
||||
[3, 72, 24, False, "relu", 2],
|
||||
[3, 88, 24, False, "relu", 1],
|
||||
[5, 96, 40, True, "hardswish", 2],
|
||||
[5, 240, 40, True, "hardswish", 1],
|
||||
[5, 240, 40, True, "hardswish", 1],
|
||||
[5, 120, 48, True, "hardswish", 1],
|
||||
[5, 144, 48, True, "hardswish", 1],
|
||||
[5, 288, 96, True, "hardswish", 2],
|
||||
[5, 576, 96, True, "hardswish", 1],
|
||||
[5, 576, 96, True, "hardswish", 1],
|
||||
]
|
||||
}
|
||||
# first conv output channel number in MobileNetV3
|
||||
STEM_CONV_NUMBER = 16
|
||||
# last second conv output channel for "small"
|
||||
LAST_SECOND_CONV_SMALL = 576
|
||||
# last second conv output channel for "large"
|
||||
LAST_SECOND_CONV_LARGE = 960
|
||||
# last conv output channel number for "large" and "small"
|
||||
LAST_CONV = 1280
|
||||
|
||||
|
||||
def _make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def _create_act(act):
|
||||
if act == "hardswish":
|
||||
return nn.Hardswish()
|
||||
elif act == "relu":
|
||||
return nn.ReLU()
|
||||
elif act is None:
|
||||
return None
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"The activation function is not supported: {}".format(act))
|
||||
|
||||
|
||||
class MobileNetV3(TheseusLayer):
|
||||
"""
|
||||
MobileNetV3
|
||||
Args:
|
||||
config: list. MobileNetV3 depthwise blocks config.
|
||||
scale: float=1.0. The coefficient that controls the size of network parameters.
|
||||
class_num: int=1000. The number of classes.
|
||||
inplanes: int=16. The output channel number of first convolution layer.
|
||||
class_squeeze: int=960. The output channel number of penultimate convolution layer.
|
||||
class_expand: int=1280. The output channel number of last convolution layer.
|
||||
dropout_prob: float=0.2. Probability of setting units to zero.
|
||||
Returns:
|
||||
model: nn.Layer. Specific MobileNetV3 model depends on args.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
scale=1.0,
|
||||
class_num=1000,
|
||||
inplanes=STEM_CONV_NUMBER,
|
||||
class_squeeze=LAST_SECOND_CONV_LARGE,
|
||||
class_expand=LAST_CONV,
|
||||
dropout_prob=0.2):
|
||||
super().__init__()
|
||||
|
||||
self.cfg = config
|
||||
self.scale = scale
|
||||
self.inplanes = inplanes
|
||||
self.class_squeeze = class_squeeze
|
||||
self.class_expand = class_expand
|
||||
self.class_num = class_num
|
||||
|
||||
self.conv = ConvBNLayer(
|
||||
in_c=3,
|
||||
out_c=_make_divisible(self.inplanes * self.scale),
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
num_groups=1,
|
||||
if_act=True,
|
||||
act="hardswish")
|
||||
|
||||
self.blocks = nn.Sequential(*[
|
||||
ResidualUnit(
|
||||
in_c=_make_divisible(self.inplanes * self.scale if i == 0 else
|
||||
self.cfg[i - 1][2] * self.scale),
|
||||
mid_c=_make_divisible(self.scale * exp),
|
||||
out_c=_make_divisible(self.scale * c),
|
||||
filter_size=k,
|
||||
stride=s,
|
||||
use_se=se,
|
||||
act=act) for i, (k, exp, c, se, act, s) in enumerate(self.cfg)
|
||||
])
|
||||
|
||||
self.last_second_conv = ConvBNLayer(
|
||||
in_c=_make_divisible(self.cfg[-1][2] * self.scale),
|
||||
out_c=_make_divisible(self.scale * self.class_squeeze),
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
num_groups=1,
|
||||
if_act=True,
|
||||
act="hardswish")
|
||||
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
|
||||
self.last_conv = Conv2D(
|
||||
in_channels=_make_divisible(self.scale * self.class_squeeze),
|
||||
out_channels=self.class_expand,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias_attr=False)
|
||||
|
||||
self.hardswish = nn.Hardswish()
|
||||
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
|
||||
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
|
||||
|
||||
self.fc = Linear(self.class_expand, class_num)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.blocks(x)
|
||||
x = self.last_second_conv(x)
|
||||
x = self.avg_pool(x)
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNLayer(TheseusLayer):
|
||||
def __init__(self,
|
||||
in_c,
|
||||
out_c,
|
||||
filter_size,
|
||||
stride,
|
||||
padding,
|
||||
num_groups=1,
|
||||
if_act=True,
|
||||
act=None):
|
||||
super().__init__()
|
||||
|
||||
self.conv = Conv2D(
|
||||
in_channels=in_c,
|
||||
out_channels=out_c,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
bias_attr=False)
|
||||
self.bn = BatchNorm(
|
||||
num_channels=out_c,
|
||||
act=None,
|
||||
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
self.if_act = if_act
|
||||
self.act = _create_act(act)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
if self.if_act:
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualUnit(TheseusLayer):
|
||||
def __init__(self,
|
||||
in_c,
|
||||
mid_c,
|
||||
out_c,
|
||||
filter_size,
|
||||
stride,
|
||||
use_se,
|
||||
act=None):
|
||||
super().__init__()
|
||||
self.if_shortcut = stride == 1 and in_c == out_c
|
||||
self.if_se = use_se
|
||||
|
||||
self.expand_conv = ConvBNLayer(
|
||||
in_c=in_c,
|
||||
out_c=mid_c,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=True,
|
||||
act=act)
|
||||
self.bottleneck_conv = ConvBNLayer(
|
||||
in_c=mid_c,
|
||||
out_c=mid_c,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=int((filter_size - 1) // 2),
|
||||
num_groups=mid_c,
|
||||
if_act=True,
|
||||
act=act)
|
||||
if self.if_se:
|
||||
self.mid_se = SEModule(mid_c)
|
||||
self.linear_conv = ConvBNLayer(
|
||||
in_c=mid_c,
|
||||
out_c=out_c,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=False,
|
||||
act=None)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.expand_conv(x)
|
||||
x = self.bottleneck_conv(x)
|
||||
if self.if_se:
|
||||
x = self.mid_se(x)
|
||||
x = self.linear_conv(x)
|
||||
if self.if_shortcut:
|
||||
x = paddle.add(identity, x)
|
||||
return x
|
||||
|
||||
|
||||
# nn.Hardsigmoid can't transfer "slope" and "offset" in nn.functional.hardsigmoid
|
||||
class Hardsigmoid(TheseusLayer):
|
||||
def __init__(self, slope=0.2, offset=0.5):
|
||||
super().__init__()
|
||||
self.slope = slope
|
||||
self.offset = offset
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.hardsigmoid(
|
||||
x, slope=self.slope, offset=self.offset)
|
||||
|
||||
|
||||
class SEModule(TheseusLayer):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super().__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=channel,
|
||||
out_channels=channel // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=channel // reduction,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.hardsigmoid = Hardsigmoid(slope=0.2, offset=0.5)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.hardsigmoid(x)
|
||||
return paddle.multiply(x=identity, y=x)
|
||||
|
||||
|
||||
def _load_pretrained(pretrained, model, model_url, use_ssld):
|
||||
if pretrained is False:
|
||||
pass
|
||||
elif pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||
elif isinstance(pretrained, str):
|
||||
load_dygraph_pretrain(model, pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||
)
|
||||
|
||||
|
||||
def MobileNetV3_small_x0_35(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV3_small_x0_35
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV3_small_x0_35` model depends on args.
|
||||
"""
|
||||
model = MobileNetV3(
|
||||
config=NET_CONFIG["small"],
|
||||
scale=0.35,
|
||||
class_squeeze=LAST_SECOND_CONV_SMALL,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_35"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_small_x0_5(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV3_small_x0_5
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV3_small_x0_5` model depends on args.
|
||||
"""
|
||||
model = MobileNetV3(
|
||||
config=NET_CONFIG["small"],
|
||||
scale=0.5,
|
||||
class_squeeze=LAST_SECOND_CONV_SMALL,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_5"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_small_x0_75(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV3_small_x0_75
|
||||
Args:
|
||||
pretrained: bool=false or str. if `true` load pretrained parameters, `false` otherwise.
|
||||
if str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV3_small_x0_75` model depends on args.
|
||||
"""
|
||||
model = MobileNetV3(
|
||||
config=NET_CONFIG["small"],
|
||||
scale=0.75,
|
||||
class_squeeze=LAST_SECOND_CONV_SMALL,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x0_75"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_small_x1_0(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV3_small_x1_0
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV3_small_x1_0` model depends on args.
|
||||
"""
|
||||
model = MobileNetV3(
|
||||
config=NET_CONFIG["small"],
|
||||
scale=1.0,
|
||||
class_squeeze=LAST_SECOND_CONV_SMALL,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x1_0"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_small_x1_25(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV3_small_x1_25
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV3_small_x1_25` model depends on args.
|
||||
"""
|
||||
model = MobileNetV3(
|
||||
config=NET_CONFIG["small"],
|
||||
scale=1.25,
|
||||
class_squeeze=LAST_SECOND_CONV_SMALL,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_small_x1_25"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_large_x0_35(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV3_large_x0_35
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV3_large_x0_35` model depends on args.
|
||||
"""
|
||||
model = MobileNetV3(
|
||||
config=NET_CONFIG["large"],
|
||||
scale=0.35,
|
||||
class_squeeze=LAST_SECOND_CONV_LARGE,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_35"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_large_x0_5(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV3_large_x0_5
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV3_large_x0_5` model depends on args.
|
||||
"""
|
||||
model = MobileNetV3(
|
||||
config=NET_CONFIG["large"],
|
||||
scale=0.5,
|
||||
class_squeeze=LAST_SECOND_CONV_LARGE,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_5"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_large_x0_75(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV3_large_x0_75
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV3_large_x0_75` model depends on args.
|
||||
"""
|
||||
model = MobileNetV3(
|
||||
config=NET_CONFIG["large"],
|
||||
scale=0.75,
|
||||
class_squeeze=LAST_SECOND_CONV_LARGE,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x0_75"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_large_x1_0(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV3_large_x1_0
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV3_large_x1_0` model depends on args.
|
||||
"""
|
||||
model = MobileNetV3(
|
||||
config=NET_CONFIG["large"],
|
||||
scale=1.0,
|
||||
class_squeeze=LAST_SECOND_CONV_LARGE,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x1_0"],
|
||||
use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_large_x1_25(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
MobileNetV3_large_x1_25
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV3_large_x1_25` model depends on args.
|
||||
"""
|
||||
model = MobileNetV3(
|
||||
config=NET_CONFIG["large"],
|
||||
scale=1.25,
|
||||
class_squeeze=LAST_SECOND_CONV_LARGE,
|
||||
**kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["MobileNetV3_large_x1_25"],
|
||||
use_ssld)
|
||||
return model
|
|
@ -24,26 +24,34 @@ from paddle.nn.initializer import Uniform
|
|||
import math
|
||||
|
||||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain_from, load_dygraph_pretrain_from_url
|
||||
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
MODEL_URLS = {
|
||||
"ResNet18": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_pretrained.pdparams",
|
||||
"ResNet18_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams",
|
||||
"ResNet34": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_pretrained.pdparams",
|
||||
"ResNet34_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_vd_pretrained.pdparams",
|
||||
"ResNet50": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams",
|
||||
"ResNet50_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams",
|
||||
"ResNet101": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet101_pretrained.pdparams",
|
||||
"ResNet101_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet101_vd_pretrained.pdparams",
|
||||
"ResNet152": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet152_pretrained.pdparams",
|
||||
"ResNet152_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet152_vd_pretrained.pdparams",
|
||||
"ResNet200_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet200_vd_pretrained.pdparams",
|
||||
"ResNet18":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams",
|
||||
"ResNet18_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams",
|
||||
"ResNet34":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_pretrained.pdparams",
|
||||
"ResNet34_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_vd_pretrained.pdparams",
|
||||
"ResNet50":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams",
|
||||
"ResNet50_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_vd_pretrained.pdparams",
|
||||
"ResNet101":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_pretrained.pdparams",
|
||||
"ResNet101_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_pretrained.pdparams",
|
||||
"ResNet152":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_pretrained.pdparams",
|
||||
"ResNet152_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_vd_pretrained.pdparams",
|
||||
"ResNet200_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet200_vd_pretrained.pdparams",
|
||||
}
|
||||
|
||||
__all__ = MODEL_URLS.keys()
|
||||
|
||||
|
||||
'''
|
||||
ResNet config: dict.
|
||||
key: depth of ResNet.
|
||||
|
@ -55,17 +63,35 @@ ResNet config: dict.
|
|||
'''
|
||||
NET_CONFIG = {
|
||||
"18": {
|
||||
"block_type": "BasicBlock", "block_depth": [2, 2, 2, 2], "num_channels": [64, 64, 128, 256]},
|
||||
"block_type": "BasicBlock",
|
||||
"block_depth": [2, 2, 2, 2],
|
||||
"num_channels": [64, 64, 128, 256]
|
||||
},
|
||||
"34": {
|
||||
"block_type": "BasicBlock", "block_depth": [3, 4, 6, 3], "num_channels": [64, 64, 128, 256]},
|
||||
"block_type": "BasicBlock",
|
||||
"block_depth": [3, 4, 6, 3],
|
||||
"num_channels": [64, 64, 128, 256]
|
||||
},
|
||||
"50": {
|
||||
"block_type": "BottleneckBlock", "block_depth": [3, 4, 6, 3], "num_channels": [64, 256, 512, 1024]},
|
||||
"block_type": "BottleneckBlock",
|
||||
"block_depth": [3, 4, 6, 3],
|
||||
"num_channels": [64, 256, 512, 1024]
|
||||
},
|
||||
"101": {
|
||||
"block_type": "BottleneckBlock", "block_depth": [3, 4, 23, 3], "num_channels": [64, 256, 512, 1024]},
|
||||
"block_type": "BottleneckBlock",
|
||||
"block_depth": [3, 4, 23, 3],
|
||||
"num_channels": [64, 256, 512, 1024]
|
||||
},
|
||||
"152": {
|
||||
"block_type": "BottleneckBlock", "block_depth": [3, 8, 36, 3], "num_channels": [64, 256, 512, 1024]},
|
||||
"block_type": "BottleneckBlock",
|
||||
"block_depth": [3, 8, 36, 3],
|
||||
"num_channels": [64, 256, 512, 1024]
|
||||
},
|
||||
"200": {
|
||||
"block_type": "BottleneckBlock", "block_depth": [3, 12, 48, 3], "num_channels": [64, 256, 512, 1024]},
|
||||
"block_type": "BottleneckBlock",
|
||||
"block_depth": [3, 12, 48, 3],
|
||||
"num_channels": [64, 256, 512, 1024]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
@ -110,14 +136,14 @@ class ConvBNLayer(TheseusLayer):
|
|||
|
||||
|
||||
class BottleneckBlock(TheseusLayer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
lr_mult=1.0,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
lr_mult=1.0, ):
|
||||
super().__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
|
@ -222,16 +248,15 @@ class ResNet(TheseusLayer):
|
|||
version: str="vb". Different version of ResNet, version vd can perform better.
|
||||
class_num: int=1000. The number of classes.
|
||||
lr_mult_list: list. Control the learning rate of different stages.
|
||||
pretrained: (True or False) or path of pretrained_model. Whether to load the pretrained model.
|
||||
Returns:
|
||||
model: nn.Layer. Specific ResNet model depends on args.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
version="vb",
|
||||
class_num=1000,
|
||||
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
pretrained=False):
|
||||
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
|
||||
super().__init__()
|
||||
|
||||
self.cfg = config
|
||||
|
@ -243,51 +268,46 @@ class ResNet(TheseusLayer):
|
|||
self.block_type = self.cfg["block_type"]
|
||||
self.num_channels = self.cfg["num_channels"]
|
||||
self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
|
||||
self.pretrained = pretrained
|
||||
|
||||
|
||||
assert isinstance(self.lr_mult_list, (
|
||||
list, tuple
|
||||
)), "lr_mult_list should be in (list, tuple) but got {}".format(
|
||||
type(self.lr_mult_list))
|
||||
assert len(
|
||||
self.lr_mult_list
|
||||
) == 5, "lr_mult_list length should be 5 but got {}".format(
|
||||
len(self.lr_mult_list))
|
||||
|
||||
assert len(self.lr_mult_list
|
||||
) == 5, "lr_mult_list length should be 5 but got {}".format(
|
||||
len(self.lr_mult_list))
|
||||
|
||||
self.stem_cfg = {
|
||||
#num_channels, num_filters, filter_size, stride
|
||||
"vb": [[3, 64, 7, 2]],
|
||||
"vd": [[3, 32, 3, 2],
|
||||
[32, 32, 3, 1],
|
||||
[32, 64, 3, 1]]}
|
||||
|
||||
"vd": [[3, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
|
||||
}
|
||||
|
||||
self.stem = nn.Sequential(*[
|
||||
ConvBNLayer(
|
||||
num_channels=in_c,
|
||||
num_filters=out_c,
|
||||
filter_size=k,
|
||||
stride=s,
|
||||
act="relu",
|
||||
lr_mult=self.lr_mult_list[0])
|
||||
num_channels=in_c,
|
||||
num_filters=out_c,
|
||||
filter_size=k,
|
||||
stride=s,
|
||||
act="relu",
|
||||
lr_mult=self.lr_mult_list[0])
|
||||
for in_c, out_c, k, s in self.stem_cfg[version]
|
||||
])
|
||||
|
||||
|
||||
self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
block_list = []
|
||||
for block_idx in range(len(self.block_depth)):
|
||||
shortcut = False
|
||||
for i in range(self.block_depth[block_idx]):
|
||||
block_list.append(
|
||||
globals()[self.block_type](
|
||||
num_channels=self.num_channels[block_idx]
|
||||
if i == 0 else self.num_filters[block_idx] * self.channels_mult,
|
||||
block_list.append(globals()[self.block_type](
|
||||
num_channels=self.num_channels[block_idx] if i == 0 else
|
||||
self.num_filters[block_idx] * self.channels_mult,
|
||||
num_filters=self.num_filters[block_idx],
|
||||
stride=2 if i == 0 and block_idx != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block_idx == i == 0 if version == "vd" else True,
|
||||
lr_mult=self.lr_mult_list[block_idx + 1]))
|
||||
shortcut = True
|
||||
shortcut = True
|
||||
self.blocks = nn.Sequential(*block_list)
|
||||
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
|
@ -297,8 +317,7 @@ class ResNet(TheseusLayer):
|
|||
self.fc = Linear(
|
||||
self.avg_pool_channels,
|
||||
self.class_num,
|
||||
weight_attr=ParamAttr(
|
||||
initializer=Uniform(-stdv, stdv)))
|
||||
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
|
@ -310,254 +329,179 @@ class ResNet(TheseusLayer):
|
|||
return x
|
||||
|
||||
|
||||
def ResNet18(**args):
|
||||
def _load_pretrained(pretrained, model, model_url, use_ssld):
|
||||
if pretrained is False:
|
||||
pass
|
||||
elif pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||
elif isinstance(pretrained, str):
|
||||
load_dygraph_pretrain(model, pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||
)
|
||||
|
||||
|
||||
def ResNet18(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet18
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
|
||||
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet18` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["18"], version="vb", **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet18"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = ResNet(config=NET_CONFIG["18"], version="vb", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet18_vd(**args):
|
||||
def ResNet18_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet18_vd
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
|
||||
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet18_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["18"], version="vd", **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet18_vd"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = ResNet(config=NET_CONFIG["18"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18_vd"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet34(**args):
|
||||
def ResNet34(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet34
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
|
||||
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet34` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["34"], version="vb", **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet34"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = ResNet(config=NET_CONFIG["34"], version="vb", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet34_vd(**args):
|
||||
def ResNet34_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet34_vd
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
|
||||
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet34_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["34"], version="vd", **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet34_vd"], use_ssld=True)
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = ResNet(config=NET_CONFIG["34"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34_vd"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet50(**args):
|
||||
def ResNet50(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet50
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
|
||||
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet50` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["50"], version="vb", **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet50"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = ResNet(config=NET_CONFIG["50"], version="vb", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet50_vd(**args):
|
||||
def ResNet50_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet50_vd
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
|
||||
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet50_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["50"], version="vd", **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet50_vd"], use_ssld=True)
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = ResNet(config=NET_CONFIG["50"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50_vd"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet101(**args):
|
||||
def ResNet101(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet101
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
|
||||
pretrained: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet101` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["101"], version="vb", **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet101"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = ResNet(config=NET_CONFIG["101"], version="vb", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet101_vd(**args):
|
||||
def ResNet101_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet101_vd
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
|
||||
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet101_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["101"], version="vd", **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet101_vd"], use_ssld=True)
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = ResNet(config=NET_CONFIG["101"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101_vd"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet152(**args):
|
||||
def ResNet152(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet152
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
|
||||
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet152` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["152"], version="vb", **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet152"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = ResNet(config=NET_CONFIG["152"], version="vb", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet152_vd(**args):
|
||||
def ResNet152_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet152_vd
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
|
||||
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet152_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["152"], version="vd", **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet152_vd"])
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = ResNet(config=NET_CONFIG["152"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152_vd"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet200_vd(**args):
|
||||
def ResNet200_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet200_vd
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
|
||||
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet200_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["200"], version="vd", **args)
|
||||
if isinstance(model.pretrained, bool):
|
||||
if model.pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet200_vd"], use_ssld=True)
|
||||
elif isinstance(model.pretrained, str):
|
||||
load_dygraph_pretrain(model, model.pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type")
|
||||
model = ResNet(config=NET_CONFIG["200"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet200_vd"], use_ssld)
|
||||
return model
|
||||
|
|
|
@ -14,16 +14,24 @@
|
|||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import MaxPool2D
|
||||
|
||||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
__all__ = ["VGG11", "VGG13", "VGG16", "VGG19"]
|
||||
MODEL_URLS = {
|
||||
"VGG11":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG11_pretrained.pdparams",
|
||||
"VGG13":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG13_pretrained.pdparams",
|
||||
"VGG16":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG16_pretrained.pdparams",
|
||||
"VGG19":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG19_pretrained.pdparams",
|
||||
}
|
||||
__all__ = MODEL_URLS.keys()
|
||||
|
||||
# VGG config
|
||||
# key: VGG network depth
|
||||
|
@ -36,68 +44,12 @@ NET_CONFIG = {
|
|||
}
|
||||
|
||||
|
||||
def VGG11(**args):
|
||||
"""
|
||||
VGG11
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
|
||||
Returns:
|
||||
model: nn.Layer. Specific `VGG11` model depends on args.
|
||||
"""
|
||||
model = VGGNet(config=NET_CONFIG[11], **args)
|
||||
return model
|
||||
|
||||
|
||||
def VGG13(**args):
|
||||
"""
|
||||
VGG13
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
|
||||
Returns:
|
||||
model: nn.Layer. Specific `VGG11` model depends on args.
|
||||
"""
|
||||
model = VGGNet(config=NET_CONFIG[13], **args)
|
||||
return model
|
||||
|
||||
|
||||
def VGG16(**args):
|
||||
"""
|
||||
VGG16
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
|
||||
Returns:
|
||||
model: nn.Layer. Specific `VGG11` model depends on args.
|
||||
"""
|
||||
model = VGGNet(config=NET_CONFIG[16], **args)
|
||||
return model
|
||||
|
||||
|
||||
def VGG19(**args):
|
||||
"""
|
||||
VGG19
|
||||
Args:
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
|
||||
Returns:
|
||||
model: nn.Layer. Specific `VGG11` model depends on args.
|
||||
"""
|
||||
model = VGGNet(config=NET_CONFIG[19], **args)
|
||||
return model
|
||||
|
||||
|
||||
class ConvBlock(TheseusLayer):
|
||||
def __init__(self, input_channels, output_channels, groups):
|
||||
super(ConvBlock, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.groups = groups
|
||||
self._conv_1 = Conv2D(
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=input_channels,
|
||||
out_channels=output_channels,
|
||||
kernel_size=3,
|
||||
|
@ -105,7 +57,7 @@ class ConvBlock(TheseusLayer):
|
|||
padding=1,
|
||||
bias_attr=False)
|
||||
if groups == 2 or groups == 3 or groups == 4:
|
||||
self._conv_2 = Conv2D(
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=output_channels,
|
||||
out_channels=output_channels,
|
||||
kernel_size=3,
|
||||
|
@ -113,7 +65,7 @@ class ConvBlock(TheseusLayer):
|
|||
padding=1,
|
||||
bias_attr=False)
|
||||
if groups == 3 or groups == 4:
|
||||
self._conv_3 = Conv2D(
|
||||
self.conv3 = Conv2D(
|
||||
in_channels=output_channels,
|
||||
out_channels=output_channels,
|
||||
kernel_size=3,
|
||||
|
@ -121,7 +73,7 @@ class ConvBlock(TheseusLayer):
|
|||
padding=1,
|
||||
bias_attr=False)
|
||||
if groups == 4:
|
||||
self._conv_4 = Conv2D(
|
||||
self.conv4 = Conv2D(
|
||||
in_channels=output_channels,
|
||||
out_channels=output_channels,
|
||||
kernel_size=3,
|
||||
|
@ -129,73 +81,148 @@ class ConvBlock(TheseusLayer):
|
|||
padding=1,
|
||||
bias_attr=False)
|
||||
|
||||
self._pool = MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self._relu = nn.ReLU()
|
||||
self.max_pool = MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv_1(inputs)
|
||||
x = self._relu(x)
|
||||
x = self.conv1(inputs)
|
||||
x = self.relu(x)
|
||||
if self.groups == 2 or self.groups == 3 or self.groups == 4:
|
||||
x = self._conv_2(x)
|
||||
x = self._relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
if self.groups == 3 or self.groups == 4:
|
||||
x = self._conv_3(x)
|
||||
x = self._relu(x)
|
||||
x = self.conv3(x)
|
||||
x = self.relu(x)
|
||||
if self.groups == 4:
|
||||
x = self._conv_4(x)
|
||||
x = self._relu(x)
|
||||
x = self._pool(x)
|
||||
x = self.conv4(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
return x
|
||||
|
||||
|
||||
class VGGNet(TheseusLayer):
|
||||
def __init__(self,
|
||||
config,
|
||||
stop_grad_layers=0,
|
||||
class_num=1000,
|
||||
pretrained=False,
|
||||
**args):
|
||||
"""
|
||||
VGGNet
|
||||
Args:
|
||||
config: list. VGGNet config.
|
||||
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
|
||||
class_num: int=1000. The number of classes.
|
||||
Returns:
|
||||
model: nn.Layer. Specific VGG model depends on args.
|
||||
"""
|
||||
|
||||
def __init__(self, config, stop_grad_layers=0, class_num=1000):
|
||||
super().__init__()
|
||||
|
||||
self.stop_grad_layers = stop_grad_layers
|
||||
|
||||
self._conv_block_1 = ConvBlock(3, 64, config[0])
|
||||
self._conv_block_2 = ConvBlock(64, 128, config[1])
|
||||
self._conv_block_3 = ConvBlock(128, 256, config[2])
|
||||
self._conv_block_4 = ConvBlock(256, 512, config[3])
|
||||
self._conv_block_5 = ConvBlock(512, 512, config[4])
|
||||
self.conv_block_1 = ConvBlock(3, 64, config[0])
|
||||
self.conv_block_2 = ConvBlock(64, 128, config[1])
|
||||
self.conv_block_3 = ConvBlock(128, 256, config[2])
|
||||
self.conv_block_4 = ConvBlock(256, 512, config[3])
|
||||
self.conv_block_5 = ConvBlock(512, 512, config[4])
|
||||
|
||||
self._relu = nn.ReLU()
|
||||
self._flatten = nn.Flatten(start_axis=1, stop_axis=-1)
|
||||
self.relu = nn.ReLU()
|
||||
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
|
||||
|
||||
for idx, block in enumerate([
|
||||
self._conv_block_1, self._conv_block_2, self._conv_block_3,
|
||||
self._conv_block_4, self._conv_block_5
|
||||
self.conv_block_1, self.conv_block_2, self.conv_block_3,
|
||||
self.conv_block_4, self.conv_block_5
|
||||
]):
|
||||
if self.stop_grad_layers >= idx + 1:
|
||||
for param in block.parameters():
|
||||
param.trainable = False
|
||||
|
||||
self._drop = Dropout(p=0.5, mode="downscale_in_infer")
|
||||
self._fc1 = Linear(7 * 7 * 512, 4096)
|
||||
self._fc2 = Linear(4096, 4096)
|
||||
self._out = Linear(4096, class_num)
|
||||
|
||||
if pretrained is not None:
|
||||
load_dygraph_pretrain(self, pretrained)
|
||||
self.drop = Dropout(p=0.5, mode="downscale_in_infer")
|
||||
self.fc1 = Linear(7 * 7 * 512, 4096)
|
||||
self.fc2 = Linear(4096, 4096)
|
||||
self.fc3 = Linear(4096, class_num)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv_block_1(inputs)
|
||||
x = self._conv_block_2(x)
|
||||
x = self._conv_block_3(x)
|
||||
x = self._conv_block_4(x)
|
||||
x = self._conv_block_5(x)
|
||||
x = self._flatten(x)
|
||||
x = self._fc1(x)
|
||||
x = self._relu(x)
|
||||
x = self._drop(x)
|
||||
x = self._fc2(x)
|
||||
x = self._relu(x)
|
||||
x = self._drop(x)
|
||||
x = self._out(x)
|
||||
x = self.conv_block_1(inputs)
|
||||
x = self.conv_block_2(x)
|
||||
x = self.conv_block_3(x)
|
||||
x = self.conv_block_4(x)
|
||||
x = self.conv_block_5(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
def _load_pretrained(pretrained, model, model_url, use_ssld):
|
||||
if pretrained is False:
|
||||
pass
|
||||
elif pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||
elif isinstance(pretrained, str):
|
||||
load_dygraph_pretrain(model, pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||
)
|
||||
|
||||
|
||||
def VGG11(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
VGG11
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `VGG11` model depends on args.
|
||||
"""
|
||||
model = VGGNet(config=NET_CONFIG[11], **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["VGG11"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def VGG13(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
VGG13
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `VGG13` model depends on args.
|
||||
"""
|
||||
model = VGGNet(config=NET_CONFIG[13], **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["VGG13"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def VGG16(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
VGG16
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `VGG16` model depends on args.
|
||||
"""
|
||||
model = VGGNet(config=NET_CONFIG[16], **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["VGG16"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def VGG19(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
VGG19
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `VGG19` model depends on args.
|
||||
"""
|
||||
model = VGGNet(config=NET_CONFIG[19], **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["VGG19"], use_ssld)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue