# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) from ppcls.utils import logger import cv2 import time import requests import json import base64 import imghdr def get_image_file_list(img_file): imgs_lists = [] if img_file is None or not os.path.exists(img_file): raise Exception("not found any img file in {}".format(img_file)) img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'} if os.path.isfile(img_file) and imghdr.what(img_file) in img_end: imgs_lists.append(img_file) elif os.path.isdir(img_file): for single_file in os.listdir(img_file): file_path = os.path.join(img_file, single_file) if imghdr.what(file_path) in img_end: imgs_lists.append(file_path) if len(imgs_lists) == 0: raise Exception("not found any img file in {}".format(img_file)) return imgs_lists def cv2_to_base64(image): return base64.b64encode(image).decode('utf8') def main(url, image_path, top_k=1): image_file_list = get_image_file_list(image_path) headers = {"Content-type": "application/json"} cnt = 0 total_time = 0 all_acc = 0.0 for image_file in image_file_list: file_str = image_file.split('/')[-1] img = open(image_file, 'rb').read() if img is None: logger.error("Loading image:{} failed".format(image_file)) continue data = {'images': [cv2_to_base64(img)], 'top_k': top_k} try: r = requests.post(url=url, headers=headers, data=json.dumps(data)) r.raise_for_status() except Exception as e: logger.error("File:{}, {}".format(file_str, e)) continue if r.json()['status'] != '000': logger.error( "File:{}, The parameters returned by the server are: {}". format(file_str, r.json()['msg'])) continue res = r.json()["results"][0] classes, scores, elapse = res all_acc += scores[0] total_time += elapse cnt += 1 scores = map(lambda x: round(x, 5), scores) results = dict(zip(classes, scores)) message = "No.{}, File:{}, The top-{} result(s):{}, Time cost:{:.3f}".format( cnt, file_str, top_k, results, elapse) logger.info(message) logger.info("The average time cost: {}".format(float(total_time) / cnt)) logger.info("The average top-1 score: {}".format(float(all_acc) / cnt)) if __name__ == '__main__': if len(sys.argv) != 3 and len(sys.argv) != 4: logger.info("Usage: %s server_url image_path" % sys.argv[0]) else: server_url = sys.argv[1] image_path = sys.argv[2] top_k = int(sys.argv[3]) if len(sys.argv) == 4 else 1 main(server_url, image_path, top_k)