fix feature map visualization (#377)

fix feature map visualization
pull/379/head
littletomatodonkey 2020-11-07 22:07:33 +08:00 committed by GitHub
parent c9f8e8c6e7
commit f921b9bd5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 239 additions and 139 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 10 KiB

After

Width:  |  Height:  |  Size: 95 KiB

View File

@ -55,16 +55,30 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
+ `-i`:待预测的图片文件路径,如 `./test.jpeg`
+ `-c`:特征图维度,如 `./resnet50_vd/model`
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/`
+ `--show`:是否展示图片,默认值 False
+ `--interpolation`: 图像插值方式, 默认值 1
+ `--save_path`:保存路径,如:`./tools/`
+ `--use_gpu`:是否使用 GPU 预测默认值True
## 四、结果
输入图片:
![](../../../tools/feature_maps_visualization/test.jpg)
* 输入图片:
输出特征图:
![](../../../docs/images/feature_maps/feature_visualization_input.jpg)
![](../../../tools/feature_maps_visualization/fm.jpg)
* 运行下面的特征图可视化脚本
```
python tools/feature_maps_visualization/fm_vis.py \
-i ./docs/images/feature_maps/feature_visualization_input.jpg \
-c 5 \
-p pretrained/ResNet50_pretrained/ \
--show=True \
--interpolation=1 \
--save_path="./output.png" \
--use_gpu=False \
--load_static_weights=True
```
* 输出特征图保存为`output.png`,如下所示。
![](../../../docs/images/feature_maps/feature_visualization_output.jpg)

View File

@ -11,84 +11,92 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from resnet import ResNet50
import paddle.fluid as fluid
import numpy as np
import cv2
import utils
import argparse
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import paddle
from paddle.distributed import ParallelEnv
from resnet import ResNet50
from ppcls.utils.save_load import load_dygraph_pretrain
def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("-c", "--channel_num", type=int)
parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("--show", type=str2bool, default=False)
parser.add_argument("--interpolation", type=int, default=1)
parser.add_argument("--save_path", type=str)
parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument(
"--load_static_weights",
type=str2bool,
default=False,
help='Whether to load the pretrained weights saved in static mode')
return parser.parse_args()
def create_operators(interpolation=1):
size = 224
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
img_scale = 1.0 / 255.0
decode_op = utils.DecodeImage()
resize_op = utils.ResizeImage(resize_short=256, interpolation=interpolation)
resize_op = utils.ResizeImage(
resize_short=256, interpolation=interpolation)
crop_op = utils.CropImage(size=(size, size))
normalize_op = utils.NormalizeImage(
scale=img_scale, mean=img_mean, std=img_std)
totensor_op = utils.ToTensor()
return [decode_op, resize_op, crop_op, normalize_op, totensor_op]
return [resize_op, crop_op, normalize_op, totensor_op]
def preprocess(fname, ops):
data = open(fname, 'rb').read()
def preprocess(data, ops):
for op in ops:
data = op(data)
return data
def main():
args = parse_args()
operators = create_operators(args.interpolation)
# assign the place
if args.use_gpu:
gpu_id = fluid.dygraph.parallel.Env().dev_id
place = fluid.CUDAPlace(gpu_id)
else:
place = fluid.CPUPlace()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if args.use_gpu else 'cpu'
place = paddle.set_device(place)
#pre_weights_dict = fluid.load_program_state(args.pretrained_model)
with fluid.dygraph.guard(place):
net = ResNet50()
data = preprocess(args.image_file, operators)
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
img = cv2.imread(args.image_file, cv2.IMREAD_COLOR)
data = preprocess(img, operators)
data = np.expand_dims(data, axis=0)
data = fluid.dygraph.to_variable(data)
dy_weights_dict = net.state_dict()
pre_weights_dict_new = {}
for key in dy_weights_dict:
weights_name = dy_weights_dict[key].name
pre_weights_dict_new[key] = pre_weights_dict[weights_name]
net.set_dict(pre_weights_dict_new)
data = paddle.to_tensor(data)
net.eval()
_, fm = net(data)
assert args.channel_num >= 0 and args.channel_num <= fm.shape[1], "the channel is out of the range, should be in {} but got {}".format([0, fm.shape[1]], args.channel_num)
fm = (np.squeeze(fm[0][args.channel_num].numpy())*255).astype(np.uint8)
if fm is not None:
if args.save:
assert args.channel_num >= 0 and args.channel_num <= fm.shape[
1], "the channel is out of the range, should be in {} but got {}".format(
[0, fm.shape[1]], args.channel_num)
fm = (np.squeeze(fm[0][args.channel_num].numpy()) * 255).astype(np.uint8)
fm = cv2.resize(fm, (img.shape[1], img.shape[0]))
if args.save_path is not None:
print("the feature map is saved in path: {}".format(args.save_path))
cv2.imwrite(args.save_path, fm)
if args.show:
cv2.show(fm)
cv2.waitKey(0)
if __name__ == "__main__":
main()

View File

@ -1,20 +1,36 @@
import numpy as np
import argparse
import ast
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import framework
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
import math
import sys
import time
class ConvBNLayer(fluid.dygraph.Layer):
__all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
num_filters,
@ -26,25 +42,25 @@ class ConvBNLayer(fluid.dygraph.Layer):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = BatchNorm(num_filters,
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(bn_name + "_offset"),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance")
def forward(self, inputs):
y = self._conv(inputs)
@ -52,7 +68,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
return y
class BottleneckBlock(fluid.dygraph.Layer):
class BottleneckBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
@ -65,21 +81,21 @@ class BottleneckBlock(fluid.dygraph.Layer):
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name+"_branch2a")
act="relu",
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name+"_branch2b")
act="relu",
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name+"_branch2c")
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
@ -103,57 +119,104 @@ class BottleneckBlock(fluid.dygraph.Layer):
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(y)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class ResNet(fluid.dygraph.Layer):
class BasicBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu",
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=stride,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self, layers=50, class_dim=1000):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [50, 101, 152]
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
self.fm = None
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 50:
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512, 1024]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.feature_map = None
self.conv = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act='relu',
act="relu",
name="conv1")
self.pool2d_max = Pool2D(
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
self.bottleneck_block_list = []
self.block_list = []
if layers >= 50:
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name="res"+str(block+2)+"a"
conv_name = "res" + str(block + 2) + "a"
else:
conv_name="res"+str(block+2)+"b"+str(i)
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name="res"+str(block+2)+chr(97+i)
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
conv_name,
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
@ -161,32 +224,58 @@ class ResNet(fluid.dygraph.Layer):
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
name=conv_name))
self.bottleneck_block_list.append(bottleneck_block)
self.block_list.append(bottleneck_block)
shortcut = True
else:
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
conv_name,
BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
name=conv_name))
self.block_list.append(basic_block)
shortcut = True
self.pool2d_avg = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
self.pool2d_avg = AdaptiveAvgPool2D(1)
self.pool2d_avg_output = num_filters[len(num_filters) - 1] * 4 * 1 * 1
self.pool2d_avg_channels = num_channels[-1] * 2
stdv = 1.0 / math.sqrt(2048 * 1.0)
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
self.out = Linear(self.pool2d_avg_output,
self.out = Linear(
self.pool2d_avg_channels,
class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), name="fc_0.w_0"),
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
self.fm = y
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
self.feature_map = y
for block in self.block_list:
y = block(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y)
return y, self.fm
return y, self.feature_map
def ResNet18(**args):
model = ResNet(layers=18, **args)
return model
def ResNet34(**args):
model = ResNet(layers=34, **args)
return model
def ResNet50(**args):
@ -202,14 +291,3 @@ def ResNet101(**args):
def ResNet152(**args):
model = ResNet(layers=152, **args)
return model
if __name__ == "__main__":
import numpy as np
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
model = ResNet50()
img = np.random.uniform(0, 255, [1, 3, 224, 224]).astype('float32')
img = fluid.dygraph.to_variable(img)
res = model(img)
print(res.shape)