mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Merge pull request #198 from wqz960/PaddleClas_74
add feature maps visualization
This commit is contained in:
commit
184bdd76db
BIN
docs/images/feature_maps/feature_visualization_input.jpg
Normal file
BIN
docs/images/feature_maps/feature_visualization_input.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 30 KiB |
BIN
docs/images/feature_maps/feature_visualization_output.jpg
Normal file
BIN
docs/images/feature_maps/feature_visualization_output.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 10 KiB |
70
docs/zh_CN/feature_visiualization/get_started.md
Normal file
70
docs/zh_CN/feature_visiualization/get_started.md
Normal file
@ -0,0 +1,70 @@
|
||||
# 特征图可视化指南
|
||||
|
||||
## 一、概述
|
||||
|
||||
特征图是输入图片在卷积网络中的特征表达,对特征图的研究可以有利于我们对于模型的理解与设计,所以基于动态图我们使用本工具来可视化特征图。
|
||||
|
||||
## 二、准备工作
|
||||
|
||||
首先需要选定研究的模型,本文设定ResNet50作为研究模型,将resnet.py从[模型库](../../../ppcls/modeling/architecture/)拷贝到当前目录下,并下载预训练模型[预训练模型](../../zh_CN/models/models_intro), 复制resnet50的模型链接,使用下列命令下载并解压预训练模型。
|
||||
|
||||
```bash
|
||||
wget The Link for Pretrained Model
|
||||
tar -xf Downloaded Pretrained Model
|
||||
```
|
||||
|
||||
以resnet50为例:
|
||||
```bash
|
||||
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
|
||||
tar -xf ResNet50_pretrained.tar
|
||||
```
|
||||
|
||||
## 三、修改模型
|
||||
|
||||
找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。
|
||||
|
||||
在fm_vis.py中修改模型的名字。
|
||||
|
||||
在ResNet50的__init__函数中定义self.fm
|
||||
```python
|
||||
self.fm = None
|
||||
```
|
||||
在ResNet50的forward函数中指定特征图
|
||||
```python
|
||||
def forward(self, inputs):
|
||||
y = self.conv(inputs)
|
||||
self.fm = y
|
||||
y = self.pool2d_max(y)
|
||||
for bottleneck_block in self.bottleneck_block_list:
|
||||
y = bottleneck_block(y)
|
||||
y = self.pool2d_avg(y)
|
||||
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
|
||||
y = self.out(y)
|
||||
return y, self.fm
|
||||
```
|
||||
执行函数
|
||||
```bash
|
||||
python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \
|
||||
-c channel_num -p pretrained model \
|
||||
--show whether to show \
|
||||
--interpolation interpolation method\
|
||||
--save_path where to save \
|
||||
--use_gpu whether to use gpu
|
||||
```
|
||||
参数说明:
|
||||
+ `-i`:待预测的图片文件路径,如 `./test.jpeg`
|
||||
+ `-c`:特征图维度,如 `./resnet50_vd/model`
|
||||
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/`
|
||||
+ `--show`:是否展示图片,默认值 False
|
||||
+ `--interpolation`: 图像插值方式, 默认值 1
|
||||
+ `--save_path`:保存路径,如:`./tools/`
|
||||
+ `--use_gpu`:是否使用 GPU 预测,默认值:True
|
||||
|
||||
## 四、结果
|
||||
输入图片:
|
||||
|
||||

|
||||
|
||||
输出特征图:
|
||||
|
||||

|
@ -0,0 +1,2 @@
|
||||
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
|
||||
tar -xf ResNet50_pretrained.tar
|
94
tools/feature_maps_visualization/fm_vis.py
Normal file
94
tools/feature_maps_visualization/fm_vis.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from resnet import ResNet50
|
||||
import paddle.fluid as fluid
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import utils
|
||||
import argparse
|
||||
|
||||
def parse_args():
|
||||
def str2bool(v):
|
||||
return v.lower() in ("true", "t", "1")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--image_file", type=str)
|
||||
parser.add_argument("-c", "--channel_num", type=int)
|
||||
parser.add_argument("-p", "--pretrained_model", type=str)
|
||||
parser.add_argument("--show", type=str2bool, default=False)
|
||||
parser.add_argument("--interpolation", type=int, default=1)
|
||||
parser.add_argument("--save_path", type=str)
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def create_operators(interpolation=1):
|
||||
size = 224
|
||||
img_mean = [0.485, 0.456, 0.406]
|
||||
img_std = [0.229, 0.224, 0.225]
|
||||
img_scale = 1.0 / 255.0
|
||||
|
||||
decode_op = utils.DecodeImage()
|
||||
resize_op = utils.ResizeImage(resize_short=256, interpolation=interpolation)
|
||||
crop_op = utils.CropImage(size=(size, size))
|
||||
normalize_op = utils.NormalizeImage(
|
||||
scale=img_scale, mean=img_mean, std=img_std)
|
||||
totensor_op = utils.ToTensor()
|
||||
|
||||
return [decode_op, resize_op, crop_op, normalize_op, totensor_op]
|
||||
|
||||
|
||||
def preprocess(fname, ops):
|
||||
data = open(fname, 'rb').read()
|
||||
for op in ops:
|
||||
data = op(data)
|
||||
|
||||
return data
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
operators = create_operators(args.interpolation)
|
||||
# assign the place
|
||||
if args.use_gpu:
|
||||
gpu_id = fluid.dygraph.parallel.Env().dev_id
|
||||
place = fluid.CUDAPlace(gpu_id)
|
||||
else:
|
||||
place = fluid.CPUPlace()
|
||||
|
||||
#pre_weights_dict = fluid.load_program_state(args.pretrained_model)
|
||||
with fluid.dygraph.guard(place):
|
||||
net = ResNet50()
|
||||
data = preprocess(args.image_file, operators)
|
||||
data = np.expand_dims(data, axis=0)
|
||||
data = fluid.dygraph.to_variable(data)
|
||||
dy_weights_dict = net.state_dict()
|
||||
pre_weights_dict_new = {}
|
||||
for key in dy_weights_dict:
|
||||
weights_name = dy_weights_dict[key].name
|
||||
pre_weights_dict_new[key] = pre_weights_dict[weights_name]
|
||||
net.set_dict(pre_weights_dict_new)
|
||||
net.eval()
|
||||
_, fm = net(data)
|
||||
assert args.channel_num >= 0 and args.channel_num <= fm.shape[1], "the channel is out of the range, should be in {} but got {}".format([0, fm.shape[1]], args.channel_num)
|
||||
fm = (np.squeeze(fm[0][args.channel_num].numpy())*255).astype(np.uint8)
|
||||
if fm is not None:
|
||||
if args.save:
|
||||
cv2.imwrite(args.save_path, fm)
|
||||
if args.show:
|
||||
cv2.show(fm)
|
||||
cv2.waitKey(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
215
tools/feature_maps_visualization/resnet.py
Normal file
215
tools/feature_maps_visualization/resnet.py
Normal file
@ -0,0 +1,215 @@
|
||||
import numpy as np
|
||||
import argparse
|
||||
import ast
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.layer_helper import LayerHelper
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
|
||||
from paddle.fluid.dygraph.base import to_variable
|
||||
|
||||
from paddle.fluid import framework
|
||||
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
|
||||
class ConvBNLayer(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self._conv = Conv2D(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = BatchNorm(num_filters,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class BottleneckBlock(fluid.dygraph.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
stride,
|
||||
shortcut=True,
|
||||
name=None):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
act='relu',
|
||||
name=name+"_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name+"_branch2b")
|
||||
self.conv2 = ConvBNLayer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters * 4,
|
||||
filter_size=1,
|
||||
act=None,
|
||||
name=name+"_branch2c")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters * 4,
|
||||
filter_size=1,
|
||||
stride=stride,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
self._num_channels_out = num_filters * 4
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
conv2 = self.conv2(conv1)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
|
||||
y = fluid.layers.elementwise_add(x=short, y=conv2)
|
||||
|
||||
layer_helper = LayerHelper(self.full_name(), act='relu')
|
||||
return layer_helper.append_activation(y)
|
||||
|
||||
|
||||
class ResNet(fluid.dygraph.Layer):
|
||||
def __init__(self, layers=50, class_dim=1000):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
supported_layers = [50, 101, 152]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(supported_layers, layers)
|
||||
self.fm = None
|
||||
|
||||
if layers == 50:
|
||||
depth = [3, 4, 6, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
num_channels = [64, 256, 512, 1024]
|
||||
num_filters = [64, 128, 256, 512]
|
||||
|
||||
self.conv = ConvBNLayer(
|
||||
num_channels=3,
|
||||
num_filters=64,
|
||||
filter_size=7,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1")
|
||||
self.pool2d_max = Pool2D(
|
||||
pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_padding=1,
|
||||
pool_type='max')
|
||||
|
||||
self.bottleneck_block_list = []
|
||||
for block in range(len(depth)):
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name="res"+str(block+2)+"a"
|
||||
else:
|
||||
conv_name="res"+str(block+2)+"b"+str(i)
|
||||
else:
|
||||
conv_name="res"+str(block+2)+chr(97+i)
|
||||
bottleneck_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BottleneckBlock(
|
||||
num_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block] * 4,
|
||||
num_filters=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
name=conv_name))
|
||||
self.bottleneck_block_list.append(bottleneck_block)
|
||||
shortcut = True
|
||||
|
||||
self.pool2d_avg = Pool2D(
|
||||
pool_size=7, pool_type='avg', global_pooling=True)
|
||||
|
||||
self.pool2d_avg_output = num_filters[len(num_filters) - 1] * 4 * 1 * 1
|
||||
|
||||
stdv = 1.0 / math.sqrt(2048 * 1.0)
|
||||
|
||||
self.out = Linear(self.pool2d_avg_output,
|
||||
class_dim,
|
||||
param_attr=ParamAttr(
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv), name="fc_0.w_0"),
|
||||
bias_attr=ParamAttr(name="fc_0.b_0"))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv(inputs)
|
||||
y = self.pool2d_max(y)
|
||||
self.fm = y
|
||||
for bottleneck_block in self.bottleneck_block_list:
|
||||
y = bottleneck_block(y)
|
||||
y = self.pool2d_avg(y)
|
||||
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
|
||||
y = self.out(y)
|
||||
return y, self.fm
|
||||
|
||||
|
||||
def ResNet50(**args):
|
||||
model = ResNet(layers=50, **args)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet101(**args):
|
||||
model = ResNet(layers=101, **args)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet152(**args):
|
||||
model = ResNet(layers=152, **args)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import numpy as np
|
||||
place = fluid.CPUPlace()
|
||||
with fluid.dygraph.guard(place):
|
||||
model = ResNet50()
|
||||
img = np.random.uniform(0, 255, [1, 3, 224, 224]).astype('float32')
|
||||
img = fluid.dygraph.to_variable(img)
|
||||
res = model(img)
|
||||
print(res.shape)
|
85
tools/feature_maps_visualization/utils.py
Normal file
85
tools/feature_maps_visualization/utils.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
def __init__(self, to_rgb=True):
|
||||
self.to_rgb = to_rgb
|
||||
|
||||
def __call__(self, img):
|
||||
data = np.frombuffer(img, dtype='uint8')
|
||||
img = cv2.imdecode(data, 1)
|
||||
if self.to_rgb:
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
|
||||
img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class ResizeImage(object):
|
||||
def __init__(self, resize_short=None, interpolation=1):
|
||||
self.resize_short = resize_short
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
img_h, img_w = img.shape[:2]
|
||||
percent = float(self.resize_short) / min(img_w, img_h)
|
||||
w = int(round(img_w * percent))
|
||||
h = int(round(img_h * percent))
|
||||
return cv2.resize(img, (w, h), interpolation=self.interpolation)
|
||||
|
||||
|
||||
class CropImage(object):
|
||||
def __init__(self, size):
|
||||
if type(size) is int:
|
||||
self.size = (size, size)
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
w, h = self.size
|
||||
img_h, img_w = img.shape[:2]
|
||||
w_start = (img_w - w) // 2
|
||||
h_start = (img_h - h) // 2
|
||||
|
||||
w_end = w_start + w
|
||||
h_end = h_start + h
|
||||
return img[h_start:h_end, w_start:w_end, :]
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
def __init__(self, scale=None, mean=None, std=None):
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||
self.std = np.array(std).reshape(shape).astype('float32')
|
||||
|
||||
def __call__(self, img):
|
||||
return (img.astype('float32') * self.scale - self.mean) / self.std
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, img):
|
||||
img = img.transpose((2, 0, 1))
|
||||
return img
|
Loading…
x
Reference in New Issue
Block a user