diff --git a/docs/images/feature_maps/feature_visualization_input.jpg b/docs/images/feature_maps/feature_visualization_input.jpg new file mode 100644 index 000000000..da9d1a756 Binary files /dev/null and b/docs/images/feature_maps/feature_visualization_input.jpg differ diff --git a/docs/images/feature_maps/feature_visualization_output.jpg b/docs/images/feature_maps/feature_visualization_output.jpg new file mode 100644 index 000000000..18b99f96f Binary files /dev/null and b/docs/images/feature_maps/feature_visualization_output.jpg differ diff --git a/docs/zh_CN/feature_visiualization/get_started.md b/docs/zh_CN/feature_visiualization/get_started.md new file mode 100644 index 000000000..f80a2f848 --- /dev/null +++ b/docs/zh_CN/feature_visiualization/get_started.md @@ -0,0 +1,70 @@ +# 特征图可视化指南 + +## 一、概述 + +特征图是输入图片在卷积网络中的特征表达,对特征图的研究可以有利于我们对于模型的理解与设计,所以基于动态图我们使用本工具来可视化特征图。 + +## 二、准备工作 + +首先需要选定研究的模型,本文设定ResNet50作为研究模型,将resnet.py从[模型库](../../../ppcls/modeling/architecture/)拷贝到当前目录下,并下载预训练模型[预训练模型](../../zh_CN/models/models_intro), 复制resnet50的模型链接,使用下列命令下载并解压预训练模型。 + +```bash +wget The Link for Pretrained Model +tar -xf Downloaded Pretrained Model +``` + +以resnet50为例: +```bash +wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar +tar -xf ResNet50_pretrained.tar +``` + +## 三、修改模型 + +找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。 + +在fm_vis.py中修改模型的名字。 + +在ResNet50的__init__函数中定义self.fm +```python +self.fm = None +``` +在ResNet50的forward函数中指定特征图 +```python +def forward(self, inputs): + y = self.conv(inputs) + self.fm = y + y = self.pool2d_max(y) + for bottleneck_block in self.bottleneck_block_list: + y = bottleneck_block(y) + y = self.pool2d_avg(y) + y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output]) + y = self.out(y) + return y, self.fm +``` +执行函数 +```bash +python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \ + -c channel_num -p pretrained model \ + --show whether to show \ + --interpolation interpolation method\ + --save_path where to save \ + --use_gpu whether to use gpu +``` +参数说明: ++ `-i`:待预测的图片文件路径,如 `./test.jpeg` ++ `-c`:特征图维度,如 `./resnet50_vd/model` ++ `-p`:权重文件路径,如 `./ResNet50_pretrained/` ++ `--show`:是否展示图片,默认值 False ++ `--interpolation`: 图像插值方式, 默认值 1 ++ `--save_path`:保存路径,如:`./tools/` ++ `--use_gpu`:是否使用 GPU 预测,默认值:True + +## 四、结果 +输入图片: + +![](../../../tools/feature_maps_visualization/test.jpg) + +输出特征图: + +![](../../../tools/feature_maps_visualization/fm.jpg) diff --git a/tools/feature_maps_visualization/download_resnet50_pretrained.sh b/tools/feature_maps_visualization/download_resnet50_pretrained.sh new file mode 100644 index 000000000..286c2400a --- /dev/null +++ b/tools/feature_maps_visualization/download_resnet50_pretrained.sh @@ -0,0 +1,2 @@ +wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar +tar -xf ResNet50_pretrained.tar \ No newline at end of file diff --git a/tools/feature_maps_visualization/fm_vis.py b/tools/feature_maps_visualization/fm_vis.py new file mode 100644 index 000000000..b389d833c --- /dev/null +++ b/tools/feature_maps_visualization/fm_vis.py @@ -0,0 +1,94 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from resnet import ResNet50 +import paddle.fluid as fluid + +import numpy as np +import cv2 +import utils +import argparse + +def parse_args(): + def str2bool(v): + return v.lower() in ("true", "t", "1") + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--image_file", type=str) + parser.add_argument("-c", "--channel_num", type=int) + parser.add_argument("-p", "--pretrained_model", type=str) + parser.add_argument("--show", type=str2bool, default=False) + parser.add_argument("--interpolation", type=int, default=1) + parser.add_argument("--save_path", type=str) + parser.add_argument("--use_gpu", type=str2bool, default=True) + + return parser.parse_args() + +def create_operators(interpolation=1): + size = 224 + img_mean = [0.485, 0.456, 0.406] + img_std = [0.229, 0.224, 0.225] + img_scale = 1.0 / 255.0 + + decode_op = utils.DecodeImage() + resize_op = utils.ResizeImage(resize_short=256, interpolation=interpolation) + crop_op = utils.CropImage(size=(size, size)) + normalize_op = utils.NormalizeImage( + scale=img_scale, mean=img_mean, std=img_std) + totensor_op = utils.ToTensor() + + return [decode_op, resize_op, crop_op, normalize_op, totensor_op] + + +def preprocess(fname, ops): + data = open(fname, 'rb').read() + for op in ops: + data = op(data) + + return data + +def main(): + args = parse_args() + operators = create_operators(args.interpolation) + # assign the place + if args.use_gpu: + gpu_id = fluid.dygraph.parallel.Env().dev_id + place = fluid.CUDAPlace(gpu_id) + else: + place = fluid.CPUPlace() + + #pre_weights_dict = fluid.load_program_state(args.pretrained_model) + with fluid.dygraph.guard(place): + net = ResNet50() + data = preprocess(args.image_file, operators) + data = np.expand_dims(data, axis=0) + data = fluid.dygraph.to_variable(data) + dy_weights_dict = net.state_dict() + pre_weights_dict_new = {} + for key in dy_weights_dict: + weights_name = dy_weights_dict[key].name + pre_weights_dict_new[key] = pre_weights_dict[weights_name] + net.set_dict(pre_weights_dict_new) + net.eval() + _, fm = net(data) + assert args.channel_num >= 0 and args.channel_num <= fm.shape[1], "the channel is out of the range, should be in {} but got {}".format([0, fm.shape[1]], args.channel_num) + fm = (np.squeeze(fm[0][args.channel_num].numpy())*255).astype(np.uint8) + if fm is not None: + if args.save: + cv2.imwrite(args.save_path, fm) + if args.show: + cv2.show(fm) + cv2.waitKey(0) + +if __name__ == "__main__": + main() diff --git a/tools/feature_maps_visualization/resnet.py b/tools/feature_maps_visualization/resnet.py new file mode 100644 index 000000000..d3f230da6 --- /dev/null +++ b/tools/feature_maps_visualization/resnet.py @@ -0,0 +1,215 @@ +import numpy as np +import argparse +import ast +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear +from paddle.fluid.dygraph.base import to_variable + +from paddle.fluid import framework + +import math +import sys +import time + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self._batch_norm = BatchNorm(num_filters, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class BottleneckBlock(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + stride, + shortcut=True, + name=None): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name+"_branch2a") + self.conv1 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name+"_branch2b") + self.conv2 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name+"_branch2c") + + if not shortcut: + self.short = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters * 4, + filter_size=1, + stride=stride, + name=name + "_branch1") + + self.shortcut = shortcut + + self._num_channels_out = num_filters * 4 + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + + y = fluid.layers.elementwise_add(x=short, y=conv2) + + layer_helper = LayerHelper(self.full_name(), act='relu') + return layer_helper.append_activation(y) + + +class ResNet(fluid.dygraph.Layer): + def __init__(self, layers=50, class_dim=1000): + super(ResNet, self).__init__() + + self.layers = layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + self.fm = None + + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_channels = [64, 256, 512, 1024] + num_filters = [64, 128, 256, 512] + + self.conv = ConvBNLayer( + num_channels=3, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name="conv1") + self.pool2d_max = Pool2D( + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + self.bottleneck_block_list = [] + for block in range(len(depth)): + shortcut = False + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name="res"+str(block+2)+"a" + else: + conv_name="res"+str(block+2)+"b"+str(i) + else: + conv_name="res"+str(block+2)+chr(97+i) + bottleneck_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BottleneckBlock( + num_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + name=conv_name)) + self.bottleneck_block_list.append(bottleneck_block) + shortcut = True + + self.pool2d_avg = Pool2D( + pool_size=7, pool_type='avg', global_pooling=True) + + self.pool2d_avg_output = num_filters[len(num_filters) - 1] * 4 * 1 * 1 + + stdv = 1.0 / math.sqrt(2048 * 1.0) + + self.out = Linear(self.pool2d_avg_output, + class_dim, + param_attr=ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), name="fc_0.w_0"), + bias_attr=ParamAttr(name="fc_0.b_0")) + + def forward(self, inputs): + y = self.conv(inputs) + y = self.pool2d_max(y) + self.fm = y + for bottleneck_block in self.bottleneck_block_list: + y = bottleneck_block(y) + y = self.pool2d_avg(y) + y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output]) + y = self.out(y) + return y, self.fm + + +def ResNet50(**args): + model = ResNet(layers=50, **args) + return model + + +def ResNet101(**args): + model = ResNet(layers=101, **args) + return model + + +def ResNet152(**args): + model = ResNet(layers=152, **args) + return model + + +if __name__ == "__main__": + import numpy as np + place = fluid.CPUPlace() + with fluid.dygraph.guard(place): + model = ResNet50() + img = np.random.uniform(0, 255, [1, 3, 224, 224]).astype('float32') + img = fluid.dygraph.to_variable(img) + res = model(img) + print(res.shape) diff --git a/tools/feature_maps_visualization/utils.py b/tools/feature_maps_visualization/utils.py new file mode 100644 index 000000000..7c7014932 --- /dev/null +++ b/tools/feature_maps_visualization/utils.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +class DecodeImage(object): + def __init__(self, to_rgb=True): + self.to_rgb = to_rgb + + def __call__(self, img): + data = np.frombuffer(img, dtype='uint8') + img = cv2.imdecode(data, 1) + if self.to_rgb: + assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( + img.shape) + img = img[:, :, ::-1] + + return img + + +class ResizeImage(object): + def __init__(self, resize_short=None, interpolation=1): + self.resize_short = resize_short + self.interpolation = interpolation + + def __call__(self, img): + img_h, img_w = img.shape[:2] + percent = float(self.resize_short) / min(img_w, img_h) + w = int(round(img_w * percent)) + h = int(round(img_h * percent)) + return cv2.resize(img, (w, h), interpolation=self.interpolation) + + +class CropImage(object): + def __init__(self, size): + if type(size) is int: + self.size = (size, size) + else: + self.size = size + + def __call__(self, img): + w, h = self.size + img_h, img_w = img.shape[:2] + w_start = (img_w - w) // 2 + h_start = (img_h - h) // 2 + + w_end = w_start + w + h_end = h_start + h + return img[h_start:h_end, w_start:w_end, :] + + +class NormalizeImage(object): + def __init__(self, scale=None, mean=None, std=None): + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, img): + return (img.astype('float32') * self.scale - self.mean) / self.std + + +class ToTensor(object): + def __init__(self): + pass + + def __call__(self, img): + img = img.transpose((2, 0, 1)) + return img