fix feature map visualization (#377)

fix feature map visualization
pull/379/head
littletomatodonkey 2020-11-07 22:07:33 +08:00 committed by GitHub
parent c9f8e8c6e7
commit f921b9bd5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 239 additions and 139 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 10 KiB

After

Width:  |  Height:  |  Size: 95 KiB

View File

@ -55,16 +55,30 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
+ `-i`:待预测的图片文件路径,如 `./test.jpeg` + `-i`:待预测的图片文件路径,如 `./test.jpeg`
+ `-c`:特征图维度,如 `./resnet50_vd/model` + `-c`:特征图维度,如 `./resnet50_vd/model`
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/` + `-p`:权重文件路径,如 `./ResNet50_pretrained/`
+ `--show`:是否展示图片,默认值 False
+ `--interpolation`: 图像插值方式, 默认值 1 + `--interpolation`: 图像插值方式, 默认值 1
+ `--save_path`:保存路径,如:`./tools/` + `--save_path`:保存路径,如:`./tools/`
+ `--use_gpu`:是否使用 GPU 预测默认值True + `--use_gpu`:是否使用 GPU 预测默认值True
## 四、结果 ## 四、结果
输入图片:
![](../../../tools/feature_maps_visualization/test.jpg) * 输入图片:
输出特征图: ![](../../../docs/images/feature_maps/feature_visualization_input.jpg)
![](../../../tools/feature_maps_visualization/fm.jpg) * 运行下面的特征图可视化脚本
```
python tools/feature_maps_visualization/fm_vis.py \
-i ./docs/images/feature_maps/feature_visualization_input.jpg \
-c 5 \
-p pretrained/ResNet50_pretrained/ \
--show=True \
--interpolation=1 \
--save_path="./output.png" \
--use_gpu=False \
--load_static_weights=True
```
* 输出特征图保存为`output.png`,如下所示。
![](../../../docs/images/feature_maps/feature_visualization_output.jpg)

View File

@ -11,84 +11,92 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
from resnet import ResNet50 import cv2
import paddle.fluid as fluid
import numpy as np
import cv2
import utils import utils
import argparse import argparse
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import paddle
from paddle.distributed import ParallelEnv
from resnet import ResNet50
from ppcls.utils.save_load import load_dygraph_pretrain
def parse_args(): def parse_args():
def str2bool(v): def str2bool(v):
return v.lower() in ("true", "t", "1") return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str) parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("-c", "--channel_num", type=int) parser.add_argument("-c", "--channel_num", type=int)
parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("--show", type=str2bool, default=False) parser.add_argument("--show", type=str2bool, default=False)
parser.add_argument("--interpolation", type=int, default=1) parser.add_argument("--interpolation", type=int, default=1)
parser.add_argument("--save_path", type=str) parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument(
"--load_static_weights",
type=str2bool,
default=False,
help='Whether to load the pretrained weights saved in static mode')
return parser.parse_args() return parser.parse_args()
def create_operators(interpolation=1): def create_operators(interpolation=1):
size = 224 size = 224
img_mean = [0.485, 0.456, 0.406] img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225] img_std = [0.229, 0.224, 0.225]
img_scale = 1.0 / 255.0 img_scale = 1.0 / 255.0
decode_op = utils.DecodeImage() resize_op = utils.ResizeImage(
resize_op = utils.ResizeImage(resize_short=256, interpolation=interpolation) resize_short=256, interpolation=interpolation)
crop_op = utils.CropImage(size=(size, size)) crop_op = utils.CropImage(size=(size, size))
normalize_op = utils.NormalizeImage( normalize_op = utils.NormalizeImage(
scale=img_scale, mean=img_mean, std=img_std) scale=img_scale, mean=img_mean, std=img_std)
totensor_op = utils.ToTensor() totensor_op = utils.ToTensor()
return [decode_op, resize_op, crop_op, normalize_op, totensor_op] return [resize_op, crop_op, normalize_op, totensor_op]
def preprocess(fname, ops): def preprocess(data, ops):
data = open(fname, 'rb').read()
for op in ops: for op in ops:
data = op(data) data = op(data)
return data return data
def main(): def main():
args = parse_args() args = parse_args()
operators = create_operators(args.interpolation) operators = create_operators(args.interpolation)
# assign the place # assign the place
if args.use_gpu: place = 'gpu:{}'.format(ParallelEnv().dev_id) if args.use_gpu else 'cpu'
gpu_id = fluid.dygraph.parallel.Env().dev_id place = paddle.set_device(place)
place = fluid.CUDAPlace(gpu_id)
else: net = ResNet50()
place = fluid.CPUPlace() load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
img = cv2.imread(args.image_file, cv2.IMREAD_COLOR)
data = preprocess(img, operators)
data = np.expand_dims(data, axis=0)
data = paddle.to_tensor(data)
net.eval()
_, fm = net(data)
assert args.channel_num >= 0 and args.channel_num <= fm.shape[
1], "the channel is out of the range, should be in {} but got {}".format(
[0, fm.shape[1]], args.channel_num)
fm = (np.squeeze(fm[0][args.channel_num].numpy()) * 255).astype(np.uint8)
fm = cv2.resize(fm, (img.shape[1], img.shape[0]))
if args.save_path is not None:
print("the feature map is saved in path: {}".format(args.save_path))
cv2.imwrite(args.save_path, fm)
#pre_weights_dict = fluid.load_program_state(args.pretrained_model)
with fluid.dygraph.guard(place):
net = ResNet50()
data = preprocess(args.image_file, operators)
data = np.expand_dims(data, axis=0)
data = fluid.dygraph.to_variable(data)
dy_weights_dict = net.state_dict()
pre_weights_dict_new = {}
for key in dy_weights_dict:
weights_name = dy_weights_dict[key].name
pre_weights_dict_new[key] = pre_weights_dict[weights_name]
net.set_dict(pre_weights_dict_new)
net.eval()
_, fm = net(data)
assert args.channel_num >= 0 and args.channel_num <= fm.shape[1], "the channel is out of the range, should be in {} but got {}".format([0, fm.shape[1]], args.channel_num)
fm = (np.squeeze(fm[0][args.channel_num].numpy())*255).astype(np.uint8)
if fm is not None:
if args.save:
cv2.imwrite(args.save_path, fm)
if args.show:
cv2.show(fm)
cv2.waitKey(0)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,20 +1,36 @@
import numpy as np # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
import argparse #
import ast # Licensed under the Apache License, Version 2.0 (the "License");
import paddle # you may not use this file except in compliance with the License.
import paddle.fluid as fluid # You may obtain a copy of the License at
from paddle.fluid.param_attr import ParamAttr #
from paddle.fluid.layer_helper import LayerHelper # http://www.apache.org/licenses/LICENSE-2.0
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear #
from paddle.fluid.dygraph.base import to_variable # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import framework from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
import math import math
import sys
import time
class ConvBNLayer(fluid.dygraph.Layer): __all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
class ConvBNLayer(nn.Layer):
def __init__(self, def __init__(self,
num_channels, num_channels,
num_filters, num_filters,
@ -26,25 +42,25 @@ class ConvBNLayer(fluid.dygraph.Layer):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self._conv = Conv2D( self._conv = Conv2D(
num_channels=num_channels, in_channels=num_channels,
num_filters=num_filters, out_channels=num_filters,
filter_size=filter_size, kernel_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, weight_attr=ParamAttr(name=name + "_weights"),
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False) bias_attr=False)
if name == "conv1": if name == "conv1":
bn_name = "bn_" + name bn_name = "bn_" + name
else: else:
bn_name = "bn" + name[3:] bn_name = "bn" + name[3:]
self._batch_norm = BatchNorm(num_filters, self._batch_norm = BatchNorm(
act=act, num_filters,
param_attr=ParamAttr(name=bn_name + '_scale'), act=act,
bias_attr=ParamAttr(bn_name + '_offset'), param_attr=ParamAttr(name=bn_name + "_scale"),
moving_mean_name=bn_name + '_mean', bias_attr=ParamAttr(bn_name + "_offset"),
moving_variance_name=bn_name + '_variance') moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance")
def forward(self, inputs): def forward(self, inputs):
y = self._conv(inputs) y = self._conv(inputs)
@ -52,7 +68,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
return y return y
class BottleneckBlock(fluid.dygraph.Layer): class BottleneckBlock(nn.Layer):
def __init__(self, def __init__(self,
num_channels, num_channels,
num_filters, num_filters,
@ -65,21 +81,21 @@ class BottleneckBlock(fluid.dygraph.Layer):
num_channels=num_channels, num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=1, filter_size=1,
act='relu', act="relu",
name=name+"_branch2a") name=name + "_branch2a")
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
num_filters=num_filters, num_filters=num_filters,
filter_size=3, filter_size=3,
stride=stride, stride=stride,
act='relu', act="relu",
name=name+"_branch2b") name=name + "_branch2b")
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
num_filters=num_filters * 4, num_filters=num_filters * 4,
filter_size=1, filter_size=1,
act=None, act=None,
name=name+"_branch2c") name=name + "_branch2c")
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
@ -103,90 +119,163 @@ class BottleneckBlock(fluid.dygraph.Layer):
else: else:
short = self.short(inputs) short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv2) y = paddle.add(x=short, y=conv2)
y = F.relu(y)
layer_helper = LayerHelper(self.full_name(), act='relu') return y
return layer_helper.append_activation(y)
class ResNet(fluid.dygraph.Layer): class BasicBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu",
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=stride,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self, layers=50, class_dim=1000): def __init__(self, layers=50, class_dim=1000):
super(ResNet, self).__init__() super(ResNet, self).__init__()
self.layers = layers self.layers = layers
supported_layers = [50, 101, 152] supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \ assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers) "supported layers are {} but input layer is {}".format(
self.fm = None supported_layers, layers)
if layers == 50: if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3] depth = [3, 4, 6, 3]
elif layers == 101: elif layers == 101:
depth = [3, 4, 23, 3] depth = [3, 4, 23, 3]
elif layers == 152: elif layers == 152:
depth = [3, 8, 36, 3] depth = [3, 8, 36, 3]
num_channels = [64, 256, 512, 1024] num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512] num_filters = [64, 128, 256, 512]
self.feature_map = None
self.conv = ConvBNLayer( self.conv = ConvBNLayer(
num_channels=3, num_channels=3,
num_filters=64, num_filters=64,
filter_size=7, filter_size=7,
stride=2, stride=2,
act='relu', act="relu",
name="conv1") name="conv1")
self.pool2d_max = Pool2D( self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
self.bottleneck_block_list = [] self.block_list = []
for block in range(len(depth)): if layers >= 50:
shortcut = False for block in range(len(depth)):
for i in range(depth[block]): shortcut = False
if layers in [101, 152] and block == 2: for i in range(depth[block]):
if i == 0: if layers in [101, 152] and block == 2:
conv_name="res"+str(block+2)+"a" if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else: else:
conv_name="res"+str(block+2)+"b"+str(i) conv_name = "res" + str(block + 2) + chr(97 + i)
else: bottleneck_block = self.add_sublayer(
conv_name="res"+str(block+2)+chr(97+i) conv_name,
bottleneck_block = self.add_sublayer( BottleneckBlock(
'bb_%d_%d' % (block, i), num_channels=num_channels[block]
BottleneckBlock( if i == 0 else num_filters[block] * 4,
num_channels=num_channels[block] num_filters=num_filters[block],
if i == 0 else num_filters[block] * 4, stride=2 if i == 0 and block != 0 else 1,
num_filters=num_filters[block], shortcut=shortcut,
stride=2 if i == 0 and block != 0 else 1, name=conv_name))
shortcut=shortcut, self.block_list.append(bottleneck_block)
name=conv_name)) shortcut = True
self.bottleneck_block_list.append(bottleneck_block) else:
shortcut = True for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
conv_name,
BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
name=conv_name))
self.block_list.append(basic_block)
shortcut = True
self.pool2d_avg = Pool2D( self.pool2d_avg = AdaptiveAvgPool2D(1)
pool_size=7, pool_type='avg', global_pooling=True)
self.pool2d_avg_output = num_filters[len(num_filters) - 1] * 4 * 1 * 1 self.pool2d_avg_channels = num_channels[-1] * 2
stdv = 1.0 / math.sqrt(2048 * 1.0) stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
self.out = Linear(self.pool2d_avg_output, self.out = Linear(
class_dim, self.pool2d_avg_channels,
param_attr=ParamAttr( class_dim,
initializer=fluid.initializer.Uniform(-stdv, stdv), name="fc_0.w_0"), weight_attr=ParamAttr(
bias_attr=ParamAttr(name="fc_0.b_0")) initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
def forward(self, inputs): def forward(self, inputs):
y = self.conv(inputs) y = self.conv(inputs)
y = self.pool2d_max(y) y = self.pool2d_max(y)
self.fm = y self.feature_map = y
for bottleneck_block in self.bottleneck_block_list: for block in self.block_list:
y = bottleneck_block(y) y = block(y)
y = self.pool2d_avg(y) y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output]) y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y) y = self.out(y)
return y, self.fm return y, self.feature_map
def ResNet18(**args):
model = ResNet(layers=18, **args)
return model
def ResNet34(**args):
model = ResNet(layers=34, **args)
return model
def ResNet50(**args): def ResNet50(**args):
@ -202,14 +291,3 @@ def ResNet101(**args):
def ResNet152(**args): def ResNet152(**args):
model = ResNet(layers=152, **args) model = ResNet(layers=152, **args)
return model return model
if __name__ == "__main__":
import numpy as np
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
model = ResNet50()
img = np.random.uniform(0, 255, [1, 3, 224, 224]).astype('float32')
img = fluid.dygraph.to_variable(img)
res = model(img)
print(res.shape)