modify some format problem
parent
22132ebd4e
commit
28c9e03276
|
@ -28,6 +28,7 @@ from python.predict_rec import RecPredictor
|
||||||
from utils import logger
|
from utils import logger
|
||||||
from utils import config
|
from utils import config
|
||||||
|
|
||||||
|
|
||||||
def split_datafile(data_file, image_root, delimiter="\t"):
|
def split_datafile(data_file, image_root, delimiter="\t"):
|
||||||
'''
|
'''
|
||||||
data_file: image path and info, which can be splitted by spacer
|
data_file: image path and info, which can be splitted by spacer
|
||||||
|
@ -70,7 +71,6 @@ class GalleryBuilder(object):
|
||||||
# when remove data in index, do not need extract fatures
|
# when remove data in index, do not need extract fatures
|
||||||
if operation_method != "remove":
|
if operation_method != "remove":
|
||||||
gallery_features = self._extract_features(gallery_images, config)
|
gallery_features = self._extract_features(gallery_images, config)
|
||||||
|
|
||||||
assert operation_method in [
|
assert operation_method in [
|
||||||
"new", "remove", "append"
|
"new", "remove", "append"
|
||||||
], "Only append, remove and new operation are supported"
|
], "Only append, remove and new operation are supported"
|
||||||
|
@ -105,7 +105,7 @@ class GalleryBuilder(object):
|
||||||
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
|
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
|
||||||
|
|
||||||
# for binary index, add B at head of index_method
|
# for binary index, add B at head of index_method
|
||||||
if config["dist_type"] == "hamming":
|
if config["dist_type"] == "hamming":
|
||||||
index_method = "B" + index_method
|
index_method = "B" + index_method
|
||||||
|
|
||||||
#dist_type
|
#dist_type
|
||||||
|
@ -113,11 +113,12 @@ class GalleryBuilder(object):
|
||||||
"dist_type"] == "IP" else faiss.METRIC_L2
|
"dist_type"] == "IP" else faiss.METRIC_L2
|
||||||
|
|
||||||
#build index
|
#build index
|
||||||
if config["dist_type"] == "hamming":
|
if config["dist_type"] == "hamming":
|
||||||
index = faiss.index_binary_factory(config["embedding_size"], index_method)
|
index = faiss.index_binary_factory(config["embedding_size"],
|
||||||
|
index_method)
|
||||||
else:
|
else:
|
||||||
index = faiss.index_factory(config["embedding_size"], index_method,
|
index = faiss.index_factory(config["embedding_size"],
|
||||||
dist_type)
|
index_method, dist_type)
|
||||||
index = faiss.IndexIDMap2(index)
|
index = faiss.IndexIDMap2(index)
|
||||||
ids = {}
|
ids = {}
|
||||||
|
|
||||||
|
@ -133,12 +134,12 @@ class GalleryBuilder(object):
|
||||||
|
|
||||||
# only train when new index file
|
# only train when new index file
|
||||||
if operation_method == "new":
|
if operation_method == "new":
|
||||||
if config["dist_type"] == "hamming":
|
if config["dist_type"] == "hamming":
|
||||||
index.add(gallery_features)
|
index.add(gallery_features)
|
||||||
else:
|
else:
|
||||||
index.train(gallery_features)
|
index.train(gallery_features)
|
||||||
|
|
||||||
if not config["dist_type"] == "hamming":
|
if not config["dist_type"] == "hamming":
|
||||||
index.add_with_ids(gallery_features, ids_now)
|
index.add_with_ids(gallery_features, ids_now)
|
||||||
|
|
||||||
for i, d in zip(list(ids_now), gallery_docs):
|
for i, d in zip(list(ids_now), gallery_docs):
|
||||||
|
@ -157,25 +158,26 @@ class GalleryBuilder(object):
|
||||||
del ids[k]
|
del ids[k]
|
||||||
|
|
||||||
# store faiss index file and id_map file
|
# store faiss index file and id_map file
|
||||||
if config["dist_type"] == "hamming":
|
if config["dist_type"] == "hamming":
|
||||||
faiss.write_index_binary(index,
|
faiss.write_index_binary(
|
||||||
os.path.join(config["index_dir"], "vector.index"))
|
index, os.path.join(config["index_dir"], "vector.index"))
|
||||||
else:
|
else:
|
||||||
faiss.write_index(index,
|
faiss.write_index(
|
||||||
os.path.join(config["index_dir"], "vector.index"))
|
index, os.path.join(config["index_dir"], "vector.index"))
|
||||||
|
|
||||||
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
|
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
|
||||||
pickle.dump(ids, fd)
|
pickle.dump(ids, fd)
|
||||||
|
|
||||||
def _extract_features(self, gallery_images, config):
|
def _extract_features(self, gallery_images, config):
|
||||||
|
|
||||||
# extract gallery features
|
# extract gallery features
|
||||||
if config["dist_type"] == "hamming":
|
if config["dist_type"] == "hamming":
|
||||||
gallery_features = np.zeros(
|
gallery_features = np.zeros(
|
||||||
[len(gallery_images), config['embedding_size'] // 8], dtype=np.uint8)
|
[len(gallery_images), config['embedding_size'] // 8],
|
||||||
|
dtype=np.uint8)
|
||||||
else:
|
else:
|
||||||
gallery_features = np.zeros(
|
gallery_features = np.zeros(
|
||||||
[len(gallery_images), config['embedding_size']], dtype=np.float32)
|
[len(gallery_images), config['embedding_size']],
|
||||||
|
dtype=np.float32)
|
||||||
|
|
||||||
#construct batch imgs and do inference
|
#construct batch imgs and do inference
|
||||||
batch_size = config.get("batch_size", 32)
|
batch_size = config.get("batch_size", 32)
|
||||||
|
|
Loading…
Reference in New Issue