Update inception_v3.py

pull/746/head
Felix 2021-05-28 16:20:06 +08:00 committed by GitHub
parent c98c4e28d5
commit 4ad209e5fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 70 additions and 52 deletions

View File

@ -1,4 +1,4 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,28 +12,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform from paddle.nn.initializer import Uniform
import math import math
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
__all__ = ["InceptionV3"]
# InceptionV3 config MODEL_URLS = {
# key: inception blocks "InceptionV3": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/InceptionV3_pretrained.pdparams",
# value: conv num in different blocks }
__all__ = MODEL_URLS.keys()
'''
InceptionV3 config: dict.
key: inception blocks of InceptionV3.
values: conv num in different blocks.
'''
NET_CONFIG = { NET_CONFIG = {
'inception_a':[[192, 256, 288], [32, 64, 64]], 'inception_a':[[192, 256, 288], [32, 64, 64]],
'inception_b':[288], 'inception_b':[288],
@ -42,7 +46,6 @@ NET_CONFIG = {
'inception_e':[1280,2048] 'inception_e':[1280,2048]
} }
class ConvBNLayer(TheseusLayer): class ConvBNLayer(TheseusLayer):
def __init__(self, def __init__(self,
num_channels, num_channels,
@ -53,7 +56,7 @@ class ConvBNLayer(TheseusLayer):
groups=1, groups=1,
act="relu"): act="relu"):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self.act = act
self.conv = Conv2D( self.conv = Conv2D(
in_channels=num_channels, in_channels=num_channels,
out_channels=num_filters, out_channels=num_filters,
@ -63,13 +66,15 @@ class ConvBNLayer(TheseusLayer):
groups=groups, groups=groups,
bias_attr=False) bias_attr=False)
self.batch_norm = BatchNorm( self.batch_norm = BatchNorm(
num_filters, num_filters)
act=act) self.relu = nn.ReLU()
def forward(self, inputs): def forward(self, x):
y = self.conv(inputs) x = self.conv(x)
y = self.batch_norm(y) x = self.batch_norm(x)
return y if self.act:
x = self.relu(x)
return x
class InceptionStem(TheseusLayer): class InceptionStem(TheseusLayer):
def __init__(self): def __init__(self):
@ -100,14 +105,14 @@ class InceptionStem(TheseusLayer):
filter_size=3, filter_size=3,
act="relu") act="relu")
def forward(self, x): def forward(self, x):
y = self.conv_1a_3x3(x) x = self.conv_1a_3x3(x)
y = self.conv_2a_3x3(y) x = self.conv_2a_3x3(x)
y = self.conv_2b_3x3(y) x = self.conv_2b_3x3(x)
y = self.maxpool(y) x = self.maxpool(x)
y = self.conv_3b_1x1(y) x = self.conv_3b_1x1(x)
y = self.conv_4a_3x3(y) x = self.conv_4a_3x3(x)
y = self.maxpool(y) x = self.maxpool(x)
return y return x
class InceptionA(TheseusLayer): class InceptionA(TheseusLayer):
@ -158,8 +163,8 @@ class InceptionA(TheseusLayer):
branch_pool = self.branch_pool(x) branch_pool = self.branch_pool(x)
branch_pool = self.branch_pool_conv(branch_pool) branch_pool = self.branch_pool_conv(branch_pool)
outputs = paddle.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1) x = paddle.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1)
return outputs return x
class InceptionB(TheseusLayer): class InceptionB(TheseusLayer):
@ -195,9 +200,9 @@ class InceptionB(TheseusLayer):
branch_pool = self.branch_pool(x) branch_pool = self.branch_pool(x)
outputs = paddle.concat([branch3x3, branch3x3dbl, branch_pool], axis=1) x = paddle.concat([branch3x3, branch3x3dbl, branch_pool], axis=1)
return outputs return x
class InceptionC(TheseusLayer): class InceptionC(TheseusLayer):
def __init__(self, num_channels, channels_7x7): def __init__(self, num_channels, channels_7x7):
@ -273,9 +278,9 @@ class InceptionC(TheseusLayer):
branch_pool = self.branch_pool(x) branch_pool = self.branch_pool(x)
branch_pool = self.branch_pool_conv(branch_pool) branch_pool = self.branch_pool_conv(branch_pool)
outputs = paddle.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1) x = paddle.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1)
return outputs return x
class InceptionD(TheseusLayer): class InceptionD(TheseusLayer):
def __init__(self, num_channels): def __init__(self, num_channels):
@ -321,8 +326,8 @@ class InceptionD(TheseusLayer):
branch_pool = self.branch_pool(x) branch_pool = self.branch_pool(x)
outputs = paddle.concat([branch3x3, branch7x7x3, branch_pool], axis=1) x = paddle.concat([branch3x3, branch7x7x3, branch_pool], axis=1)
return outputs return x
class InceptionE(TheseusLayer): class InceptionE(TheseusLayer):
def __init__(self, num_channels): def __init__(self, num_channels):
@ -391,12 +396,20 @@ class InceptionE(TheseusLayer):
branch_pool = self.branch_pool(x) branch_pool = self.branch_pool(x)
branch_pool = self.branch_pool_conv(branch_pool) branch_pool = self.branch_pool_conv(branch_pool)
outputs = paddle.concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1) x = paddle.concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1)
return outputs return x
class Inception_V3(TheseusLayer): class Inception_V3(TheseusLayer):
"""
Inception_V3
Args:
config: dict. config of Inception_V3.
class_num: int=1000. The number of classes.
pretrained: (True or False) or path of pretrained_model. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific Inception_V3 model depends on args.
"""
def __init__(self, def __init__(self,
config, config,
class_num=1000, class_num=1000,
@ -409,6 +422,7 @@ class Inception_V3(TheseusLayer):
self.inception_b_list = config['inception_b'] self.inception_b_list = config['inception_b']
self.inception_d_list = config['inception_d'] self.inception_d_list = config['inception_d']
self.inception_e_list = config ['inception_e'] self.inception_e_list = config ['inception_e']
self.pretrained = pretrained
self.inception_stem = InceptionStem() self.inception_stem = InceptionStem()
@ -445,20 +459,15 @@ class Inception_V3(TheseusLayer):
initializer=Uniform(-stdv, stdv)), initializer=Uniform(-stdv, stdv)),
bias_attr=ParamAttr()) bias_attr=ParamAttr())
if pretrained is not None:
load_dygraph_pretrain(self, pretrained)
def forward(self, x): def forward(self, x):
y = self.inception_stem(x) x = self.inception_stem(x)
for inception_block in self.inception_block_list: for inception_block in self.inception_block_list:
y = inception_block(y) x = inception_block(x)
y = self.gap(y) x = self.gap(x)
y = paddle.reshape(y, shape=[-1, 2048]) x = paddle.reshape(x, shape=[-1, 2048])
y = self.drop(y) x = self.drop(x)
y = self.out(y) x = self.out(x)
return y return x
def InceptionV3(**kwargs): def InceptionV3(**kwargs):
@ -467,10 +476,19 @@ def InceptionV3(**kwargs):
Args: Args:
kwargs: kwargs:
class_num: int=1000. Output dim of last fc layer. class_num: int=1000. Output dim of last fc layer.
pretrained: str, pretrained model file pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns: Returns:
model: nn.Layer. Specific `InceptionV3` model model: nn.Layer. Specific `InceptionV3` model
""" """
model = Inception_V3(NET_CONFIG, **kwargs) model = Inception_V3(NET_CONFIG, **kwargs)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["InceptionV3"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model return model