faiss/benchs/bench_fw/benchmark.py
Gergely Szilvasy c3b9374984 bench_fw - fixes & nits for oss (#3102)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3102

Reviewed By: pemazare

Differential Revision: D50426528

Pulled By: algoriddle

fbshipit-source-id: 886960b8b522318967fc5ec305666871b496cae8
2023-10-20 07:53:56 -07:00

723 lines
26 KiB
Python

# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
from contextlib import contextmanager
import json
import logging
from dataclasses import dataclass
from multiprocessing.pool import ThreadPool
from operator import itemgetter
from statistics import median, mean
from time import perf_counter
from typing import Any, List, Optional
from .descriptors import DatasetDescriptor, IndexDescriptor
import faiss # @manual=//faiss/python:pyfaiss_gpu
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
knn_intersection_measure,
OperatingPointsWithRanges,
)
from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
add_preassigned,
)
import numpy as np
from scipy.optimize import curve_fit
logger = logging.getLogger(__name__)
@contextmanager
def timer(name) -> float:
logger.info(f"Measuring {name}")
t1 = t2 = perf_counter()
yield lambda: t2 - t1
t2 = perf_counter()
logger.info(f"Time for {name}: {t2 - t1:.3f} seconds")
def refine_distances_knn(
D: np.ndarray, I: np.ndarray, xq: np.ndarray, xb: np.ndarray, metric
):
return np.where(
I >= 0,
np.square(np.linalg.norm(xq[:, None] - xb[I], axis=2))
if metric == faiss.METRIC_L2
else np.einsum("qd,qkd->qk", xq, xb[I]),
D,
)
def refine_distances_range(
lims: np.ndarray,
D: np.ndarray,
I: np.ndarray,
xq: np.ndarray,
xb: np.ndarray,
metric,
):
with ThreadPool(32) as pool:
R = pool.map(
lambda i: (
np.sum(np.square(xq[i] - xb[I[lims[i]:lims[i + 1]]]), axis=1)
if metric == faiss.METRIC_L2
else np.tensordot(
xq[i], xb[I[lims[i]:lims[i + 1]]], axes=(0, 1)
)
)
if lims[i + 1] > lims[i]
else [],
range(len(lims) - 1),
)
return np.hstack(R)
def range_search_pr_curve(
dist_ann: np.ndarray, metric_score: np.ndarray, gt_rsm: float
):
assert dist_ann.shape == metric_score.shape
assert dist_ann.ndim == 1
sort_by_dist_ann = dist_ann.argsort()
dist_ann = dist_ann[sort_by_dist_ann]
metric_score = metric_score[sort_by_dist_ann]
cum_score = np.cumsum(metric_score)
precision = cum_score / np.arange(1, len(cum_score) + 1)
recall = cum_score / gt_rsm
unique_key = np.round(precision * 100) * 100 + np.round(recall * 100)
tbl = np.vstack(
[dist_ann, metric_score, cum_score, precision, recall, unique_key]
)
group_by_dist_max_cum_score = np.empty(len(dist_ann), bool)
group_by_dist_max_cum_score[-1] = True
group_by_dist_max_cum_score[:-1] = dist_ann[1:] != dist_ann[:-1]
tbl = tbl[:, group_by_dist_max_cum_score]
_, unique_key_idx = np.unique(tbl[5], return_index=True)
dist_ann, metric_score, cum_score, precision, recall, unique_key = tbl[
:, np.sort(unique_key_idx)
].tolist()
return {
"dist_ann": dist_ann,
"metric_score_sample": metric_score,
"cum_score": cum_score,
"precision": precision,
"recall": recall,
"unique_key": unique_key,
}
def set_index_parameter(index, name, val):
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexPreTransform):
set_index_parameter(index.index, name, val)
elif name.startswith("quantizer_"):
index_ivf = faiss.extract_index_ivf(index)
set_index_parameter(
index_ivf.quantizer, name[name.find("_") + 1:], val
)
elif name == "efSearch":
index.hnsw.efSearch
index.hnsw.efSearch = int(val)
elif name == "nprobe":
index_ivf = faiss.extract_index_ivf(index)
index_ivf.nprobe
index_ivf.nprobe = int(val)
elif name == "noop":
pass
else:
raise RuntimeError(f"could not set param {name} on {index}")
def optimizer(codec, search, cost_metric, perf_metric):
op = OperatingPointsWithRanges()
op.add_range("noop", [0])
codec_ivf = faiss.try_extract_index_ivf(codec)
if codec_ivf is not None:
op.add_range(
"nprobe",
[2**i for i in range(12) if 2**i < codec_ivf.nlist * 0.1],
)
totex = op.num_experiments()
rs = np.random.RandomState(123)
if totex > 1:
experiments = rs.permutation(totex - 2) + 1
experiments = [0, totex - 1] + list(experiments)
else:
experiments = [0]
print(f"total nb experiments {totex}, running {len(experiments)}")
for cno in experiments:
key = op.cno_to_key(cno)
parameters = op.get_parameters(key)
(max_perf, min_cost) = op.predict_bounds(key)
if not op.is_pareto_optimal(max_perf, min_cost):
logger.info(
f"{cno=:4d} {str(parameters):50}: SKIP, {max_perf=:.3f} {min_cost=:.3f}",
)
continue
logger.info(f"{cno=:4d} {str(parameters):50}: RUN")
cost, perf = search(
parameters,
cost_metric,
perf_metric,
)
logger.info(
f"{cno=:4d} {str(parameters):50}: DONE, {cost=:.3f} {perf=:.3f}"
)
op.add_operating_point(key, perf, cost)
def distance_ratio_measure(I, R, D_GT, metric):
sum_of_R = np.sum(np.where(I >= 0, R, 0))
sum_of_D_GT = np.sum(np.where(I >= 0, D_GT, 0))
if metric == faiss.METRIC_INNER_PRODUCT:
return (sum_of_R / sum_of_D_GT).item()
elif metric == faiss.METRIC_L2:
return (sum_of_D_GT / sum_of_R).item()
else:
raise RuntimeError(f"unknown metric {metric}")
# range_metric possible values:
#
# radius
# [0..radius) -> 1
# [radius..inf) -> 0
#
# [[radius1, score1], ...]
# [0..radius1) -> score1
# [radius1..radius2) -> score2
#
# [[radius1_from, radius1_to, score1], ...]
# [radius1_from, radius1_to) -> score1,
# [radius2_from, radius2_to) -> score2
def get_range_search_metric_function(range_metric, D, R):
if D is not None:
assert R is not None
assert D.shape == R.shape
if isinstance(range_metric, list):
aradius, ascore, aradius_from, aradius_to = [], [], [], []
radius_to = 0
for rsd in range_metric:
assert isinstance(rsd, list)
if len(rsd) == 3:
radius_from, radius_to, score = rsd
elif len(rsd) == 2:
radius_from = radius_to
radius_to, score = rsd
else:
raise AssertionError(f"invalid range definition {rsd}")
# radius_from and radius_to are compressed distances,
# we need to convert them to real embedding distances.
if D is not None:
sample_idxs = np.argwhere((D <= radius_to) & (D > radius_from))
assert len(sample_idxs) > 0
real_radius = np.mean(R[sample_idxs]).item()
else:
real_radius = mean([radius_from, radius_to])
logger.info(
f"range_search_metric_function {radius_from=} {radius_to=} {real_radius=} {score=}"
)
aradius.append(real_radius)
ascore.append(score)
aradius_from.append(radius_from)
aradius_to.append(radius_to)
def sigmoid(x, a, b, c):
return a / (1 + np.exp(b * x - c))
cutoff = max(aradius)
popt, _ = curve_fit(sigmoid, aradius, ascore, [1, 5, 5])
for r in np.arange(0, cutoff + 0.05, 0.05):
logger.info(
f"range_search_metric_function {r=} {sigmoid(r, *popt)=}"
)
assert isinstance(cutoff, float)
return (
cutoff,
lambda x: np.where(x < cutoff, sigmoid(x, *popt), 0),
popt.tolist(),
list(zip(aradius, ascore, aradius_from, aradius_to, strict=True))
)
else:
# Assuming that the range_metric is a float,
# so the range is [0..range_metric).
# D is the result of a range_search with a radius of range_metric,
# but both range_metric and D may be compressed distances.
# We approximate the real embedding distance as max(R).
if R is not None:
real_range = np.max(R).item()
else:
real_range = range_metric
logger.info(
f"range_search_metric_function {range_metric=} {real_range=}"
)
assert isinstance(real_range, float)
return real_range * 2, lambda x: np.where(x < real_range, 1, 0), [], []
@dataclass
class Benchmark:
training_vectors: Optional[DatasetDescriptor] = None
db_vectors: Optional[DatasetDescriptor] = None
query_vectors: Optional[DatasetDescriptor] = None
index_descs: Optional[List[IndexDescriptor]] = None
range_ref_index_desc: Optional[str] = None
k: Optional[int] = None
distance_metric: str = "METRIC_L2"
def __post_init__(self):
if self.distance_metric == "METRIC_INNER_PRODUCT":
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
elif self.distance_metric == "METRIC_L2":
self.distance_metric_type = faiss.METRIC_L2
else:
raise ValueError
self.cached_index_key = None
def set_io(self, benchmark_io):
self.io = benchmark_io
self.io.distance_metric = self.distance_metric
self.io.distance_metric_type = self.distance_metric_type
def get_index_desc(self, factory: str) -> Optional[IndexDescriptor]:
for desc in self.index_descs:
if desc.factory == factory:
return desc
return None
def get_index(self, index_desc: IndexDescriptor):
if self.cached_index_key != index_desc.factory:
xb = self.io.get_dataset(self.db_vectors)
index = faiss.clone_index(
self.io.get_codec(index_desc, xb.shape[1])
)
assert index.ntotal == 0
logger.info("Adding vectors to index")
index_ivf = faiss.try_extract_index_ivf(index)
if index_ivf is not None:
QD, QI, _, QP = self.knn_search(
index_desc,
parameters=None,
db_vectors=None,
query_vectors=self.db_vectors,
k=1,
index=index_ivf.quantizer,
level=1,
)
print(f"{QI.ravel().shape=}")
add_preassigned(index_ivf, xb, QI.ravel())
else:
index.add(xb)
assert index.ntotal == xb.shape[0]
logger.info("Added vectors to index")
self.cached_index_key = index_desc.factory
self.cached_index = index
return self.cached_index
def range_search_reference(self, index_desc, range_metric):
logger.info("range_search_reference: begin")
if isinstance(range_metric, list):
assert len(range_metric) > 0
ri = len(range_metric[0]) - 1
m_radius = (
max(rm[ri] for rm in range_metric)
if self.distance_metric_type == faiss.METRIC_L2
else min(rm[ri] for rm in range_metric)
)
else:
m_radius = range_metric
lims, D, I, R, P = self.range_search(
index_desc,
index_desc.parameters,
radius=m_radius,
)
flat = index_desc.factory == "Flat"
(
gt_radius,
range_search_metric_function,
coefficients,
coefficients_training_data,
) = get_range_search_metric_function(
range_metric,
D if not flat else None,
R if not flat else None,
)
logger.info("range_search_reference: end")
return gt_radius, range_search_metric_function, coefficients, coefficients_training_data
def estimate_range(self, index_desc, parameters, range_scoring_radius):
D, I, R, P = self.knn_search(
index_desc, parameters, self.db_vectors, self.query_vectors
)
samples = []
for i, j in np.argwhere(R < range_scoring_radius):
samples.append((R[i, j].item(), D[i, j].item()))
samples.sort(key=itemgetter(0))
return median(r for _, r in samples[-3:])
def range_search(
self,
index_desc: IndexDescriptor,
parameters: Optional[dict[str, int]],
radius: Optional[float] = None,
gt_radius: Optional[float] = None,
):
logger.info("range_search: begin")
flat = index_desc.factory == "Flat"
if radius is None:
assert gt_radius is not None
radius = (
gt_radius
if flat
else self.estimate_range(index_desc, parameters, gt_radius)
)
logger.info(f"Radius={radius}")
filename = self.io.get_filename_range_search(
factory=index_desc.factory,
parameters=parameters,
level=0,
db_vectors=self.db_vectors,
query_vectors=self.query_vectors,
r=radius,
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {index_desc.factory}")
lims, D, I, R, P = self.io.read_file(
filename, ["lims", "D", "I", "R", "P"]
)
else:
xq = self.io.get_dataset(self.query_vectors)
index = self.get_index(index_desc)
if parameters:
for name, val in parameters.items():
set_index_parameter(index, name, val)
index_ivf = faiss.try_extract_index_ivf(index)
if index_ivf is not None:
QD, QI, _, QP = self.knn_search(
index_desc,
parameters=None,
db_vectors=None,
query_vectors=self.query_vectors,
k=index.nprobe,
index=index_ivf.quantizer,
level=1,
)
# QD = QD[:, :index.nprobe]
# QI = QI[:, :index.nprobe]
faiss.cvar.indexIVF_stats.reset()
with timer("range_search_preassigned") as t:
lims, D, I = index.range_search_preassigned(xq, radius, QI, QD)
else:
with timer("range_search") as t:
lims, D, I = index.range_search(xq, radius)
if flat:
R = D
else:
xb = self.io.get_dataset(self.db_vectors)
R = refine_distances_range(
lims, D, I, xq, xb, self.distance_metric_type
)
P = {
"time": t(),
"radius": radius,
"count": lims[-1].item(),
"parameters": parameters,
"index": index_desc.factory,
}
if index_ivf is not None:
stats = faiss.cvar.indexIVF_stats
P |= {
"quantizer": QP,
"nq": stats.nq,
"nlist": stats.nlist,
"ndis": stats.ndis,
"nheap_updates": stats.nheap_updates,
"quantization_time": stats.quantization_time,
"search_time": stats.search_time,
}
self.io.write_file(
filename, ["lims", "D", "I", "R", "P"], [lims, D, I, R, P]
)
logger.info("range_seach: end")
return lims, D, I, R, P
def range_ground_truth(self, gt_radius, range_search_metric_function):
logger.info("range_ground_truth: begin")
flat_desc = self.get_index_desc("Flat")
lims, D, I, R, P = self.range_search(
flat_desc,
flat_desc.parameters,
radius=gt_radius,
)
gt_rsm = np.sum(range_search_metric_function(R)).item()
logger.info("range_ground_truth: end")
return gt_rsm
def range_search_benchmark(
self,
results: dict[str, Any],
index_desc: IndexDescriptor,
metric_key: str,
gt_radius: float,
range_search_metric_function,
gt_rsm: float,
):
logger.info(f"range_search_benchmark: begin {index_desc.factory=}")
xq = self.io.get_dataset(self.query_vectors)
(nq, d) = xq.shape
logger.info(
f"Searching {index_desc.factory} with {nq} vectors of dimension {d}"
)
codec = self.io.get_codec(index_desc, d)
faiss.omp_set_num_threads(16)
def experiment(parameters, cost_metric, perf_metric):
nonlocal results
key = self.io.get_filename_evaluation_name(
factory=index_desc.factory,
parameters=parameters,
level=0,
db_vectors=self.db_vectors,
query_vectors=self.query_vectors,
evaluation_name=metric_key,
)
if key in results["experiments"]:
metrics = results["experiments"][key]
else:
lims, D, I, R, P = self.range_search(
index_desc, parameters, gt_radius=gt_radius
)
range_search_metric = range_search_metric_function(R)
range_search_pr = range_search_pr_curve(
D, range_search_metric, gt_rsm
)
range_score_sum = np.sum(range_search_metric).item()
metrics = P | {
"range_score_sum": range_score_sum,
"range_score_max_recall": range_score_sum / gt_rsm,
"range_search_pr": range_search_pr,
}
results["experiments"][key] = metrics
return metrics[cost_metric], metrics[perf_metric]
for cost_metric in ["time"]:
for perf_metric in ["range_score_max_recall"]:
optimizer(
codec,
experiment,
cost_metric,
perf_metric,
)
logger.info("range_search_benchmark: end")
return results
def knn_search(
self,
index_desc: IndexDescriptor,
parameters: Optional[dict[str, int]],
db_vectors: Optional[DatasetDescriptor],
query_vectors: DatasetDescriptor,
k: Optional[int] = None,
index: Optional[faiss.Index] = None,
level: int = 0,
):
assert level >= 0
if level == 0:
assert index is None
assert db_vectors is not None
else:
assert index is not None # quantizer
assert db_vectors is None
logger.info("knn_seach: begin")
k = k if k is not None else self.k
flat = index_desc.factory == "Flat"
filename = self.io.get_filename_knn_search(
factory=index_desc.factory,
parameters=parameters,
level=level,
db_vectors=db_vectors,
query_vectors=query_vectors,
k=k,
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {index_desc.factory}")
D, I, R, P = self.io.read_file(filename, ["D", "I", "R", "P"])
else:
xq = self.io.get_dataset(query_vectors)
if index is None:
index = self.get_index(index_desc)
if parameters:
for name, val in parameters.items():
set_index_parameter(index, name, val)
index_ivf = faiss.try_extract_index_ivf(index)
if index_ivf is not None:
QD, QI, _, QP = self.knn_search(
index_desc,
parameters=None,
db_vectors=None,
query_vectors=query_vectors,
k=index.nprobe,
index=index_ivf.quantizer,
level=level + 1,
)
# QD = QD[:, :index.nprobe]
# QI = QI[:, :index.nprobe]
faiss.cvar.indexIVF_stats.reset()
with timer("knn search_preassigned") as t:
D, I = index.search_preassigned(xq, k, QI, QD)
else:
with timer("knn search") as t:
D, I = index.search(xq, k)
if flat or level > 0:
R = D
else:
xb = self.io.get_dataset(db_vectors)
R = refine_distances_knn(
D, I, xq, xb, self.distance_metric_type
)
P = {
"time": t(),
"parameters": parameters,
"index": index_desc.factory,
"level": level,
}
if index_ivf is not None:
stats = faiss.cvar.indexIVF_stats
P |= {
"quantizer": QP,
"nq": stats.nq,
"nlist": stats.nlist,
"ndis": stats.ndis,
"nheap_updates": stats.nheap_updates,
"quantization_time": stats.quantization_time,
"search_time": stats.search_time,
}
self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P])
logger.info("knn_seach: end")
return D, I, R, P
def knn_ground_truth(self):
logger.info("knn_ground_truth: begin")
flat_desc = self.get_index_desc("Flat")
self.gt_knn_D, self.gt_knn_I, _, _ = self.knn_search(
flat_desc,
flat_desc.parameters,
self.db_vectors,
self.query_vectors,
)
logger.info("knn_ground_truth: end")
def knn_search_benchmark(
self, results: dict[str, Any], index_desc: IndexDescriptor
):
logger.info(f"knn_search_benchmark: begin {index_desc.factory=}")
xq = self.io.get_dataset(self.query_vectors)
(nq, d) = xq.shape
logger.info(
f"Searching {index_desc.factory} with {nq} vectors of dimension {d}"
)
codec = self.io.get_codec(index_desc, d)
codec_ivf = faiss.try_extract_index_ivf(codec)
if codec_ivf is not None:
results["indices"][index_desc.factory] = {"nlist": codec_ivf.nlist}
faiss.omp_set_num_threads(16)
def experiment(parameters, cost_metric, perf_metric):
nonlocal results
key = self.io.get_filename_evaluation_name(
factory=index_desc.factory,
parameters=parameters,
level=0,
db_vectors=self.db_vectors,
query_vectors=self.query_vectors,
evaluation_name="knn",
)
if key in results["experiments"]:
metrics = results["experiments"][key]
else:
D, I, R, P = self.knn_search(
index_desc, parameters, self.db_vectors, self.query_vectors
)
metrics = P | {
"knn_intersection": knn_intersection_measure(
I, self.gt_knn_I
),
"distance_ratio": distance_ratio_measure(
I, R, self.gt_knn_D, self.distance_metric_type
),
}
results["experiments"][key] = metrics
return metrics[cost_metric], metrics[perf_metric]
for cost_metric in ["time"]:
for perf_metric in ["knn_intersection", "distance_ratio"]:
optimizer(
codec,
experiment,
cost_metric,
perf_metric,
)
logger.info("knn_search_benchmark: end")
return results
def benchmark(self) -> str:
logger.info("begin evaluate")
results = {"indices": {}, "experiments": {}}
if self.get_index_desc("Flat") is None:
self.index_descs.append(IndexDescriptor(factory="Flat"))
self.knn_ground_truth()
for index_desc in self.index_descs:
results = self.knn_search_benchmark(
results=results,
index_desc=index_desc,
)
if self.range_ref_index_desc is not None:
index_desc = self.get_index_desc(self.range_ref_index_desc)
if index_desc is None:
raise ValueError(
f"Unknown range index {self.range_ref_index_desc}"
)
if index_desc.range_metrics is None:
raise ValueError(
f"Range index {index_desc.factory} has no radius_score"
)
results["metrics"] = {}
for metric_key, range_metric in index_desc.range_metrics.items():
(
gt_radius,
range_search_metric_function,
coefficients,
coefficients_training_data,
) = self.range_search_reference(index_desc, range_metric)
results["metrics"][metric_key] = {
"coefficients": coefficients,
"training_data": coefficients_training_data,
}
gt_rsm = self.range_ground_truth(
gt_radius, range_search_metric_function
)
for index_desc in self.index_descs:
results = self.range_search_benchmark(
results=results,
index_desc=index_desc,
metric_key=metric_key,
gt_radius=gt_radius,
range_search_metric_function=range_search_metric_function,
gt_rsm=gt_rsm,
)
self.io.write_json(results, "result.json", overwrite=True)
logger.info("end evaluate")
return json.dumps(results)