diff --git a/deploy/configs/inference_rec.yaml b/deploy/configs/inference_rec.yaml index 2d864493e..ba64b6d8e 100644 --- a/deploy/configs/inference_rec.yaml +++ b/deploy/configs/inference_rec.yaml @@ -49,4 +49,21 @@ RecPreProcess: order: '' - ToCHWImage: -RecPostProcess: null \ No newline at end of file +RecPostProcess: null + + +# indexing engine config +IndexProcess: + build: + enable: True + index_path: "./logo_index/" + image_root: "dataset/LogoDet-3K-crop/train" + data_file: "dataset/LogoDet-3K-crop/LogoDet-3K+train.txt" + spacer: " " + dist_type: "IP" + pq_size: 100 + embedding_size: 1000 + infer: + index_path: "./logo_index/" + search_budget: 100 + return_k: 10 diff --git a/deploy/python/predict_system.py b/deploy/python/predict_system.py index 84e60e806..0460a08ea 100644 --- a/deploy/python/predict_system.py +++ b/deploy/python/predict_system.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,28 +23,78 @@ import numpy as np from python.predict_rec import RecPredictor from python.predict_det import DetPredictor +from vector_search import Graph_Index from utils import logger from utils import config from utils.get_image_list import get_image_list +def split_datafile(data_file, image_root): + gallery_images = [] + gallery_docs = [] + with open(datafile) as f: + lines = f.readlines() + for i, line in enumerate(lines): + line = line.strip().split("\t") + if line[0] == 'image_id': + continue + image_file = os.path.join(image_root, line[3]) + image_doc = line[1] + gallery_images.append(image_file) + gallery_docs.append(image_doc) + return gallery_images, gallery_docs + + class SystemPredictor(object): def __init__(self, config): + + self.config = config self.rec_predictor = RecPredictor(config) self.det_predictor = DetPredictor(config) + assert 'IndexProcess' in config.keys(), "Index config not found ... " + self.indexer(config['IndexProcess']) + self.return_k = self.config['IndexProcess']['infer']['return_k'] + self.search_budget = self.config['IndexProcess']['infer']['search_budget'] + + def indexer(self, config): + if 'build' in config.keys() and config['build']['enable']: # build the index from scratch + with open(config['build']['datafile']) as f: + lines = f.readlines() + gallery_images, gallery_docs = split_datafile(config['build']['data_file'], config['build']['image_root']) + # extract gallery features + gallery_features = np.zeros([len(gallery_images), config['build']['embedding_size']], dtype=np.float32) + for i, image_file in enumerate(gallery_images): + img = cv2.imread(image_file)[:, :, ::-1] + rec_feat = self.rec_predictor.predict(img) + gallery_features[i,:] = rec_feat + # train index + self.Searcher = Graph_Index(dist_type=config['build']['dist_type']) + self.Searcher.build(gallery_vectors=gallery_features, gallery_docs=gallery_docs, + pq_size=config['build']['pq_size'], index_path=config['build']['index_path']) + + else: # load local index + self.Searcher = Graph_Index(dist_type=config['build']['dist_type']) + self.Searcher.load(config['infer']['index_path']) + def predict(self, img): output = [] results = self.det_predictor.predict(img) for result in results: - print(result) + #print(result) xmin, xmax, ymin, ymax = result["bbox"].astype("int") crop_img = img[xmin:xmax, ymin:ymax, :].copy() rec_results = self.rec_predictor.predict(crop_img) - result["feature"] = rec_results + result["featrue"] = rec_results + + scores, docs = self.Searcher.search(query=rec_results, return_k=self.return_k, search_budget=self.search_budget) + result["ret_docs"] = docs + result["ret_scores"] = scores + output.append(result) return output + def main(config): @@ -55,7 +105,7 @@ def main(config): for idx, image_file in enumerate(image_list): img = cv2.imread(image_file)[:, :, ::-1] output = system_predictor.predict(img) - print(output) + #print(output) return diff --git a/deploy/shell/predict.sh b/deploy/shell/predict.sh index 6bbacabaf..990a5703e 100644 --- a/deploy/shell/predict.sh +++ b/deploy/shell/predict.sh @@ -7,5 +7,5 @@ python3.7 python/predict_cls.py -c configs/inference_cls.yaml # detection # python3.7 python/predict_det.py -c configs/inference_rec.yaml -# mainbody detection + feature extractor -# python3.7 python/predict_system.py -c configs/inference_rec.yaml \ No newline at end of file +# mainbody detection + feature extractor + retrieval +# python3.7 python/predict_system.py -c configs/inference_rec.yaml diff --git a/deploy/vector_search/interface.py b/deploy/vector_search/interface.py index bbb460b9d..51ef7bdd2 100644 --- a/deploy/vector_search/interface.py +++ b/deploy/vector_search/interface.py @@ -22,7 +22,9 @@ import json from ctypes import * from numpy.ctypeslib import ndpointer -lib = ctypes.cdll.LoadLibrary("./index.so") +__dir__ = os.path.dirname(os.path.abspath(__file__)) +so_path = os.path.join(__dir__, "index.so") +lib = ctypes.cdll.LoadLibrary(so_path) class IndexContext(Structure): _fields_=[("graph",c_void_p),