modify some format problem

pull/1218/head
stephon 2021-09-13 09:06:21 +00:00
parent 22132ebd4e
commit 28c9e03276
1 changed files with 21 additions and 19 deletions

View File

@ -28,6 +28,7 @@ from python.predict_rec import RecPredictor
from utils import logger from utils import logger
from utils import config from utils import config
def split_datafile(data_file, image_root, delimiter="\t"): def split_datafile(data_file, image_root, delimiter="\t"):
''' '''
data_file: image path and info, which can be splitted by spacer data_file: image path and info, which can be splitted by spacer
@ -70,7 +71,6 @@ class GalleryBuilder(object):
# when remove data in index, do not need extract fatures # when remove data in index, do not need extract fatures
if operation_method != "remove": if operation_method != "remove":
gallery_features = self._extract_features(gallery_images, config) gallery_features = self._extract_features(gallery_images, config)
assert operation_method in [ assert operation_method in [
"new", "remove", "append" "new", "remove", "append"
], "Only append, remove and new operation are supported" ], "Only append, remove and new operation are supported"
@ -105,7 +105,7 @@ class GalleryBuilder(object):
min(int(len(gallery_images) // 8), 65536)) + ",Flat" min(int(len(gallery_images) // 8), 65536)) + ",Flat"
# for binary index, add B at head of index_method # for binary index, add B at head of index_method
if config["dist_type"] == "hamming": if config["dist_type"] == "hamming":
index_method = "B" + index_method index_method = "B" + index_method
#dist_type #dist_type
@ -113,11 +113,12 @@ class GalleryBuilder(object):
"dist_type"] == "IP" else faiss.METRIC_L2 "dist_type"] == "IP" else faiss.METRIC_L2
#build index #build index
if config["dist_type"] == "hamming": if config["dist_type"] == "hamming":
index = faiss.index_binary_factory(config["embedding_size"], index_method) index = faiss.index_binary_factory(config["embedding_size"],
index_method)
else: else:
index = faiss.index_factory(config["embedding_size"], index_method, index = faiss.index_factory(config["embedding_size"],
dist_type) index_method, dist_type)
index = faiss.IndexIDMap2(index) index = faiss.IndexIDMap2(index)
ids = {} ids = {}
@ -133,12 +134,12 @@ class GalleryBuilder(object):
# only train when new index file # only train when new index file
if operation_method == "new": if operation_method == "new":
if config["dist_type"] == "hamming": if config["dist_type"] == "hamming":
index.add(gallery_features) index.add(gallery_features)
else: else:
index.train(gallery_features) index.train(gallery_features)
if not config["dist_type"] == "hamming": if not config["dist_type"] == "hamming":
index.add_with_ids(gallery_features, ids_now) index.add_with_ids(gallery_features, ids_now)
for i, d in zip(list(ids_now), gallery_docs): for i, d in zip(list(ids_now), gallery_docs):
@ -157,25 +158,26 @@ class GalleryBuilder(object):
del ids[k] del ids[k]
# store faiss index file and id_map file # store faiss index file and id_map file
if config["dist_type"] == "hamming": if config["dist_type"] == "hamming":
faiss.write_index_binary(index, faiss.write_index_binary(
os.path.join(config["index_dir"], "vector.index")) index, os.path.join(config["index_dir"], "vector.index"))
else: else:
faiss.write_index(index, faiss.write_index(
os.path.join(config["index_dir"], "vector.index")) index, os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd: with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
pickle.dump(ids, fd) pickle.dump(ids, fd)
def _extract_features(self, gallery_images, config): def _extract_features(self, gallery_images, config):
# extract gallery features # extract gallery features
if config["dist_type"] == "hamming": if config["dist_type"] == "hamming":
gallery_features = np.zeros( gallery_features = np.zeros(
[len(gallery_images), config['embedding_size'] // 8], dtype=np.uint8) [len(gallery_images), config['embedding_size'] // 8],
dtype=np.uint8)
else: else:
gallery_features = np.zeros( gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']], dtype=np.float32) [len(gallery_images), config['embedding_size']],
dtype=np.float32)
#construct batch imgs and do inference #construct batch imgs and do inference
batch_size = config.get("batch_size", 32) batch_size = config.get("batch_size", 32)