benchmark refactor

Summary:
1. Support for index construction parameters outside of the factory string (arbitrary depth of quantizers).
2. Refactor that provides an index wrapper which is a prereq for the optimizer, which will generate indices from pre-optimized components (particularly quantizers)

Reviewed By: mdouze

Differential Revision: D51427452

fbshipit-source-id: 014d05dd798d856360f2546963e7cad64c2fcaeb
pull/3164/head
Gergely Szilvasy 2023-12-04 05:53:17 -08:00 committed by Facebook GitHub Bot
parent a5b03cb9f6
commit 9519a19f42
7 changed files with 1202 additions and 483 deletions

View File

@ -1,23 +1,20 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
import json
import logging
from dataclasses import dataclass
from multiprocessing.pool import ThreadPool
from operator import itemgetter
from statistics import median, mean
from time import perf_counter
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
from .index import Index, IndexFromCodec, IndexFromFactory
from .descriptors import DatasetDescriptor, IndexDescriptor
import faiss # @manual=//faiss/python:pyfaiss_gpu
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
knn_intersection_measure,
OperatingPointsWithRanges,
)
from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
add_preassigned,
)
import numpy as np
@ -27,56 +24,21 @@ from scipy.optimize import curve_fit
logger = logging.getLogger(__name__)
@contextmanager
def timer(name) -> float:
logger.info(f"Measuring {name}")
t1 = t2 = perf_counter()
yield lambda: t2 - t1
t2 = perf_counter()
logger.info(f"Time for {name}: {t2 - t1:.3f} seconds")
def refine_distances_knn(
D: np.ndarray, I: np.ndarray, xq: np.ndarray, xb: np.ndarray, metric
):
return np.where(
I >= 0,
np.square(np.linalg.norm(xq[:, None] - xb[I], axis=2))
if metric == faiss.METRIC_L2
else np.einsum("qd,qkd->qk", xq, xb[I]),
D,
)
def refine_distances_range(
lims: np.ndarray,
D: np.ndarray,
I: np.ndarray,
xq: np.ndarray,
xb: np.ndarray,
metric,
):
with ThreadPool(32) as pool:
R = pool.map(
lambda i: (
np.sum(np.square(xq[i] - xb[I[lims[i]:lims[i + 1]]]), axis=1)
if metric == faiss.METRIC_L2
else np.tensordot(
xq[i], xb[I[lims[i]:lims[i + 1]]], axes=(0, 1)
)
)
if lims[i + 1] > lims[i]
else [],
range(len(lims) - 1),
)
return np.hstack(R)
def range_search_pr_curve(
dist_ann: np.ndarray, metric_score: np.ndarray, gt_rsm: float
):
assert dist_ann.shape == metric_score.shape
assert dist_ann.ndim == 1
l = len(dist_ann)
if l == 0:
return {
"dist_ann": [],
"metric_score_sample": [],
"cum_score": [],
"precision": [],
"recall": [],
"unique_key": [],
}
sort_by_dist_ann = dist_ann.argsort()
dist_ann = dist_ann[sort_by_dist_ann]
metric_score = metric_score[sort_by_dist_ann]
@ -87,7 +49,7 @@ def range_search_pr_curve(
tbl = np.vstack(
[dist_ann, metric_score, cum_score, precision, recall, unique_key]
)
group_by_dist_max_cum_score = np.empty(len(dist_ann), bool)
group_by_dist_max_cum_score = np.empty(l, bool)
group_by_dist_max_cum_score[-1] = True
group_by_dist_max_cum_score[:-1] = dist_ann[1:] != dist_ann[:-1]
tbl = tbl[:, group_by_dist_max_cum_score]
@ -105,39 +67,7 @@ def range_search_pr_curve(
}
def set_index_parameter(index, name, val):
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexPreTransform):
set_index_parameter(index.index, name, val)
elif name.startswith("quantizer_"):
index_ivf = faiss.extract_index_ivf(index)
set_index_parameter(
index_ivf.quantizer, name[name.find("_") + 1:], val
)
elif name == "efSearch":
index.hnsw.efSearch
index.hnsw.efSearch = int(val)
elif name == "nprobe":
index_ivf = faiss.extract_index_ivf(index)
index_ivf.nprobe
index_ivf.nprobe = int(val)
elif name == "noop":
pass
else:
raise RuntimeError(f"could not set param {name} on {index}")
def optimizer(codec, search, cost_metric, perf_metric):
op = OperatingPointsWithRanges()
op.add_range("noop", [0])
codec_ivf = faiss.try_extract_index_ivf(codec)
if codec_ivf is not None:
op.add_range(
"nprobe",
[2**i for i in range(12) if 2**i < codec_ivf.nlist * 0.1],
)
def optimizer(op, search, cost_metric, perf_metric):
totex = op.num_experiments()
rs = np.random.RandomState(123)
if totex > 1:
@ -243,7 +173,7 @@ def get_range_search_metric_function(range_metric, D, R):
cutoff,
lambda x: np.where(x < cutoff, sigmoid(x, *popt), 0),
popt.tolist(),
list(zip(aradius, ascore, aradius_from, aradius_to, strict=True))
list(zip(aradius, ascore, aradius_from, aradius_to, strict=True)),
)
else:
# Assuming that the range_metric is a float,
@ -265,21 +195,20 @@ def get_range_search_metric_function(range_metric, D, R):
@dataclass
class Benchmark:
training_vectors: Optional[DatasetDescriptor] = None
db_vectors: Optional[DatasetDescriptor] = None
database_vectors: Optional[DatasetDescriptor] = None
query_vectors: Optional[DatasetDescriptor] = None
index_descs: Optional[List[IndexDescriptor]] = None
range_ref_index_desc: Optional[str] = None
k: Optional[int] = None
distance_metric: str = "METRIC_L2"
distance_metric: str = "L2"
def __post_init__(self):
if self.distance_metric == "METRIC_INNER_PRODUCT":
if self.distance_metric == "IP":
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
elif self.distance_metric == "METRIC_L2":
elif self.distance_metric == "L2":
self.distance_metric_type = faiss.METRIC_L2
else:
raise ValueError
self.cached_index_key = None
def set_io(self, benchmark_io):
self.io = benchmark_io
@ -292,54 +221,24 @@ class Benchmark:
return desc
return None
def get_index(self, index_desc: IndexDescriptor):
if self.cached_index_key != index_desc.factory:
xb = self.io.get_dataset(self.db_vectors)
index = faiss.clone_index(
self.io.get_codec(index_desc, xb.shape[1])
)
assert index.ntotal == 0
logger.info("Adding vectors to index")
index_ivf = faiss.try_extract_index_ivf(index)
if index_ivf is not None:
QD, QI, _, QP = self.knn_search(
index_desc,
parameters=None,
db_vectors=None,
query_vectors=self.db_vectors,
k=1,
index=index_ivf.quantizer,
level=1,
)
print(f"{QI.ravel().shape=}")
add_preassigned(index_ivf, xb, QI.ravel())
else:
index.add(xb)
assert index.ntotal == xb.shape[0]
logger.info("Added vectors to index")
self.cached_index_key = index_desc.factory
self.cached_index = index
return self.cached_index
def range_search_reference(self, index_desc, range_metric):
def range_search_reference(self, index, parameters, range_metric):
logger.info("range_search_reference: begin")
if isinstance(range_metric, list):
assert len(range_metric) > 0
ri = len(range_metric[0]) - 1
m_radius = (
max(rm[ri] for rm in range_metric)
max(rm[-2] for rm in range_metric)
if self.distance_metric_type == faiss.METRIC_L2
else min(rm[ri] for rm in range_metric)
else min(rm[-2] for rm in range_metric)
)
else:
m_radius = range_metric
lims, D, I, R, P = self.range_search(
index_desc,
index_desc.parameters,
index,
parameters,
radius=m_radius,
)
flat = index_desc.factory == "Flat"
flat = index.factory == "Flat"
(
gt_radius,
range_search_metric_function,
@ -351,111 +250,61 @@ class Benchmark:
R if not flat else None,
)
logger.info("range_search_reference: end")
return gt_radius, range_search_metric_function, coefficients, coefficients_training_data
return (
gt_radius,
range_search_metric_function,
coefficients,
coefficients_training_data,
)
def estimate_range(self, index_desc, parameters, range_scoring_radius):
D, I, R, P = self.knn_search(
index_desc, parameters, self.db_vectors, self.query_vectors
def estimate_range(self, index, parameters, range_scoring_radius):
D, I, R, P = index.knn_search(
parameters,
self.query_vectors,
self.k,
)
samples = []
for i, j in np.argwhere(R < range_scoring_radius):
samples.append((R[i, j].item(), D[i, j].item()))
samples.sort(key=itemgetter(0))
return median(r for _, r in samples[-3:])
if len(samples) > 0: # estimate range
samples.sort(key=itemgetter(0))
return median(r for _, r in samples[-3:])
else: # ensure at least one result
i, j = np.argwhere(R.min() == R)[0]
return D[i, j].item()
def range_search(
self,
index_desc: IndexDescriptor,
parameters: Optional[dict[str, int]],
index: Index,
search_parameters: Optional[Dict[str, int]],
radius: Optional[float] = None,
gt_radius: Optional[float] = None,
):
logger.info("range_search: begin")
flat = index_desc.factory == "Flat"
if radius is None:
assert gt_radius is not None
radius = (
gt_radius
if flat
else self.estimate_range(index_desc, parameters, gt_radius)
if index.is_flat()
else self.estimate_range(
index,
search_parameters,
gt_radius,
)
)
logger.info(f"Radius={radius}")
filename = self.io.get_filename_range_search(
factory=index_desc.factory,
parameters=parameters,
level=0,
db_vectors=self.db_vectors,
return index.range_search(
search_parameters=search_parameters,
query_vectors=self.query_vectors,
r=radius,
radius=radius,
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {index_desc.factory}")
lims, D, I, R, P = self.io.read_file(
filename, ["lims", "D", "I", "R", "P"]
)
else:
xq = self.io.get_dataset(self.query_vectors)
index = self.get_index(index_desc)
if parameters:
for name, val in parameters.items():
set_index_parameter(index, name, val)
index_ivf = faiss.try_extract_index_ivf(index)
if index_ivf is not None:
QD, QI, _, QP = self.knn_search(
index_desc,
parameters=None,
db_vectors=None,
query_vectors=self.query_vectors,
k=index.nprobe,
index=index_ivf.quantizer,
level=1,
)
# QD = QD[:, :index.nprobe]
# QI = QI[:, :index.nprobe]
faiss.cvar.indexIVF_stats.reset()
with timer("range_search_preassigned") as t:
lims, D, I = index.range_search_preassigned(xq, radius, QI, QD)
else:
with timer("range_search") as t:
lims, D, I = index.range_search(xq, radius)
if flat:
R = D
else:
xb = self.io.get_dataset(self.db_vectors)
R = refine_distances_range(
lims, D, I, xq, xb, self.distance_metric_type
)
P = {
"time": t(),
"radius": radius,
"count": lims[-1].item(),
"parameters": parameters,
"index": index_desc.factory,
}
if index_ivf is not None:
stats = faiss.cvar.indexIVF_stats
P |= {
"quantizer": QP,
"nq": stats.nq,
"nlist": stats.nlist,
"ndis": stats.ndis,
"nheap_updates": stats.nheap_updates,
"quantization_time": stats.quantization_time,
"search_time": stats.search_time,
}
self.io.write_file(
filename, ["lims", "D", "I", "R", "P"], [lims, D, I, R, P]
)
logger.info("range_seach: end")
return lims, D, I, R, P
def range_ground_truth(self, gt_radius, range_search_metric_function):
logger.info("range_ground_truth: begin")
flat_desc = self.get_index_desc("Flat")
lims, D, I, R, P = self.range_search(
flat_desc,
flat_desc.parameters,
flat_desc.index,
search_parameters=None,
radius=gt_radius,
)
gt_rsm = np.sum(range_search_metric_function(R)).item()
@ -464,37 +313,32 @@ class Benchmark:
def range_search_benchmark(
self,
results: dict[str, Any],
index_desc: IndexDescriptor,
results: Dict[str, Any],
index: Index,
metric_key: str,
radius: float,
gt_radius: float,
range_search_metric_function,
gt_rsm: float,
):
logger.info(f"range_search_benchmark: begin {index_desc.factory=}")
xq = self.io.get_dataset(self.query_vectors)
(nq, d) = xq.shape
logger.info(
f"Searching {index_desc.factory} with {nq} vectors of dimension {d}"
)
codec = self.io.get_codec(index_desc, d)
faiss.omp_set_num_threads(16)
logger.info(f"range_search_benchmark: begin {index.get_index_name()}")
def experiment(parameters, cost_metric, perf_metric):
nonlocal results
key = self.io.get_filename_evaluation_name(
factory=index_desc.factory,
parameters=parameters,
level=0,
db_vectors=self.db_vectors,
key = index.get_range_search_name(
search_parameters=parameters,
query_vectors=self.query_vectors,
evaluation_name=metric_key,
radius=radius,
)
key += metric_key
if key in results["experiments"]:
metrics = results["experiments"][key]
else:
lims, D, I, R, P = self.range_search(
index_desc, parameters, gt_radius=gt_radius
index,
parameters,
radius=radius,
gt_radius=gt_radius,
)
range_search_metric = range_search_metric_function(R)
range_search_pr = range_search_pr_curve(
@ -511,8 +355,9 @@ class Benchmark:
for cost_metric in ["time"]:
for perf_metric in ["range_score_max_recall"]:
op = index.get_operating_points()
optimizer(
codec,
op,
experiment,
cost_metric,
perf_metric,
@ -520,134 +365,33 @@ class Benchmark:
logger.info("range_search_benchmark: end")
return results
def knn_search(
self,
index_desc: IndexDescriptor,
parameters: Optional[dict[str, int]],
db_vectors: Optional[DatasetDescriptor],
query_vectors: DatasetDescriptor,
k: Optional[int] = None,
index: Optional[faiss.Index] = None,
level: int = 0,
):
assert level >= 0
if level == 0:
assert index is None
assert db_vectors is not None
else:
assert index is not None # quantizer
assert db_vectors is None
logger.info("knn_seach: begin")
k = k if k is not None else self.k
flat = index_desc.factory == "Flat"
filename = self.io.get_filename_knn_search(
factory=index_desc.factory,
parameters=parameters,
level=level,
db_vectors=db_vectors,
query_vectors=query_vectors,
k=k,
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {index_desc.factory}")
D, I, R, P = self.io.read_file(filename, ["D", "I", "R", "P"])
else:
xq = self.io.get_dataset(query_vectors)
if index is None:
index = self.get_index(index_desc)
if parameters:
for name, val in parameters.items():
set_index_parameter(index, name, val)
index_ivf = faiss.try_extract_index_ivf(index)
if index_ivf is not None:
QD, QI, _, QP = self.knn_search(
index_desc,
parameters=None,
db_vectors=None,
query_vectors=query_vectors,
k=index.nprobe,
index=index_ivf.quantizer,
level=level + 1,
)
# QD = QD[:, :index.nprobe]
# QI = QI[:, :index.nprobe]
faiss.cvar.indexIVF_stats.reset()
with timer("knn search_preassigned") as t:
D, I = index.search_preassigned(xq, k, QI, QD)
else:
with timer("knn search") as t:
D, I = index.search(xq, k)
if flat or level > 0:
R = D
else:
xb = self.io.get_dataset(db_vectors)
R = refine_distances_knn(
D, I, xq, xb, self.distance_metric_type
)
P = {
"time": t(),
"parameters": parameters,
"index": index_desc.factory,
"level": level,
}
if index_ivf is not None:
stats = faiss.cvar.indexIVF_stats
P |= {
"quantizer": QP,
"nq": stats.nq,
"nlist": stats.nlist,
"ndis": stats.ndis,
"nheap_updates": stats.nheap_updates,
"quantization_time": stats.quantization_time,
"search_time": stats.search_time,
}
self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P])
logger.info("knn_seach: end")
return D, I, R, P
def knn_ground_truth(self):
logger.info("knn_ground_truth: begin")
flat_desc = self.get_index_desc("Flat")
self.gt_knn_D, self.gt_knn_I, _, _ = self.knn_search(
flat_desc,
flat_desc.parameters,
self.db_vectors,
self.query_vectors,
self.gt_knn_D, self.gt_knn_I, _, _ = flat_desc.index.knn_search(
search_parameters=None,
query_vectors=self.query_vectors,
k=self.k,
)
logger.info("knn_ground_truth: end")
def knn_search_benchmark(
self, results: dict[str, Any], index_desc: IndexDescriptor
):
logger.info(f"knn_search_benchmark: begin {index_desc.factory=}")
xq = self.io.get_dataset(self.query_vectors)
(nq, d) = xq.shape
logger.info(
f"Searching {index_desc.factory} with {nq} vectors of dimension {d}"
)
codec = self.io.get_codec(index_desc, d)
codec_ivf = faiss.try_extract_index_ivf(codec)
if codec_ivf is not None:
results["indices"][index_desc.factory] = {"nlist": codec_ivf.nlist}
faiss.omp_set_num_threads(16)
def knn_search_benchmark(self, results: Dict[str, Any], index: Index):
index_name = index.get_index_name()
logger.info(f"knn_search_benchmark: begin {index_name}")
def experiment(parameters, cost_metric, perf_metric):
nonlocal results
key = self.io.get_filename_evaluation_name(
factory=index_desc.factory,
parameters=parameters,
level=0,
db_vectors=self.db_vectors,
query_vectors=self.query_vectors,
evaluation_name="knn",
key = index.get_knn_search_name(
parameters,
self.query_vectors,
self.k,
)
key += "knn"
if key in results["experiments"]:
metrics = results["experiments"][key]
else:
D, I, R, P = self.knn_search(
index_desc, parameters, self.db_vectors, self.query_vectors
D, I, R, P = index.knn_search(
parameters, self.query_vectors, self.k
)
metrics = P | {
"knn_intersection": knn_intersection_measure(
@ -662,8 +406,9 @@ class Benchmark:
for cost_metric in ["time"]:
for perf_metric in ["knn_intersection", "distance_ratio"]:
op = index.get_operating_points()
optimizer(
codec,
op,
experiment,
cost_metric,
perf_metric,
@ -671,18 +416,61 @@ class Benchmark:
logger.info("knn_search_benchmark: end")
return results
def benchmark(self) -> str:
logger.info("begin evaluate")
results = {"indices": {}, "experiments": {}}
def train(self, results):
xq = self.io.get_dataset(self.query_vectors)
self.d = xq.shape[1]
if self.get_index_desc("Flat") is None:
self.index_descs.append(IndexDescriptor(factory="Flat"))
for index_desc in self.index_descs:
if index_desc.factory is not None:
index = IndexFromFactory(
d=self.d,
metric=self.distance_metric,
database_vectors=self.database_vectors,
search_params=index_desc.search_params,
construction_params=index_desc.construction_params,
factory=index_desc.factory,
training_vectors=self.training_vectors,
)
index.set_io(self.io)
index.train()
index_desc.index = index
results["indices"][index.get_codec_name()] = {
"code_size": index.get_code_size()
}
else:
index = IndexFromCodec(
d=self.d,
metric=self.distance_metric,
database_vectors=self.database_vectors,
search_params=index_desc.search_params,
construction_params=index_desc.construction_params,
path=index_desc.path,
bucket=index_desc.bucket,
)
index.set_io(self.io)
index_desc.index = index
results["indices"][index.get_codec_name()] = {
"code_size": index.get_code_size()
}
return results
def benchmark(self, result_file=None):
logger.info("begin evaluate")
faiss.omp_set_num_threads(24)
results = {"indices": {}, "experiments": {}}
results = self.train(results)
# knn search
self.knn_ground_truth()
for index_desc in self.index_descs:
results = self.knn_search_benchmark(
results=results,
index_desc=index_desc,
index=index_desc.index,
)
# range search
if self.range_ref_index_desc is not None:
index_desc = self.get_index_desc(self.range_ref_index_desc)
if index_desc is None:
@ -700,7 +488,9 @@ class Benchmark:
range_search_metric_function,
coefficients,
coefficients_training_data,
) = self.range_search_reference(index_desc, range_metric)
) = self.range_search_reference(
index_desc.index, index_desc.search_params, range_metric
)
results["metrics"][metric_key] = {
"coefficients": coefficients,
"training_data": coefficients_training_data,
@ -709,14 +499,18 @@ class Benchmark:
gt_radius, range_search_metric_function
)
for index_desc in self.index_descs:
if not index_desc.index.supports_range_search():
continue
results = self.range_search_benchmark(
results=results,
index_desc=index_desc,
index=index_desc.index,
metric_key=metric_key,
radius=index_desc.radius,
gt_radius=gt_radius,
range_search_metric_function=range_search_metric_function,
gt_rsm=gt_rsm,
)
self.io.write_json(results, "result.json", overwrite=True)
if result_file is not None:
self.io.write_json(results, result_file, overwrite=True)
logger.info("end evaluate")
return json.dumps(results)
return results

View File

@ -1,7 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import hashlib
import io
import json
import logging
import os
import pickle
from dataclasses import dataclass
from typing import Any, List, Optional
from zipfile import ZipFile
@ -9,115 +16,45 @@ from zipfile import ZipFile
import faiss # @manual=//faiss/python:pyfaiss_gpu
import numpy as np
from .descriptors import DatasetDescriptor, IndexDescriptor
from faiss.contrib.datasets import ( # @manual=//faiss/contrib:faiss_contrib_gpu
dataset_from_name,
)
logger = logging.getLogger(__name__)
# merge RCQ coarse quantizer and ITQ encoder to one Faiss index
def merge_rcq_itq(
# pyre-ignore[11]: `faiss.ResidualCoarseQuantizer` is not defined as a type
rcq_coarse_quantizer: faiss.ResidualCoarseQuantizer,
itq_encoder: faiss.IndexPreTransform,
# pyre-ignore[11]: `faiss.IndexIVFSpectralHash` is not defined as a type.
) -> faiss.IndexIVFSpectralHash:
# pyre-ignore[16]: `faiss` has no attribute `IndexIVFSpectralHash`.
index = faiss.IndexIVFSpectralHash(
rcq_coarse_quantizer,
rcq_coarse_quantizer.d,
rcq_coarse_quantizer.ntotal,
itq_encoder.sa_code_size() * 8,
1000000, # larger than the magnitude of the vectors
)
index.replace_vt(itq_encoder)
return index
@dataclass
class BenchmarkIO:
path: str
def __post_init__(self):
self.cached_ds = {}
self.cached_codec_key = None
def get_filename_search(
self,
factory: str,
parameters: Optional[dict[str, int]],
level: int,
db_vectors: DatasetDescriptor,
query_vectors: DatasetDescriptor,
k: Optional[int] = None,
r: Optional[float] = None,
evaluation_name: Optional[str] = None,
):
assert factory is not None
assert level is not None
assert self.distance_metric is not None
assert query_vectors is not None
assert self.distance_metric is not None
filename = f"{factory.lower().replace(',', '_')}."
if level > 0:
filename += f"l_{level}."
if db_vectors is not None:
filename += db_vectors.get_filename("d")
filename += query_vectors.get_filename("q")
filename += self.distance_metric.upper() + "."
if k is not None:
filename += f"k_{k}."
if r is not None:
filename += f"r_{int(r * 1000)}."
if parameters is not None:
for name, val in parameters.items():
if name != "noop":
filename += f"{name}_{val}."
if evaluation_name is None:
filename += "zip"
else:
filename += evaluation_name
return filename
def get_filename_knn_search(
self,
factory: str,
parameters: Optional[dict[str, int]],
level: int,
db_vectors: DatasetDescriptor,
query_vectors: DatasetDescriptor,
k: int,
):
assert k is not None
return self.get_filename_search(
factory=factory,
parameters=parameters,
level=level,
db_vectors=db_vectors,
query_vectors=query_vectors,
k=k,
)
def get_filename_range_search(
self,
factory: str,
parameters: Optional[dict[str, int]],
level: int,
db_vectors: DatasetDescriptor,
query_vectors: DatasetDescriptor,
r: float,
):
assert r is not None
return self.get_filename_search(
factory=factory,
parameters=parameters,
level=level,
db_vectors=db_vectors,
query_vectors=query_vectors,
r=r,
)
def get_filename_evaluation_name(
self,
factory: str,
parameters: Optional[dict[str, int]],
level: int,
db_vectors: DatasetDescriptor,
query_vectors: DatasetDescriptor,
evaluation_name: str,
):
assert evaluation_name is not None
return self.get_filename_search(
factory=factory,
parameters=parameters,
level=level,
db_vectors=db_vectors,
query_vectors=query_vectors,
evaluation_name=evaluation_name,
)
def get_local_filename(self, filename):
if len(filename) > 184:
fn, ext = os.path.splitext(filename)
filename = (
fn[:184] + hashlib.sha256(filename.encode()).hexdigest() + ext
)
return os.path.join(self.path, filename)
def download_file_from_blobstore(
@ -143,22 +80,6 @@ class BenchmarkIO:
logger.info(f"{filename} {exists=}")
return exists
def get_codec(self, index_desc: IndexDescriptor, d: int):
if index_desc.factory == "Flat":
return faiss.IndexFlat(d, self.distance_metric_type)
else:
if self.cached_codec_key != index_desc.factory:
codec = faiss.read_index(
self.get_local_filename(index_desc.path)
)
assert (
codec.metric_type == self.distance_metric_type
), f"{codec.metric_type=} != {self.distance_metric_type=}"
logger.info(f"Loaded codec from {index_desc.path}")
self.cached_codec_key = index_desc.factory
self.cached_codec = codec
return self.cached_codec
def read_file(self, filename: str, keys: List[str]):
fn = self.download_file_from_blobstore(filename)
logger.info(f"Loading file {fn}")
@ -196,19 +117,50 @@ class BenchmarkIO:
self.upload_file_to_blobstore(filename, overwrite=overwrite)
def get_dataset(self, dataset):
if dataset not in self.cached_ds:
self.cached_ds[dataset] = self.read_nparray(
os.path.join(self.path, dataset.tablename)
)
if dataset.namespace is not None and dataset.namespace[:4] == "std_":
if dataset.tablename not in self.cached_ds:
self.cached_ds[dataset.tablename] = dataset_from_name(
dataset.tablename,
)
p = dataset.namespace[4]
if p == "t":
return self.cached_ds[dataset.tablename].get_train()
elif p == "d":
return self.cached_ds[dataset.tablename].get_database()
elif p == "q":
return self.cached_ds[dataset.tablename].get_queries()
else:
raise ValueError
elif dataset not in self.cached_ds:
if dataset.namespace == "syn":
d, seed = dataset.tablename.split("_")
d = int(d)
seed = int(seed)
n = dataset.num_vectors
# based on faiss.contrib.datasets.SyntheticDataset
d1 = 10
rs = np.random.RandomState(seed)
x = rs.normal(size=(n, d1))
x = np.dot(x, rs.rand(d1, d))
x = x * (rs.rand(d) * 4 + 0.1)
x = np.sin(x)
x = x.astype(np.float32)
self.cached_ds[dataset] = x
else:
self.cached_ds[dataset] = self.read_nparray(
os.path.join(self.path, dataset.tablename),
mmap_mode="r",
)[: dataset.num_vectors].copy()
return self.cached_ds[dataset]
def read_nparray(
self,
filename: str,
mmap_mode: Optional[str] = None,
):
fn = self.download_file_from_blobstore(filename)
logger.info(f"Loading nparray from {fn}")
nparray = np.load(fn)
nparray = np.load(fn, mmap_mode=mmap_mode)
logger.info(f"Loaded nparray {nparray.shape} from {fn}")
return nparray
@ -244,3 +196,32 @@ class BenchmarkIO:
with open(fn, "w") as fp:
json.dump(json_dict, fp)
self.upload_file_to_blobstore(filename, overwrite=overwrite)
def read_index(
self,
filename: str,
bucket: Optional[str] = None,
path: Optional[str] = None,
):
fn = self.download_file_from_blobstore(filename, bucket, path)
logger.info(f"Loading index {fn}")
ext = os.path.splitext(fn)[1]
if ext in [".faiss", ".codec"]:
index = faiss.read_index(fn)
elif ext == ".pkl":
with open(fn, "rb") as model_file:
model = pickle.load(model_file)
rcq_coarse_quantizer, itq_encoder = model["model"]
index = merge_rcq_itq(rcq_coarse_quantizer, itq_encoder)
logger.info(f"Loaded index from {fn}")
return index
def write_index(
self,
index: faiss.Index,
filename: str,
):
fn = self.get_local_filename(filename)
logger.info(f"Saving index to {fn}")
faiss.write_index(index, fn)
self.upload_file_to_blobstore(filename)

View File

@ -1,15 +1,21 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
@dataclass
class IndexDescriptor:
factory: str
bucket: Optional[str] = None
# either path or factory should be set,
# but not both at the same time.
path: Optional[str] = None
parameters: Optional[dict[str, int]] = None
factory: Optional[str] = None
construction_params: Optional[List[Dict[str, int]]] = None
search_params: Optional[Dict[str, int]] = None
# range metric definitions
# key: name
# value: one of the following:
@ -25,14 +31,38 @@ class IndexDescriptor:
# [[radius1_from, radius1_to, score1], ...]
# [radius1_from, radius1_to) -> score1,
# [radius2_from, radius2_to) -> score2
range_metrics: Optional[dict[str, Any]] = None
range_metrics: Optional[Dict[str, Any]] = None
radius: Optional[float] = None
@dataclass
class DatasetDescriptor:
# namespace possible values:
# 1. a hive namespace
# 2. 'std_t', 'std_d', 'std_q' for the standard datasets
# via faiss.contrib.datasets.dataset_from_name()
# t - training, d - database, q - queries
# eg. "std_t"
# 3. 'syn' for synthetic data
# 4. None for local files
namespace: Optional[str] = None
# tablename possible values, corresponding to the
# namespace value above:
# 1. a hive table name
# 2. name of the standard dataset as recognized
# by faiss.contrib.datasets.dataset_from_name()
# eg. "bigann1M"
# 3. d_seed, eg. 128_1234 for 128 dimensional vectors
# with seed 1234
# 4. a local file name (relative to benchmark_io.path)
tablename: Optional[str] = None
# partition names and values for hive
# eg. ["ds=2021-09-01"]
partitions: Optional[List[str]] = None
# number of vectors to load from the dataset
num_vectors: Optional[int] = None
def __hash__(self):
@ -40,9 +70,11 @@ class DatasetDescriptor:
def get_filename(
self,
prefix: str = "v",
prefix: str = None,
) -> str:
filename = prefix + "_"
filename = ""
if prefix is not None:
filename += prefix + "_"
if self.namespace is not None:
filename += self.namespace + "_"
assert self.tablename is not None

View File

@ -0,0 +1,785 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass
from multiprocessing.pool import ThreadPool
from time import perf_counter
from typing import ClassVar, Dict, List, Optional
import faiss # @manual=//faiss/python:pyfaiss_gpu
import numpy as np
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
OperatingPointsWithRanges,
)
from faiss.contrib.factory_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
reverse_index_factory,
)
from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
add_preassigned,
replace_ivf_quantizer,
)
from .descriptors import DatasetDescriptor
logger = logging.getLogger(__name__)
def timer(name, func, once=False) -> float:
logger.info(f"Measuring {name}")
t1 = perf_counter()
res = func()
t2 = perf_counter()
t = t2 - t1
repeat = 1
if not once and t < 1.0:
repeat = int(2.0 // t)
logger.info(
f"Time for {name}: {t:.3f} seconds, repeating {repeat} times"
)
t1 = perf_counter()
for _ in range(repeat):
res = func()
t2 = perf_counter()
t = (t2 - t1) / repeat
logger.info(f"Time for {name}: {t:.3f} seconds")
return res, t, repeat
def refine_distances_knn(
D: np.ndarray, I: np.ndarray, xq: np.ndarray, xb: np.ndarray, metric
):
return np.where(
I >= 0,
np.square(np.linalg.norm(xq[:, None] - xb[I], axis=2))
if metric == faiss.METRIC_L2
else np.einsum("qd,qkd->qk", xq, xb[I]),
D,
)
def refine_distances_range(
lims: np.ndarray,
D: np.ndarray,
I: np.ndarray,
xq: np.ndarray,
xb: np.ndarray,
metric,
):
with ThreadPool(32) as pool:
R = pool.map(
lambda i: (
np.sum(np.square(xq[i] - xb[I[lims[i]:lims[i + 1]]]), axis=1)
if metric == faiss.METRIC_L2
else np.tensordot(
xq[i], xb[I[lims[i]:lims[i + 1]]], axes=(0, 1)
)
)
if lims[i + 1] > lims[i]
else [],
range(len(lims) - 1),
)
return np.hstack(R)
# The classes below are wrappers around Faiss indices, with different
# implementations for the case when we start with an already trained
# index (IndexFromCodec) vs factory strings (IndexFromFactory).
# In both cases the classes have operations for adding to an index
# and searching it, and outputs are cached on disk.
# IndexFromFactory also decomposes the index (pretransform and quantizer)
# and trains sub-indices independently.
class IndexBase:
def set_io(self, benchmark_io):
self.io = benchmark_io
@staticmethod
def param_dict_list_to_name(param_dict_list):
if not param_dict_list:
return ""
l = 0
n = ""
for param_dict in param_dict_list:
n += IndexBase.param_dict_to_name(param_dict, f"cp{l}")
return n
@staticmethod
def param_dict_to_name(param_dict, prefix="sp"):
if not param_dict:
return ""
n = prefix
for name, val in param_dict.items():
if name != "noop":
n += f"_{name}_{val}"
if n == prefix:
return ""
n += "."
return n
@staticmethod
def set_index_param_dict_list(index, param_dict_list):
if not param_dict_list:
return
index = faiss.downcast_index(index)
for param_dict in param_dict_list:
assert index is not None
IndexBase.set_index_param_dict(index, param_dict)
index = faiss.try_extract_index_ivf(index)
@staticmethod
def set_index_param_dict(index, param_dict):
if not param_dict:
return
for name, val in param_dict.items():
IndexBase.set_index_param(index, name, val)
@staticmethod
def set_index_param(index, name, val):
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexPreTransform):
Index.set_index_param(index.index, name, val)
elif name == "efSearch":
index.hnsw.efSearch
index.hnsw.efSearch = int(val)
elif name == "efConstruction":
index.hnsw.efConstruction
index.hnsw.efConstruction = int(val)
elif name == "nprobe":
index_ivf = faiss.extract_index_ivf(index)
index_ivf.nprobe
index_ivf.nprobe = int(val)
elif name == "k_factor":
index.k_factor
index.k_factor = int(val)
elif name == "parallel_mode":
index_ivf = faiss.extract_index_ivf(index)
index_ivf.parallel_mode
index_ivf.parallel_mode = int(val)
elif name == "noop":
pass
else:
raise RuntimeError(f"could not set param {name} on {index}")
def is_flat(self):
codec = faiss.downcast_index(self.get_model())
return isinstance(codec, faiss.IndexFlat)
def is_ivf(self):
codec = self.get_model()
return faiss.try_extract_index_ivf(codec) is not None
def is_pretransform(self):
codec = self.get_model()
if isinstance(codec, faiss.IndexRefine):
codec = faiss.downcast_index(codec.base_index)
return isinstance(codec, faiss.IndexPreTransform)
# index is a codec + database vectors
# in other words: a trained Faiss index
# that contains database vectors
def get_index_name(self):
raise NotImplementedError
def get_index(self):
raise NotImplementedError
# codec is a trained model
# in other words: a trained Faiss index
# without any database vectors
def get_codec_name(self):
raise NotImplementedError
def get_codec(self):
raise NotImplementedError
# model is an untrained Faiss index
# it can be used for training (see codec)
# or to inspect its structure
def get_model_name(self):
raise NotImplementedError
def get_model(self):
raise NotImplementedError
def transform(self, vectors):
transformed_vectors = DatasetDescriptor(
tablename=f"{vectors.get_filename()}{self.get_codec_name()}transform.npy"
)
if not self.io.file_exist(transformed_vectors.tablename):
codec = self.fetch_codec()
assert isinstance(codec, faiss.IndexPreTransform)
transform = faiss.downcast_VectorTransform(codec.chain.at(0))
x = self.io.get_dataset(vectors)
xt = transform.apply(x)
self.io.write_nparray(xt, transformed_vectors.tablename)
return transformed_vectors
def knn_search_quantizer(self, index, query_vectors, k):
if self.is_pretransform():
pretransform = self.get_pretransform()
quantizer_query_vectors = pretransform.transform(query_vectors)
else:
pretransform = None
quantizer_query_vectors = query_vectors
QD, QI, _, QP = self.get_quantizer(pretransform).knn_search(
search_parameters=None,
query_vectors=quantizer_query_vectors,
k=k,
)
xqt = self.io.get_dataset(quantizer_query_vectors)
return xqt, QD, QI, QP
def get_knn_search_name(
self,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
k: int,
):
name = self.get_index_name()
name += Index.param_dict_to_name(search_parameters)
name += query_vectors.get_filename("q")
name += f"k_{k}."
return name
def knn_search(
self,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
k: int,
):
logger.info("knn_seach: begin")
filename = (
self.get_knn_search_name(search_parameters, query_vectors, k)
+ "zip"
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {filename}")
D, I, R, P = self.io.read_file(filename, ["D", "I", "R", "P"])
else:
xq = self.io.get_dataset(query_vectors)
index = self.get_index()
Index.set_index_param_dict(index, search_parameters)
if self.is_ivf():
xqt, QD, QI, QP = self.knn_search_quantizer(
index, query_vectors, search_parameters["nprobe"]
)
index_ivf = faiss.extract_index_ivf(index)
if index_ivf.parallel_mode != 2:
logger.info("Setting IVF parallel mode")
index_ivf.parallel_mode = 2
(D, I), t, repeat = timer(
"knn_search_preassigned",
lambda: index_ivf.search_preassigned(xqt, k, QI, QD),
)
else:
(D, I), t, _ = timer("knn_search", lambda: index.search(xq, k))
if self.is_flat() or not hasattr(self, "database_vectors"): # TODO
R = D
else:
xb = self.io.get_dataset(self.database_vectors)
R = refine_distances_knn(D, I, xq, xb, self.metric_type)
P = {
"time": t,
"index": self.get_index_name(),
"codec": self.get_codec_name(),
"factory": self.factory if hasattr(self, "factory") else "",
"search_params": search_parameters,
"k": k,
}
if self.is_ivf():
stats = faiss.cvar.indexIVF_stats
P |= {
"quantizer": QP,
"nq": int(stats.nq // repeat),
"nlist": int(stats.nlist // repeat),
"ndis": int(stats.ndis // repeat),
"nheap_updates": int(stats.nheap_updates // repeat),
"quantization_time": int(
stats.quantization_time // repeat
),
"search_time": int(stats.search_time // repeat),
}
self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P])
logger.info("knn_seach: end")
return D, I, R, P
def range_search(
self,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
radius: Optional[float] = None,
):
logger.info("range_search: begin")
filename = (
self.get_range_search_name(
search_parameters, query_vectors, radius
)
+ "zip"
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {filename}")
lims, D, I, R, P = self.io.read_file(
filename, ["lims", "D", "I", "R", "P"]
)
else:
xq = self.io.get_dataset(query_vectors)
index = self.get_index()
Index.set_index_param_dict(index, search_parameters)
if self.is_ivf():
xqt, QD, QI, QP = self.knn_search_quantizer(
index, query_vectors, search_parameters["nprobe"]
)
index_ivf = faiss.extract_index_ivf(index)
if index_ivf.parallel_mode != 2:
logger.info("Setting IVF parallel mode")
index_ivf.parallel_mode = 2
(lims, D, I), t, repeat = timer(
"range_search_preassigned",
lambda: index_ivf.range_search_preassigned(
xqt, radius, QI, QD
),
)
else:
(lims, D, I), t, _ = timer(
"range_search", lambda: index.range_search(xq, radius)
)
if self.is_flat():
R = D
else:
xb = self.io.get_dataset(self.database_vectors)
R = refine_distances_range(
lims, D, I, xq, xb, self.metric_type
)
P = {
"time": t,
"index": self.get_codec_name(),
"codec": self.get_codec_name(),
"search_params": search_parameters,
"radius": radius,
"count": len(I),
}
if self.is_ivf():
stats = faiss.cvar.indexIVF_stats
P |= {
"quantizer": QP,
"nq": int(stats.nq // repeat),
"nlist": int(stats.nlist // repeat),
"ndis": int(stats.ndis // repeat),
"nheap_updates": int(stats.nheap_updates // repeat),
"quantization_time": int(
stats.quantization_time // repeat
),
"search_time": int(stats.search_time // repeat),
}
self.io.write_file(
filename, ["lims", "D", "I", "R", "P"], [lims, D, I, R, P]
)
logger.info("range_seach: end")
return lims, D, I, R, P
# Common base for IndexFromCodec and IndexFromFactory,
# but not for the sub-indices of codec-based indices
# IndexFromQuantizer and IndexFromPreTransform, because
# they share the configuration of their parent IndexFromCodec
@dataclass
class Index(IndexBase):
d: int
metric: str
database_vectors: DatasetDescriptor
construction_params: List[Dict[str, int]]
search_params: Dict[str, int]
cached_codec_name: ClassVar[str] = None
cached_codec: ClassVar[faiss.Index] = None
cached_index_name: ClassVar[str] = None
cached_index: ClassVar[faiss.Index] = None
def __post_init__(self):
if isinstance(self.metric, str):
if self.metric == "IP":
self.metric_type = faiss.METRIC_INNER_PRODUCT
elif self.metric == "L2":
self.metric_type = faiss.METRIC_L2
else:
raise ValueError
elif isinstance(self.metric, int):
self.metric_type = self.metric
if self.metric_type == faiss.METRIC_INNER_PRODUCT:
self.metric = "IP"
elif self.metric_type == faiss.METRIC_L2:
self.metric = "L2"
else:
raise ValueError
else:
raise ValueError
def supports_range_search(self):
codec = self.get_codec()
return not type(codec) in [
faiss.IndexHNSWFlat,
faiss.IndexIVFFastScan,
faiss.IndexRefine,
faiss.IndexPQ,
]
def fetch_codec(self):
raise NotImplementedError
def train(self):
# get triggers a train, if necessary
self.get_codec()
def get_codec(self):
codec_name = self.get_codec_name()
if Index.cached_codec_name != codec_name:
Index.cached_codec = self.fetch_codec()
Index.cached_codec_name = codec_name
return Index.cached_codec
def get_index_name(self):
name = self.get_codec_name()
assert self.database_vectors is not None
name += self.database_vectors.get_filename("xb")
return name
def fetch_index(self):
index = faiss.clone_index(self.get_codec())
assert index.ntotal == 0
logger.info("Adding vectors to index")
xb = self.io.get_dataset(self.database_vectors)
if self.is_ivf():
xbt, QD, QI, QP = self.knn_search_quantizer(
index, self.database_vectors, 1
)
index_ivf = faiss.extract_index_ivf(index)
if index_ivf.parallel_mode != 2:
logger.info("Setting IVF parallel mode")
index_ivf.parallel_mode = 2
_, t, _ = timer(
"add_preassigned",
lambda: add_preassigned(index_ivf, xbt, QI.ravel()),
once=True,
)
else:
_, t, _ = timer(
"add",
lambda: index.add(xb),
once=True,
)
assert index.ntotal == xb.shape[0] or index_ivf.ntotal == xb.shape[0]
logger.info("Added vectors to index")
return index
def get_index(self):
index_name = self.get_index_name()
if Index.cached_index_name != index_name:
Index.cached_index = self.fetch_index()
Index.cached_index_name = index_name
return Index.cached_index
def get_code_size(self):
def get_index_code_size(index):
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexPreTransform):
return get_index_code_size(index.index)
elif isinstance(index, faiss.IndexHNSWFlat):
return index.d * 4 # TODO
elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]:
return get_index_code_size(
index.base_index
) + get_index_code_size(index.refine_index)
else:
return index.code_size
codec = self.get_codec()
return get_index_code_size(codec)
def get_operating_points(self):
op = OperatingPointsWithRanges()
def add_range_or_val(name, range):
op.add_range(
name,
[self.search_params[name]]
if self.search_params and name in self.search_params
else range,
)
op.add_range("noop", [0])
codec = faiss.downcast_index(self.get_codec())
codec_ivf = faiss.try_extract_index_ivf(codec)
if codec_ivf is not None:
add_range_or_val(
"nprobe",
[
2**i
for i in range(12)
if 2**i <= codec_ivf.nlist * 0.25
],
)
if isinstance(codec, faiss.IndexRefine):
add_range_or_val(
"k_factor",
[2**i for i in range(11)],
)
if isinstance(codec, faiss.IndexHNSWFlat):
add_range_or_val(
"efSearch",
[2**i for i in range(3, 11)],
)
return op
def get_range_search_name(
self,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
radius: Optional[float] = None,
):
name = self.get_index_name()
name += Index.param_dict_to_name(search_parameters)
name += query_vectors.get_filename("q")
if radius is not None:
name += f"r_{int(radius * 1000)}."
else:
name += "r_auto."
return name
# IndexFromCodec, IndexFromQuantizer and IndexFromPreTransform
# are used to wrap pre-trained Faiss indices (codecs)
@dataclass
class IndexFromCodec(Index):
path: str
bucket: Optional[str] = None
def get_quantizer(self):
if not self.is_ivf():
raise ValueError("Not an IVF index")
quantizer = IndexFromQuantizer(self)
quantizer.set_io(self.io)
return quantizer
def get_pretransform(self):
if not self.is_ivf():
raise ValueError("Not an IVF index")
quantizer = IndexFromPreTransform(self)
quantizer.set_io(self.io)
return quantizer
def get_codec_name(self):
assert self.path is not None
name = os.path.basename(self.path)
name += Index.param_dict_list_to_name(self.construction_params)
return name
def fetch_codec(self):
codec = self.io.read_index(
os.path.basename(self.path),
self.bucket,
os.path.dirname(self.path),
)
assert self.d == codec.d
assert self.metric_type == codec.metric_type
Index.set_index_param_dict_list(codec, self.construction_params)
return codec
def get_model(self):
return self.get_codec()
class IndexFromQuantizer(IndexBase):
ivf_index: Index
def __init__(self, ivf_index: Index):
self.ivf_index = ivf_index
super().__init__()
def get_codec_name(self):
return self.get_index_name()
def get_codec(self):
return self.get_index()
def get_index_name(self):
ivf_codec_name = self.ivf_index.get_codec_name()
return f"{ivf_codec_name}quantizer."
def get_index(self):
ivf_codec = faiss.extract_index_ivf(self.ivf_index.get_codec())
return ivf_codec.quantizer
class IndexFromPreTransform(IndexBase):
pre_transform_index: Index
def __init__(self, pre_transform_index: Index):
self.pre_transform_index = pre_transform_index
super().__init__()
def get_codec_name(self):
pre_transform_codec_name = self.pre_transform_index.get_codec_name()
return f"{pre_transform_codec_name}pretransform."
def get_codec(self):
return self.get_codec()
# IndexFromFactory is for creating and training indices from scratch
@dataclass
class IndexFromFactory(Index):
factory: str
training_vectors: DatasetDescriptor
def get_codec_name(self):
assert self.factory is not None
name = f"{self.factory.replace(',', '_')}."
assert self.d is not None
assert self.metric is not None
name += f"d_{self.d}.{self.metric.upper()}."
if self.factory != "Flat":
assert self.training_vectors is not None
name += self.training_vectors.get_filename("xt")
name += Index.param_dict_list_to_name(self.construction_params)
return name
def fetch_codec(self):
codec_filename = self.get_codec_name() + "codec"
if self.io.file_exist(codec_filename):
codec = self.io.read_index(codec_filename)
assert self.d == codec.d
assert self.metric_type == codec.metric_type
else:
codec = self.assemble()
if self.factory != "Flat":
self.io.write_index(codec, codec_filename)
return codec
def get_model(self):
model = faiss.index_factory(self.d, self.factory, self.metric_type)
Index.set_index_param_dict_list(model, self.construction_params)
return model
def get_pretransform(self):
model = faiss.index_factory(self.d, self.factory, self.metric_type)
assert isinstance(model, faiss.IndexPreTransform)
sub_index = faiss.downcast_index(model.index)
if isinstance(sub_index, faiss.IndexFlat):
return self
# replace the sub-index with Flat
codec = faiss.clone_index(model)
codec.index = faiss.IndexFlat(codec.index.d, codec.index.metric_type)
pretransform = IndexFromFactory(
d=codec.d,
metric=codec.metric_type,
database_vectors=self.database_vectors,
construction_params=self.construction_params,
search_params=self.search_params,
factory=reverse_index_factory(codec),
training_vectors=self.training_vectors,
)
pretransform.set_io(self.io)
return pretransform
def get_quantizer(self, pretransform=None):
model = self.get_model()
model_ivf = faiss.extract_index_ivf(model)
assert isinstance(model_ivf, faiss.IndexIVF)
assert ord(model_ivf.quantizer_trains_alone) in [0, 2]
if pretransform is None:
training_vectors = self.training_vectors
else:
training_vectors = pretransform.transform(self.training_vectors)
centroids = self.k_means(training_vectors, model_ivf.nlist)
quantizer = IndexFromFactory(
d=model_ivf.quantizer.d,
metric=model_ivf.quantizer.metric_type,
database_vectors=centroids,
construction_params=None, # self.construction_params[1:],
search_params=None, # self.construction_params[0], # TODO: verify
factory=reverse_index_factory(model_ivf.quantizer),
training_vectors=centroids,
)
quantizer.set_io(self.io)
return quantizer
def k_means(self, vectors, k):
kmeans_vectors = DatasetDescriptor(
tablename=f"{vectors.get_filename()}kmeans_{k}.npy"
)
if not self.io.file_exist(kmeans_vectors.tablename):
x = self.io.get_dataset(vectors)
kmeans = faiss.Kmeans(d=x.shape[1], k=k, gpu=True)
kmeans.train(x)
self.io.write_nparray(kmeans.centroids, kmeans_vectors.tablename)
return kmeans_vectors
def assemble(self):
model = self.get_model()
codec = faiss.clone_index(model)
if isinstance(model, faiss.IndexPreTransform):
sub_index = faiss.downcast_index(model.index)
if not isinstance(sub_index, faiss.IndexFlat):
# replace the sub-index with Flat and fetch pre-trained
pretransform = self.get_pretransform()
codec = pretransform.fetch_codec()
assert codec.is_trained
transformed_training_vectors = pretransform.transform(
self.training_vectors
)
transformed_database_vectors = pretransform.transform(
self.database_vectors
)
# replace the Flat index with the required sub-index
wrapper = IndexFromFactory(
d=sub_index.d,
metric=sub_index.metric_type,
database_vectors=transformed_database_vectors,
construction_params=self.construction_params,
search_params=self.search_params,
factory=reverse_index_factory(sub_index),
training_vectors=transformed_training_vectors,
)
wrapper.set_io(self.io)
codec.index = wrapper.fetch_codec()
assert codec.index.is_trained
elif isinstance(model, faiss.IndexIVF):
# replace the quantizer
quantizer = self.get_quantizer()
replace_ivf_quantizer(codec, quantizer.fetch_index())
assert codec.quantizer.is_trained
assert codec.nlist == codec.quantizer.ntotal
elif isinstance(model, faiss.IndexRefine) or isinstance(
model, faiss.IndexRefineFlat
):
# replace base_index
wrapper = IndexFromFactory(
d=model.base_index.d,
metric=model.base_index.metric_type,
database_vectors=self.database_vectors,
construction_params=self.construction_params,
search_params=self.search_params,
factory=reverse_index_factory(model.base_index),
training_vectors=self.training_vectors,
)
wrapper.set_io(self.io)
codec.base_index = wrapper.fetch_codec()
assert codec.base_index.is_trained
xt = self.io.get_dataset(self.training_vectors)
codec.train(xt)
return codec

View File

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from bench_fw.benchmark import Benchmark
from bench_fw.benchmark_io import BenchmarkIO
from bench_fw.descriptors import DatasetDescriptor, IndexDescriptor
logging.basicConfig(level=logging.INFO)
benchmark = Benchmark(
training_vectors=DatasetDescriptor(
namespace="std_d", tablename="sift1M"
),
database_vectors=DatasetDescriptor(
namespace="std_d", tablename="sift1M"
),
query_vectors=DatasetDescriptor(
namespace="std_q", tablename="sift1M"
),
index_descs=[
IndexDescriptor(
factory=f"IVF{2 ** nlist},Flat",
)
for nlist in range(8, 15)
],
k=1,
distance_metric="L2",
)
io = BenchmarkIO(
path="/checkpoint",
)
benchmark.set_io(io)
print(benchmark.benchmark("result.json"))

View File

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from bench_fw.benchmark import Benchmark
from bench_fw.benchmark_io import BenchmarkIO
from bench_fw.descriptors import DatasetDescriptor, IndexDescriptor
logging.basicConfig(level=logging.INFO)
benchmark = Benchmark(
training_vectors=DatasetDescriptor(
tablename="training.npy", num_vectors=200000
),
database_vectors=DatasetDescriptor(
tablename="database.npy", num_vectors=200000
),
query_vectors=DatasetDescriptor(tablename="query.npy", num_vectors=2000),
index_descs=[
IndexDescriptor(
factory="Flat",
range_metrics={
"weighted": [
[0.1, 0.928],
[0.2, 0.865],
[0.3, 0.788],
[0.4, 0.689],
[0.5, 0.49],
[0.6, 0.308],
[0.7, 0.193],
[0.8, 0.0],
]
},
),
IndexDescriptor(
factory="OPQ32_128,IVF512,PQ32",
),
IndexDescriptor(
factory="OPQ32_256,IVF512,PQ32",
),
IndexDescriptor(
factory="HNSW32",
construction_params=[
{
"efConstruction": 64,
}
],
),
],
k=10,
distance_metric="L2",
range_ref_index_desc="Flat",
)
io = BenchmarkIO(
path="/checkpoint",
)
benchmark.set_io(io)
print(benchmark.benchmark("result.json"))

View File

@ -72,6 +72,9 @@ def get_code_size(d, indexkey):
raise RuntimeError("cannot parse " + indexkey)
def get_hnsw_M(index):
return index.hnsw.cum_nneighbor_per_level.at(1) // 2
def reverse_index_factory(index):
"""
@ -80,21 +83,47 @@ def reverse_index_factory(index):
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexFlat):
return "Flat"
if isinstance(index, faiss.IndexIVF):
elif isinstance(index, faiss.IndexIVF):
quantizer = faiss.downcast_index(index.quantizer)
if isinstance(quantizer, faiss.IndexFlat):
prefix = "IVF%d" % index.nlist
prefix = f"IVF{index.nlist}"
elif isinstance(quantizer, faiss.MultiIndexQuantizer):
prefix = "IMI%dx%d" % (quantizer.pq.M, quantizer.pq.nbit)
prefix = f"IMI{quantizer.pq.M}x{quantizer.pq.nbits}"
elif isinstance(quantizer, faiss.IndexHNSW):
prefix = "IVF%d_HNSW%d" % (index.nlist, quantizer.hnsw.M)
prefix = f"IVF{index.nlist}_HNSW{get_hnsw_M(quantizer)}"
else:
prefix = "IVF%d(%s)" % (index.nlist, reverse_index_factory(quantizer))
prefix = f"IVF{index.nlist}({reverse_index_factory(quantizer)})"
if isinstance(index, faiss.IndexIVFFlat):
return prefix + ",Flat"
if isinstance(index, faiss.IndexIVFScalarQuantizer):
return prefix + ",SQ8"
if isinstance(index, faiss.IndexIVFPQ):
return prefix + f",PQ{index.pq.M}x{index.pq.nbits}"
elif isinstance(index, faiss.IndexPreTransform):
assert index.chain.size() == 1
vt = faiss.downcast_VectorTransform(index.chain.at(0))
if isinstance(vt, faiss.OPQMatrix):
return f"OPQ{vt.M}_{vt.d_out},{reverse_index_factory(index.index)}"
elif isinstance(index, faiss.IndexHNSW):
return f"HNSW{get_hnsw_M(index)}"
elif isinstance(index, faiss.IndexRefine):
return f"{reverse_index_factory(index.base_index)},Refine({reverse_index_factory(index.refine_index)})"
elif isinstance(index, faiss.IndexPQFastScan):
return f"PQ{index.pq.M}x{index.pq.nbits}fs"
elif isinstance(index, faiss.IndexScalarQuantizer):
sqtypes = {
faiss.ScalarQuantizer.QT_8bit: "8",
faiss.ScalarQuantizer.QT_4bit: "4",
faiss.ScalarQuantizer.QT_6bit: "6",
faiss.ScalarQuantizer.QT_fp16: "fp16",
}
return f"SQ{sqtypes[index.sq.qtype]}"
raise NotImplementedError()