fix build_gallery bug

pull/2236/head
dongshuilong 2022-08-29 19:17:13 +08:00
parent 9eaa3353af
commit 52ba23c8b9
1 changed files with 9 additions and 11 deletions
deploy/python

View File

@ -50,8 +50,8 @@ class GalleryBuilder(object):
self.config = config
self.rec_predictor = RecPredictor(config)
assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.build(config['IndexProcess'])
self.android_demo = config["Global"].get("android_demo", False)
self.build(config['IndexProcess'])
def build(self, config):
'''
@ -78,7 +78,8 @@ class GalleryBuilder(object):
index, ids = None, None
if operation_method in ["remove", "append"]:
# if remove or append, load vector.index and id_map.pkl
index, ids = self._load_index()
index, ids = self._load_index(config)
index_method = config.get("index_method", "HNSW32")
else:
index_method, index, ids = self._create_index(config)
if index_method == "HNSW32":
@ -87,7 +88,7 @@ class GalleryBuilder(object):
if operation_method != "remove":
# calculate id for new data
index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs)
index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs, config, operation_method)
else:
if index_method == "HNSW32":
raise RuntimeError(
@ -104,14 +105,11 @@ class GalleryBuilder(object):
os.makedirs(config["index_dir"], exist_ok=True)
#build index
index = faiss.IndexFlatIP(config["embedding_size"])
ids = {}
# calculate id for new data
ids_now = (
np.arange(0, len(gallery_images))).astype(np.int64)
index.add(gallery_features)
# calculate id for data
ids_now = (np.arange(0, len(gallery_docs))).astype(np.int64)
ids = {}
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
self._save_gallery(config, index, ids)
@ -197,7 +195,7 @@ class GalleryBuilder(object):
ids = {}
return index_method, index, ids
def _add_gallery(self, index, ids, gallery_features, gallery_docs):
def _add_gallery(self, index, ids, gallery_features, gallery_docs, config, operation_method):
start_id = max(ids.keys()) + 1 if ids else 0
ids_now = (
np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
@ -216,7 +214,7 @@ class GalleryBuilder(object):
ids[i] = d
return index, ids
def _rm_id_in_galllery(self, index, ids, gallery_docs)
def _rm_id_in_galllery(self, index, ids, gallery_docs):
remove_ids = list(
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
remove_ids = np.asarray(remove_ids)