replace vector_search with faiss
parent
800487e416
commit
687c13522e
|
@ -28,11 +28,11 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
|
||||
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
|
||||
index_dir: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
|
||||
image_root: "./recognition_demo_data_v1.0/gallery_cartoon/"
|
||||
data_file: "./recognition_demo_data_v1.0/gallery_cartoon/data_file.txt"
|
||||
append_index: False
|
||||
index_operation: "new" # suported: "append", "remove", "new"
|
||||
delimiter: "\t"
|
||||
dist_type: "IP"
|
||||
pq_size: 100
|
||||
embedding_size: 2048
|
||||
|
|
|
@ -26,11 +26,11 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_logo/index/"
|
||||
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
|
||||
index_dir: "./recognition_demo_data_v1.0/gallery_logo/index/"
|
||||
image_root: "./recognition_demo_data_v1.0/gallery_logo/"
|
||||
data_file: "./recognition_demo_data_v1.0/gallery_logo/data_file.txt"
|
||||
append_index: False
|
||||
index_operation: "new" # suported: "append", "remove", "new"
|
||||
delimiter: "\t"
|
||||
dist_type: "IP"
|
||||
pq_size: 100
|
||||
embedding_size: 512
|
||||
|
|
|
@ -26,11 +26,11 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_product/index"
|
||||
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
|
||||
index_dir: "./recognition_demo_data_v1.0/gallery_product/index"
|
||||
image_root: "./recognition_demo_data_v1.0/gallery_product/"
|
||||
data_file: "./recognition_demo_data_v1.0/gallery_product/data_file.txt"
|
||||
append_index: False
|
||||
index_operation: "new" # suported: "append", "remove", "new"
|
||||
delimiter: "\t"
|
||||
dist_type: "IP"
|
||||
pq_size: 100
|
||||
embedding_size: 512
|
||||
|
|
|
@ -26,11 +26,11 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
|
||||
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
|
||||
index_dir: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
|
||||
image_root: "./recognition_demo_data_v1.0/gallery_vehicle/"
|
||||
data_file: "./recognition_demo_data_v1.0/gallery_vehicle/data_file.txt"
|
||||
append_index: False
|
||||
index_operation: "new" # suported: "append", "remove", "new"
|
||||
delimiter: "\t"
|
||||
dist_type: "IP"
|
||||
pq_size: 100
|
||||
embedding_size: 512
|
||||
|
|
|
@ -51,8 +51,6 @@ RecPreProcess:
|
|||
RecPostProcess: null
|
||||
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
|
||||
search_budget: 100
|
||||
index_dir: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
|
||||
return_k: 5
|
||||
dist_type: "IP"
|
||||
score_thres: 0.5
|
||||
|
|
|
@ -50,8 +50,6 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_logo/index/"
|
||||
search_budget: 100
|
||||
index_dir: "./recognition_demo_data_v1.0/gallery_logo/index/"
|
||||
return_k: 5
|
||||
dist_type: "IP"
|
||||
score_thres: 0.5
|
||||
|
|
|
@ -50,8 +50,6 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_product/index"
|
||||
search_budget: 100
|
||||
index_dir: "./recognition_demo_data_v1.0/gallery_product/index"
|
||||
return_k: 5
|
||||
dist_type: "IP"
|
||||
score_thres: 0.5
|
||||
|
|
|
@ -52,8 +52,6 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
|
||||
search_budget: 100
|
||||
index_dir: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
|
||||
return_k: 5
|
||||
dist_type: "IP"
|
||||
score_thres: 0.5
|
||||
|
|
|
@ -17,13 +17,13 @@ import sys
|
|||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
import copy
|
||||
import cv2
|
||||
import faiss
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import pickle
|
||||
|
||||
from python.predict_rec import RecPredictor
|
||||
from vector_search import Graph_Index
|
||||
|
||||
from utils import logger
|
||||
from utils import config
|
||||
|
@ -31,9 +31,9 @@ from utils import config
|
|||
|
||||
def split_datafile(data_file, image_root, delimiter="\t"):
|
||||
'''
|
||||
data_file: image path and info, which can be splitted by spacer
|
||||
data_file: image path and info, which can be splitted by spacer
|
||||
image_root: image path root
|
||||
delimiter: delimiter
|
||||
delimiter: delimiter
|
||||
'''
|
||||
gallery_images = []
|
||||
gallery_docs = []
|
||||
|
@ -45,9 +45,8 @@ def split_datafile(data_file, image_root, delimiter="\t"):
|
|||
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
|
||||
image_file = os.path.join(image_root, line[0])
|
||||
|
||||
image_doc = line[1]
|
||||
gallery_images.append(image_file)
|
||||
gallery_docs.append(image_doc)
|
||||
gallery_docs.append(ori_line.strip())
|
||||
|
||||
return gallery_images, gallery_docs
|
||||
|
||||
|
@ -64,9 +63,77 @@ class GalleryBuilder(object):
|
|||
'''
|
||||
build index from scratch
|
||||
'''
|
||||
operation_method = config.get("index_operation", "new").lower()
|
||||
|
||||
gallery_images, gallery_docs = split_datafile(
|
||||
config['data_file'], config['image_root'], config['delimiter'])
|
||||
if operation_method != "remove":
|
||||
gallery_features = self._extract_features(gallery_images, config)
|
||||
|
||||
assert operation_method in [
|
||||
"new", "remove", "append"
|
||||
], "Only append, remove and new operation are supported"
|
||||
if operation_method in ["remove", "append"]:
|
||||
assert os.path.join(
|
||||
config["index_dir"], "vector.index"
|
||||
), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
|
||||
config["index_dir"])
|
||||
assert os.path.join(
|
||||
config["index_dir"], "id_map.pkl"
|
||||
), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
|
||||
config["index_dir"])
|
||||
index = faiss.read_index(
|
||||
os.path.join(config["index_dir"], "vector.index"))
|
||||
with open(os.path.join(config["index_dir"], "id_map.pkl"),
|
||||
'rb') as fd:
|
||||
ids = pickle.load(fd)
|
||||
assert index.ntotal == len(ids.keys(
|
||||
)), "data number in index is not equal in in id_map"
|
||||
else:
|
||||
if not os.path.exists(config["index_dir"]):
|
||||
os.makedirs(config["index_dir"], exist_ok=True)
|
||||
index_method = config.get("index_method", "HNSW32")
|
||||
if index_method == "IVF":
|
||||
index_method = index_method + str(
|
||||
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
|
||||
dist_type = faiss.METRIC_INNER_PRODUCT if config[
|
||||
"dist_type"] == "IP" else faiss.METRIC_L2
|
||||
index = faiss.index_factory(config["embedding_size"], index_method,
|
||||
dist_type)
|
||||
index = faiss.IndexIDMap2(index)
|
||||
ids = {}
|
||||
|
||||
if config["index_method"] == "HNSW32":
|
||||
logger.warning(
|
||||
"The HNSW32 method dose not support 'remove' operation")
|
||||
|
||||
if operation_method != "remove":
|
||||
start_id = max(ids.keys()) + 1 if ids else 0
|
||||
ids_now = np.arange(0, len(gallery_images)) + start_id
|
||||
if operation_method == "new":
|
||||
index.train(gallery_features)
|
||||
index.add_with_ids(gallery_features, ids_now)
|
||||
|
||||
for i, d in zip(list(ids_now), gallery_docs):
|
||||
ids[i] = d
|
||||
else:
|
||||
if config["index_method"] == "HNSW32":
|
||||
raise RuntimeError(
|
||||
"The index_method: HNSW32 dose not support 'remove' operation"
|
||||
)
|
||||
remove_ids = list(
|
||||
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
|
||||
remove_ids = np.asarray(remove_ids)
|
||||
index.remove_ids(remove_ids)
|
||||
for k in remove_ids:
|
||||
del ids[k]
|
||||
|
||||
faiss.write_index(index,
|
||||
os.path.join(config["index_dir"], "vector.index"))
|
||||
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
|
||||
pickle.dump(ids, fd)
|
||||
|
||||
def _extract_features(self, gallery_images, config):
|
||||
# extract gallery features
|
||||
gallery_features = np.zeros(
|
||||
[len(gallery_images), config['embedding_size']], dtype=np.float32)
|
||||
|
@ -91,19 +158,11 @@ class GalleryBuilder(object):
|
|||
rec_feat = self.rec_predictor.predict(batch_img)
|
||||
gallery_features[-len(batch_img):, :] = rec_feat
|
||||
batch_img = []
|
||||
|
||||
# train index
|
||||
self.Searcher = Graph_Index(dist_type=config['dist_type'])
|
||||
self.Searcher.build(
|
||||
gallery_vectors=gallery_features,
|
||||
gallery_docs=gallery_docs,
|
||||
pq_size=config['pq_size'],
|
||||
index_path=config['index_path'],
|
||||
append_index=config["append_index"])
|
||||
return gallery_features
|
||||
|
||||
|
||||
def main(config):
|
||||
system_builder = GalleryBuilder(config)
|
||||
GalleryBuilder(config)
|
||||
return
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
|
@ -20,10 +20,11 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
|||
import copy
|
||||
import cv2
|
||||
import numpy as np
|
||||
import faiss
|
||||
import pickle
|
||||
|
||||
from python.predict_rec import RecPredictor
|
||||
from python.predict_det import DetPredictor
|
||||
from vector_search import Graph_Index
|
||||
|
||||
from utils import logger
|
||||
from utils import config
|
||||
|
@ -40,11 +41,16 @@ class SystemPredictor(object):
|
|||
|
||||
assert 'IndexProcess' in config.keys(), "Index config not found ... "
|
||||
self.return_k = self.config['IndexProcess']['return_k']
|
||||
self.search_budget = self.config['IndexProcess']['search_budget']
|
||||
|
||||
self.Searcher = Graph_Index(
|
||||
dist_type=config['IndexProcess']['dist_type'])
|
||||
self.Searcher.load(config['IndexProcess']['index_path'])
|
||||
index_dir = self.config["IndexProcess"]["index_dir"]
|
||||
assert os.path.exists(os.path.join(
|
||||
index_dir, "vector.index")), "vector.index not found ..."
|
||||
assert os.path.exists(os.path.join(
|
||||
index_dir, "id_map.pkl")), "id_map.pkl not found ... "
|
||||
self.Searcher = faiss.read_index(
|
||||
os.path.join(index_dir, "vector.index"))
|
||||
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
|
||||
self.id_map = pickle.load(fd)
|
||||
|
||||
def append_self(self, results, shape):
|
||||
results.append({
|
||||
|
@ -98,14 +104,11 @@ class SystemPredictor(object):
|
|||
crop_img = img[ymin:ymax, xmin:xmax, :].copy()
|
||||
rec_results = self.rec_predictor.predict(crop_img)
|
||||
preds["bbox"] = [xmin, ymin, xmax, ymax]
|
||||
scores, docs = self.Searcher.search(
|
||||
query=rec_results,
|
||||
return_k=self.return_k,
|
||||
search_budget=self.search_budget)
|
||||
scores, docs = self.Searcher.search(rec_results, self.return_k)
|
||||
# just top-1 result will be returned for the final
|
||||
if scores[0] >= self.config["IndexProcess"]["score_thres"]:
|
||||
preds["rec_docs"] = docs[0]
|
||||
preds["rec_scores"] = scores[0]
|
||||
if scores[0][0] >= self.config["IndexProcess"]["score_thres"]:
|
||||
preds["rec_docs"] = self.id_map[docs[0][0]].split()[1]
|
||||
preds["rec_scores"] = scores[0][0]
|
||||
output.append(preds)
|
||||
|
||||
# st5: nms to the final results to avoid fetching duplicate results
|
||||
|
|
Loading…
Reference in New Issue