fix build_gallery bug
parent
9eaa3353af
commit
52ba23c8b9
|
@ -50,8 +50,8 @@ class GalleryBuilder(object):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rec_predictor = RecPredictor(config)
|
self.rec_predictor = RecPredictor(config)
|
||||||
assert 'IndexProcess' in config.keys(), "Index config not found ... "
|
assert 'IndexProcess' in config.keys(), "Index config not found ... "
|
||||||
self.build(config['IndexProcess'])
|
|
||||||
self.android_demo = config["Global"].get("android_demo", False)
|
self.android_demo = config["Global"].get("android_demo", False)
|
||||||
|
self.build(config['IndexProcess'])
|
||||||
|
|
||||||
def build(self, config):
|
def build(self, config):
|
||||||
'''
|
'''
|
||||||
|
@ -78,7 +78,8 @@ class GalleryBuilder(object):
|
||||||
index, ids = None, None
|
index, ids = None, None
|
||||||
if operation_method in ["remove", "append"]:
|
if operation_method in ["remove", "append"]:
|
||||||
# if remove or append, load vector.index and id_map.pkl
|
# if remove or append, load vector.index and id_map.pkl
|
||||||
index, ids = self._load_index()
|
index, ids = self._load_index(config)
|
||||||
|
index_method = config.get("index_method", "HNSW32")
|
||||||
else:
|
else:
|
||||||
index_method, index, ids = self._create_index(config)
|
index_method, index, ids = self._create_index(config)
|
||||||
if index_method == "HNSW32":
|
if index_method == "HNSW32":
|
||||||
|
@ -87,7 +88,7 @@ class GalleryBuilder(object):
|
||||||
|
|
||||||
if operation_method != "remove":
|
if operation_method != "remove":
|
||||||
# calculate id for new data
|
# calculate id for new data
|
||||||
index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs)
|
index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs, config, operation_method)
|
||||||
else:
|
else:
|
||||||
if index_method == "HNSW32":
|
if index_method == "HNSW32":
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -104,14 +105,11 @@ class GalleryBuilder(object):
|
||||||
os.makedirs(config["index_dir"], exist_ok=True)
|
os.makedirs(config["index_dir"], exist_ok=True)
|
||||||
#build index
|
#build index
|
||||||
index = faiss.IndexFlatIP(config["embedding_size"])
|
index = faiss.IndexFlatIP(config["embedding_size"])
|
||||||
ids = {}
|
|
||||||
|
|
||||||
# calculate id for new data
|
|
||||||
ids_now = (
|
|
||||||
np.arange(0, len(gallery_images))).astype(np.int64)
|
|
||||||
|
|
||||||
index.add(gallery_features)
|
index.add(gallery_features)
|
||||||
|
|
||||||
|
# calculate id for data
|
||||||
|
ids_now = (np.arange(0, len(gallery_docs))).astype(np.int64)
|
||||||
|
ids = {}
|
||||||
for i, d in zip(list(ids_now), gallery_docs):
|
for i, d in zip(list(ids_now), gallery_docs):
|
||||||
ids[i] = d
|
ids[i] = d
|
||||||
self._save_gallery(config, index, ids)
|
self._save_gallery(config, index, ids)
|
||||||
|
@ -197,7 +195,7 @@ class GalleryBuilder(object):
|
||||||
ids = {}
|
ids = {}
|
||||||
return index_method, index, ids
|
return index_method, index, ids
|
||||||
|
|
||||||
def _add_gallery(self, index, ids, gallery_features, gallery_docs):
|
def _add_gallery(self, index, ids, gallery_features, gallery_docs, config, operation_method):
|
||||||
start_id = max(ids.keys()) + 1 if ids else 0
|
start_id = max(ids.keys()) + 1 if ids else 0
|
||||||
ids_now = (
|
ids_now = (
|
||||||
np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
|
np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
|
||||||
|
@ -216,7 +214,7 @@ class GalleryBuilder(object):
|
||||||
ids[i] = d
|
ids[i] = d
|
||||||
return index, ids
|
return index, ids
|
||||||
|
|
||||||
def _rm_id_in_galllery(self, index, ids, gallery_docs)
|
def _rm_id_in_galllery(self, index, ids, gallery_docs):
|
||||||
remove_ids = list(
|
remove_ids = list(
|
||||||
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
|
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
|
||||||
remove_ids = np.asarray(remove_ids)
|
remove_ids = np.asarray(remove_ids)
|
||||||
|
|
Loading…
Reference in New Issue