2020-07-04 13:08:48 +08:00
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from resnet import ResNet50
import paddle . fluid as fluid
import numpy as np
import cv2
import utils
import argparse
def parse_args ( ) :
def str2bool ( v ) :
return v . lower ( ) in ( " true " , " t " , " 1 " )
parser = argparse . ArgumentParser ( )
parser . add_argument ( " -i " , " --image_file " , type = str )
parser . add_argument ( " -c " , " --channel_num " , type = int )
parser . add_argument ( " -p " , " --pretrained_model " , type = str )
parser . add_argument ( " --show " , type = str2bool , default = False )
2020-07-20 12:16:02 +08:00
parser . add_argument ( " --interpolation " , type = int , default = 1 )
2020-07-04 13:08:48 +08:00
parser . add_argument ( " --save_path " , type = str )
parser . add_argument ( " --use_gpu " , type = str2bool , default = True )
return parser . parse_args ( )
2020-07-20 12:16:02 +08:00
def create_operators ( interpolation = 1 ) :
2020-07-04 13:08:48 +08:00
size = 224
img_mean = [ 0.485 , 0.456 , 0.406 ]
img_std = [ 0.229 , 0.224 , 0.225 ]
img_scale = 1.0 / 255.0
decode_op = utils . DecodeImage ( )
2020-07-20 12:16:02 +08:00
resize_op = utils . ResizeImage ( resize_short = 256 , interpolation = interpolation )
2020-07-04 13:08:48 +08:00
crop_op = utils . CropImage ( size = ( size , size ) )
normalize_op = utils . NormalizeImage (
scale = img_scale , mean = img_mean , std = img_std )
totensor_op = utils . ToTensor ( )
return [ decode_op , resize_op , crop_op , normalize_op , totensor_op ]
def preprocess ( fname , ops ) :
data = open ( fname , ' rb ' ) . read ( )
for op in ops :
data = op ( data )
return data
def main ( ) :
args = parse_args ( )
2020-07-20 12:16:02 +08:00
operators = create_operators ( args . interpolation )
2020-07-04 13:08:48 +08:00
# assign the place
if args . use_gpu :
gpu_id = fluid . dygraph . parallel . Env ( ) . dev_id
place = fluid . CUDAPlace ( gpu_id )
else :
place = fluid . CPUPlace ( )
2020-07-20 11:28:10 +08:00
2020-07-20 12:16:02 +08:00
#pre_weights_dict = fluid.load_program_state(args.pretrained_model)
2020-07-04 13:08:48 +08:00
with fluid . dygraph . guard ( place ) :
net = ResNet50 ( )
data = preprocess ( args . image_file , operators )
data = np . expand_dims ( data , axis = 0 )
data = fluid . dygraph . to_variable ( data )
dy_weights_dict = net . state_dict ( )
pre_weights_dict_new = { }
for key in dy_weights_dict :
weights_name = dy_weights_dict [ key ] . name
pre_weights_dict_new [ key ] = pre_weights_dict [ weights_name ]
net . set_dict ( pre_weights_dict_new )
net . eval ( )
_ , fm = net ( data )
2020-07-20 11:28:10 +08:00
assert args . channel_num > = 0 and args . channel_num < = fm . shape [ 1 ] , " the channel is out of the range, should be in {} but got {} " . format ( [ 0 , fm . shape [ 1 ] ] , args . channel_num )
2020-07-04 13:08:48 +08:00
fm = ( np . squeeze ( fm [ 0 ] [ args . channel_num ] . numpy ( ) ) * 255 ) . astype ( np . uint8 )
if fm is not None :
if args . save :
cv2 . imwrite ( args . save_path , fm )
if args . show :
cv2 . show ( fm )
cv2 . waitKey ( 0 )
if __name__ == " __main__ " :
main ( )