137 lines
3.6 KiB
Python
137 lines
3.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
Common functions to load datasets and compute their ground-truth
|
|
"""
|
|
|
|
import time
|
|
import numpy as np
|
|
import faiss
|
|
|
|
from faiss.contrib import datasets as faiss_datasets
|
|
|
|
print("path:", faiss_datasets.__file__)
|
|
|
|
faiss_datasets.dataset_basedir = '/checkpoint/matthijs/simsearch/'
|
|
|
|
def sanitize(x):
|
|
return np.ascontiguousarray(x, dtype='float32')
|
|
|
|
|
|
#################################################################
|
|
# Dataset
|
|
#################################################################
|
|
|
|
class DatasetCentroids(faiss_datasets.Dataset):
|
|
|
|
def __init__(self, ds, indexfile):
|
|
self.d = ds.d
|
|
self.metric = ds.metric
|
|
self.nq = ds.nq
|
|
self.xq = ds.get_queries()
|
|
|
|
# get the xb set
|
|
src_index = faiss.read_index(indexfile)
|
|
src_quant = faiss.downcast_index(src_index.quantizer)
|
|
centroids = faiss.vector_to_array(src_quant.xb)
|
|
self.xb = centroids.reshape(-1, self.d)
|
|
self.nb = self.nt = len(self.xb)
|
|
|
|
def get_queries(self):
|
|
return self.xq
|
|
|
|
def get_database(self):
|
|
return self.xb
|
|
|
|
def get_train(self, maxtrain=None):
|
|
return self.xb
|
|
|
|
def get_groundtruth(self, k=100):
|
|
return faiss.knn(
|
|
self.xq, self.xb, k,
|
|
faiss.METRIC_L2 if self.metric == 'L2' else faiss.METRIC_INNER_PRODUCT
|
|
)[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_dataset(dataset='deep1M', compute_gt=False, download=False):
|
|
|
|
print("load data", dataset)
|
|
|
|
if dataset == 'sift1M':
|
|
return faiss_datasets.DatasetSIFT1M()
|
|
|
|
elif dataset.startswith('bigann'):
|
|
|
|
dbsize = 1000 if dataset == "bigann1B" else int(dataset[6:-1])
|
|
|
|
return faiss_datasets.DatasetBigANN(nb_M=dbsize)
|
|
|
|
elif dataset.startswith("deep_centroids_"):
|
|
ncent = int(dataset[len("deep_centroids_"):])
|
|
centdir = "/checkpoint/matthijs/bench_all_ivf/precomputed_clusters"
|
|
return DatasetCentroids(
|
|
faiss_datasets.DatasetDeep1B(nb=1000000),
|
|
f"{centdir}/clustering.dbdeep1M.IVF{ncent}.faissindex"
|
|
)
|
|
|
|
elif dataset.startswith("deep"):
|
|
|
|
szsuf = dataset[4:]
|
|
if szsuf[-1] == 'M':
|
|
dbsize = 10 ** 6 * int(szsuf[:-1])
|
|
elif szsuf == '1B':
|
|
dbsize = 10 ** 9
|
|
elif szsuf[-1] == 'k':
|
|
dbsize = 1000 * int(szsuf[:-1])
|
|
else:
|
|
assert False, "did not recognize suffix " + szsuf
|
|
return faiss_datasets.DatasetDeep1B(nb=dbsize)
|
|
|
|
elif dataset == "music-100":
|
|
return faiss_datasets.DatasetMusic100()
|
|
|
|
elif dataset == "glove":
|
|
return faiss_datasets.DatasetGlove(download=download)
|
|
|
|
else:
|
|
assert False
|
|
|
|
|
|
#################################################################
|
|
# Evaluation
|
|
#################################################################
|
|
|
|
|
|
def evaluate_DI(D, I, gt):
|
|
nq = gt.shape[0]
|
|
k = I.shape[1]
|
|
rank = 1
|
|
while rank <= k:
|
|
recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
|
|
print("R@%d: %.4f" % (rank, recall), end=' ')
|
|
rank *= 10
|
|
|
|
|
|
def evaluate(xq, gt, index, k=100, endl=True):
|
|
t0 = time.time()
|
|
D, I = index.search(xq, k)
|
|
t1 = time.time()
|
|
nq = xq.shape[0]
|
|
print("\t %8.4f ms per query, " % (
|
|
(t1 - t0) * 1000.0 / nq), end=' ')
|
|
rank = 1
|
|
while rank <= k:
|
|
recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
|
|
print("R@%d: %.4f" % (rank, recall), end=' ')
|
|
rank *= 10
|
|
if endl:
|
|
print()
|
|
return D, I
|