diff --git a/docs/images/feature_maps/feature_visualization_input.jpg b/docs/images/feature_maps/feature_visualization_input.jpg index da9d1a756..6eb5b7143 100644 Binary files a/docs/images/feature_maps/feature_visualization_input.jpg and b/docs/images/feature_maps/feature_visualization_input.jpg differ diff --git a/docs/images/feature_maps/feature_visualization_output.jpg b/docs/images/feature_maps/feature_visualization_output.jpg index 18b99f96f..6d7d2a6e8 100644 Binary files a/docs/images/feature_maps/feature_visualization_output.jpg and b/docs/images/feature_maps/feature_visualization_output.jpg differ diff --git a/docs/zh_CN/feature_visiualization/get_started.md b/docs/zh_CN/feature_visiualization/get_started.md index f80a2f848..e63c08b84 100644 --- a/docs/zh_CN/feature_visiualization/get_started.md +++ b/docs/zh_CN/feature_visiualization/get_started.md @@ -55,16 +55,30 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test + `-i`:待预测的图片文件路径,如 `./test.jpeg` + `-c`:特征图维度,如 `./resnet50_vd/model` + `-p`:权重文件路径,如 `./ResNet50_pretrained/` -+ `--show`:是否展示图片,默认值 False + `--interpolation`: 图像插值方式, 默认值 1 + `--save_path`:保存路径,如:`./tools/` + `--use_gpu`:是否使用 GPU 预测,默认值:True ## 四、结果 -输入图片: -![](../../../tools/feature_maps_visualization/test.jpg) +* 输入图片: -输出特征图: +![](../../../docs/images/feature_maps/feature_visualization_input.jpg) -![](../../../tools/feature_maps_visualization/fm.jpg) +* 运行下面的特征图可视化脚本 + +``` +python tools/feature_maps_visualization/fm_vis.py \ + -i ./docs/images/feature_maps/feature_visualization_input.jpg \ + -c 5 \ + -p pretrained/ResNet50_pretrained/ \ + --show=True \ + --interpolation=1 \ + --save_path="./output.png" \ + --use_gpu=False \ + --load_static_weights=True +``` + +* 输出特征图保存为`output.png`,如下所示。 + +![](../../../docs/images/feature_maps/feature_visualization_output.jpg) diff --git a/tools/feature_maps_visualization/fm_vis.py b/tools/feature_maps_visualization/fm_vis.py index b389d833c..d8ee125f0 100644 --- a/tools/feature_maps_visualization/fm_vis.py +++ b/tools/feature_maps_visualization/fm_vis.py @@ -11,84 +11,92 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from resnet import ResNet50 -import paddle.fluid as fluid - -import numpy as np -import cv2 +import numpy as np +import cv2 import utils import argparse +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +import paddle +from paddle.distributed import ParallelEnv + +from resnet import ResNet50 +from ppcls.utils.save_load import load_dygraph_pretrain + def parse_args(): def str2bool(v): return v.lower() in ("true", "t", "1") + parser = argparse.ArgumentParser() parser.add_argument("-i", "--image_file", type=str) parser.add_argument("-c", "--channel_num", type=int) parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("--show", type=str2bool, default=False) parser.add_argument("--interpolation", type=int, default=1) - parser.add_argument("--save_path", type=str) + parser.add_argument("--save_path", type=str, default=None) parser.add_argument("--use_gpu", type=str2bool, default=True) - + parser.add_argument( + "--load_static_weights", + type=str2bool, + default=False, + help='Whether to load the pretrained weights saved in static mode') + return parser.parse_args() + def create_operators(interpolation=1): size = 224 img_mean = [0.485, 0.456, 0.406] img_std = [0.229, 0.224, 0.225] img_scale = 1.0 / 255.0 - decode_op = utils.DecodeImage() - resize_op = utils.ResizeImage(resize_short=256, interpolation=interpolation) + resize_op = utils.ResizeImage( + resize_short=256, interpolation=interpolation) crop_op = utils.CropImage(size=(size, size)) normalize_op = utils.NormalizeImage( scale=img_scale, mean=img_mean, std=img_std) totensor_op = utils.ToTensor() - return [decode_op, resize_op, crop_op, normalize_op, totensor_op] + return [resize_op, crop_op, normalize_op, totensor_op] -def preprocess(fname, ops): - data = open(fname, 'rb').read() +def preprocess(data, ops): for op in ops: data = op(data) - return data + def main(): args = parse_args() operators = create_operators(args.interpolation) # assign the place - if args.use_gpu: - gpu_id = fluid.dygraph.parallel.Env().dev_id - place = fluid.CUDAPlace(gpu_id) - else: - place = fluid.CPUPlace() + place = 'gpu:{}'.format(ParallelEnv().dev_id) if args.use_gpu else 'cpu' + place = paddle.set_device(place) + + net = ResNet50() + load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights) + + img = cv2.imread(args.image_file, cv2.IMREAD_COLOR) + data = preprocess(img, operators) + data = np.expand_dims(data, axis=0) + data = paddle.to_tensor(data) + net.eval() + _, fm = net(data) + assert args.channel_num >= 0 and args.channel_num <= fm.shape[ + 1], "the channel is out of the range, should be in {} but got {}".format( + [0, fm.shape[1]], args.channel_num) + + fm = (np.squeeze(fm[0][args.channel_num].numpy()) * 255).astype(np.uint8) + fm = cv2.resize(fm, (img.shape[1], img.shape[0])) + if args.save_path is not None: + print("the feature map is saved in path: {}".format(args.save_path)) + cv2.imwrite(args.save_path, fm) + - #pre_weights_dict = fluid.load_program_state(args.pretrained_model) - with fluid.dygraph.guard(place): - net = ResNet50() - data = preprocess(args.image_file, operators) - data = np.expand_dims(data, axis=0) - data = fluid.dygraph.to_variable(data) - dy_weights_dict = net.state_dict() - pre_weights_dict_new = {} - for key in dy_weights_dict: - weights_name = dy_weights_dict[key].name - pre_weights_dict_new[key] = pre_weights_dict[weights_name] - net.set_dict(pre_weights_dict_new) - net.eval() - _, fm = net(data) - assert args.channel_num >= 0 and args.channel_num <= fm.shape[1], "the channel is out of the range, should be in {} but got {}".format([0, fm.shape[1]], args.channel_num) - fm = (np.squeeze(fm[0][args.channel_num].numpy())*255).astype(np.uint8) - if fm is not None: - if args.save: - cv2.imwrite(args.save_path, fm) - if args.show: - cv2.show(fm) - cv2.waitKey(0) - if __name__ == "__main__": main() diff --git a/tools/feature_maps_visualization/resnet.py b/tools/feature_maps_visualization/resnet.py index d3f230da6..b0171f849 100644 --- a/tools/feature_maps_visualization/resnet.py +++ b/tools/feature_maps_visualization/resnet.py @@ -1,20 +1,36 @@ -import numpy as np -import argparse -import ast -import paddle -import paddle.fluid as fluid -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear -from paddle.fluid.dygraph.base import to_variable +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from paddle.fluid import framework +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D +from paddle.nn.initializer import Uniform import math -import sys -import time -class ConvBNLayer(fluid.dygraph.Layer): +__all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"] + + +class ConvBNLayer(nn.Layer): def __init__(self, num_channels, num_filters, @@ -26,25 +42,25 @@ class ConvBNLayer(fluid.dygraph.Layer): super(ConvBNLayer, self).__init__() self._conv = Conv2D( - num_channels=num_channels, - num_filters=num_filters, - filter_size=filter_size, + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, stride=stride, padding=(filter_size - 1) // 2, groups=groups, - act=None, - param_attr=ParamAttr(name=name + "_weights"), + weight_attr=ParamAttr(name=name + "_weights"), bias_attr=False) if name == "conv1": bn_name = "bn_" + name else: bn_name = "bn" + name[3:] - self._batch_norm = BatchNorm(num_filters, - act=act, - param_attr=ParamAttr(name=bn_name + '_scale'), - bias_attr=ParamAttr(bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') + self._batch_norm = BatchNorm( + num_filters, + act=act, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(bn_name + "_offset"), + moving_mean_name=bn_name + "_mean", + moving_variance_name=bn_name + "_variance") def forward(self, inputs): y = self._conv(inputs) @@ -52,7 +68,7 @@ class ConvBNLayer(fluid.dygraph.Layer): return y -class BottleneckBlock(fluid.dygraph.Layer): +class BottleneckBlock(nn.Layer): def __init__(self, num_channels, num_filters, @@ -65,21 +81,21 @@ class BottleneckBlock(fluid.dygraph.Layer): num_channels=num_channels, num_filters=num_filters, filter_size=1, - act='relu', - name=name+"_branch2a") + act="relu", + name=name + "_branch2a") self.conv1 = ConvBNLayer( num_channels=num_filters, num_filters=num_filters, filter_size=3, stride=stride, - act='relu', - name=name+"_branch2b") + act="relu", + name=name + "_branch2b") self.conv2 = ConvBNLayer( num_channels=num_filters, num_filters=num_filters * 4, filter_size=1, act=None, - name=name+"_branch2c") + name=name + "_branch2c") if not shortcut: self.short = ConvBNLayer( @@ -103,90 +119,163 @@ class BottleneckBlock(fluid.dygraph.Layer): else: short = self.short(inputs) - y = fluid.layers.elementwise_add(x=short, y=conv2) - - layer_helper = LayerHelper(self.full_name(), act='relu') - return layer_helper.append_activation(y) + y = paddle.add(x=short, y=conv2) + y = F.relu(y) + return y -class ResNet(fluid.dygraph.Layer): +class BasicBlock(nn.Layer): + def __init__(self, + num_channels, + num_filters, + stride, + shortcut=True, + name=None): + super(BasicBlock, self).__init__() + self.stride = stride + self.conv0 = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=3, + stride=stride, + act="relu", + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + act=None, + name=name + "_branch2b") + + if not shortcut: + self.short = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=1, + stride=stride, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv1) + y = F.relu(y) + return y + + +class ResNet(nn.Layer): def __init__(self, layers=50, class_dim=1000): super(ResNet, self).__init__() self.layers = layers - supported_layers = [50, 101, 152] + supported_layers = [18, 34, 50, 101, 152] assert layers in supported_layers, \ - "supported layers are {} but input layer is {}".format(supported_layers, layers) - self.fm = None + "supported layers are {} but input layer is {}".format( + supported_layers, layers) - if layers == 50: + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: depth = [3, 4, 6, 3] elif layers == 101: depth = [3, 4, 23, 3] elif layers == 152: depth = [3, 8, 36, 3] - num_channels = [64, 256, 512, 1024] + num_channels = [64, 256, 512, + 1024] if layers >= 50 else [64, 64, 128, 256] num_filters = [64, 128, 256, 512] + self.feature_map = None + self.conv = ConvBNLayer( num_channels=3, num_filters=64, filter_size=7, stride=2, - act='relu', + act="relu", name="conv1") - self.pool2d_max = Pool2D( - pool_size=3, - pool_stride=2, - pool_padding=1, - pool_type='max') + self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1) - self.bottleneck_block_list = [] - for block in range(len(depth)): - shortcut = False - for i in range(depth[block]): - if layers in [101, 152] and block == 2: - if i == 0: - conv_name="res"+str(block+2)+"a" + self.block_list = [] + if layers >= 50: + for block in range(len(depth)): + shortcut = False + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) else: - conv_name="res"+str(block+2)+"b"+str(i) - else: - conv_name="res"+str(block+2)+chr(97+i) - bottleneck_block = self.add_sublayer( - 'bb_%d_%d' % (block, i), - BottleneckBlock( - num_channels=num_channels[block] - if i == 0 else num_filters[block] * 4, - num_filters=num_filters[block], - stride=2 if i == 0 and block != 0 else 1, - shortcut=shortcut, - name=conv_name)) - self.bottleneck_block_list.append(bottleneck_block) - shortcut = True + conv_name = "res" + str(block + 2) + chr(97 + i) + bottleneck_block = self.add_sublayer( + conv_name, + BottleneckBlock( + num_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + name=conv_name)) + self.block_list.append(bottleneck_block) + shortcut = True + else: + for block in range(len(depth)): + shortcut = False + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + basic_block = self.add_sublayer( + conv_name, + BasicBlock( + num_channels=num_channels[block] + if i == 0 else num_filters[block], + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + name=conv_name)) + self.block_list.append(basic_block) + shortcut = True - self.pool2d_avg = Pool2D( - pool_size=7, pool_type='avg', global_pooling=True) + self.pool2d_avg = AdaptiveAvgPool2D(1) - self.pool2d_avg_output = num_filters[len(num_filters) - 1] * 4 * 1 * 1 + self.pool2d_avg_channels = num_channels[-1] * 2 - stdv = 1.0 / math.sqrt(2048 * 1.0) + stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0) - self.out = Linear(self.pool2d_avg_output, - class_dim, - param_attr=ParamAttr( - initializer=fluid.initializer.Uniform(-stdv, stdv), name="fc_0.w_0"), - bias_attr=ParamAttr(name="fc_0.b_0")) + self.out = Linear( + self.pool2d_avg_channels, + class_dim, + weight_attr=ParamAttr( + initializer=Uniform(-stdv, stdv), name="fc_0.w_0"), + bias_attr=ParamAttr(name="fc_0.b_0")) def forward(self, inputs): y = self.conv(inputs) y = self.pool2d_max(y) - self.fm = y - for bottleneck_block in self.bottleneck_block_list: - y = bottleneck_block(y) + self.feature_map = y + for block in self.block_list: + y = block(y) y = self.pool2d_avg(y) - y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output]) + y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels]) y = self.out(y) - return y, self.fm + return y, self.feature_map + + +def ResNet18(**args): + model = ResNet(layers=18, **args) + return model + + +def ResNet34(**args): + model = ResNet(layers=34, **args) + return model def ResNet50(**args): @@ -202,14 +291,3 @@ def ResNet101(**args): def ResNet152(**args): model = ResNet(layers=152, **args) return model - - -if __name__ == "__main__": - import numpy as np - place = fluid.CPUPlace() - with fluid.dygraph.guard(place): - model = ResNet50() - img = np.random.uniform(0, 255, [1, 3, 224, 224]).astype('float32') - img = fluid.dygraph.to_variable(img) - res = model(img) - print(res.shape)