benchmark refactor
Summary: 1. Support for index construction parameters outside of the factory string (arbitrary depth of quantizers). 2. Refactor that provides an index wrapper which is a prereq for the optimizer, which will generate indices from pre-optimized components (particularly quantizers) Reviewed By: mdouze Differential Revision: D51427452 fbshipit-source-id: 014d05dd798d856360f2546963e7cad64c2fcaebpull/3164/head
parent
a5b03cb9f6
commit
9519a19f42
|
@ -1,23 +1,20 @@
|
|||
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from contextlib import contextmanager
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from operator import itemgetter
|
||||
from statistics import median, mean
|
||||
from time import perf_counter
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .index import Index, IndexFromCodec, IndexFromFactory
|
||||
from .descriptors import DatasetDescriptor, IndexDescriptor
|
||||
|
||||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
knn_intersection_measure,
|
||||
OperatingPointsWithRanges,
|
||||
)
|
||||
from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
add_preassigned,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
@ -27,56 +24,21 @@ from scipy.optimize import curve_fit
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timer(name) -> float:
|
||||
logger.info(f"Measuring {name}")
|
||||
t1 = t2 = perf_counter()
|
||||
yield lambda: t2 - t1
|
||||
t2 = perf_counter()
|
||||
logger.info(f"Time for {name}: {t2 - t1:.3f} seconds")
|
||||
|
||||
|
||||
def refine_distances_knn(
|
||||
D: np.ndarray, I: np.ndarray, xq: np.ndarray, xb: np.ndarray, metric
|
||||
):
|
||||
return np.where(
|
||||
I >= 0,
|
||||
np.square(np.linalg.norm(xq[:, None] - xb[I], axis=2))
|
||||
if metric == faiss.METRIC_L2
|
||||
else np.einsum("qd,qkd->qk", xq, xb[I]),
|
||||
D,
|
||||
)
|
||||
|
||||
|
||||
def refine_distances_range(
|
||||
lims: np.ndarray,
|
||||
D: np.ndarray,
|
||||
I: np.ndarray,
|
||||
xq: np.ndarray,
|
||||
xb: np.ndarray,
|
||||
metric,
|
||||
):
|
||||
with ThreadPool(32) as pool:
|
||||
R = pool.map(
|
||||
lambda i: (
|
||||
np.sum(np.square(xq[i] - xb[I[lims[i]:lims[i + 1]]]), axis=1)
|
||||
if metric == faiss.METRIC_L2
|
||||
else np.tensordot(
|
||||
xq[i], xb[I[lims[i]:lims[i + 1]]], axes=(0, 1)
|
||||
)
|
||||
)
|
||||
if lims[i + 1] > lims[i]
|
||||
else [],
|
||||
range(len(lims) - 1),
|
||||
)
|
||||
return np.hstack(R)
|
||||
|
||||
|
||||
def range_search_pr_curve(
|
||||
dist_ann: np.ndarray, metric_score: np.ndarray, gt_rsm: float
|
||||
):
|
||||
assert dist_ann.shape == metric_score.shape
|
||||
assert dist_ann.ndim == 1
|
||||
l = len(dist_ann)
|
||||
if l == 0:
|
||||
return {
|
||||
"dist_ann": [],
|
||||
"metric_score_sample": [],
|
||||
"cum_score": [],
|
||||
"precision": [],
|
||||
"recall": [],
|
||||
"unique_key": [],
|
||||
}
|
||||
sort_by_dist_ann = dist_ann.argsort()
|
||||
dist_ann = dist_ann[sort_by_dist_ann]
|
||||
metric_score = metric_score[sort_by_dist_ann]
|
||||
|
@ -87,7 +49,7 @@ def range_search_pr_curve(
|
|||
tbl = np.vstack(
|
||||
[dist_ann, metric_score, cum_score, precision, recall, unique_key]
|
||||
)
|
||||
group_by_dist_max_cum_score = np.empty(len(dist_ann), bool)
|
||||
group_by_dist_max_cum_score = np.empty(l, bool)
|
||||
group_by_dist_max_cum_score[-1] = True
|
||||
group_by_dist_max_cum_score[:-1] = dist_ann[1:] != dist_ann[:-1]
|
||||
tbl = tbl[:, group_by_dist_max_cum_score]
|
||||
|
@ -105,39 +67,7 @@ def range_search_pr_curve(
|
|||
}
|
||||
|
||||
|
||||
def set_index_parameter(index, name, val):
|
||||
index = faiss.downcast_index(index)
|
||||
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
set_index_parameter(index.index, name, val)
|
||||
elif name.startswith("quantizer_"):
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
set_index_parameter(
|
||||
index_ivf.quantizer, name[name.find("_") + 1:], val
|
||||
)
|
||||
elif name == "efSearch":
|
||||
index.hnsw.efSearch
|
||||
index.hnsw.efSearch = int(val)
|
||||
elif name == "nprobe":
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
index_ivf.nprobe
|
||||
index_ivf.nprobe = int(val)
|
||||
elif name == "noop":
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f"could not set param {name} on {index}")
|
||||
|
||||
|
||||
def optimizer(codec, search, cost_metric, perf_metric):
|
||||
op = OperatingPointsWithRanges()
|
||||
op.add_range("noop", [0])
|
||||
codec_ivf = faiss.try_extract_index_ivf(codec)
|
||||
if codec_ivf is not None:
|
||||
op.add_range(
|
||||
"nprobe",
|
||||
[2**i for i in range(12) if 2**i < codec_ivf.nlist * 0.1],
|
||||
)
|
||||
|
||||
def optimizer(op, search, cost_metric, perf_metric):
|
||||
totex = op.num_experiments()
|
||||
rs = np.random.RandomState(123)
|
||||
if totex > 1:
|
||||
|
@ -243,7 +173,7 @@ def get_range_search_metric_function(range_metric, D, R):
|
|||
cutoff,
|
||||
lambda x: np.where(x < cutoff, sigmoid(x, *popt), 0),
|
||||
popt.tolist(),
|
||||
list(zip(aradius, ascore, aradius_from, aradius_to, strict=True))
|
||||
list(zip(aradius, ascore, aradius_from, aradius_to, strict=True)),
|
||||
)
|
||||
else:
|
||||
# Assuming that the range_metric is a float,
|
||||
|
@ -265,21 +195,20 @@ def get_range_search_metric_function(range_metric, D, R):
|
|||
@dataclass
|
||||
class Benchmark:
|
||||
training_vectors: Optional[DatasetDescriptor] = None
|
||||
db_vectors: Optional[DatasetDescriptor] = None
|
||||
database_vectors: Optional[DatasetDescriptor] = None
|
||||
query_vectors: Optional[DatasetDescriptor] = None
|
||||
index_descs: Optional[List[IndexDescriptor]] = None
|
||||
range_ref_index_desc: Optional[str] = None
|
||||
k: Optional[int] = None
|
||||
distance_metric: str = "METRIC_L2"
|
||||
distance_metric: str = "L2"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.distance_metric == "METRIC_INNER_PRODUCT":
|
||||
if self.distance_metric == "IP":
|
||||
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
|
||||
elif self.distance_metric == "METRIC_L2":
|
||||
elif self.distance_metric == "L2":
|
||||
self.distance_metric_type = faiss.METRIC_L2
|
||||
else:
|
||||
raise ValueError
|
||||
self.cached_index_key = None
|
||||
|
||||
def set_io(self, benchmark_io):
|
||||
self.io = benchmark_io
|
||||
|
@ -292,54 +221,24 @@ class Benchmark:
|
|||
return desc
|
||||
return None
|
||||
|
||||
def get_index(self, index_desc: IndexDescriptor):
|
||||
if self.cached_index_key != index_desc.factory:
|
||||
xb = self.io.get_dataset(self.db_vectors)
|
||||
index = faiss.clone_index(
|
||||
self.io.get_codec(index_desc, xb.shape[1])
|
||||
)
|
||||
assert index.ntotal == 0
|
||||
logger.info("Adding vectors to index")
|
||||
index_ivf = faiss.try_extract_index_ivf(index)
|
||||
if index_ivf is not None:
|
||||
QD, QI, _, QP = self.knn_search(
|
||||
index_desc,
|
||||
parameters=None,
|
||||
db_vectors=None,
|
||||
query_vectors=self.db_vectors,
|
||||
k=1,
|
||||
index=index_ivf.quantizer,
|
||||
level=1,
|
||||
)
|
||||
print(f"{QI.ravel().shape=}")
|
||||
add_preassigned(index_ivf, xb, QI.ravel())
|
||||
else:
|
||||
index.add(xb)
|
||||
assert index.ntotal == xb.shape[0]
|
||||
logger.info("Added vectors to index")
|
||||
self.cached_index_key = index_desc.factory
|
||||
self.cached_index = index
|
||||
return self.cached_index
|
||||
|
||||
def range_search_reference(self, index_desc, range_metric):
|
||||
def range_search_reference(self, index, parameters, range_metric):
|
||||
logger.info("range_search_reference: begin")
|
||||
if isinstance(range_metric, list):
|
||||
assert len(range_metric) > 0
|
||||
ri = len(range_metric[0]) - 1
|
||||
m_radius = (
|
||||
max(rm[ri] for rm in range_metric)
|
||||
max(rm[-2] for rm in range_metric)
|
||||
if self.distance_metric_type == faiss.METRIC_L2
|
||||
else min(rm[ri] for rm in range_metric)
|
||||
else min(rm[-2] for rm in range_metric)
|
||||
)
|
||||
else:
|
||||
m_radius = range_metric
|
||||
|
||||
lims, D, I, R, P = self.range_search(
|
||||
index_desc,
|
||||
index_desc.parameters,
|
||||
index,
|
||||
parameters,
|
||||
radius=m_radius,
|
||||
)
|
||||
flat = index_desc.factory == "Flat"
|
||||
flat = index.factory == "Flat"
|
||||
(
|
||||
gt_radius,
|
||||
range_search_metric_function,
|
||||
|
@ -351,111 +250,61 @@ class Benchmark:
|
|||
R if not flat else None,
|
||||
)
|
||||
logger.info("range_search_reference: end")
|
||||
return gt_radius, range_search_metric_function, coefficients, coefficients_training_data
|
||||
return (
|
||||
gt_radius,
|
||||
range_search_metric_function,
|
||||
coefficients,
|
||||
coefficients_training_data,
|
||||
)
|
||||
|
||||
def estimate_range(self, index_desc, parameters, range_scoring_radius):
|
||||
D, I, R, P = self.knn_search(
|
||||
index_desc, parameters, self.db_vectors, self.query_vectors
|
||||
def estimate_range(self, index, parameters, range_scoring_radius):
|
||||
D, I, R, P = index.knn_search(
|
||||
parameters,
|
||||
self.query_vectors,
|
||||
self.k,
|
||||
)
|
||||
samples = []
|
||||
for i, j in np.argwhere(R < range_scoring_radius):
|
||||
samples.append((R[i, j].item(), D[i, j].item()))
|
||||
samples.sort(key=itemgetter(0))
|
||||
return median(r for _, r in samples[-3:])
|
||||
if len(samples) > 0: # estimate range
|
||||
samples.sort(key=itemgetter(0))
|
||||
return median(r for _, r in samples[-3:])
|
||||
else: # ensure at least one result
|
||||
i, j = np.argwhere(R.min() == R)[0]
|
||||
return D[i, j].item()
|
||||
|
||||
def range_search(
|
||||
self,
|
||||
index_desc: IndexDescriptor,
|
||||
parameters: Optional[dict[str, int]],
|
||||
index: Index,
|
||||
search_parameters: Optional[Dict[str, int]],
|
||||
radius: Optional[float] = None,
|
||||
gt_radius: Optional[float] = None,
|
||||
):
|
||||
logger.info("range_search: begin")
|
||||
flat = index_desc.factory == "Flat"
|
||||
if radius is None:
|
||||
assert gt_radius is not None
|
||||
radius = (
|
||||
gt_radius
|
||||
if flat
|
||||
else self.estimate_range(index_desc, parameters, gt_radius)
|
||||
if index.is_flat()
|
||||
else self.estimate_range(
|
||||
index,
|
||||
search_parameters,
|
||||
gt_radius,
|
||||
)
|
||||
)
|
||||
logger.info(f"Radius={radius}")
|
||||
filename = self.io.get_filename_range_search(
|
||||
factory=index_desc.factory,
|
||||
parameters=parameters,
|
||||
level=0,
|
||||
db_vectors=self.db_vectors,
|
||||
return index.range_search(
|
||||
search_parameters=search_parameters,
|
||||
query_vectors=self.query_vectors,
|
||||
r=radius,
|
||||
radius=radius,
|
||||
)
|
||||
if self.io.file_exist(filename):
|
||||
logger.info(f"Using cached results for {index_desc.factory}")
|
||||
lims, D, I, R, P = self.io.read_file(
|
||||
filename, ["lims", "D", "I", "R", "P"]
|
||||
)
|
||||
else:
|
||||
xq = self.io.get_dataset(self.query_vectors)
|
||||
index = self.get_index(index_desc)
|
||||
if parameters:
|
||||
for name, val in parameters.items():
|
||||
set_index_parameter(index, name, val)
|
||||
|
||||
index_ivf = faiss.try_extract_index_ivf(index)
|
||||
if index_ivf is not None:
|
||||
QD, QI, _, QP = self.knn_search(
|
||||
index_desc,
|
||||
parameters=None,
|
||||
db_vectors=None,
|
||||
query_vectors=self.query_vectors,
|
||||
k=index.nprobe,
|
||||
index=index_ivf.quantizer,
|
||||
level=1,
|
||||
)
|
||||
# QD = QD[:, :index.nprobe]
|
||||
# QI = QI[:, :index.nprobe]
|
||||
faiss.cvar.indexIVF_stats.reset()
|
||||
with timer("range_search_preassigned") as t:
|
||||
lims, D, I = index.range_search_preassigned(xq, radius, QI, QD)
|
||||
else:
|
||||
with timer("range_search") as t:
|
||||
lims, D, I = index.range_search(xq, radius)
|
||||
if flat:
|
||||
R = D
|
||||
else:
|
||||
xb = self.io.get_dataset(self.db_vectors)
|
||||
R = refine_distances_range(
|
||||
lims, D, I, xq, xb, self.distance_metric_type
|
||||
)
|
||||
P = {
|
||||
"time": t(),
|
||||
"radius": radius,
|
||||
"count": lims[-1].item(),
|
||||
"parameters": parameters,
|
||||
"index": index_desc.factory,
|
||||
}
|
||||
if index_ivf is not None:
|
||||
stats = faiss.cvar.indexIVF_stats
|
||||
P |= {
|
||||
"quantizer": QP,
|
||||
"nq": stats.nq,
|
||||
"nlist": stats.nlist,
|
||||
"ndis": stats.ndis,
|
||||
"nheap_updates": stats.nheap_updates,
|
||||
"quantization_time": stats.quantization_time,
|
||||
"search_time": stats.search_time,
|
||||
}
|
||||
self.io.write_file(
|
||||
filename, ["lims", "D", "I", "R", "P"], [lims, D, I, R, P]
|
||||
)
|
||||
logger.info("range_seach: end")
|
||||
return lims, D, I, R, P
|
||||
|
||||
def range_ground_truth(self, gt_radius, range_search_metric_function):
|
||||
logger.info("range_ground_truth: begin")
|
||||
flat_desc = self.get_index_desc("Flat")
|
||||
lims, D, I, R, P = self.range_search(
|
||||
flat_desc,
|
||||
flat_desc.parameters,
|
||||
flat_desc.index,
|
||||
search_parameters=None,
|
||||
radius=gt_radius,
|
||||
)
|
||||
gt_rsm = np.sum(range_search_metric_function(R)).item()
|
||||
|
@ -464,37 +313,32 @@ class Benchmark:
|
|||
|
||||
def range_search_benchmark(
|
||||
self,
|
||||
results: dict[str, Any],
|
||||
index_desc: IndexDescriptor,
|
||||
results: Dict[str, Any],
|
||||
index: Index,
|
||||
metric_key: str,
|
||||
radius: float,
|
||||
gt_radius: float,
|
||||
range_search_metric_function,
|
||||
gt_rsm: float,
|
||||
):
|
||||
logger.info(f"range_search_benchmark: begin {index_desc.factory=}")
|
||||
xq = self.io.get_dataset(self.query_vectors)
|
||||
(nq, d) = xq.shape
|
||||
logger.info(
|
||||
f"Searching {index_desc.factory} with {nq} vectors of dimension {d}"
|
||||
)
|
||||
codec = self.io.get_codec(index_desc, d)
|
||||
faiss.omp_set_num_threads(16)
|
||||
logger.info(f"range_search_benchmark: begin {index.get_index_name()}")
|
||||
|
||||
def experiment(parameters, cost_metric, perf_metric):
|
||||
nonlocal results
|
||||
key = self.io.get_filename_evaluation_name(
|
||||
factory=index_desc.factory,
|
||||
parameters=parameters,
|
||||
level=0,
|
||||
db_vectors=self.db_vectors,
|
||||
key = index.get_range_search_name(
|
||||
search_parameters=parameters,
|
||||
query_vectors=self.query_vectors,
|
||||
evaluation_name=metric_key,
|
||||
radius=radius,
|
||||
)
|
||||
key += metric_key
|
||||
if key in results["experiments"]:
|
||||
metrics = results["experiments"][key]
|
||||
else:
|
||||
lims, D, I, R, P = self.range_search(
|
||||
index_desc, parameters, gt_radius=gt_radius
|
||||
index,
|
||||
parameters,
|
||||
radius=radius,
|
||||
gt_radius=gt_radius,
|
||||
)
|
||||
range_search_metric = range_search_metric_function(R)
|
||||
range_search_pr = range_search_pr_curve(
|
||||
|
@ -511,8 +355,9 @@ class Benchmark:
|
|||
|
||||
for cost_metric in ["time"]:
|
||||
for perf_metric in ["range_score_max_recall"]:
|
||||
op = index.get_operating_points()
|
||||
optimizer(
|
||||
codec,
|
||||
op,
|
||||
experiment,
|
||||
cost_metric,
|
||||
perf_metric,
|
||||
|
@ -520,134 +365,33 @@ class Benchmark:
|
|||
logger.info("range_search_benchmark: end")
|
||||
return results
|
||||
|
||||
def knn_search(
|
||||
self,
|
||||
index_desc: IndexDescriptor,
|
||||
parameters: Optional[dict[str, int]],
|
||||
db_vectors: Optional[DatasetDescriptor],
|
||||
query_vectors: DatasetDescriptor,
|
||||
k: Optional[int] = None,
|
||||
index: Optional[faiss.Index] = None,
|
||||
level: int = 0,
|
||||
):
|
||||
assert level >= 0
|
||||
if level == 0:
|
||||
assert index is None
|
||||
assert db_vectors is not None
|
||||
else:
|
||||
assert index is not None # quantizer
|
||||
assert db_vectors is None
|
||||
logger.info("knn_seach: begin")
|
||||
k = k if k is not None else self.k
|
||||
flat = index_desc.factory == "Flat"
|
||||
filename = self.io.get_filename_knn_search(
|
||||
factory=index_desc.factory,
|
||||
parameters=parameters,
|
||||
level=level,
|
||||
db_vectors=db_vectors,
|
||||
query_vectors=query_vectors,
|
||||
k=k,
|
||||
)
|
||||
if self.io.file_exist(filename):
|
||||
logger.info(f"Using cached results for {index_desc.factory}")
|
||||
D, I, R, P = self.io.read_file(filename, ["D", "I", "R", "P"])
|
||||
else:
|
||||
xq = self.io.get_dataset(query_vectors)
|
||||
if index is None:
|
||||
index = self.get_index(index_desc)
|
||||
if parameters:
|
||||
for name, val in parameters.items():
|
||||
set_index_parameter(index, name, val)
|
||||
|
||||
index_ivf = faiss.try_extract_index_ivf(index)
|
||||
if index_ivf is not None:
|
||||
QD, QI, _, QP = self.knn_search(
|
||||
index_desc,
|
||||
parameters=None,
|
||||
db_vectors=None,
|
||||
query_vectors=query_vectors,
|
||||
k=index.nprobe,
|
||||
index=index_ivf.quantizer,
|
||||
level=level + 1,
|
||||
)
|
||||
# QD = QD[:, :index.nprobe]
|
||||
# QI = QI[:, :index.nprobe]
|
||||
faiss.cvar.indexIVF_stats.reset()
|
||||
with timer("knn search_preassigned") as t:
|
||||
D, I = index.search_preassigned(xq, k, QI, QD)
|
||||
else:
|
||||
with timer("knn search") as t:
|
||||
D, I = index.search(xq, k)
|
||||
if flat or level > 0:
|
||||
R = D
|
||||
else:
|
||||
xb = self.io.get_dataset(db_vectors)
|
||||
R = refine_distances_knn(
|
||||
D, I, xq, xb, self.distance_metric_type
|
||||
)
|
||||
P = {
|
||||
"time": t(),
|
||||
"parameters": parameters,
|
||||
"index": index_desc.factory,
|
||||
"level": level,
|
||||
}
|
||||
if index_ivf is not None:
|
||||
stats = faiss.cvar.indexIVF_stats
|
||||
P |= {
|
||||
"quantizer": QP,
|
||||
"nq": stats.nq,
|
||||
"nlist": stats.nlist,
|
||||
"ndis": stats.ndis,
|
||||
"nheap_updates": stats.nheap_updates,
|
||||
"quantization_time": stats.quantization_time,
|
||||
"search_time": stats.search_time,
|
||||
}
|
||||
self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P])
|
||||
logger.info("knn_seach: end")
|
||||
return D, I, R, P
|
||||
|
||||
def knn_ground_truth(self):
|
||||
logger.info("knn_ground_truth: begin")
|
||||
flat_desc = self.get_index_desc("Flat")
|
||||
self.gt_knn_D, self.gt_knn_I, _, _ = self.knn_search(
|
||||
flat_desc,
|
||||
flat_desc.parameters,
|
||||
self.db_vectors,
|
||||
self.query_vectors,
|
||||
self.gt_knn_D, self.gt_knn_I, _, _ = flat_desc.index.knn_search(
|
||||
search_parameters=None,
|
||||
query_vectors=self.query_vectors,
|
||||
k=self.k,
|
||||
)
|
||||
logger.info("knn_ground_truth: end")
|
||||
|
||||
def knn_search_benchmark(
|
||||
self, results: dict[str, Any], index_desc: IndexDescriptor
|
||||
):
|
||||
logger.info(f"knn_search_benchmark: begin {index_desc.factory=}")
|
||||
xq = self.io.get_dataset(self.query_vectors)
|
||||
(nq, d) = xq.shape
|
||||
logger.info(
|
||||
f"Searching {index_desc.factory} with {nq} vectors of dimension {d}"
|
||||
)
|
||||
codec = self.io.get_codec(index_desc, d)
|
||||
codec_ivf = faiss.try_extract_index_ivf(codec)
|
||||
if codec_ivf is not None:
|
||||
results["indices"][index_desc.factory] = {"nlist": codec_ivf.nlist}
|
||||
|
||||
faiss.omp_set_num_threads(16)
|
||||
def knn_search_benchmark(self, results: Dict[str, Any], index: Index):
|
||||
index_name = index.get_index_name()
|
||||
logger.info(f"knn_search_benchmark: begin {index_name}")
|
||||
|
||||
def experiment(parameters, cost_metric, perf_metric):
|
||||
nonlocal results
|
||||
key = self.io.get_filename_evaluation_name(
|
||||
factory=index_desc.factory,
|
||||
parameters=parameters,
|
||||
level=0,
|
||||
db_vectors=self.db_vectors,
|
||||
query_vectors=self.query_vectors,
|
||||
evaluation_name="knn",
|
||||
key = index.get_knn_search_name(
|
||||
parameters,
|
||||
self.query_vectors,
|
||||
self.k,
|
||||
)
|
||||
key += "knn"
|
||||
if key in results["experiments"]:
|
||||
metrics = results["experiments"][key]
|
||||
else:
|
||||
D, I, R, P = self.knn_search(
|
||||
index_desc, parameters, self.db_vectors, self.query_vectors
|
||||
D, I, R, P = index.knn_search(
|
||||
parameters, self.query_vectors, self.k
|
||||
)
|
||||
metrics = P | {
|
||||
"knn_intersection": knn_intersection_measure(
|
||||
|
@ -662,8 +406,9 @@ class Benchmark:
|
|||
|
||||
for cost_metric in ["time"]:
|
||||
for perf_metric in ["knn_intersection", "distance_ratio"]:
|
||||
op = index.get_operating_points()
|
||||
optimizer(
|
||||
codec,
|
||||
op,
|
||||
experiment,
|
||||
cost_metric,
|
||||
perf_metric,
|
||||
|
@ -671,18 +416,61 @@ class Benchmark:
|
|||
logger.info("knn_search_benchmark: end")
|
||||
return results
|
||||
|
||||
def benchmark(self) -> str:
|
||||
logger.info("begin evaluate")
|
||||
results = {"indices": {}, "experiments": {}}
|
||||
def train(self, results):
|
||||
xq = self.io.get_dataset(self.query_vectors)
|
||||
self.d = xq.shape[1]
|
||||
if self.get_index_desc("Flat") is None:
|
||||
self.index_descs.append(IndexDescriptor(factory="Flat"))
|
||||
for index_desc in self.index_descs:
|
||||
if index_desc.factory is not None:
|
||||
index = IndexFromFactory(
|
||||
d=self.d,
|
||||
metric=self.distance_metric,
|
||||
database_vectors=self.database_vectors,
|
||||
search_params=index_desc.search_params,
|
||||
construction_params=index_desc.construction_params,
|
||||
factory=index_desc.factory,
|
||||
training_vectors=self.training_vectors,
|
||||
)
|
||||
index.set_io(self.io)
|
||||
index.train()
|
||||
index_desc.index = index
|
||||
results["indices"][index.get_codec_name()] = {
|
||||
"code_size": index.get_code_size()
|
||||
}
|
||||
else:
|
||||
index = IndexFromCodec(
|
||||
d=self.d,
|
||||
metric=self.distance_metric,
|
||||
database_vectors=self.database_vectors,
|
||||
search_params=index_desc.search_params,
|
||||
construction_params=index_desc.construction_params,
|
||||
path=index_desc.path,
|
||||
bucket=index_desc.bucket,
|
||||
)
|
||||
index.set_io(self.io)
|
||||
index_desc.index = index
|
||||
results["indices"][index.get_codec_name()] = {
|
||||
"code_size": index.get_code_size()
|
||||
}
|
||||
return results
|
||||
|
||||
def benchmark(self, result_file=None):
|
||||
logger.info("begin evaluate")
|
||||
|
||||
faiss.omp_set_num_threads(24)
|
||||
results = {"indices": {}, "experiments": {}}
|
||||
results = self.train(results)
|
||||
|
||||
# knn search
|
||||
self.knn_ground_truth()
|
||||
for index_desc in self.index_descs:
|
||||
results = self.knn_search_benchmark(
|
||||
results=results,
|
||||
index_desc=index_desc,
|
||||
index=index_desc.index,
|
||||
)
|
||||
|
||||
# range search
|
||||
if self.range_ref_index_desc is not None:
|
||||
index_desc = self.get_index_desc(self.range_ref_index_desc)
|
||||
if index_desc is None:
|
||||
|
@ -700,7 +488,9 @@ class Benchmark:
|
|||
range_search_metric_function,
|
||||
coefficients,
|
||||
coefficients_training_data,
|
||||
) = self.range_search_reference(index_desc, range_metric)
|
||||
) = self.range_search_reference(
|
||||
index_desc.index, index_desc.search_params, range_metric
|
||||
)
|
||||
results["metrics"][metric_key] = {
|
||||
"coefficients": coefficients,
|
||||
"training_data": coefficients_training_data,
|
||||
|
@ -709,14 +499,18 @@ class Benchmark:
|
|||
gt_radius, range_search_metric_function
|
||||
)
|
||||
for index_desc in self.index_descs:
|
||||
if not index_desc.index.supports_range_search():
|
||||
continue
|
||||
results = self.range_search_benchmark(
|
||||
results=results,
|
||||
index_desc=index_desc,
|
||||
index=index_desc.index,
|
||||
metric_key=metric_key,
|
||||
radius=index_desc.radius,
|
||||
gt_radius=gt_radius,
|
||||
range_search_metric_function=range_search_metric_function,
|
||||
gt_rsm=gt_rsm,
|
||||
)
|
||||
self.io.write_json(results, "result.json", overwrite=True)
|
||||
if result_file is not None:
|
||||
self.io.write_json(results, result_file, overwrite=True)
|
||||
logger.info("end evaluate")
|
||||
return json.dumps(results)
|
||||
return results
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
from zipfile import ZipFile
|
||||
|
@ -9,115 +16,45 @@ from zipfile import ZipFile
|
|||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .descriptors import DatasetDescriptor, IndexDescriptor
|
||||
from faiss.contrib.datasets import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
dataset_from_name,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# merge RCQ coarse quantizer and ITQ encoder to one Faiss index
|
||||
def merge_rcq_itq(
|
||||
# pyre-ignore[11]: `faiss.ResidualCoarseQuantizer` is not defined as a type
|
||||
rcq_coarse_quantizer: faiss.ResidualCoarseQuantizer,
|
||||
itq_encoder: faiss.IndexPreTransform,
|
||||
# pyre-ignore[11]: `faiss.IndexIVFSpectralHash` is not defined as a type.
|
||||
) -> faiss.IndexIVFSpectralHash:
|
||||
# pyre-ignore[16]: `faiss` has no attribute `IndexIVFSpectralHash`.
|
||||
index = faiss.IndexIVFSpectralHash(
|
||||
rcq_coarse_quantizer,
|
||||
rcq_coarse_quantizer.d,
|
||||
rcq_coarse_quantizer.ntotal,
|
||||
itq_encoder.sa_code_size() * 8,
|
||||
1000000, # larger than the magnitude of the vectors
|
||||
)
|
||||
index.replace_vt(itq_encoder)
|
||||
return index
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkIO:
|
||||
path: str
|
||||
|
||||
def __post_init__(self):
|
||||
self.cached_ds = {}
|
||||
self.cached_codec_key = None
|
||||
|
||||
def get_filename_search(
|
||||
self,
|
||||
factory: str,
|
||||
parameters: Optional[dict[str, int]],
|
||||
level: int,
|
||||
db_vectors: DatasetDescriptor,
|
||||
query_vectors: DatasetDescriptor,
|
||||
k: Optional[int] = None,
|
||||
r: Optional[float] = None,
|
||||
evaluation_name: Optional[str] = None,
|
||||
):
|
||||
assert factory is not None
|
||||
assert level is not None
|
||||
assert self.distance_metric is not None
|
||||
assert query_vectors is not None
|
||||
assert self.distance_metric is not None
|
||||
filename = f"{factory.lower().replace(',', '_')}."
|
||||
if level > 0:
|
||||
filename += f"l_{level}."
|
||||
if db_vectors is not None:
|
||||
filename += db_vectors.get_filename("d")
|
||||
filename += query_vectors.get_filename("q")
|
||||
filename += self.distance_metric.upper() + "."
|
||||
if k is not None:
|
||||
filename += f"k_{k}."
|
||||
if r is not None:
|
||||
filename += f"r_{int(r * 1000)}."
|
||||
if parameters is not None:
|
||||
for name, val in parameters.items():
|
||||
if name != "noop":
|
||||
filename += f"{name}_{val}."
|
||||
if evaluation_name is None:
|
||||
filename += "zip"
|
||||
else:
|
||||
filename += evaluation_name
|
||||
return filename
|
||||
|
||||
def get_filename_knn_search(
|
||||
self,
|
||||
factory: str,
|
||||
parameters: Optional[dict[str, int]],
|
||||
level: int,
|
||||
db_vectors: DatasetDescriptor,
|
||||
query_vectors: DatasetDescriptor,
|
||||
k: int,
|
||||
):
|
||||
assert k is not None
|
||||
return self.get_filename_search(
|
||||
factory=factory,
|
||||
parameters=parameters,
|
||||
level=level,
|
||||
db_vectors=db_vectors,
|
||||
query_vectors=query_vectors,
|
||||
k=k,
|
||||
)
|
||||
|
||||
def get_filename_range_search(
|
||||
self,
|
||||
factory: str,
|
||||
parameters: Optional[dict[str, int]],
|
||||
level: int,
|
||||
db_vectors: DatasetDescriptor,
|
||||
query_vectors: DatasetDescriptor,
|
||||
r: float,
|
||||
):
|
||||
assert r is not None
|
||||
return self.get_filename_search(
|
||||
factory=factory,
|
||||
parameters=parameters,
|
||||
level=level,
|
||||
db_vectors=db_vectors,
|
||||
query_vectors=query_vectors,
|
||||
r=r,
|
||||
)
|
||||
|
||||
def get_filename_evaluation_name(
|
||||
self,
|
||||
factory: str,
|
||||
parameters: Optional[dict[str, int]],
|
||||
level: int,
|
||||
db_vectors: DatasetDescriptor,
|
||||
query_vectors: DatasetDescriptor,
|
||||
evaluation_name: str,
|
||||
):
|
||||
assert evaluation_name is not None
|
||||
return self.get_filename_search(
|
||||
factory=factory,
|
||||
parameters=parameters,
|
||||
level=level,
|
||||
db_vectors=db_vectors,
|
||||
query_vectors=query_vectors,
|
||||
evaluation_name=evaluation_name,
|
||||
)
|
||||
|
||||
def get_local_filename(self, filename):
|
||||
if len(filename) > 184:
|
||||
fn, ext = os.path.splitext(filename)
|
||||
filename = (
|
||||
fn[:184] + hashlib.sha256(filename.encode()).hexdigest() + ext
|
||||
)
|
||||
return os.path.join(self.path, filename)
|
||||
|
||||
def download_file_from_blobstore(
|
||||
|
@ -143,22 +80,6 @@ class BenchmarkIO:
|
|||
logger.info(f"{filename} {exists=}")
|
||||
return exists
|
||||
|
||||
def get_codec(self, index_desc: IndexDescriptor, d: int):
|
||||
if index_desc.factory == "Flat":
|
||||
return faiss.IndexFlat(d, self.distance_metric_type)
|
||||
else:
|
||||
if self.cached_codec_key != index_desc.factory:
|
||||
codec = faiss.read_index(
|
||||
self.get_local_filename(index_desc.path)
|
||||
)
|
||||
assert (
|
||||
codec.metric_type == self.distance_metric_type
|
||||
), f"{codec.metric_type=} != {self.distance_metric_type=}"
|
||||
logger.info(f"Loaded codec from {index_desc.path}")
|
||||
self.cached_codec_key = index_desc.factory
|
||||
self.cached_codec = codec
|
||||
return self.cached_codec
|
||||
|
||||
def read_file(self, filename: str, keys: List[str]):
|
||||
fn = self.download_file_from_blobstore(filename)
|
||||
logger.info(f"Loading file {fn}")
|
||||
|
@ -196,19 +117,50 @@ class BenchmarkIO:
|
|||
self.upload_file_to_blobstore(filename, overwrite=overwrite)
|
||||
|
||||
def get_dataset(self, dataset):
|
||||
if dataset not in self.cached_ds:
|
||||
self.cached_ds[dataset] = self.read_nparray(
|
||||
os.path.join(self.path, dataset.tablename)
|
||||
)
|
||||
if dataset.namespace is not None and dataset.namespace[:4] == "std_":
|
||||
if dataset.tablename not in self.cached_ds:
|
||||
self.cached_ds[dataset.tablename] = dataset_from_name(
|
||||
dataset.tablename,
|
||||
)
|
||||
p = dataset.namespace[4]
|
||||
if p == "t":
|
||||
return self.cached_ds[dataset.tablename].get_train()
|
||||
elif p == "d":
|
||||
return self.cached_ds[dataset.tablename].get_database()
|
||||
elif p == "q":
|
||||
return self.cached_ds[dataset.tablename].get_queries()
|
||||
else:
|
||||
raise ValueError
|
||||
elif dataset not in self.cached_ds:
|
||||
if dataset.namespace == "syn":
|
||||
d, seed = dataset.tablename.split("_")
|
||||
d = int(d)
|
||||
seed = int(seed)
|
||||
n = dataset.num_vectors
|
||||
# based on faiss.contrib.datasets.SyntheticDataset
|
||||
d1 = 10
|
||||
rs = np.random.RandomState(seed)
|
||||
x = rs.normal(size=(n, d1))
|
||||
x = np.dot(x, rs.rand(d1, d))
|
||||
x = x * (rs.rand(d) * 4 + 0.1)
|
||||
x = np.sin(x)
|
||||
x = x.astype(np.float32)
|
||||
self.cached_ds[dataset] = x
|
||||
else:
|
||||
self.cached_ds[dataset] = self.read_nparray(
|
||||
os.path.join(self.path, dataset.tablename),
|
||||
mmap_mode="r",
|
||||
)[: dataset.num_vectors].copy()
|
||||
return self.cached_ds[dataset]
|
||||
|
||||
def read_nparray(
|
||||
self,
|
||||
filename: str,
|
||||
mmap_mode: Optional[str] = None,
|
||||
):
|
||||
fn = self.download_file_from_blobstore(filename)
|
||||
logger.info(f"Loading nparray from {fn}")
|
||||
nparray = np.load(fn)
|
||||
nparray = np.load(fn, mmap_mode=mmap_mode)
|
||||
logger.info(f"Loaded nparray {nparray.shape} from {fn}")
|
||||
return nparray
|
||||
|
||||
|
@ -244,3 +196,32 @@ class BenchmarkIO:
|
|||
with open(fn, "w") as fp:
|
||||
json.dump(json_dict, fp)
|
||||
self.upload_file_to_blobstore(filename, overwrite=overwrite)
|
||||
|
||||
def read_index(
|
||||
self,
|
||||
filename: str,
|
||||
bucket: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
):
|
||||
fn = self.download_file_from_blobstore(filename, bucket, path)
|
||||
logger.info(f"Loading index {fn}")
|
||||
ext = os.path.splitext(fn)[1]
|
||||
if ext in [".faiss", ".codec"]:
|
||||
index = faiss.read_index(fn)
|
||||
elif ext == ".pkl":
|
||||
with open(fn, "rb") as model_file:
|
||||
model = pickle.load(model_file)
|
||||
rcq_coarse_quantizer, itq_encoder = model["model"]
|
||||
index = merge_rcq_itq(rcq_coarse_quantizer, itq_encoder)
|
||||
logger.info(f"Loaded index from {fn}")
|
||||
return index
|
||||
|
||||
def write_index(
|
||||
self,
|
||||
index: faiss.Index,
|
||||
filename: str,
|
||||
):
|
||||
fn = self.get_local_filename(filename)
|
||||
logger.info(f"Saving index to {fn}")
|
||||
faiss.write_index(index, fn)
|
||||
self.upload_file_to_blobstore(filename)
|
||||
|
|
|
@ -1,15 +1,21 @@
|
|||
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexDescriptor:
|
||||
factory: str
|
||||
bucket: Optional[str] = None
|
||||
# either path or factory should be set,
|
||||
# but not both at the same time.
|
||||
path: Optional[str] = None
|
||||
parameters: Optional[dict[str, int]] = None
|
||||
factory: Optional[str] = None
|
||||
construction_params: Optional[List[Dict[str, int]]] = None
|
||||
search_params: Optional[Dict[str, int]] = None
|
||||
# range metric definitions
|
||||
# key: name
|
||||
# value: one of the following:
|
||||
|
@ -25,14 +31,38 @@ class IndexDescriptor:
|
|||
# [[radius1_from, radius1_to, score1], ...]
|
||||
# [radius1_from, radius1_to) -> score1,
|
||||
# [radius2_from, radius2_to) -> score2
|
||||
range_metrics: Optional[dict[str, Any]] = None
|
||||
range_metrics: Optional[Dict[str, Any]] = None
|
||||
radius: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetDescriptor:
|
||||
# namespace possible values:
|
||||
# 1. a hive namespace
|
||||
# 2. 'std_t', 'std_d', 'std_q' for the standard datasets
|
||||
# via faiss.contrib.datasets.dataset_from_name()
|
||||
# t - training, d - database, q - queries
|
||||
# eg. "std_t"
|
||||
# 3. 'syn' for synthetic data
|
||||
# 4. None for local files
|
||||
namespace: Optional[str] = None
|
||||
|
||||
# tablename possible values, corresponding to the
|
||||
# namespace value above:
|
||||
# 1. a hive table name
|
||||
# 2. name of the standard dataset as recognized
|
||||
# by faiss.contrib.datasets.dataset_from_name()
|
||||
# eg. "bigann1M"
|
||||
# 3. d_seed, eg. 128_1234 for 128 dimensional vectors
|
||||
# with seed 1234
|
||||
# 4. a local file name (relative to benchmark_io.path)
|
||||
tablename: Optional[str] = None
|
||||
|
||||
# partition names and values for hive
|
||||
# eg. ["ds=2021-09-01"]
|
||||
partitions: Optional[List[str]] = None
|
||||
|
||||
# number of vectors to load from the dataset
|
||||
num_vectors: Optional[int] = None
|
||||
|
||||
def __hash__(self):
|
||||
|
@ -40,9 +70,11 @@ class DatasetDescriptor:
|
|||
|
||||
def get_filename(
|
||||
self,
|
||||
prefix: str = "v",
|
||||
prefix: str = None,
|
||||
) -> str:
|
||||
filename = prefix + "_"
|
||||
filename = ""
|
||||
if prefix is not None:
|
||||
filename += prefix + "_"
|
||||
if self.namespace is not None:
|
||||
filename += self.namespace + "_"
|
||||
assert self.tablename is not None
|
||||
|
|
|
@ -0,0 +1,785 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from time import perf_counter
|
||||
from typing import ClassVar, Dict, List, Optional
|
||||
|
||||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
|
||||
import numpy as np
|
||||
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
OperatingPointsWithRanges,
|
||||
)
|
||||
|
||||
from faiss.contrib.factory_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
reverse_index_factory,
|
||||
)
|
||||
from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
add_preassigned,
|
||||
replace_ivf_quantizer,
|
||||
)
|
||||
|
||||
from .descriptors import DatasetDescriptor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def timer(name, func, once=False) -> float:
|
||||
logger.info(f"Measuring {name}")
|
||||
t1 = perf_counter()
|
||||
res = func()
|
||||
t2 = perf_counter()
|
||||
t = t2 - t1
|
||||
repeat = 1
|
||||
if not once and t < 1.0:
|
||||
repeat = int(2.0 // t)
|
||||
logger.info(
|
||||
f"Time for {name}: {t:.3f} seconds, repeating {repeat} times"
|
||||
)
|
||||
t1 = perf_counter()
|
||||
for _ in range(repeat):
|
||||
res = func()
|
||||
t2 = perf_counter()
|
||||
t = (t2 - t1) / repeat
|
||||
logger.info(f"Time for {name}: {t:.3f} seconds")
|
||||
return res, t, repeat
|
||||
|
||||
|
||||
def refine_distances_knn(
|
||||
D: np.ndarray, I: np.ndarray, xq: np.ndarray, xb: np.ndarray, metric
|
||||
):
|
||||
return np.where(
|
||||
I >= 0,
|
||||
np.square(np.linalg.norm(xq[:, None] - xb[I], axis=2))
|
||||
if metric == faiss.METRIC_L2
|
||||
else np.einsum("qd,qkd->qk", xq, xb[I]),
|
||||
D,
|
||||
)
|
||||
|
||||
|
||||
def refine_distances_range(
|
||||
lims: np.ndarray,
|
||||
D: np.ndarray,
|
||||
I: np.ndarray,
|
||||
xq: np.ndarray,
|
||||
xb: np.ndarray,
|
||||
metric,
|
||||
):
|
||||
with ThreadPool(32) as pool:
|
||||
R = pool.map(
|
||||
lambda i: (
|
||||
np.sum(np.square(xq[i] - xb[I[lims[i]:lims[i + 1]]]), axis=1)
|
||||
if metric == faiss.METRIC_L2
|
||||
else np.tensordot(
|
||||
xq[i], xb[I[lims[i]:lims[i + 1]]], axes=(0, 1)
|
||||
)
|
||||
)
|
||||
if lims[i + 1] > lims[i]
|
||||
else [],
|
||||
range(len(lims) - 1),
|
||||
)
|
||||
return np.hstack(R)
|
||||
|
||||
|
||||
# The classes below are wrappers around Faiss indices, with different
|
||||
# implementations for the case when we start with an already trained
|
||||
# index (IndexFromCodec) vs factory strings (IndexFromFactory).
|
||||
# In both cases the classes have operations for adding to an index
|
||||
# and searching it, and outputs are cached on disk.
|
||||
# IndexFromFactory also decomposes the index (pretransform and quantizer)
|
||||
# and trains sub-indices independently.
|
||||
class IndexBase:
|
||||
def set_io(self, benchmark_io):
|
||||
self.io = benchmark_io
|
||||
|
||||
@staticmethod
|
||||
def param_dict_list_to_name(param_dict_list):
|
||||
if not param_dict_list:
|
||||
return ""
|
||||
l = 0
|
||||
n = ""
|
||||
for param_dict in param_dict_list:
|
||||
n += IndexBase.param_dict_to_name(param_dict, f"cp{l}")
|
||||
return n
|
||||
|
||||
@staticmethod
|
||||
def param_dict_to_name(param_dict, prefix="sp"):
|
||||
if not param_dict:
|
||||
return ""
|
||||
n = prefix
|
||||
for name, val in param_dict.items():
|
||||
if name != "noop":
|
||||
n += f"_{name}_{val}"
|
||||
if n == prefix:
|
||||
return ""
|
||||
n += "."
|
||||
return n
|
||||
|
||||
@staticmethod
|
||||
def set_index_param_dict_list(index, param_dict_list):
|
||||
if not param_dict_list:
|
||||
return
|
||||
index = faiss.downcast_index(index)
|
||||
for param_dict in param_dict_list:
|
||||
assert index is not None
|
||||
IndexBase.set_index_param_dict(index, param_dict)
|
||||
index = faiss.try_extract_index_ivf(index)
|
||||
|
||||
@staticmethod
|
||||
def set_index_param_dict(index, param_dict):
|
||||
if not param_dict:
|
||||
return
|
||||
for name, val in param_dict.items():
|
||||
IndexBase.set_index_param(index, name, val)
|
||||
|
||||
@staticmethod
|
||||
def set_index_param(index, name, val):
|
||||
index = faiss.downcast_index(index)
|
||||
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
Index.set_index_param(index.index, name, val)
|
||||
elif name == "efSearch":
|
||||
index.hnsw.efSearch
|
||||
index.hnsw.efSearch = int(val)
|
||||
elif name == "efConstruction":
|
||||
index.hnsw.efConstruction
|
||||
index.hnsw.efConstruction = int(val)
|
||||
elif name == "nprobe":
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
index_ivf.nprobe
|
||||
index_ivf.nprobe = int(val)
|
||||
elif name == "k_factor":
|
||||
index.k_factor
|
||||
index.k_factor = int(val)
|
||||
elif name == "parallel_mode":
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
index_ivf.parallel_mode
|
||||
index_ivf.parallel_mode = int(val)
|
||||
elif name == "noop":
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f"could not set param {name} on {index}")
|
||||
|
||||
def is_flat(self):
|
||||
codec = faiss.downcast_index(self.get_model())
|
||||
return isinstance(codec, faiss.IndexFlat)
|
||||
|
||||
def is_ivf(self):
|
||||
codec = self.get_model()
|
||||
return faiss.try_extract_index_ivf(codec) is not None
|
||||
|
||||
def is_pretransform(self):
|
||||
codec = self.get_model()
|
||||
if isinstance(codec, faiss.IndexRefine):
|
||||
codec = faiss.downcast_index(codec.base_index)
|
||||
return isinstance(codec, faiss.IndexPreTransform)
|
||||
|
||||
# index is a codec + database vectors
|
||||
# in other words: a trained Faiss index
|
||||
# that contains database vectors
|
||||
def get_index_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_index(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# codec is a trained model
|
||||
# in other words: a trained Faiss index
|
||||
# without any database vectors
|
||||
def get_codec_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_codec(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# model is an untrained Faiss index
|
||||
# it can be used for training (see codec)
|
||||
# or to inspect its structure
|
||||
def get_model_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def transform(self, vectors):
|
||||
transformed_vectors = DatasetDescriptor(
|
||||
tablename=f"{vectors.get_filename()}{self.get_codec_name()}transform.npy"
|
||||
)
|
||||
if not self.io.file_exist(transformed_vectors.tablename):
|
||||
codec = self.fetch_codec()
|
||||
assert isinstance(codec, faiss.IndexPreTransform)
|
||||
transform = faiss.downcast_VectorTransform(codec.chain.at(0))
|
||||
x = self.io.get_dataset(vectors)
|
||||
xt = transform.apply(x)
|
||||
self.io.write_nparray(xt, transformed_vectors.tablename)
|
||||
return transformed_vectors
|
||||
|
||||
def knn_search_quantizer(self, index, query_vectors, k):
|
||||
if self.is_pretransform():
|
||||
pretransform = self.get_pretransform()
|
||||
quantizer_query_vectors = pretransform.transform(query_vectors)
|
||||
else:
|
||||
pretransform = None
|
||||
quantizer_query_vectors = query_vectors
|
||||
|
||||
QD, QI, _, QP = self.get_quantizer(pretransform).knn_search(
|
||||
search_parameters=None,
|
||||
query_vectors=quantizer_query_vectors,
|
||||
k=k,
|
||||
)
|
||||
xqt = self.io.get_dataset(quantizer_query_vectors)
|
||||
return xqt, QD, QI, QP
|
||||
|
||||
def get_knn_search_name(
|
||||
self,
|
||||
search_parameters: Optional[Dict[str, int]],
|
||||
query_vectors: DatasetDescriptor,
|
||||
k: int,
|
||||
):
|
||||
name = self.get_index_name()
|
||||
name += Index.param_dict_to_name(search_parameters)
|
||||
name += query_vectors.get_filename("q")
|
||||
name += f"k_{k}."
|
||||
return name
|
||||
|
||||
def knn_search(
|
||||
self,
|
||||
search_parameters: Optional[Dict[str, int]],
|
||||
query_vectors: DatasetDescriptor,
|
||||
k: int,
|
||||
):
|
||||
logger.info("knn_seach: begin")
|
||||
filename = (
|
||||
self.get_knn_search_name(search_parameters, query_vectors, k)
|
||||
+ "zip"
|
||||
)
|
||||
if self.io.file_exist(filename):
|
||||
logger.info(f"Using cached results for {filename}")
|
||||
D, I, R, P = self.io.read_file(filename, ["D", "I", "R", "P"])
|
||||
else:
|
||||
xq = self.io.get_dataset(query_vectors)
|
||||
index = self.get_index()
|
||||
Index.set_index_param_dict(index, search_parameters)
|
||||
|
||||
if self.is_ivf():
|
||||
xqt, QD, QI, QP = self.knn_search_quantizer(
|
||||
index, query_vectors, search_parameters["nprobe"]
|
||||
)
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
if index_ivf.parallel_mode != 2:
|
||||
logger.info("Setting IVF parallel mode")
|
||||
index_ivf.parallel_mode = 2
|
||||
|
||||
(D, I), t, repeat = timer(
|
||||
"knn_search_preassigned",
|
||||
lambda: index_ivf.search_preassigned(xqt, k, QI, QD),
|
||||
)
|
||||
else:
|
||||
(D, I), t, _ = timer("knn_search", lambda: index.search(xq, k))
|
||||
if self.is_flat() or not hasattr(self, "database_vectors"): # TODO
|
||||
R = D
|
||||
else:
|
||||
xb = self.io.get_dataset(self.database_vectors)
|
||||
R = refine_distances_knn(D, I, xq, xb, self.metric_type)
|
||||
P = {
|
||||
"time": t,
|
||||
"index": self.get_index_name(),
|
||||
"codec": self.get_codec_name(),
|
||||
"factory": self.factory if hasattr(self, "factory") else "",
|
||||
"search_params": search_parameters,
|
||||
"k": k,
|
||||
}
|
||||
if self.is_ivf():
|
||||
stats = faiss.cvar.indexIVF_stats
|
||||
P |= {
|
||||
"quantizer": QP,
|
||||
"nq": int(stats.nq // repeat),
|
||||
"nlist": int(stats.nlist // repeat),
|
||||
"ndis": int(stats.ndis // repeat),
|
||||
"nheap_updates": int(stats.nheap_updates // repeat),
|
||||
"quantization_time": int(
|
||||
stats.quantization_time // repeat
|
||||
),
|
||||
"search_time": int(stats.search_time // repeat),
|
||||
}
|
||||
self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P])
|
||||
logger.info("knn_seach: end")
|
||||
return D, I, R, P
|
||||
|
||||
def range_search(
|
||||
self,
|
||||
search_parameters: Optional[Dict[str, int]],
|
||||
query_vectors: DatasetDescriptor,
|
||||
radius: Optional[float] = None,
|
||||
):
|
||||
logger.info("range_search: begin")
|
||||
filename = (
|
||||
self.get_range_search_name(
|
||||
search_parameters, query_vectors, radius
|
||||
)
|
||||
+ "zip"
|
||||
)
|
||||
if self.io.file_exist(filename):
|
||||
logger.info(f"Using cached results for {filename}")
|
||||
lims, D, I, R, P = self.io.read_file(
|
||||
filename, ["lims", "D", "I", "R", "P"]
|
||||
)
|
||||
else:
|
||||
xq = self.io.get_dataset(query_vectors)
|
||||
index = self.get_index()
|
||||
Index.set_index_param_dict(index, search_parameters)
|
||||
|
||||
if self.is_ivf():
|
||||
xqt, QD, QI, QP = self.knn_search_quantizer(
|
||||
index, query_vectors, search_parameters["nprobe"]
|
||||
)
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
if index_ivf.parallel_mode != 2:
|
||||
logger.info("Setting IVF parallel mode")
|
||||
index_ivf.parallel_mode = 2
|
||||
|
||||
(lims, D, I), t, repeat = timer(
|
||||
"range_search_preassigned",
|
||||
lambda: index_ivf.range_search_preassigned(
|
||||
xqt, radius, QI, QD
|
||||
),
|
||||
)
|
||||
else:
|
||||
(lims, D, I), t, _ = timer(
|
||||
"range_search", lambda: index.range_search(xq, radius)
|
||||
)
|
||||
if self.is_flat():
|
||||
R = D
|
||||
else:
|
||||
xb = self.io.get_dataset(self.database_vectors)
|
||||
R = refine_distances_range(
|
||||
lims, D, I, xq, xb, self.metric_type
|
||||
)
|
||||
P = {
|
||||
"time": t,
|
||||
"index": self.get_codec_name(),
|
||||
"codec": self.get_codec_name(),
|
||||
"search_params": search_parameters,
|
||||
"radius": radius,
|
||||
"count": len(I),
|
||||
}
|
||||
if self.is_ivf():
|
||||
stats = faiss.cvar.indexIVF_stats
|
||||
P |= {
|
||||
"quantizer": QP,
|
||||
"nq": int(stats.nq // repeat),
|
||||
"nlist": int(stats.nlist // repeat),
|
||||
"ndis": int(stats.ndis // repeat),
|
||||
"nheap_updates": int(stats.nheap_updates // repeat),
|
||||
"quantization_time": int(
|
||||
stats.quantization_time // repeat
|
||||
),
|
||||
"search_time": int(stats.search_time // repeat),
|
||||
}
|
||||
self.io.write_file(
|
||||
filename, ["lims", "D", "I", "R", "P"], [lims, D, I, R, P]
|
||||
)
|
||||
logger.info("range_seach: end")
|
||||
return lims, D, I, R, P
|
||||
|
||||
|
||||
# Common base for IndexFromCodec and IndexFromFactory,
|
||||
# but not for the sub-indices of codec-based indices
|
||||
# IndexFromQuantizer and IndexFromPreTransform, because
|
||||
# they share the configuration of their parent IndexFromCodec
|
||||
@dataclass
|
||||
class Index(IndexBase):
|
||||
d: int
|
||||
metric: str
|
||||
database_vectors: DatasetDescriptor
|
||||
construction_params: List[Dict[str, int]]
|
||||
search_params: Dict[str, int]
|
||||
|
||||
cached_codec_name: ClassVar[str] = None
|
||||
cached_codec: ClassVar[faiss.Index] = None
|
||||
cached_index_name: ClassVar[str] = None
|
||||
cached_index: ClassVar[faiss.Index] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.metric, str):
|
||||
if self.metric == "IP":
|
||||
self.metric_type = faiss.METRIC_INNER_PRODUCT
|
||||
elif self.metric == "L2":
|
||||
self.metric_type = faiss.METRIC_L2
|
||||
else:
|
||||
raise ValueError
|
||||
elif isinstance(self.metric, int):
|
||||
self.metric_type = self.metric
|
||||
if self.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
self.metric = "IP"
|
||||
elif self.metric_type == faiss.METRIC_L2:
|
||||
self.metric = "L2"
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def supports_range_search(self):
|
||||
codec = self.get_codec()
|
||||
return not type(codec) in [
|
||||
faiss.IndexHNSWFlat,
|
||||
faiss.IndexIVFFastScan,
|
||||
faiss.IndexRefine,
|
||||
faiss.IndexPQ,
|
||||
]
|
||||
|
||||
def fetch_codec(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def train(self):
|
||||
# get triggers a train, if necessary
|
||||
self.get_codec()
|
||||
|
||||
def get_codec(self):
|
||||
codec_name = self.get_codec_name()
|
||||
if Index.cached_codec_name != codec_name:
|
||||
Index.cached_codec = self.fetch_codec()
|
||||
Index.cached_codec_name = codec_name
|
||||
return Index.cached_codec
|
||||
|
||||
def get_index_name(self):
|
||||
name = self.get_codec_name()
|
||||
assert self.database_vectors is not None
|
||||
name += self.database_vectors.get_filename("xb")
|
||||
return name
|
||||
|
||||
def fetch_index(self):
|
||||
index = faiss.clone_index(self.get_codec())
|
||||
assert index.ntotal == 0
|
||||
logger.info("Adding vectors to index")
|
||||
xb = self.io.get_dataset(self.database_vectors)
|
||||
|
||||
if self.is_ivf():
|
||||
xbt, QD, QI, QP = self.knn_search_quantizer(
|
||||
index, self.database_vectors, 1
|
||||
)
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
if index_ivf.parallel_mode != 2:
|
||||
logger.info("Setting IVF parallel mode")
|
||||
index_ivf.parallel_mode = 2
|
||||
|
||||
_, t, _ = timer(
|
||||
"add_preassigned",
|
||||
lambda: add_preassigned(index_ivf, xbt, QI.ravel()),
|
||||
once=True,
|
||||
)
|
||||
else:
|
||||
_, t, _ = timer(
|
||||
"add",
|
||||
lambda: index.add(xb),
|
||||
once=True,
|
||||
)
|
||||
assert index.ntotal == xb.shape[0] or index_ivf.ntotal == xb.shape[0]
|
||||
logger.info("Added vectors to index")
|
||||
return index
|
||||
|
||||
def get_index(self):
|
||||
index_name = self.get_index_name()
|
||||
if Index.cached_index_name != index_name:
|
||||
Index.cached_index = self.fetch_index()
|
||||
Index.cached_index_name = index_name
|
||||
return Index.cached_index
|
||||
|
||||
def get_code_size(self):
|
||||
def get_index_code_size(index):
|
||||
index = faiss.downcast_index(index)
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
return get_index_code_size(index.index)
|
||||
elif isinstance(index, faiss.IndexHNSWFlat):
|
||||
return index.d * 4 # TODO
|
||||
elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]:
|
||||
return get_index_code_size(
|
||||
index.base_index
|
||||
) + get_index_code_size(index.refine_index)
|
||||
else:
|
||||
return index.code_size
|
||||
|
||||
codec = self.get_codec()
|
||||
return get_index_code_size(codec)
|
||||
|
||||
def get_operating_points(self):
|
||||
op = OperatingPointsWithRanges()
|
||||
|
||||
def add_range_or_val(name, range):
|
||||
op.add_range(
|
||||
name,
|
||||
[self.search_params[name]]
|
||||
if self.search_params and name in self.search_params
|
||||
else range,
|
||||
)
|
||||
|
||||
op.add_range("noop", [0])
|
||||
codec = faiss.downcast_index(self.get_codec())
|
||||
codec_ivf = faiss.try_extract_index_ivf(codec)
|
||||
if codec_ivf is not None:
|
||||
add_range_or_val(
|
||||
"nprobe",
|
||||
[
|
||||
2**i
|
||||
for i in range(12)
|
||||
if 2**i <= codec_ivf.nlist * 0.25
|
||||
],
|
||||
)
|
||||
if isinstance(codec, faiss.IndexRefine):
|
||||
add_range_or_val(
|
||||
"k_factor",
|
||||
[2**i for i in range(11)],
|
||||
)
|
||||
if isinstance(codec, faiss.IndexHNSWFlat):
|
||||
add_range_or_val(
|
||||
"efSearch",
|
||||
[2**i for i in range(3, 11)],
|
||||
)
|
||||
return op
|
||||
|
||||
def get_range_search_name(
|
||||
self,
|
||||
search_parameters: Optional[Dict[str, int]],
|
||||
query_vectors: DatasetDescriptor,
|
||||
radius: Optional[float] = None,
|
||||
):
|
||||
name = self.get_index_name()
|
||||
name += Index.param_dict_to_name(search_parameters)
|
||||
name += query_vectors.get_filename("q")
|
||||
if radius is not None:
|
||||
name += f"r_{int(radius * 1000)}."
|
||||
else:
|
||||
name += "r_auto."
|
||||
return name
|
||||
|
||||
|
||||
# IndexFromCodec, IndexFromQuantizer and IndexFromPreTransform
|
||||
# are used to wrap pre-trained Faiss indices (codecs)
|
||||
@dataclass
|
||||
class IndexFromCodec(Index):
|
||||
path: str
|
||||
bucket: Optional[str] = None
|
||||
|
||||
def get_quantizer(self):
|
||||
if not self.is_ivf():
|
||||
raise ValueError("Not an IVF index")
|
||||
quantizer = IndexFromQuantizer(self)
|
||||
quantizer.set_io(self.io)
|
||||
return quantizer
|
||||
|
||||
def get_pretransform(self):
|
||||
if not self.is_ivf():
|
||||
raise ValueError("Not an IVF index")
|
||||
quantizer = IndexFromPreTransform(self)
|
||||
quantizer.set_io(self.io)
|
||||
return quantizer
|
||||
|
||||
def get_codec_name(self):
|
||||
assert self.path is not None
|
||||
name = os.path.basename(self.path)
|
||||
name += Index.param_dict_list_to_name(self.construction_params)
|
||||
return name
|
||||
|
||||
def fetch_codec(self):
|
||||
codec = self.io.read_index(
|
||||
os.path.basename(self.path),
|
||||
self.bucket,
|
||||
os.path.dirname(self.path),
|
||||
)
|
||||
assert self.d == codec.d
|
||||
assert self.metric_type == codec.metric_type
|
||||
Index.set_index_param_dict_list(codec, self.construction_params)
|
||||
return codec
|
||||
|
||||
def get_model(self):
|
||||
return self.get_codec()
|
||||
|
||||
|
||||
class IndexFromQuantizer(IndexBase):
|
||||
ivf_index: Index
|
||||
|
||||
def __init__(self, ivf_index: Index):
|
||||
self.ivf_index = ivf_index
|
||||
super().__init__()
|
||||
|
||||
def get_codec_name(self):
|
||||
return self.get_index_name()
|
||||
|
||||
def get_codec(self):
|
||||
return self.get_index()
|
||||
|
||||
def get_index_name(self):
|
||||
ivf_codec_name = self.ivf_index.get_codec_name()
|
||||
return f"{ivf_codec_name}quantizer."
|
||||
|
||||
def get_index(self):
|
||||
ivf_codec = faiss.extract_index_ivf(self.ivf_index.get_codec())
|
||||
return ivf_codec.quantizer
|
||||
|
||||
|
||||
class IndexFromPreTransform(IndexBase):
|
||||
pre_transform_index: Index
|
||||
|
||||
def __init__(self, pre_transform_index: Index):
|
||||
self.pre_transform_index = pre_transform_index
|
||||
super().__init__()
|
||||
|
||||
def get_codec_name(self):
|
||||
pre_transform_codec_name = self.pre_transform_index.get_codec_name()
|
||||
return f"{pre_transform_codec_name}pretransform."
|
||||
|
||||
def get_codec(self):
|
||||
return self.get_codec()
|
||||
|
||||
|
||||
# IndexFromFactory is for creating and training indices from scratch
|
||||
@dataclass
|
||||
class IndexFromFactory(Index):
|
||||
factory: str
|
||||
training_vectors: DatasetDescriptor
|
||||
|
||||
def get_codec_name(self):
|
||||
assert self.factory is not None
|
||||
name = f"{self.factory.replace(',', '_')}."
|
||||
assert self.d is not None
|
||||
assert self.metric is not None
|
||||
name += f"d_{self.d}.{self.metric.upper()}."
|
||||
if self.factory != "Flat":
|
||||
assert self.training_vectors is not None
|
||||
name += self.training_vectors.get_filename("xt")
|
||||
name += Index.param_dict_list_to_name(self.construction_params)
|
||||
return name
|
||||
|
||||
def fetch_codec(self):
|
||||
codec_filename = self.get_codec_name() + "codec"
|
||||
if self.io.file_exist(codec_filename):
|
||||
codec = self.io.read_index(codec_filename)
|
||||
assert self.d == codec.d
|
||||
assert self.metric_type == codec.metric_type
|
||||
else:
|
||||
codec = self.assemble()
|
||||
if self.factory != "Flat":
|
||||
self.io.write_index(codec, codec_filename)
|
||||
return codec
|
||||
|
||||
def get_model(self):
|
||||
model = faiss.index_factory(self.d, self.factory, self.metric_type)
|
||||
Index.set_index_param_dict_list(model, self.construction_params)
|
||||
return model
|
||||
|
||||
def get_pretransform(self):
|
||||
model = faiss.index_factory(self.d, self.factory, self.metric_type)
|
||||
assert isinstance(model, faiss.IndexPreTransform)
|
||||
sub_index = faiss.downcast_index(model.index)
|
||||
if isinstance(sub_index, faiss.IndexFlat):
|
||||
return self
|
||||
# replace the sub-index with Flat
|
||||
codec = faiss.clone_index(model)
|
||||
codec.index = faiss.IndexFlat(codec.index.d, codec.index.metric_type)
|
||||
pretransform = IndexFromFactory(
|
||||
d=codec.d,
|
||||
metric=codec.metric_type,
|
||||
database_vectors=self.database_vectors,
|
||||
construction_params=self.construction_params,
|
||||
search_params=self.search_params,
|
||||
factory=reverse_index_factory(codec),
|
||||
training_vectors=self.training_vectors,
|
||||
)
|
||||
pretransform.set_io(self.io)
|
||||
return pretransform
|
||||
|
||||
def get_quantizer(self, pretransform=None):
|
||||
model = self.get_model()
|
||||
model_ivf = faiss.extract_index_ivf(model)
|
||||
assert isinstance(model_ivf, faiss.IndexIVF)
|
||||
assert ord(model_ivf.quantizer_trains_alone) in [0, 2]
|
||||
if pretransform is None:
|
||||
training_vectors = self.training_vectors
|
||||
else:
|
||||
training_vectors = pretransform.transform(self.training_vectors)
|
||||
centroids = self.k_means(training_vectors, model_ivf.nlist)
|
||||
quantizer = IndexFromFactory(
|
||||
d=model_ivf.quantizer.d,
|
||||
metric=model_ivf.quantizer.metric_type,
|
||||
database_vectors=centroids,
|
||||
construction_params=None, # self.construction_params[1:],
|
||||
search_params=None, # self.construction_params[0], # TODO: verify
|
||||
factory=reverse_index_factory(model_ivf.quantizer),
|
||||
training_vectors=centroids,
|
||||
)
|
||||
quantizer.set_io(self.io)
|
||||
return quantizer
|
||||
|
||||
def k_means(self, vectors, k):
|
||||
kmeans_vectors = DatasetDescriptor(
|
||||
tablename=f"{vectors.get_filename()}kmeans_{k}.npy"
|
||||
)
|
||||
if not self.io.file_exist(kmeans_vectors.tablename):
|
||||
x = self.io.get_dataset(vectors)
|
||||
kmeans = faiss.Kmeans(d=x.shape[1], k=k, gpu=True)
|
||||
kmeans.train(x)
|
||||
self.io.write_nparray(kmeans.centroids, kmeans_vectors.tablename)
|
||||
return kmeans_vectors
|
||||
|
||||
def assemble(self):
|
||||
model = self.get_model()
|
||||
codec = faiss.clone_index(model)
|
||||
if isinstance(model, faiss.IndexPreTransform):
|
||||
sub_index = faiss.downcast_index(model.index)
|
||||
if not isinstance(sub_index, faiss.IndexFlat):
|
||||
# replace the sub-index with Flat and fetch pre-trained
|
||||
pretransform = self.get_pretransform()
|
||||
codec = pretransform.fetch_codec()
|
||||
assert codec.is_trained
|
||||
transformed_training_vectors = pretransform.transform(
|
||||
self.training_vectors
|
||||
)
|
||||
transformed_database_vectors = pretransform.transform(
|
||||
self.database_vectors
|
||||
)
|
||||
# replace the Flat index with the required sub-index
|
||||
wrapper = IndexFromFactory(
|
||||
d=sub_index.d,
|
||||
metric=sub_index.metric_type,
|
||||
database_vectors=transformed_database_vectors,
|
||||
construction_params=self.construction_params,
|
||||
search_params=self.search_params,
|
||||
factory=reverse_index_factory(sub_index),
|
||||
training_vectors=transformed_training_vectors,
|
||||
)
|
||||
wrapper.set_io(self.io)
|
||||
codec.index = wrapper.fetch_codec()
|
||||
assert codec.index.is_trained
|
||||
elif isinstance(model, faiss.IndexIVF):
|
||||
# replace the quantizer
|
||||
quantizer = self.get_quantizer()
|
||||
replace_ivf_quantizer(codec, quantizer.fetch_index())
|
||||
assert codec.quantizer.is_trained
|
||||
assert codec.nlist == codec.quantizer.ntotal
|
||||
elif isinstance(model, faiss.IndexRefine) or isinstance(
|
||||
model, faiss.IndexRefineFlat
|
||||
):
|
||||
# replace base_index
|
||||
wrapper = IndexFromFactory(
|
||||
d=model.base_index.d,
|
||||
metric=model.base_index.metric_type,
|
||||
database_vectors=self.database_vectors,
|
||||
construction_params=self.construction_params,
|
||||
search_params=self.search_params,
|
||||
factory=reverse_index_factory(model.base_index),
|
||||
training_vectors=self.training_vectors,
|
||||
)
|
||||
wrapper.set_io(self.io)
|
||||
codec.base_index = wrapper.fetch_codec()
|
||||
assert codec.base_index.is_trained
|
||||
|
||||
xt = self.io.get_dataset(self.training_vectors)
|
||||
codec.train(xt)
|
||||
return codec
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from bench_fw.benchmark import Benchmark
|
||||
from bench_fw.benchmark_io import BenchmarkIO
|
||||
from bench_fw.descriptors import DatasetDescriptor, IndexDescriptor
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
benchmark = Benchmark(
|
||||
training_vectors=DatasetDescriptor(
|
||||
namespace="std_d", tablename="sift1M"
|
||||
),
|
||||
database_vectors=DatasetDescriptor(
|
||||
namespace="std_d", tablename="sift1M"
|
||||
),
|
||||
query_vectors=DatasetDescriptor(
|
||||
namespace="std_q", tablename="sift1M"
|
||||
),
|
||||
index_descs=[
|
||||
IndexDescriptor(
|
||||
factory=f"IVF{2 ** nlist},Flat",
|
||||
)
|
||||
for nlist in range(8, 15)
|
||||
],
|
||||
k=1,
|
||||
distance_metric="L2",
|
||||
)
|
||||
io = BenchmarkIO(
|
||||
path="/checkpoint",
|
||||
)
|
||||
benchmark.set_io(io)
|
||||
print(benchmark.benchmark("result.json"))
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from bench_fw.benchmark import Benchmark
|
||||
from bench_fw.benchmark_io import BenchmarkIO
|
||||
from bench_fw.descriptors import DatasetDescriptor, IndexDescriptor
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
benchmark = Benchmark(
|
||||
training_vectors=DatasetDescriptor(
|
||||
tablename="training.npy", num_vectors=200000
|
||||
),
|
||||
database_vectors=DatasetDescriptor(
|
||||
tablename="database.npy", num_vectors=200000
|
||||
),
|
||||
query_vectors=DatasetDescriptor(tablename="query.npy", num_vectors=2000),
|
||||
index_descs=[
|
||||
IndexDescriptor(
|
||||
factory="Flat",
|
||||
range_metrics={
|
||||
"weighted": [
|
||||
[0.1, 0.928],
|
||||
[0.2, 0.865],
|
||||
[0.3, 0.788],
|
||||
[0.4, 0.689],
|
||||
[0.5, 0.49],
|
||||
[0.6, 0.308],
|
||||
[0.7, 0.193],
|
||||
[0.8, 0.0],
|
||||
]
|
||||
},
|
||||
),
|
||||
IndexDescriptor(
|
||||
factory="OPQ32_128,IVF512,PQ32",
|
||||
),
|
||||
IndexDescriptor(
|
||||
factory="OPQ32_256,IVF512,PQ32",
|
||||
),
|
||||
IndexDescriptor(
|
||||
factory="HNSW32",
|
||||
construction_params=[
|
||||
{
|
||||
"efConstruction": 64,
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
k=10,
|
||||
distance_metric="L2",
|
||||
range_ref_index_desc="Flat",
|
||||
)
|
||||
io = BenchmarkIO(
|
||||
path="/checkpoint",
|
||||
)
|
||||
benchmark.set_io(io)
|
||||
print(benchmark.benchmark("result.json"))
|
|
@ -72,6 +72,9 @@ def get_code_size(d, indexkey):
|
|||
raise RuntimeError("cannot parse " + indexkey)
|
||||
|
||||
|
||||
def get_hnsw_M(index):
|
||||
return index.hnsw.cum_nneighbor_per_level.at(1) // 2
|
||||
|
||||
|
||||
def reverse_index_factory(index):
|
||||
"""
|
||||
|
@ -80,21 +83,47 @@ def reverse_index_factory(index):
|
|||
index = faiss.downcast_index(index)
|
||||
if isinstance(index, faiss.IndexFlat):
|
||||
return "Flat"
|
||||
if isinstance(index, faiss.IndexIVF):
|
||||
elif isinstance(index, faiss.IndexIVF):
|
||||
quantizer = faiss.downcast_index(index.quantizer)
|
||||
|
||||
if isinstance(quantizer, faiss.IndexFlat):
|
||||
prefix = "IVF%d" % index.nlist
|
||||
prefix = f"IVF{index.nlist}"
|
||||
elif isinstance(quantizer, faiss.MultiIndexQuantizer):
|
||||
prefix = "IMI%dx%d" % (quantizer.pq.M, quantizer.pq.nbit)
|
||||
prefix = f"IMI{quantizer.pq.M}x{quantizer.pq.nbits}"
|
||||
elif isinstance(quantizer, faiss.IndexHNSW):
|
||||
prefix = "IVF%d_HNSW%d" % (index.nlist, quantizer.hnsw.M)
|
||||
prefix = f"IVF{index.nlist}_HNSW{get_hnsw_M(quantizer)}"
|
||||
else:
|
||||
prefix = "IVF%d(%s)" % (index.nlist, reverse_index_factory(quantizer))
|
||||
prefix = f"IVF{index.nlist}({reverse_index_factory(quantizer)})"
|
||||
|
||||
if isinstance(index, faiss.IndexIVFFlat):
|
||||
return prefix + ",Flat"
|
||||
if isinstance(index, faiss.IndexIVFScalarQuantizer):
|
||||
return prefix + ",SQ8"
|
||||
if isinstance(index, faiss.IndexIVFPQ):
|
||||
return prefix + f",PQ{index.pq.M}x{index.pq.nbits}"
|
||||
|
||||
elif isinstance(index, faiss.IndexPreTransform):
|
||||
assert index.chain.size() == 1
|
||||
vt = faiss.downcast_VectorTransform(index.chain.at(0))
|
||||
if isinstance(vt, faiss.OPQMatrix):
|
||||
return f"OPQ{vt.M}_{vt.d_out},{reverse_index_factory(index.index)}"
|
||||
|
||||
elif isinstance(index, faiss.IndexHNSW):
|
||||
return f"HNSW{get_hnsw_M(index)}"
|
||||
|
||||
elif isinstance(index, faiss.IndexRefine):
|
||||
return f"{reverse_index_factory(index.base_index)},Refine({reverse_index_factory(index.refine_index)})"
|
||||
|
||||
elif isinstance(index, faiss.IndexPQFastScan):
|
||||
return f"PQ{index.pq.M}x{index.pq.nbits}fs"
|
||||
|
||||
elif isinstance(index, faiss.IndexScalarQuantizer):
|
||||
sqtypes = {
|
||||
faiss.ScalarQuantizer.QT_8bit: "8",
|
||||
faiss.ScalarQuantizer.QT_4bit: "4",
|
||||
faiss.ScalarQuantizer.QT_6bit: "6",
|
||||
faiss.ScalarQuantizer.QT_fp16: "fp16",
|
||||
}
|
||||
return f"SQ{sqtypes[index.sq.qtype]}"
|
||||
|
||||
raise NotImplementedError()
|
||||
|
|
Loading…
Reference in New Issue