index optimizer (#3154)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3154

Using the benchmark to find Pareto optimal indices, in this case on BigANN as an example.

Separately optimize the coarse quantizer and the vector codec and use Pareto optimal configurations to construct IVF indices, which are then retested at various scales. See `optimize()` in `optimize.py` as the main function driving the process.

The results can be interpreted with `bench_fw_notebook.ipynb`, which allows:
* filtering by maximum code size
* maximum time
* minimum accuracy
* space or time Pareto optimal options
* and visualize the results and output them as a table.

This version is intentionally limited to IVF(Flat|HNSW),PQ|SQ indices...

Reviewed By: mdouze

Differential Revision: D51781670

fbshipit-source-id: 2c0f800d374ea845255934f519cc28095c00a51f
pull/3233/head
Gergely Szilvasy 2024-01-30 10:58:13 -08:00 committed by Facebook GitHub Bot
parent 75ae0bfb7f
commit 1d0e8d489f
8 changed files with 1319 additions and 751 deletions

View File

@ -7,19 +7,20 @@ import logging
from copy import copy
from dataclasses import dataclass
from operator import itemgetter
from statistics import median, mean
from statistics import mean, median
from typing import Any, Dict, List, Optional
from .utils import dict_merge
from .index import Index, IndexFromCodec, IndexFromFactory
from .descriptors import DatasetDescriptor, IndexDescriptor
import faiss # @manual=//faiss/python:pyfaiss_gpu
import numpy as np
from scipy.optimize import curve_fit
from .descriptors import DatasetDescriptor, IndexDescriptor
from .index import Index, IndexFromCodec, IndexFromFactory
from .utils import dict_merge
logger = logging.getLogger(__name__)
@ -274,8 +275,8 @@ class Benchmark:
search_parameters: Optional[Dict[str, int]],
radius: Optional[float] = None,
gt_radius: Optional[float] = None,
range_search_metric_function = None,
gt_rsm = None,
range_search_metric_function=None,
gt_rsm=None,
):
logger.info("range_search: begin")
if radius is None:
@ -328,7 +329,13 @@ class Benchmark:
logger.info("knn_ground_truth: begin")
flat_desc = self.get_index_desc("Flat")
self.build_index_wrapper(flat_desc)
self.gt_knn_D, self.gt_knn_I, _, _, requires = flat_desc.index.knn_search(
(
self.gt_knn_D,
self.gt_knn_I,
_,
_,
requires,
) = flat_desc.index.knn_search(
dry_run=False,
search_parameters=None,
query_vectors=self.query_vectors,
@ -338,13 +345,13 @@ class Benchmark:
logger.info("knn_ground_truth: end")
def search_benchmark(
self,
self,
name,
search_func,
key_func,
cost_metrics,
perf_metrics,
results: Dict[str, Any],
results: Dict[str, Any],
index: Index,
):
index_name = index.get_index_name()
@ -376,11 +383,18 @@ class Benchmark:
logger.info(f"{name}_benchmark: end")
return results, requires
def knn_search_benchmark(self, dry_run, results: Dict[str, Any], index: Index):
def knn_search_benchmark(
self, dry_run, results: Dict[str, Any], index: Index
):
return self.search_benchmark(
name="knn_search",
search_func=lambda parameters: index.knn_search(
dry_run, parameters, self.query_vectors, self.k, self.gt_knn_I, self.gt_knn_D,
dry_run,
parameters,
self.query_vectors,
self.k,
self.gt_knn_I,
self.gt_knn_D,
)[3:],
key_func=lambda parameters: index.get_knn_search_name(
search_parameters=parameters,
@ -394,11 +408,17 @@ class Benchmark:
index=index,
)
def reconstruct_benchmark(self, dry_run, results: Dict[str, Any], index: Index):
def reconstruct_benchmark(
self, dry_run, results: Dict[str, Any], index: Index
):
return self.search_benchmark(
name="reconstruct",
search_func=lambda parameters: index.reconstruct(
dry_run, parameters, self.query_vectors, self.k, self.gt_knn_I,
dry_run,
parameters,
self.query_vectors,
self.k,
self.gt_knn_I,
),
key_func=lambda parameters: index.get_knn_search_name(
search_parameters=parameters,
@ -426,19 +446,20 @@ class Benchmark:
return self.search_benchmark(
name="range_search",
search_func=lambda parameters: self.range_search(
dry_run=dry_run,
index=index,
search_parameters=parameters,
dry_run=dry_run,
index=index,
search_parameters=parameters,
radius=radius,
gt_radius=gt_radius,
range_search_metric_function=range_search_metric_function,
range_search_metric_function=range_search_metric_function,
gt_rsm=gt_rsm,
)[4:],
key_func=lambda parameters: index.get_range_search_name(
search_parameters=parameters,
query_vectors=self.query_vectors,
radius=radius,
) + metric_key,
)
+ metric_key,
cost_metrics=["time"],
perf_metrics=["range_score_max_recall"],
results=results,
@ -446,11 +467,12 @@ class Benchmark:
)
def build_index_wrapper(self, index_desc: IndexDescriptor):
if hasattr(index_desc, 'index'):
if hasattr(index_desc, "index"):
return
if index_desc.factory is not None:
training_vectors = copy(self.training_vectors)
training_vectors.num_vectors = index_desc.training_size
if index_desc.training_size is not None:
training_vectors.num_vectors = index_desc.training_size
index = IndexFromFactory(
num_threads=self.num_threads,
d=self.d,
@ -481,15 +503,24 @@ class Benchmark:
training_vectors=self.training_vectors,
database_vectors=self.database_vectors,
query_vectors=self.query_vectors,
index_descs = [self.get_index_desc("Flat"), index_desc],
index_descs=[self.get_index_desc("Flat"), index_desc],
range_ref_index_desc=self.range_ref_index_desc,
k=self.k,
distance_metric=self.distance_metric,
)
benchmark.set_io(self.io)
benchmark.set_io(self.io.clone())
return benchmark
def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescriptor, train, reconstruct, knn, range):
def benchmark_one(
self,
dry_run,
results: Dict[str, Any],
index_desc: IndexDescriptor,
train,
reconstruct,
knn,
range,
):
faiss.omp_set_num_threads(self.num_threads)
if not dry_run:
self.knn_ground_truth()
@ -531,9 +562,12 @@ class Benchmark:
)
assert requires is None
if self.range_ref_index_desc is None or not index_desc.index.supports_range_search():
if (
self.range_ref_index_desc is None
or not index_desc.index.supports_range_search()
):
return results, None
ref_index_desc = self.get_index_desc(self.range_ref_index_desc)
if ref_index_desc is None:
raise ValueError(
@ -550,7 +584,9 @@ class Benchmark:
coefficients,
coefficients_training_data,
) = self.range_search_reference(
ref_index_desc.index, ref_index_desc.search_params, range_metric
ref_index_desc.index,
ref_index_desc.search_params,
range_metric,
)
gt_rsm = self.range_ground_truth(
gt_radius, range_search_metric_function
@ -583,7 +619,15 @@ class Benchmark:
return results, None
def benchmark(self, result_file=None, local=False, train=False, reconstruct=False, knn=False, range=False):
def benchmark(
self,
result_file=None,
local=False,
train=False,
reconstruct=False,
knn=False,
range=False,
):
logger.info("begin evaluate")
faiss.omp_set_num_threads(self.num_threads)
@ -656,20 +700,34 @@ class Benchmark:
if current_todo:
results_one = {"indices": {}, "experiments": {}}
params = [(self.clone_one(index_desc), results_one, index_desc, train, reconstruct, knn, range) for index_desc in current_todo]
for result in self.io.launch_jobs(run_benchmark_one, params, local=local):
params = [
(
index_desc,
self.clone_one(index_desc),
results_one,
train,
reconstruct,
knn,
range,
)
for index_desc in current_todo
]
for result in self.io.launch_jobs(
run_benchmark_one, params, local=local
):
dict_merge(results, result)
todo = next_todo
todo = next_todo
if result_file is not None:
self.io.write_json(results, result_file, overwrite=True)
logger.info("end evaluate")
return results
def run_benchmark_one(params):
logger.info(params)
benchmark, results, index_desc, train, reconstruct, knn, range = params
index_desc, benchmark, results, train, reconstruct, knn, range = params
results, requires = benchmark.benchmark_one(
dry_run=False,
results=results,

View File

@ -10,13 +10,13 @@ import logging
import os
import pickle
from dataclasses import dataclass
import submitit
from typing import Any, List, Optional
from zipfile import ZipFile
import faiss # @manual=//faiss/python:pyfaiss_gpu
import numpy as np
import submitit
from faiss.contrib.datasets import ( # @manual=//faiss/contrib:faiss_contrib_gpu
dataset_from_name,
)
@ -47,6 +47,9 @@ def merge_rcq_itq(
class BenchmarkIO:
path: str
def clone(self):
return BenchmarkIO(path=self.path)
def __post_init__(self):
self.cached_ds = {}
@ -119,18 +122,27 @@ class BenchmarkIO:
def get_dataset(self, dataset):
if dataset not in self.cached_ds:
if dataset.namespace is not None and dataset.namespace[:4] == "std_":
if (
dataset.namespace is not None
and dataset.namespace[:4] == "std_"
):
if dataset.tablename not in self.cached_ds:
self.cached_ds[dataset.tablename] = dataset_from_name(
dataset.tablename,
)
p = dataset.namespace[4]
if p == "t":
self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_train(dataset.num_vectors)
self.cached_ds[dataset] = self.cached_ds[
dataset.tablename
].get_train(dataset.num_vectors)
elif p == "d":
self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_database()
self.cached_ds[dataset] = self.cached_ds[
dataset.tablename
].get_database()
elif p == "q":
self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_queries()
self.cached_ds[dataset] = self.cached_ds[
dataset.tablename
].get_queries()
else:
raise ValueError
elif dataset.namespace == "syn":
@ -233,8 +245,8 @@ class BenchmarkIO:
if local:
results = [func(p) for p in params]
return results
print(f'launching {len(params)} jobs')
executor = submitit.AutoExecutor(folder='/checkpoint/gsz/jobs')
logger.info(f"launching {len(params)} jobs")
executor = submitit.AutoExecutor(folder="/checkpoint/gsz/jobs")
executor.update_parameters(
nodes=1,
gpus_per_node=8,
@ -248,9 +260,9 @@ class BenchmarkIO:
slurm_constraint="bldg1",
)
jobs = executor.map_array(func, params)
print(f'launched {len(jobs)} jobs')
# for job, param in zip(jobs, params):
# print(f"{job.job_id=} {param=}")
logger.info(f"launched {len(jobs)} jobs")
for job, param in zip(jobs, params):
logger.info(f"{job.job_id=} {param[0]=}")
results = [job.result() for job in jobs]
print(f'received {len(results)} results')
print(f"received {len(results)} results")
return results

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
import faiss # @manual=//faiss/python:pyfaiss_gpu
from .utils import timer
logger = logging.getLogger(__name__)
@ -101,7 +102,9 @@ class DatasetDescriptor:
tablename=f"{self.get_filename()}kmeans_{k}.npy"
)
meta_filename = kmeans_vectors.tablename + ".json"
if not io.file_exist(kmeans_vectors.tablename) or not io.file_exist(meta_filename):
if not io.file_exist(kmeans_vectors.tablename) or not io.file_exist(
meta_filename
):
if dry_run:
return None, None, kmeans_vectors.tablename
x = io.get_dataset(self)

View File

@ -4,19 +4,19 @@
# LICENSE file in the root directory of this source tree.
from copy import copy
import logging
import os
from collections import OrderedDict
from copy import copy
from dataclasses import dataclass
from typing import ClassVar, Dict, List, Optional
import faiss # @manual=//faiss/python:pyfaiss_gpu
import numpy as np
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
OperatingPointsWithRanges,
knn_intersection_measure,
OperatingPointsWithRanges,
)
from faiss.contrib.factory_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
reverse_index_factory,
@ -27,7 +27,13 @@ from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_g
)
from .descriptors import DatasetDescriptor
from .utils import distance_ratio_measure, get_cpu_info, timer, refine_distances_knn, refine_distances_range
from .utils import (
distance_ratio_measure,
get_cpu_info,
refine_distances_knn,
refine_distances_range,
timer,
)
logger = logging.getLogger(__name__)
@ -106,7 +112,9 @@ class IndexBase:
icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
if isinstance(index, faiss.IndexProductLocalSearchQuantizer):
for i in range(index.plsq.nsplits):
lsq = faiss.downcast_Quantizer(index.plsq.subquantizer(i))
lsq = faiss.downcast_Quantizer(
index.plsq.subquantizer(i)
)
if lsq.icm_encoder_factory is None:
lsq.icm_encoder_factory = icm_encoder_factory
else:
@ -119,29 +127,39 @@ class IndexBase:
obj = faiss.extract_index_ivf(index)
elif name in ["use_beam_LUT", "max_beam_size"]:
if isinstance(index, faiss.IndexProductResidualQuantizer):
obj = [faiss.downcast_Quantizer(index.prq.subquantizer(i)) for i in range(index.prq.nsplits)]
obj = [
faiss.downcast_Quantizer(index.prq.subquantizer(i))
for i in range(index.prq.nsplits)
]
else:
obj = index.rq
elif name == "encode_ils_iters":
if isinstance(index, faiss.IndexProductLocalSearchQuantizer):
obj = [faiss.downcast_Quantizer(index.plsq.subquantizer(i)) for i in range(index.plsq.nsplits)]
obj = [
faiss.downcast_Quantizer(index.plsq.subquantizer(i))
for i in range(index.plsq.nsplits)
]
else:
obj = index.lsq
else:
obj = index
if not isinstance(obj, list):
obj = [obj]
for o in obj:
test = getattr(o, name)
if assert_same and not name == 'use_beam_LUT':
if assert_same and not name == "use_beam_LUT":
assert test == val
else:
setattr(o, name, val)
@staticmethod
def filter_index_param_dict_list(param_dict_list):
if param_dict_list is not None and param_dict_list[0] is not None and "k_factor" in param_dict_list[0]:
if (
param_dict_list is not None
and param_dict_list[0] is not None
and "k_factor" in param_dict_list[0]
):
filtered = copy(param_dict_list)
del filtered[0]["k_factor"]
return filtered
@ -153,6 +171,7 @@ class IndexBase:
return isinstance(model, faiss.IndexFlat)
def is_ivf(self):
return False
model = self.get_model()
return faiss.try_extract_index_ivf(model) is not None
@ -243,7 +262,9 @@ class IndexBase:
pretransform = None
quantizer_query_vectors = query_vectors
quantizer, _, _ = self.get_quantizer(dry_run=False, pretransform=pretransform)
quantizer, _, _ = self.get_quantizer(
dry_run=False, pretransform=pretransform
)
QD, QI, _, QP, _ = quantizer.knn_search(
dry_run=False,
search_parameters=None,
@ -300,7 +321,9 @@ class IndexBase:
# Index2Layer doesn't support search
xq = self.io.get_dataset(query_vectors)
xb = index.reconstruct_n(0, index.ntotal)
(D, I), t, _ = timer("knn_search 2layer", lambda: faiss.knn(xq, xb, k))
(D, I), t, _ = timer(
"knn_search 2layer", lambda: faiss.knn(xq, xb, k)
)
elif self.is_ivf() and not isinstance(index, faiss.IndexRefine):
index_ivf = faiss.extract_index_ivf(index)
nprobe = (
@ -310,7 +333,7 @@ class IndexBase:
else index_ivf.nprobe
)
xqt, QD, QI, QP = self.knn_search_quantizer(
query_vectors=query_vectors,
query_vectors=query_vectors,
k=nprobe,
)
if index_ivf.parallel_mode != 2:
@ -358,11 +381,19 @@ class IndexBase:
"construction_params": self.get_construction_params(),
"search_params": search_parameters,
"knn_intersection": knn_intersection_measure(
I, I_gt,
) if I_gt is not None else None,
I,
I_gt,
)
if I_gt is not None
else None,
"distance_ratio": distance_ratio_measure(
I, R, D_gt, self.metric_type,
) if D_gt is not None else None,
I,
R,
D_gt,
self.metric_type,
)
if D_gt is not None
else None,
}
logger.info("knn_search: end")
return D, I, R, P, None
@ -377,12 +408,14 @@ class IndexBase:
):
logger.info("reconstruct: begin")
filename = (
self.get_knn_search_name(parameters, query_vectors, k, reconstruct=True)
self.get_knn_search_name(
parameters, query_vectors, k, reconstruct=True
)
+ "zip"
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {filename}")
P, = self.io.read_file(filename, ["P"])
(P,) = self.io.read_file(filename, ["P"])
P["index"] = self.get_index_name()
P["codec"] = self.get_codec_name()
P["factory"] = self.get_model_name()
@ -395,15 +428,21 @@ class IndexBase:
codec_meta = self.fetch_meta()
Index.set_index_param_dict(codec, parameters)
xb = self.io.get_dataset(self.database_vectors)
xb_encoded, encode_t, _ = timer("sa_encode", lambda: codec.sa_encode(xb))
xb_encoded, encode_t, _ = timer(
"sa_encode", lambda: codec.sa_encode(xb)
)
xq = self.io.get_dataset(query_vectors)
if self.is_decode_supported():
xb_decoded, decode_t, _ = timer("sa_decode", lambda: codec.sa_decode(xb_encoded))
xb_decoded, decode_t, _ = timer(
"sa_decode", lambda: codec.sa_decode(xb_encoded)
)
mse = np.square(xb_decoded - xb).sum(axis=1).mean().item()
_, I = faiss.knn(xq, xb_decoded, k, metric=self.metric_type)
asym_recall = knn_intersection_measure(I, I_gt)
xq_decoded = codec.sa_decode(codec.sa_encode(xq))
_, I = faiss.knn(xq_decoded, xb_decoded, k, metric=self.metric_type)
_, I = faiss.knn(
xq_decoded, xb_decoded, k, metric=self.metric_type
)
else:
mse = None
asym_recall = None
@ -604,7 +643,7 @@ class Index(IndexBase):
if self.is_ivf() and not isinstance(index, faiss.IndexRefine):
xbt, QD, QI, QP = self.knn_search_quantizer(
query_vectors=self.database_vectors,
query_vectors=self.database_vectors,
k=1,
)
index_ivf = faiss.extract_index_ivf(index)
@ -638,22 +677,21 @@ class Index(IndexBase):
def get_construction_params(self):
return self.construction_params
# def get_code_size(self):
# def get_index_code_size(index):
# index = faiss.downcast_index(index)
# if isinstance(index, faiss.IndexPreTransform):
# return get_index_code_size(index.index)
# elif isinstance(index, faiss.IndexHNSWFlat):
# return index.d * 4 # TODO
# elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]:
# return get_index_code_size(
# index.base_index
# ) + get_index_code_size(index.refine_index)
# else:
# return index.code_size
def get_code_size(self, codec=None):
def get_index_code_size(index):
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexPreTransform):
return get_index_code_size(index.index)
elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]:
return get_index_code_size(
index.base_index
) + get_index_code_size(index.refine_index)
else:
return index.code_size if hasattr(index, "code_size") else 0
# codec = self.get_codec()
# return get_index_code_size(codec)
if codec is None:
codec = self.get_codec()
return get_index_code_size(codec)
def get_sa_code_size(self, codec=None):
if codec is None:
@ -680,32 +718,28 @@ class Index(IndexBase):
if model_ivf is not None:
add_range_or_val(
"nprobe",
# [
[2**i for i in range(12) if 2**i <= model_ivf.nlist * 0.5],
# [1, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28] + [
# i
# for i in range(32, 64, 8)
# if i <= model_ivf.nlist * 0.1
# ] + [
# i
# for i in range(64, 128, 16)
# if i <= model_ivf.nlist * 0.1
# ] + [
# i
# for i in range(128, 256, 32)
# if i <= model_ivf.nlist * 0.1
# ] + [
# i
# for i in range(256, 512, 64)
# if i <= model_ivf.nlist * 0.1
# ] + [
# 2**i
# for i in range(12)
# if 2**i <= model_ivf.nlist * 0.5
# for i in range(9, 12)
# if 2**i <= model_ivf.nlist * 0.1
# ],
[1, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28] + [
i
for i in range(32, 64, 8)
if i <= model_ivf.nlist * 0.1
] + [
i
for i in range(64, 128, 16)
if i <= model_ivf.nlist * 0.1
] + [
i
for i in range(128, 256, 32)
if i <= model_ivf.nlist * 0.1
] + [
i
for i in range(256, 512, 64)
if i <= model_ivf.nlist * 0.1
] + [
2**i
for i in range(9, 12)
if 2**i <= model_ivf.nlist * 0.1
],
)
model = faiss.downcast_index(model)
if isinstance(model, faiss.IndexRefine):
@ -718,7 +752,9 @@ class Index(IndexBase):
"efSearch",
[2**i for i in range(3, 11)],
)
elif isinstance(model, faiss.IndexResidualQuantizer) or isinstance(model, faiss.IndexProductResidualQuantizer):
elif isinstance(model, faiss.IndexResidualQuantizer) or isinstance(
model, faiss.IndexProductResidualQuantizer
):
add_range_or_val(
"max_beam_size",
[1, 2, 4, 8, 16, 32],
@ -727,7 +763,9 @@ class Index(IndexBase):
"use_beam_LUT",
[1],
)
elif isinstance(model, faiss.IndexLocalSearchQuantizer) or isinstance(model, faiss.IndexProductLocalSearchQuantizer):
elif isinstance(model, faiss.IndexLocalSearchQuantizer) or isinstance(
model, faiss.IndexProductLocalSearchQuantizer
):
add_range_or_val(
"encode_ils_iters",
[2, 4, 8, 16],
@ -854,7 +892,9 @@ class IndexFromFactory(Index):
def fetch_codec(self, dry_run=False):
codec_filename = self.get_codec_name() + "codec"
meta_filename = self.get_codec_name() + "json"
if self.io.file_exist(codec_filename) and self.io.file_exist(meta_filename):
if self.io.file_exist(codec_filename) and self.io.file_exist(
meta_filename
):
codec = self.io.read_index(codec_filename)
assert self.d == codec.d
assert self.metric_type == codec.metric_type
@ -874,6 +914,7 @@ class IndexFromFactory(Index):
"training_size": self.training_vectors.num_vectors,
"codec_size": codec_size,
"sa_code_size": self.get_sa_code_size(codec),
"code_size": self.get_code_size(codec),
"cpu": get_cpu_info(),
}
self.io.write_json(meta, meta_filename, overwrite=True)
@ -921,7 +962,9 @@ class IndexFromFactory(Index):
training_vectors = self.training_vectors
else:
training_vectors = pretransform.transform(self.training_vectors)
centroids, t, requires = training_vectors.k_means(self.io, model_ivf.nlist, dry_run)
centroids, t, requires = training_vectors.k_means(
self.io, model_ivf.nlist, dry_run
)
if requires is not None:
return None, None, requires
quantizer = IndexFromFactory(
@ -944,11 +987,11 @@ class IndexFromFactory(Index):
model = self.get_model()
opaque = True
t_aggregate = 0
try:
reverse_index_factory(model)
opaque = False
except NotImplementedError:
opaque = True
# try:
# reverse_index_factory(model)
# opaque = False
# except NotImplementedError:
# opaque = True
if opaque:
codec = model
else:
@ -958,7 +1001,9 @@ class IndexFromFactory(Index):
if not isinstance(sub_index, faiss.IndexFlat):
# replace the sub-index with Flat and fetch pre-trained
pretransform = self.get_pretransform()
codec, meta, report = pretransform.fetch_codec(dry_run=dry_run)
codec, meta, report = pretransform.fetch_codec(
dry_run=dry_run
)
if report is not None:
return None, None, report
t_aggregate += meta["training_time"]
@ -978,7 +1023,9 @@ class IndexFromFactory(Index):
training_vectors=transformed_training_vectors,
)
wrapper.set_io(self.io)
codec.index, meta, report = wrapper.fetch_codec(dry_run=dry_run)
codec.index, meta, report = wrapper.fetch_codec(
dry_run=dry_run
)
if report is not None:
return None, None, report
t_aggregate += meta["training_time"]
@ -1008,14 +1055,18 @@ class IndexFromFactory(Index):
d=model.base_index.d,
metric=model.base_index.metric_type,
database_vectors=self.database_vectors,
construction_params=IndexBase.filter_index_param_dict_list(self.construction_params),
construction_params=IndexBase.filter_index_param_dict_list(
self.construction_params
),
search_params=None,
factory=reverse_index_factory(model.base_index),
training_vectors=self.training_vectors,
)
wrapper.set_io(self.io)
codec = faiss.clone_index(model)
codec.base_index, meta, requires = wrapper.fetch_codec(dry_run=dry_run)
codec.base_index, meta, requires = wrapper.fetch_codec(
dry_run=dry_run
)
if requires is not None:
return None, None, requires
t_aggregate += meta["training_time"]

View File

@ -0,0 +1,333 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Dict, List, Tuple
import faiss # @manual=//faiss/python:pyfaiss_gpu
# from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
# OperatingPoints,
# )
from .benchmark import Benchmark
from .descriptors import DatasetDescriptor, IndexDescriptor
from .utils import dict_merge, filter_results, ParetoMetric, ParetoMode
logger = logging.getLogger(__name__)
@dataclass
class Optimizer:
distance_metric: str = "L2"
num_threads: int = 32
run_local: bool = True
def __post_init__(self):
self.cached_benchmark = None
if self.distance_metric == "IP":
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
elif self.distance_metric == "L2":
self.distance_metric_type = faiss.METRIC_L2
else:
raise ValueError
def set_io(self, benchmark_io):
self.io = benchmark_io
self.io.distance_metric = self.distance_metric
self.io.distance_metric_type = self.distance_metric_type
def benchmark_and_filter_candidates(
self,
index_descs,
training_vectors,
database_vectors,
query_vectors,
result_file,
include_flat,
min_accuracy,
pareto_metric,
):
benchmark = Benchmark(
num_threads=self.num_threads,
training_vectors=training_vectors,
database_vectors=database_vectors,
query_vectors=query_vectors,
index_descs=index_descs,
k=10,
distance_metric=self.distance_metric,
)
benchmark.set_io(self.io)
results = benchmark.benchmark(
result_file=result_file, local=self.run_local, train=True, knn=True
)
assert results
filtered = filter_results(
results=results,
evaluation="knn",
accuracy_metric="knn_intersection",
min_accuracy=min_accuracy,
name_filter=None
if include_flat
else (lambda n: not n.startswith("Flat")),
pareto_mode=ParetoMode.GLOBAL,
pareto_metric=pareto_metric,
)
assert filtered
index_descs = [
IndexDescriptor(
factory=v["factory"],
construction_params=v["construction_params"],
search_params=v["search_params"],
)
for _, _, _, _, v in filtered
]
return index_descs, filtered
def optimize_quantizer(
self,
training_vectors: DatasetDescriptor,
query_vectors: DatasetDescriptor,
nlists: List[int],
min_accuracy: float,
):
quantizer_descs = {}
for nlist in nlists:
# cluster
centroids, _, _ = training_vectors.k_means(
self.io,
nlist,
dry_run=False,
)
descs = [IndexDescriptor(factory="Flat"),] + [
IndexDescriptor(
factory="HNSW32",
construction_params=[{"efConstruction": 2**i}],
)
for i in range(6, 11)
]
descs, _ = self.benchmark_and_filter_candidates(
descs,
training_vectors=centroids,
database_vectors=centroids,
query_vectors=query_vectors,
result_file=f"result_{centroids.get_filename()}json",
include_flat=True,
min_accuracy=min_accuracy,
pareto_metric=ParetoMetric.TIME,
)
quantizer_descs[nlist] = descs
return quantizer_descs
def optimize_ivf(
self,
result_file: str,
training_vectors: DatasetDescriptor,
database_vectors: DatasetDescriptor,
query_vectors: DatasetDescriptor,
quantizers: Dict[int, List[IndexDescriptor]],
codecs: List[Tuple[str, str]],
min_accuracy: float,
):
ivf_descs = []
for nlist, quantizer_descs in quantizers.items():
# build IVF index
for quantizer_desc in quantizer_descs:
for pretransform, fine_ivf in codecs:
if pretransform is None:
pretransform = ""
else:
pretransform = pretransform + ","
if quantizer_desc.construction_params is None:
construction_params = [
None,
quantizer_desc.search_params,
]
else:
construction_params = [
None
] + quantizer_desc.construction_params
if quantizer_desc.search_params is not None:
dict_merge(
construction_params[1],
quantizer_desc.search_params,
)
ivf_descs.append(
IndexDescriptor(
factory=f"{pretransform}IVF{nlist}({quantizer_desc.factory}),{fine_ivf}",
construction_params=construction_params,
)
)
return self.benchmark_and_filter_candidates(
ivf_descs,
training_vectors,
database_vectors,
query_vectors,
result_file,
include_flat=False,
min_accuracy=min_accuracy,
pareto_metric=ParetoMetric.TIME_SPACE,
)
# train an IVFFlat index
# find the nprobe required for the given accuracy
def ivf_flat_nprobe_required_for_accuracy(
self,
result_file: str,
training_vectors: DatasetDescriptor,
database_vectors: DatasetDescriptor,
query_vectors: DatasetDescriptor,
nlist,
accuracy,
):
_, results = self.benchmark_and_filter_candidates(
index_descs=[
IndexDescriptor(factory=f"IVF{nlist}(Flat),Flat"),
],
training_vectors=training_vectors,
database_vectors=database_vectors,
query_vectors=query_vectors,
result_file=result_file,
include_flat=False,
min_accuracy=accuracy,
pareto_metric=ParetoMetric.TIME,
)
nprobe = nlist // 2
for _, _, _, k, v in results:
if (
".knn" in k
and "nprobe" in v["search_params"]
and v["knn_intersection"] >= accuracy
):
nprobe = min(nprobe, v["search_params"]["nprobe"])
return nprobe
# train candidate IVF codecs
# benchmark them at the same nprobe
# keep only the space _and_ time Pareto optimal
def optimize_codec(
self,
result_file: str,
d: int,
training_vectors: DatasetDescriptor,
database_vectors: DatasetDescriptor,
query_vectors: DatasetDescriptor,
nlist: int,
nprobe: int,
min_accuracy: float,
):
codecs = (
[
(None, "Flat"),
(None, "SQfp16"),
(None, "SQ8"),
] + [
(f"OPQ{M}_{M * dim}", f"PQ{M}x{b}")
for M in [8, 12, 16, 32, 48, 64, 96, 128, 192, 256]
if d % M == 0
for dim in range(2, 18, 2)
if M * dim <= d
for b in range(4, 14, 2)
if M * b < d * 8 # smaller than SQ8
] + [
(None, f"PQ{M}x{b}")
for M in [8, 12, 16, 32, 48, 64, 96, 128, 192, 256]
if d % M == 0
for b in range(8, 14, 2)
if M * b < d * 8 # smaller than SQ8
]
)
factory = {}
for opq, pq in codecs:
factory[
f"IVF{nlist},{pq}" if opq is None else f"{opq},IVF{nlist},{pq}"
] = (
opq,
pq,
)
_, filtered = self.benchmark_and_filter_candidates(
index_descs=[
IndexDescriptor(
factory=f"IVF{nlist},{pq}"
if opq is None
else f"{opq},IVF{nlist},{pq}",
search_params={
"nprobe": nprobe,
},
)
for opq, pq in codecs
],
training_vectors=training_vectors,
database_vectors=database_vectors,
query_vectors=query_vectors,
result_file=result_file,
include_flat=False,
min_accuracy=min_accuracy,
pareto_metric=ParetoMetric.TIME_SPACE,
)
results = [
factory[r] for r in set(v["factory"] for _, _, _, k, v in filtered)
]
return results
def optimize(
self,
d: int,
training_vectors: DatasetDescriptor,
database_vectors_list: List[DatasetDescriptor],
query_vectors: DatasetDescriptor,
min_accuracy: float,
):
# train an IVFFlat index
# find the nprobe required for near perfect accuracy
nlist = 4096
nprobe_at_95 = self.ivf_flat_nprobe_required_for_accuracy(
result_file=f"result_ivf{nlist}_flat.json",
training_vectors=training_vectors,
database_vectors=database_vectors_list[0],
query_vectors=query_vectors,
nlist=nlist,
accuracy=0.95,
)
# train candidate IVF codecs
# benchmark them at the same nprobe
# keep only the space and time Pareto optima
codecs = self.optimize_codec(
result_file=f"result_ivf{nlist}_codec.json",
d=d,
training_vectors=training_vectors,
database_vectors=database_vectors_list[0],
query_vectors=query_vectors,
nlist=nlist,
nprobe=nprobe_at_95,
min_accuracy=min_accuracy,
)
# optimize coarse quantizers
quantizers = self.optimize_quantizer(
training_vectors=training_vectors,
query_vectors=query_vectors,
nlists=[4096, 8192, 16384, 32768],
min_accuracy=0.7,
)
# combine them with the codecs
# test them at different scales
for database_vectors in database_vectors_list:
self.optimize_ivf(
result_file=f"result_{database_vectors.get_filename()}json",
training_vectors=training_vectors,
database_vectors=database_vectors,
query_vectors=query_vectors,
quantizers=quantizers,
codecs=codecs,
min_accuracy=min_accuracy,
)

View File

@ -3,15 +3,22 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from time import perf_counter
import logging
from multiprocessing.pool import ThreadPool
import numpy as np
import faiss # @manual=//faiss/python:pyfaiss_gpu
import functools
import logging
from enum import Enum
from multiprocessing.pool import ThreadPool
from time import perf_counter
import faiss # @manual=//faiss/python:pyfaiss_gpu
import numpy as np
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
OperatingPoints,
)
logger = logging.getLogger(__name__)
def timer(name, func, once=False) -> float:
logger.info(f"Measuring {name}")
t1 = perf_counter()
@ -34,28 +41,41 @@ def timer(name, func, once=False) -> float:
def refine_distances_knn(
xq: np.ndarray, xb: np.ndarray, I: np.ndarray, metric,
xq: np.ndarray,
xb: np.ndarray,
I: np.ndarray,
metric,
):
""" Recompute distances between xq[i] and xb[I[i, :]] """
"""Recompute distances between xq[i] and xb[I[i, :]]"""
nq, k = I.shape
xq = np.ascontiguousarray(xq, dtype='float32')
xq = np.ascontiguousarray(xq, dtype="float32")
nq2, d = xq.shape
xb = np.ascontiguousarray(xb, dtype='float32')
xb = np.ascontiguousarray(xb, dtype="float32")
nb, d2 = xb.shape
I = np.ascontiguousarray(I, dtype='int64')
I = np.ascontiguousarray(I, dtype="int64")
assert nq2 == nq
assert d2 == d
D = np.empty(I.shape, dtype='float32')
D = np.empty(I.shape, dtype="float32")
D[:] = np.inf
if metric == faiss.METRIC_L2:
faiss.fvec_L2sqr_by_idx(
faiss.swig_ptr(D), faiss.swig_ptr(xq), faiss.swig_ptr(xb),
faiss.swig_ptr(I), d, nq, k
faiss.swig_ptr(D),
faiss.swig_ptr(xq),
faiss.swig_ptr(xb),
faiss.swig_ptr(I),
d,
nq,
k,
)
else:
faiss.fvec_inner_products_by_idx(
faiss.swig_ptr(D), faiss.swig_ptr(xq), faiss.swig_ptr(xb),
faiss.swig_ptr(I), d, nq, k
faiss.swig_ptr(D),
faiss.swig_ptr(xq),
faiss.swig_ptr(xb),
faiss.swig_ptr(I),
d,
nq,
k,
)
return D
@ -97,7 +117,10 @@ def distance_ratio_measure(I, R, D_GT, metric):
@functools.cache
def get_cpu_info():
return [l for l in open("/proc/cpuinfo", "r") if "model name" in l][0][13:].strip()
return [l for l in open("/proc/cpuinfo", "r") if "model name" in l][0][
13:
].strip()
def dict_merge(target, source):
for k, v in source.items():
@ -105,3 +128,121 @@ def dict_merge(target, source):
dict_merge(target[k], v)
else:
target[k] = v
class Cost:
def __init__(self, values):
self.values = values
def __le__(self, other):
return all(
v1 <= v2 for v1, v2 in zip(self.values, other.values, strict=True)
)
def __lt__(self, other):
return all(
v1 < v2 for v1, v2 in zip(self.values, other.values, strict=True)
)
class ParetoMode(Enum):
DISABLE = 1 # no Pareto filtering
INDEX = 2 # index-local optima
GLOBAL = 3 # global optima
class ParetoMetric(Enum):
TIME = 0 # time vs accuracy
SPACE = 1 # space vs accuracy
TIME_SPACE = 2 # (time, space) vs accuracy
def range_search_recall_at_precision(experiment, precision):
return round(
max(
r
for r, p in zip(
experiment["range_search_pr"]["recall"],
experiment["range_search_pr"]["precision"],
)
if p > precision
),
6,
)
def filter_results(
results,
evaluation,
accuracy_metric, # str or func
time_metric=None, # func or None -> use default
space_metric=None, # func or None -> use default
min_accuracy=0,
max_space=0,
max_time=0,
scaling_factor=1.0,
name_filter=None, # func
pareto_mode=ParetoMode.DISABLE,
pareto_metric=ParetoMetric.TIME,
):
if isinstance(accuracy_metric, str):
accuracy_key = accuracy_metric
accuracy_metric = lambda v: v[accuracy_key]
if time_metric is None:
time_metric = lambda v: v["time"] * scaling_factor + (
v["quantizer"]["time"] if "quantizer" in v else 0
)
if space_metric is None:
space_metric = lambda v: results["indices"][v["codec"]]["code_size"]
fe = []
ops = {}
if pareto_mode == ParetoMode.GLOBAL:
op = OperatingPoints()
ops["global"] = op
for k, v in results["experiments"].items():
if f".{evaluation}" in k:
accuracy = accuracy_metric(v)
if min_accuracy > 0 and accuracy < min_accuracy:
continue
space = space_metric(v)
if space is None:
space = 0
if max_space > 0 and space > max_space:
continue
time = time_metric(v)
if max_time > 0 and time > max_time:
continue
idx_name = v["index"] + (
"snap"
if "search_params" in v and v["search_params"]["snap"] == 1
else ""
)
if name_filter is not None and not name_filter(idx_name):
continue
experiment = (accuracy, space, time, k, v)
if pareto_mode == ParetoMode.DISABLE:
fe.append(experiment)
continue
if pareto_mode == ParetoMode.INDEX:
if idx_name not in ops:
ops[idx_name] = OperatingPoints()
op = ops[idx_name]
if pareto_metric == ParetoMetric.TIME:
op.add_operating_point(experiment, accuracy, time)
elif pareto_metric == ParetoMetric.SPACE:
op.add_operating_point(experiment, accuracy, space)
else:
op.add_operating_point(
experiment, accuracy, Cost([time, space])
)
if ops:
for op in ops.values():
for v, _, _ in op.operating_points:
fe.append(v)
fe.sort()
return fe

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import os
from bench_fw.benchmark_io import BenchmarkIO
from bench_fw.descriptors import DatasetDescriptor
from bench_fw.optimize import Optimizer
logging.basicConfig(level=logging.INFO)
def bigann(bio):
optimizer = Optimizer(
distance_metric="L2",
num_threads=32,
run_local=False,
)
optimizer.set_io(bio)
query_vectors = DatasetDescriptor(namespace="std_q", tablename="bigann1M")
xt = bio.get_dataset(query_vectors)
optimizer.optimize(
d=xt.shape[1],
training_vectors=DatasetDescriptor(
namespace="std_t",
tablename="bigann1M",
num_vectors=2_000_000,
),
database_vectors_list=[
DatasetDescriptor(
namespace="std_d",
tablename="bigann1M",
),
DatasetDescriptor(namespace="std_d", tablename="bigann10M"),
],
query_vectors=query_vectors,
min_accuracy=0.85,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("experiment")
parser.add_argument("path")
args = parser.parse_args()
assert os.path.exists(args.path)
path = os.path.join(args.path, args.experiment)
if not os.path.exists(path):
os.mkdir(path)
bio = BenchmarkIO(
path=path,
)
if args.experiment == "bigann":
bigann(bio)