mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
update build_gallery and add android demo index support
This commit is contained in:
parent
291015f459
commit
9eaa3353af
@ -12,16 +12,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from paddleclas.deploy.python.predict_rec import RecPredictor
|
||||||
|
from paddleclas.deploy.utils import config, logger
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import pickle
|
|
||||||
|
|
||||||
from paddleclas.deploy.utils import logger, config
|
|
||||||
from paddleclas.deploy.python.predict_rec import RecPredictor
|
|
||||||
from paddleclas.deploy.python.predict_rec import RecPredictor
|
|
||||||
|
|
||||||
|
|
||||||
def split_datafile(data_file, image_root, delimiter="\t"):
|
def split_datafile(data_file, image_root, delimiter="\t"):
|
||||||
@ -53,6 +51,7 @@ class GalleryBuilder(object):
|
|||||||
self.rec_predictor = RecPredictor(config)
|
self.rec_predictor = RecPredictor(config)
|
||||||
assert 'IndexProcess' in config.keys(), "Index config not found ... "
|
assert 'IndexProcess' in config.keys(), "Index config not found ... "
|
||||||
self.build(config['IndexProcess'])
|
self.build(config['IndexProcess'])
|
||||||
|
self.android_demo = config["Global"].get("android_demo", False)
|
||||||
|
|
||||||
def build(self, config):
|
def build(self, config):
|
||||||
'''
|
'''
|
||||||
@ -70,98 +69,52 @@ class GalleryBuilder(object):
|
|||||||
"new", "remove", "append"
|
"new", "remove", "append"
|
||||||
], "Only append, remove and new operation are supported"
|
], "Only append, remove and new operation are supported"
|
||||||
|
|
||||||
|
if self.android_demo:
|
||||||
|
self._create_index_for_android_demo(config, gallery_features, gallery_docs)
|
||||||
|
return
|
||||||
|
|
||||||
# vector.index: faiss index file
|
# vector.index: faiss index file
|
||||||
# id_map.pkl: use this file to map id to image_doc
|
# id_map.pkl: use this file to map id to image_doc
|
||||||
|
index, ids = None, None
|
||||||
if operation_method in ["remove", "append"]:
|
if operation_method in ["remove", "append"]:
|
||||||
# if remove or append, vector.index and id_map.pkl must exist
|
# if remove or append, load vector.index and id_map.pkl
|
||||||
assert os.path.join(
|
index, ids = self._load_index()
|
||||||
config["index_dir"], "vector.index"
|
|
||||||
), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
|
|
||||||
config["index_dir"])
|
|
||||||
assert os.path.join(
|
|
||||||
config["index_dir"], "id_map.pkl"
|
|
||||||
), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
|
|
||||||
config["index_dir"])
|
|
||||||
index = faiss.read_index(
|
|
||||||
os.path.join(config["index_dir"], "vector.index"))
|
|
||||||
with open(os.path.join(config["index_dir"], "id_map.pkl"),
|
|
||||||
'rb') as fd:
|
|
||||||
ids = pickle.load(fd)
|
|
||||||
assert index.ntotal == len(ids.keys(
|
|
||||||
)), "data number in index is not equal in in id_map"
|
|
||||||
else:
|
else:
|
||||||
if not os.path.exists(config["index_dir"]):
|
index_method, index, ids = self._create_index(config)
|
||||||
os.makedirs(config["index_dir"], exist_ok=True)
|
if index_method == "HNSW32":
|
||||||
index_method = config.get("index_method", "HNSW32")
|
|
||||||
|
|
||||||
# if IVF method, cal ivf number automaticlly
|
|
||||||
if index_method == "IVF":
|
|
||||||
index_method = index_method + str(
|
|
||||||
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
|
|
||||||
|
|
||||||
# for binary index, add B at head of index_method
|
|
||||||
if config["dist_type"] == "hamming":
|
|
||||||
index_method = "B" + index_method
|
|
||||||
|
|
||||||
#dist_type
|
|
||||||
dist_type = faiss.METRIC_INNER_PRODUCT if config[
|
|
||||||
"dist_type"] == "IP" else faiss.METRIC_L2
|
|
||||||
|
|
||||||
#build index
|
|
||||||
if config["dist_type"] == "hamming":
|
|
||||||
index = faiss.index_binary_factory(config["embedding_size"],
|
|
||||||
index_method)
|
|
||||||
else:
|
|
||||||
index = faiss.index_factory(config["embedding_size"],
|
|
||||||
index_method, dist_type)
|
|
||||||
index = faiss.IndexIDMap2(index)
|
|
||||||
ids = {}
|
|
||||||
|
|
||||||
if config["index_method"] == "HNSW32":
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The HNSW32 method dose not support 'remove' operation")
|
"The HNSW32 method dose not support 'remove' operation")
|
||||||
|
|
||||||
if operation_method != "remove":
|
if operation_method != "remove":
|
||||||
# calculate id for new data
|
# calculate id for new data
|
||||||
start_id = max(ids.keys()) + 1 if ids else 0
|
index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs)
|
||||||
ids_now = (
|
|
||||||
np.arange(0, len(gallery_images)) + start_id).astype(np.int64)
|
|
||||||
|
|
||||||
# only train when new index file
|
|
||||||
if operation_method == "new":
|
|
||||||
if config["dist_type"] == "hamming":
|
|
||||||
index.add(gallery_features)
|
|
||||||
else:
|
else:
|
||||||
index.train(gallery_features)
|
if index_method == "HNSW32":
|
||||||
|
|
||||||
if not config["dist_type"] == "hamming":
|
|
||||||
index.add_with_ids(gallery_features, ids_now)
|
|
||||||
|
|
||||||
for i, d in zip(list(ids_now), gallery_docs):
|
|
||||||
ids[i] = d
|
|
||||||
else:
|
|
||||||
if config["index_method"] == "HNSW32":
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"The index_method: HNSW32 dose not support 'remove' operation"
|
"The index_method: HNSW32 dose not support 'remove' operation"
|
||||||
)
|
)
|
||||||
# remove ids in id_map, remove index data in faiss index
|
# remove ids in id_map, remove index data in faiss index
|
||||||
remove_ids = list(
|
index, ids = self._rm_id_in_galllery(index, ids, gallery_docs)
|
||||||
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
|
|
||||||
remove_ids = np.asarray(remove_ids)
|
|
||||||
index.remove_ids(remove_ids)
|
|
||||||
for k in remove_ids:
|
|
||||||
del ids[k]
|
|
||||||
|
|
||||||
# store faiss index file and id_map file
|
# store faiss index file and id_map file
|
||||||
if config["dist_type"] == "hamming":
|
self._save_gallery(config, index, ids)
|
||||||
faiss.write_index_binary(
|
|
||||||
index, os.path.join(config["index_dir"], "vector.index"))
|
|
||||||
else:
|
|
||||||
faiss.write_index(
|
|
||||||
index, os.path.join(config["index_dir"], "vector.index"))
|
|
||||||
|
|
||||||
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
|
def _create_index_for_android_demo(self, config, gallery_features, gallery_docs):
|
||||||
pickle.dump(ids, fd)
|
if not os.path.exists(config["index_dir"]):
|
||||||
|
os.makedirs(config["index_dir"], exist_ok=True)
|
||||||
|
#build index
|
||||||
|
index = faiss.IndexFlatIP(config["embedding_size"])
|
||||||
|
ids = {}
|
||||||
|
|
||||||
|
# calculate id for new data
|
||||||
|
ids_now = (
|
||||||
|
np.arange(0, len(gallery_images))).astype(np.int64)
|
||||||
|
|
||||||
|
index.add(gallery_features)
|
||||||
|
|
||||||
|
for i, d in zip(list(ids_now), gallery_docs):
|
||||||
|
ids[i] = d
|
||||||
|
self._save_gallery(config, index, ids)
|
||||||
|
|
||||||
def _extract_features(self, gallery_images, config):
|
def _extract_features(self, gallery_images, config):
|
||||||
# extract gallery features
|
# extract gallery features
|
||||||
@ -197,6 +150,93 @@ class GalleryBuilder(object):
|
|||||||
|
|
||||||
return gallery_features
|
return gallery_features
|
||||||
|
|
||||||
|
def _load_index(self, config):
|
||||||
|
assert os.path.join(
|
||||||
|
config["index_dir"], "vector.index"
|
||||||
|
), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
|
||||||
|
config["index_dir"])
|
||||||
|
assert os.path.join(
|
||||||
|
config["index_dir"], "id_map.pkl"
|
||||||
|
), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
|
||||||
|
config["index_dir"])
|
||||||
|
index = faiss.read_index(
|
||||||
|
os.path.join(config["index_dir"], "vector.index"))
|
||||||
|
with open(os.path.join(config["index_dir"], "id_map.pkl"),
|
||||||
|
'rb') as fd:
|
||||||
|
ids = pickle.load(fd)
|
||||||
|
assert index.ntotal == len(ids.keys(
|
||||||
|
)), "data number in index is not equal in in id_map"
|
||||||
|
return index, ids
|
||||||
|
|
||||||
|
def _create_index(self, config):
|
||||||
|
if not os.path.exists(config["index_dir"]):
|
||||||
|
os.makedirs(config["index_dir"], exist_ok=True)
|
||||||
|
index_method = config.get("index_method", "HNSW32")
|
||||||
|
|
||||||
|
# if IVF method, cal ivf number automaticlly
|
||||||
|
if index_method == "IVF":
|
||||||
|
index_method = index_method + str(
|
||||||
|
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
|
||||||
|
|
||||||
|
# for binary index, add B at head of index_method
|
||||||
|
if config["dist_type"] == "hamming":
|
||||||
|
index_method = "B" + index_method
|
||||||
|
|
||||||
|
#dist_type
|
||||||
|
dist_type = faiss.METRIC_INNER_PRODUCT if config[
|
||||||
|
"dist_type"] == "IP" else faiss.METRIC_L2
|
||||||
|
|
||||||
|
#build index
|
||||||
|
if config["dist_type"] == "hamming":
|
||||||
|
index = faiss.index_binary_factory(config["embedding_size"],
|
||||||
|
index_method)
|
||||||
|
else:
|
||||||
|
index = faiss.index_factory(config["embedding_size"],
|
||||||
|
index_method, dist_type)
|
||||||
|
index = faiss.IndexIDMap2(index)
|
||||||
|
ids = {}
|
||||||
|
return index_method, index, ids
|
||||||
|
|
||||||
|
def _add_gallery(self, index, ids, gallery_features, gallery_docs):
|
||||||
|
start_id = max(ids.keys()) + 1 if ids else 0
|
||||||
|
ids_now = (
|
||||||
|
np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
|
||||||
|
|
||||||
|
# only train when new index file
|
||||||
|
if operation_method == "new":
|
||||||
|
if config["dist_type"] == "hamming":
|
||||||
|
index.add(gallery_features)
|
||||||
|
else:
|
||||||
|
index.train(gallery_features)
|
||||||
|
|
||||||
|
if not config["dist_type"] == "hamming":
|
||||||
|
index.add_with_ids(gallery_features, ids_now)
|
||||||
|
|
||||||
|
for i, d in zip(list(ids_now), gallery_docs):
|
||||||
|
ids[i] = d
|
||||||
|
return index, ids
|
||||||
|
|
||||||
|
def _rm_id_in_galllery(self, index, ids, gallery_docs)
|
||||||
|
remove_ids = list(
|
||||||
|
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
|
||||||
|
remove_ids = np.asarray(remove_ids)
|
||||||
|
index.remove_ids(remove_ids)
|
||||||
|
for k in remove_ids:
|
||||||
|
del ids[k]
|
||||||
|
|
||||||
|
return index, ids
|
||||||
|
|
||||||
|
def _save_gallery(self, config, index, ids):
|
||||||
|
if config["dist_type"] == "hamming":
|
||||||
|
faiss.write_index_binary(
|
||||||
|
index, os.path.join(config["index_dir"], "vector.index"))
|
||||||
|
else:
|
||||||
|
faiss.write_index(
|
||||||
|
index, os.path.join(config["index_dir"], "vector.index"))
|
||||||
|
|
||||||
|
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
|
||||||
|
pickle.dump(ids, fd)
|
||||||
|
|
||||||
|
|
||||||
def main(config):
|
def main(config):
|
||||||
GalleryBuilder(config)
|
GalleryBuilder(config)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user