fix: adapt to release 2.3
parent
677f6aea7e
commit
8ac51f7416
|
@ -6,43 +6,45 @@
|
||||||
|
|
||||||
## 二、准备工作
|
## 二、准备工作
|
||||||
|
|
||||||
首先需要选定研究的模型,本文设定ResNet50作为研究模型,将resnet.py从[模型库](../../../ppcls/arch/architecture/)拷贝到当前目录下,并下载预训练模型[预训练模型](../../zh_CN/models/models_intro), 复制resnet50的模型链接,使用下列命令下载并解压预训练模型。
|
首先需要选定研究的模型,本文设定ResNet50作为研究模型,将模型组网代码[resnet.py](../../../ppcls/arch/backbone/legendary_models/resnet.py)拷贝到[目录](../../../ppcls/utils/feature_maps_visualization/)下,并下载[ResNet50预训练模型](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams),或使用以下命令下载。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
wget The Link for Pretrained Model
|
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams
|
||||||
tar -xf Downloaded Pretrained Model
|
|
||||||
```
|
```
|
||||||
|
|
||||||
以resnet50为例:
|
其他模型网络结构代码及预训练模型请自行下载:[模型库](../../../ppcls/arch/backbone/),[预训练模型](../models/models_intro.md)。
|
||||||
```bash
|
|
||||||
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
|
|
||||||
tar -xf ResNet50_pretrained.tar
|
|
||||||
```
|
|
||||||
|
|
||||||
## 三、修改模型
|
## 三、修改模型
|
||||||
|
|
||||||
找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。
|
找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。
|
||||||
|
|
||||||
在fm_vis.py中修改模型的名字。
|
在ResNet50的forward函数中指定要可视化的特征图
|
||||||
|
|
||||||
在ResNet50的__init__函数中定义self.fm
|
|
||||||
```python
|
```python
|
||||||
self.fm = None
|
def forward(self, x):
|
||||||
|
with paddle.static.amp.fp16_guard():
|
||||||
|
if self.data_format == "NHWC":
|
||||||
|
x = paddle.transpose(x, [0, 2, 3, 1])
|
||||||
|
x.stop_gradient = True
|
||||||
|
x = self.stem(x)
|
||||||
|
fm = x
|
||||||
|
x = self.max_pool(x)
|
||||||
|
x = self.blocks(x)
|
||||||
|
x = self.avg_pool(x)
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x, fm
|
||||||
```
|
```
|
||||||
在ResNet50的forward函数中指定特征图
|
|
||||||
|
然后修改代码[fm_vis.py](../../../ppcls/utils/feature_maps_visualization/fm_vis.py),引入 `ResNet50`,实例化 `net` 对象:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def forward(self, inputs):
|
from resnet import ResNet50
|
||||||
y = self.conv(inputs)
|
net = ResNet50()
|
||||||
self.fm = y
|
|
||||||
y = self.pool2d_max(y)
|
|
||||||
for bottleneck_block in self.bottleneck_block_list:
|
|
||||||
y = bottleneck_block(y)
|
|
||||||
y = self.avg_pool(y)
|
|
||||||
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
|
|
||||||
y = self.out(y)
|
|
||||||
return y, self.fm
|
|
||||||
```
|
```
|
||||||
执行函数
|
|
||||||
|
最后执行函数
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \
|
python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \
|
||||||
-c channel_num -p pretrained model \
|
-c channel_num -p pretrained model \
|
||||||
|
@ -51,9 +53,10 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
|
||||||
--save_path where to save \
|
--save_path where to save \
|
||||||
--use_gpu whether to use gpu
|
--use_gpu whether to use gpu
|
||||||
```
|
```
|
||||||
|
|
||||||
参数说明:
|
参数说明:
|
||||||
+ `-i`:待预测的图片文件路径,如 `./test.jpeg`
|
+ `-i`:待预测的图片文件路径,如 `./test.jpeg`
|
||||||
+ `-c`:特征图维度,如 `./resnet50_vd/model`
|
+ `-c`:特征图维度,如 `5`
|
||||||
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/`
|
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/`
|
||||||
+ `--interpolation`: 图像插值方式, 默认值 1
|
+ `--interpolation`: 图像插值方式, 默认值 1
|
||||||
+ `--save_path`:保存路径,如:`./tools/`
|
+ `--save_path`:保存路径,如:`./tools/`
|
||||||
|
@ -63,7 +66,7 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
|
||||||
|
|
||||||
* 输入图片:
|
* 输入图片:
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
* 运行下面的特征图可视化脚本
|
* 运行下面的特征图可视化脚本
|
||||||
|
|
||||||
|
@ -75,10 +78,9 @@ python tools/feature_maps_visualization/fm_vis.py \
|
||||||
--show=True \
|
--show=True \
|
||||||
--interpolation=1 \
|
--interpolation=1 \
|
||||||
--save_path="./output.png" \
|
--save_path="./output.png" \
|
||||||
--use_gpu=False \
|
--use_gpu=False
|
||||||
--load_static_weights=True
|
|
||||||
```
|
```
|
||||||
|
|
||||||
* 输出特征图保存为`output.png`,如下所示。
|
* 输出特征图保存为`output.png`,如下所示。
|
||||||
|
|
||||||

|

|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
|
|
||||||
tar -xf ResNet50_pretrained.tar
|
|
|
@ -19,7 +19,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append(__dir__)
|
sys.path.append(__dir__)
|
||||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../..')))
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.distributed import ParallelEnv
|
from paddle.distributed import ParallelEnv
|
||||||
|
@ -33,18 +33,13 @@ def parse_args():
|
||||||
return v.lower() in ("true", "t", "1")
|
return v.lower() in ("true", "t", "1")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-i", "--image_file", type=str)
|
parser.add_argument("-i", "--image_file", required=True, type=str)
|
||||||
parser.add_argument("-c", "--channel_num", type=int)
|
parser.add_argument("-c", "--channel_num", type=int)
|
||||||
parser.add_argument("-p", "--pretrained_model", type=str)
|
parser.add_argument("-p", "--pretrained_model", type=str)
|
||||||
parser.add_argument("--show", type=str2bool, default=False)
|
parser.add_argument("--show", type=str2bool, default=False)
|
||||||
parser.add_argument("--interpolation", type=int, default=1)
|
parser.add_argument("--interpolation", type=int, default=1)
|
||||||
parser.add_argument("--save_path", type=str, default=None)
|
parser.add_argument("--save_path", type=str, default=None)
|
||||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||||
parser.add_argument(
|
|
||||||
"--load_static_weights",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help='Whether to load the pretrained weights saved in static mode')
|
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -79,7 +74,7 @@ def main():
|
||||||
place = paddle.set_device(place)
|
place = paddle.set_device(place)
|
||||||
|
|
||||||
net = ResNet50()
|
net = ResNet50()
|
||||||
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
|
load_dygraph_pretrain(net, args.pretrained_model)
|
||||||
|
|
||||||
img = cv2.imread(args.image_file, cv2.IMREAD_COLOR)
|
img = cv2.imread(args.image_file, cv2.IMREAD_COLOR)
|
||||||
data = preprocess(img, operators)
|
data = preprocess(img, operators)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -12,126 +12,204 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import ParamAttr
|
from paddle import ParamAttr
|
||||||
import paddle.nn as nn
|
import paddle.nn as nn
|
||||||
import paddle.nn.functional as F
|
from paddle.nn import Conv2D, BatchNorm, Linear
|
||||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
|
||||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||||
from paddle.nn.initializer import Uniform
|
from paddle.nn.initializer import Uniform
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
__all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
|
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||||
|
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||||
|
|
||||||
|
MODEL_URLS = {
|
||||||
|
"ResNet18":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams",
|
||||||
|
"ResNet18_vd":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams",
|
||||||
|
"ResNet34":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_pretrained.pdparams",
|
||||||
|
"ResNet34_vd":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_vd_pretrained.pdparams",
|
||||||
|
"ResNet50":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams",
|
||||||
|
"ResNet50_vd":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_vd_pretrained.pdparams",
|
||||||
|
"ResNet101":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_pretrained.pdparams",
|
||||||
|
"ResNet101_vd":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_pretrained.pdparams",
|
||||||
|
"ResNet152":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_pretrained.pdparams",
|
||||||
|
"ResNet152_vd":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_vd_pretrained.pdparams",
|
||||||
|
"ResNet200_vd":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet200_vd_pretrained.pdparams",
|
||||||
|
}
|
||||||
|
|
||||||
|
__all__ = MODEL_URLS.keys()
|
||||||
|
'''
|
||||||
|
ResNet config: dict.
|
||||||
|
key: depth of ResNet.
|
||||||
|
values: config's dict of specific model.
|
||||||
|
keys:
|
||||||
|
block_type: Two different blocks in ResNet, BasicBlock and BottleneckBlock are optional.
|
||||||
|
block_depth: The number of blocks in different stages in ResNet.
|
||||||
|
num_channels: The number of channels to enter the next stage.
|
||||||
|
'''
|
||||||
|
NET_CONFIG = {
|
||||||
|
"18": {
|
||||||
|
"block_type": "BasicBlock",
|
||||||
|
"block_depth": [2, 2, 2, 2],
|
||||||
|
"num_channels": [64, 64, 128, 256]
|
||||||
|
},
|
||||||
|
"34": {
|
||||||
|
"block_type": "BasicBlock",
|
||||||
|
"block_depth": [3, 4, 6, 3],
|
||||||
|
"num_channels": [64, 64, 128, 256]
|
||||||
|
},
|
||||||
|
"50": {
|
||||||
|
"block_type": "BottleneckBlock",
|
||||||
|
"block_depth": [3, 4, 6, 3],
|
||||||
|
"num_channels": [64, 256, 512, 1024]
|
||||||
|
},
|
||||||
|
"101": {
|
||||||
|
"block_type": "BottleneckBlock",
|
||||||
|
"block_depth": [3, 4, 23, 3],
|
||||||
|
"num_channels": [64, 256, 512, 1024]
|
||||||
|
},
|
||||||
|
"152": {
|
||||||
|
"block_type": "BottleneckBlock",
|
||||||
|
"block_depth": [3, 8, 36, 3],
|
||||||
|
"num_channels": [64, 256, 512, 1024]
|
||||||
|
},
|
||||||
|
"200": {
|
||||||
|
"block_type": "BottleneckBlock",
|
||||||
|
"block_depth": [3, 12, 48, 3],
|
||||||
|
"num_channels": [64, 256, 512, 1024]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ConvBNLayer(nn.Layer):
|
class ConvBNLayer(TheseusLayer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_channels,
|
num_channels,
|
||||||
num_filters,
|
num_filters,
|
||||||
filter_size,
|
filter_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
groups=1,
|
groups=1,
|
||||||
|
is_vd_mode=False,
|
||||||
act=None,
|
act=None,
|
||||||
name=None):
|
lr_mult=1.0,
|
||||||
super(ConvBNLayer, self).__init__()
|
data_format="NCHW"):
|
||||||
|
super().__init__()
|
||||||
self._conv = Conv2D(
|
self.is_vd_mode = is_vd_mode
|
||||||
|
self.act = act
|
||||||
|
self.avg_pool = AvgPool2D(
|
||||||
|
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||||
|
self.conv = Conv2D(
|
||||||
in_channels=num_channels,
|
in_channels=num_channels,
|
||||||
out_channels=num_filters,
|
out_channels=num_filters,
|
||||||
kernel_size=filter_size,
|
kernel_size=filter_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=(filter_size - 1) // 2,
|
padding=(filter_size - 1) // 2,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
weight_attr=ParamAttr(name=name + "_weights"),
|
weight_attr=ParamAttr(learning_rate=lr_mult),
|
||||||
bias_attr=False)
|
bias_attr=False,
|
||||||
if name == "conv1":
|
data_format=data_format)
|
||||||
bn_name = "bn_" + name
|
self.bn = BatchNorm(
|
||||||
else:
|
|
||||||
bn_name = "bn" + name[3:]
|
|
||||||
self._batch_norm = BatchNorm(
|
|
||||||
num_filters,
|
num_filters,
|
||||||
act=act,
|
param_attr=ParamAttr(learning_rate=lr_mult),
|
||||||
param_attr=ParamAttr(name=bn_name + "_scale"),
|
bias_attr=ParamAttr(learning_rate=lr_mult),
|
||||||
bias_attr=ParamAttr(bn_name + "_offset"),
|
data_layout=data_format)
|
||||||
moving_mean_name=bn_name + "_mean",
|
self.relu = nn.ReLU()
|
||||||
moving_variance_name=bn_name + "_variance")
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, x):
|
||||||
y = self._conv(inputs)
|
if self.is_vd_mode:
|
||||||
y = self._batch_norm(y)
|
x = self.avg_pool(x)
|
||||||
return y
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
if self.act:
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class BottleneckBlock(nn.Layer):
|
class BottleneckBlock(TheseusLayer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_channels,
|
num_channels,
|
||||||
num_filters,
|
num_filters,
|
||||||
stride,
|
stride,
|
||||||
shortcut=True,
|
shortcut=True,
|
||||||
name=None):
|
if_first=False,
|
||||||
super(BottleneckBlock, self).__init__()
|
lr_mult=1.0,
|
||||||
|
data_format="NCHW"):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self.conv0 = ConvBNLayer(
|
self.conv0 = ConvBNLayer(
|
||||||
num_channels=num_channels,
|
num_channels=num_channels,
|
||||||
num_filters=num_filters,
|
num_filters=num_filters,
|
||||||
filter_size=1,
|
filter_size=1,
|
||||||
act="relu",
|
act="relu",
|
||||||
name=name + "_branch2a")
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
self.conv1 = ConvBNLayer(
|
self.conv1 = ConvBNLayer(
|
||||||
num_channels=num_filters,
|
num_channels=num_filters,
|
||||||
num_filters=num_filters,
|
num_filters=num_filters,
|
||||||
filter_size=3,
|
filter_size=3,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
act="relu",
|
act="relu",
|
||||||
name=name + "_branch2b")
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
self.conv2 = ConvBNLayer(
|
self.conv2 = ConvBNLayer(
|
||||||
num_channels=num_filters,
|
num_channels=num_filters,
|
||||||
num_filters=num_filters * 4,
|
num_filters=num_filters * 4,
|
||||||
filter_size=1,
|
filter_size=1,
|
||||||
act=None,
|
act=None,
|
||||||
name=name + "_branch2c")
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
|
|
||||||
if not shortcut:
|
if not shortcut:
|
||||||
self.short = ConvBNLayer(
|
self.short = ConvBNLayer(
|
||||||
num_channels=num_channels,
|
num_channels=num_channels,
|
||||||
num_filters=num_filters * 4,
|
num_filters=num_filters * 4,
|
||||||
filter_size=1,
|
filter_size=1,
|
||||||
stride=stride,
|
stride=stride if if_first else 1,
|
||||||
name=name + "_branch1")
|
is_vd_mode=False if if_first else True,
|
||||||
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
self.shortcut = shortcut
|
self.shortcut = shortcut
|
||||||
|
|
||||||
self._num_channels_out = num_filters * 4
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
def forward(self, inputs):
|
x = self.conv0(x)
|
||||||
y = self.conv0(inputs)
|
x = self.conv1(x)
|
||||||
conv1 = self.conv1(y)
|
x = self.conv2(x)
|
||||||
conv2 = self.conv2(conv1)
|
|
||||||
|
|
||||||
if self.shortcut:
|
if self.shortcut:
|
||||||
short = inputs
|
short = identity
|
||||||
else:
|
else:
|
||||||
short = self.short(inputs)
|
short = self.short(identity)
|
||||||
|
x = paddle.add(x=x, y=short)
|
||||||
y = paddle.add(x=short, y=conv2)
|
x = self.relu(x)
|
||||||
y = F.relu(y)
|
return x
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
class BasicBlock(nn.Layer):
|
class BasicBlock(TheseusLayer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_channels,
|
num_channels,
|
||||||
num_filters,
|
num_filters,
|
||||||
stride,
|
stride,
|
||||||
shortcut=True,
|
shortcut=True,
|
||||||
name=None):
|
if_first=False,
|
||||||
super(BasicBlock, self).__init__()
|
lr_mult=1.0,
|
||||||
|
data_format="NCHW"):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.conv0 = ConvBNLayer(
|
self.conv0 = ConvBNLayer(
|
||||||
num_channels=num_channels,
|
num_channels=num_channels,
|
||||||
|
@ -139,155 +217,319 @@ class BasicBlock(nn.Layer):
|
||||||
filter_size=3,
|
filter_size=3,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
act="relu",
|
act="relu",
|
||||||
name=name + "_branch2a")
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
self.conv1 = ConvBNLayer(
|
self.conv1 = ConvBNLayer(
|
||||||
num_channels=num_filters,
|
num_channels=num_filters,
|
||||||
num_filters=num_filters,
|
num_filters=num_filters,
|
||||||
filter_size=3,
|
filter_size=3,
|
||||||
act=None,
|
act=None,
|
||||||
name=name + "_branch2b")
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
if not shortcut:
|
if not shortcut:
|
||||||
self.short = ConvBNLayer(
|
self.short = ConvBNLayer(
|
||||||
num_channels=num_channels,
|
num_channels=num_channels,
|
||||||
num_filters=num_filters,
|
num_filters=num_filters,
|
||||||
filter_size=1,
|
filter_size=1,
|
||||||
stride=stride,
|
stride=stride if if_first else 1,
|
||||||
name=name + "_branch1")
|
is_vd_mode=False if if_first else True,
|
||||||
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
self.shortcut = shortcut
|
self.shortcut = shortcut
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, x):
|
||||||
y = self.conv0(inputs)
|
identity = x
|
||||||
conv1 = self.conv1(y)
|
x = self.conv0(x)
|
||||||
|
x = self.conv1(x)
|
||||||
if self.shortcut:
|
if self.shortcut:
|
||||||
short = inputs
|
short = identity
|
||||||
else:
|
else:
|
||||||
short = self.short(inputs)
|
short = self.short(identity)
|
||||||
y = paddle.add(x=short, y=conv1)
|
x = paddle.add(x=x, y=short)
|
||||||
y = F.relu(y)
|
x = self.relu(x)
|
||||||
return y
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ResNet(nn.Layer):
|
class ResNet(TheseusLayer):
|
||||||
def __init__(self, layers=50, class_dim=1000):
|
"""
|
||||||
super(ResNet, self).__init__()
|
ResNet
|
||||||
|
Args:
|
||||||
|
config: dict. config of ResNet.
|
||||||
|
version: str="vb". Different version of ResNet, version vd can perform better.
|
||||||
|
class_num: int=1000. The number of classes.
|
||||||
|
lr_mult_list: list. Control the learning rate of different stages.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific ResNet model depends on args.
|
||||||
|
"""
|
||||||
|
|
||||||
self.layers = layers
|
def __init__(self,
|
||||||
supported_layers = [18, 34, 50, 101, 152]
|
config,
|
||||||
assert layers in supported_layers, \
|
version="vb",
|
||||||
"supported layers are {} but input layer is {}".format(
|
class_num=1000,
|
||||||
supported_layers, layers)
|
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
|
||||||
|
data_format="NCHW",
|
||||||
|
input_image_channel=3,
|
||||||
|
return_patterns=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
if layers == 18:
|
self.cfg = config
|
||||||
depth = [2, 2, 2, 2]
|
self.lr_mult_list = lr_mult_list
|
||||||
elif layers == 34 or layers == 50:
|
self.is_vd_mode = version == "vd"
|
||||||
depth = [3, 4, 6, 3]
|
self.class_num = class_num
|
||||||
elif layers == 101:
|
self.num_filters = [64, 128, 256, 512]
|
||||||
depth = [3, 4, 23, 3]
|
self.block_depth = self.cfg["block_depth"]
|
||||||
elif layers == 152:
|
self.block_type = self.cfg["block_type"]
|
||||||
depth = [3, 8, 36, 3]
|
self.num_channels = self.cfg["num_channels"]
|
||||||
num_channels = [64, 256, 512,
|
self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
|
||||||
1024] if layers >= 50 else [64, 64, 128, 256]
|
|
||||||
num_filters = [64, 128, 256, 512]
|
|
||||||
|
|
||||||
self.feature_map = None
|
assert isinstance(self.lr_mult_list, (
|
||||||
|
list, tuple
|
||||||
|
)), "lr_mult_list should be in (list, tuple) but got {}".format(
|
||||||
|
type(self.lr_mult_list))
|
||||||
|
assert len(self.lr_mult_list
|
||||||
|
) == 5, "lr_mult_list length should be 5 but got {}".format(
|
||||||
|
len(self.lr_mult_list))
|
||||||
|
|
||||||
self.conv = ConvBNLayer(
|
self.stem_cfg = {
|
||||||
num_channels=3,
|
#num_channels, num_filters, filter_size, stride
|
||||||
num_filters=64,
|
"vb": [[input_image_channel, 64, 7, 2]],
|
||||||
filter_size=7,
|
"vd":
|
||||||
stride=2,
|
[[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
|
||||||
|
}
|
||||||
|
|
||||||
|
self.stem = nn.Sequential(* [
|
||||||
|
ConvBNLayer(
|
||||||
|
num_channels=in_c,
|
||||||
|
num_filters=out_c,
|
||||||
|
filter_size=k,
|
||||||
|
stride=s,
|
||||||
act="relu",
|
act="relu",
|
||||||
name="conv1")
|
lr_mult=self.lr_mult_list[0],
|
||||||
self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
|
data_format=data_format)
|
||||||
|
for in_c, out_c, k, s in self.stem_cfg[version]
|
||||||
|
])
|
||||||
|
|
||||||
self.block_list = []
|
self.max_pool = MaxPool2D(
|
||||||
if layers >= 50:
|
kernel_size=3, stride=2, padding=1, data_format=data_format)
|
||||||
for block in range(len(depth)):
|
block_list = []
|
||||||
|
for block_idx in range(len(self.block_depth)):
|
||||||
shortcut = False
|
shortcut = False
|
||||||
for i in range(depth[block]):
|
for i in range(self.block_depth[block_idx]):
|
||||||
if layers in [101, 152] and block == 2:
|
block_list.append(globals()[self.block_type](
|
||||||
if i == 0:
|
num_channels=self.num_channels[block_idx] if i == 0 else
|
||||||
conv_name = "res" + str(block + 2) + "a"
|
self.num_filters[block_idx] * self.channels_mult,
|
||||||
else:
|
num_filters=self.num_filters[block_idx],
|
||||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
stride=2 if i == 0 and block_idx != 0 else 1,
|
||||||
else:
|
|
||||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
|
||||||
bottleneck_block = self.add_sublayer(
|
|
||||||
conv_name,
|
|
||||||
BottleneckBlock(
|
|
||||||
num_channels=num_channels[block]
|
|
||||||
if i == 0 else num_filters[block] * 4,
|
|
||||||
num_filters=num_filters[block],
|
|
||||||
stride=2 if i == 0 and block != 0 else 1,
|
|
||||||
shortcut=shortcut,
|
shortcut=shortcut,
|
||||||
name=conv_name))
|
if_first=block_idx == i == 0 if version == "vd" else True,
|
||||||
self.block_list.append(bottleneck_block)
|
lr_mult=self.lr_mult_list[block_idx + 1],
|
||||||
|
data_format=data_format))
|
||||||
shortcut = True
|
shortcut = True
|
||||||
|
self.blocks = nn.Sequential(*block_list)
|
||||||
|
|
||||||
|
self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format)
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
self.avg_pool_channels = self.num_channels[-1] * 2
|
||||||
|
stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
|
||||||
|
self.fc = Linear(
|
||||||
|
self.avg_pool_channels,
|
||||||
|
self.class_num,
|
||||||
|
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
|
||||||
|
|
||||||
|
self.data_format = data_format
|
||||||
|
if return_patterns is not None:
|
||||||
|
self.update_res(return_patterns)
|
||||||
|
self.register_forward_post_hook(self._return_dict_hook)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
with paddle.static.amp.fp16_guard():
|
||||||
|
if self.data_format == "NHWC":
|
||||||
|
x = paddle.transpose(x, [0, 2, 3, 1])
|
||||||
|
x.stop_gradient = True
|
||||||
|
x = self.stem(x)
|
||||||
|
fm = x
|
||||||
|
x = self.max_pool(x)
|
||||||
|
x = self.blocks(x)
|
||||||
|
x = self.avg_pool(x)
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x, fm
|
||||||
|
|
||||||
|
|
||||||
|
def _load_pretrained(pretrained, model, model_url, use_ssld):
|
||||||
|
if pretrained is False:
|
||||||
|
pass
|
||||||
|
elif pretrained is True:
|
||||||
|
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||||
|
elif isinstance(pretrained, str):
|
||||||
|
load_dygraph_pretrain(model, pretrained)
|
||||||
else:
|
else:
|
||||||
for block in range(len(depth)):
|
raise RuntimeError(
|
||||||
shortcut = False
|
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||||
for i in range(depth[block]):
|
)
|
||||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
|
||||||
basic_block = self.add_sublayer(
|
|
||||||
conv_name,
|
|
||||||
BasicBlock(
|
|
||||||
num_channels=num_channels[block]
|
|
||||||
if i == 0 else num_filters[block],
|
|
||||||
num_filters=num_filters[block],
|
|
||||||
stride=2 if i == 0 and block != 0 else 1,
|
|
||||||
shortcut=shortcut,
|
|
||||||
name=conv_name))
|
|
||||||
self.block_list.append(basic_block)
|
|
||||||
shortcut = True
|
|
||||||
|
|
||||||
self.pool2d_avg = AdaptiveAvgPool2D(1)
|
|
||||||
|
|
||||||
self.pool2d_avg_channels = num_channels[-1] * 2
|
|
||||||
|
|
||||||
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
|
|
||||||
|
|
||||||
self.out = Linear(
|
|
||||||
self.pool2d_avg_channels,
|
|
||||||
class_dim,
|
|
||||||
weight_attr=ParamAttr(
|
|
||||||
initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
|
|
||||||
bias_attr=ParamAttr(name="fc_0.b_0"))
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
y = self.conv(inputs)
|
|
||||||
y = self.pool2d_max(y)
|
|
||||||
self.feature_map = y
|
|
||||||
for block in self.block_list:
|
|
||||||
y = block(y)
|
|
||||||
y = self.pool2d_avg(y)
|
|
||||||
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
|
|
||||||
y = self.out(y)
|
|
||||||
return y, self.feature_map
|
|
||||||
|
|
||||||
|
|
||||||
def ResNet18(**args):
|
def ResNet18(pretrained=False, use_ssld=False, **kwargs):
|
||||||
model = ResNet(layers=18, **args)
|
"""
|
||||||
|
ResNet18
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet18` model depends on args.
|
||||||
|
"""
|
||||||
|
model = ResNet(config=NET_CONFIG["18"], version="vb", **kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def ResNet34(**args):
|
def ResNet18_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||||
model = ResNet(layers=34, **args)
|
"""
|
||||||
|
ResNet18_vd
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet18_vd` model depends on args.
|
||||||
|
"""
|
||||||
|
model = ResNet(config=NET_CONFIG["18"], version="vd", **kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18_vd"], use_ssld)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def ResNet50(**args):
|
def ResNet34(pretrained=False, use_ssld=False, **kwargs):
|
||||||
model = ResNet(layers=50, **args)
|
"""
|
||||||
|
ResNet34
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet34` model depends on args.
|
||||||
|
"""
|
||||||
|
model = ResNet(config=NET_CONFIG["34"], version="vb", **kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34"], use_ssld)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def ResNet101(**args):
|
def ResNet34_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||||
model = ResNet(layers=101, **args)
|
"""
|
||||||
|
ResNet34_vd
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet34_vd` model depends on args.
|
||||||
|
"""
|
||||||
|
model = ResNet(config=NET_CONFIG["34"], version="vd", **kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34_vd"], use_ssld)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def ResNet152(**args):
|
def ResNet50(pretrained=False, use_ssld=False, **kwargs):
|
||||||
model = ResNet(layers=152, **args)
|
"""
|
||||||
|
ResNet50
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet50` model depends on args.
|
||||||
|
"""
|
||||||
|
model = ResNet(config=NET_CONFIG["50"], version="vb", **kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNet50_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNet50_vd
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet50_vd` model depends on args.
|
||||||
|
"""
|
||||||
|
model = ResNet(config=NET_CONFIG["50"], version="vd", **kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50_vd"], use_ssld)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNet101(pretrained=False, use_ssld=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNet101
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet101` model depends on args.
|
||||||
|
"""
|
||||||
|
model = ResNet(config=NET_CONFIG["101"], version="vb", **kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101"], use_ssld)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNet101_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNet101_vd
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet101_vd` model depends on args.
|
||||||
|
"""
|
||||||
|
model = ResNet(config=NET_CONFIG["101"], version="vd", **kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101_vd"], use_ssld)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNet152(pretrained=False, use_ssld=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNet152
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet152` model depends on args.
|
||||||
|
"""
|
||||||
|
model = ResNet(config=NET_CONFIG["152"], version="vb", **kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152"], use_ssld)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNet152_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNet152_vd
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet152_vd` model depends on args.
|
||||||
|
"""
|
||||||
|
model = ResNet(config=NET_CONFIG["152"], version="vd", **kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152_vd"], use_ssld)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ResNet200_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||||
|
"""
|
||||||
|
ResNet200_vd
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `ResNet200_vd` model depends on args.
|
||||||
|
"""
|
||||||
|
model = ResNet(config=NET_CONFIG["200"], version="vd", **kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["ResNet200_vd"], use_ssld)
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Reference in New Issue