index optimizer (#3154)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3154 Using the benchmark to find Pareto optimal indices, in this case on BigANN as an example. Separately optimize the coarse quantizer and the vector codec and use Pareto optimal configurations to construct IVF indices, which are then retested at various scales. See `optimize()` in `optimize.py` as the main function driving the process. The results can be interpreted with `bench_fw_notebook.ipynb`, which allows: * filtering by maximum code size * maximum time * minimum accuracy * space or time Pareto optimal options * and visualize the results and output them as a table. This version is intentionally limited to IVF(Flat|HNSW),PQ|SQ indices... Reviewed By: mdouze Differential Revision: D51781670 fbshipit-source-id: 2c0f800d374ea845255934f519cc28095c00a51fpull/3233/head
parent
75ae0bfb7f
commit
1d0e8d489f
|
@ -7,19 +7,20 @@ import logging
|
|||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from statistics import median, mean
|
||||
from statistics import mean, median
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .utils import dict_merge
|
||||
from .index import Index, IndexFromCodec, IndexFromFactory
|
||||
from .descriptors import DatasetDescriptor, IndexDescriptor
|
||||
|
||||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scipy.optimize import curve_fit
|
||||
|
||||
from .descriptors import DatasetDescriptor, IndexDescriptor
|
||||
from .index import Index, IndexFromCodec, IndexFromFactory
|
||||
|
||||
from .utils import dict_merge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -274,8 +275,8 @@ class Benchmark:
|
|||
search_parameters: Optional[Dict[str, int]],
|
||||
radius: Optional[float] = None,
|
||||
gt_radius: Optional[float] = None,
|
||||
range_search_metric_function = None,
|
||||
gt_rsm = None,
|
||||
range_search_metric_function=None,
|
||||
gt_rsm=None,
|
||||
):
|
||||
logger.info("range_search: begin")
|
||||
if radius is None:
|
||||
|
@ -328,7 +329,13 @@ class Benchmark:
|
|||
logger.info("knn_ground_truth: begin")
|
||||
flat_desc = self.get_index_desc("Flat")
|
||||
self.build_index_wrapper(flat_desc)
|
||||
self.gt_knn_D, self.gt_knn_I, _, _, requires = flat_desc.index.knn_search(
|
||||
(
|
||||
self.gt_knn_D,
|
||||
self.gt_knn_I,
|
||||
_,
|
||||
_,
|
||||
requires,
|
||||
) = flat_desc.index.knn_search(
|
||||
dry_run=False,
|
||||
search_parameters=None,
|
||||
query_vectors=self.query_vectors,
|
||||
|
@ -338,13 +345,13 @@ class Benchmark:
|
|||
logger.info("knn_ground_truth: end")
|
||||
|
||||
def search_benchmark(
|
||||
self,
|
||||
self,
|
||||
name,
|
||||
search_func,
|
||||
key_func,
|
||||
cost_metrics,
|
||||
perf_metrics,
|
||||
results: Dict[str, Any],
|
||||
results: Dict[str, Any],
|
||||
index: Index,
|
||||
):
|
||||
index_name = index.get_index_name()
|
||||
|
@ -376,11 +383,18 @@ class Benchmark:
|
|||
logger.info(f"{name}_benchmark: end")
|
||||
return results, requires
|
||||
|
||||
def knn_search_benchmark(self, dry_run, results: Dict[str, Any], index: Index):
|
||||
def knn_search_benchmark(
|
||||
self, dry_run, results: Dict[str, Any], index: Index
|
||||
):
|
||||
return self.search_benchmark(
|
||||
name="knn_search",
|
||||
search_func=lambda parameters: index.knn_search(
|
||||
dry_run, parameters, self.query_vectors, self.k, self.gt_knn_I, self.gt_knn_D,
|
||||
dry_run,
|
||||
parameters,
|
||||
self.query_vectors,
|
||||
self.k,
|
||||
self.gt_knn_I,
|
||||
self.gt_knn_D,
|
||||
)[3:],
|
||||
key_func=lambda parameters: index.get_knn_search_name(
|
||||
search_parameters=parameters,
|
||||
|
@ -394,11 +408,17 @@ class Benchmark:
|
|||
index=index,
|
||||
)
|
||||
|
||||
def reconstruct_benchmark(self, dry_run, results: Dict[str, Any], index: Index):
|
||||
def reconstruct_benchmark(
|
||||
self, dry_run, results: Dict[str, Any], index: Index
|
||||
):
|
||||
return self.search_benchmark(
|
||||
name="reconstruct",
|
||||
search_func=lambda parameters: index.reconstruct(
|
||||
dry_run, parameters, self.query_vectors, self.k, self.gt_knn_I,
|
||||
dry_run,
|
||||
parameters,
|
||||
self.query_vectors,
|
||||
self.k,
|
||||
self.gt_knn_I,
|
||||
),
|
||||
key_func=lambda parameters: index.get_knn_search_name(
|
||||
search_parameters=parameters,
|
||||
|
@ -426,19 +446,20 @@ class Benchmark:
|
|||
return self.search_benchmark(
|
||||
name="range_search",
|
||||
search_func=lambda parameters: self.range_search(
|
||||
dry_run=dry_run,
|
||||
index=index,
|
||||
search_parameters=parameters,
|
||||
dry_run=dry_run,
|
||||
index=index,
|
||||
search_parameters=parameters,
|
||||
radius=radius,
|
||||
gt_radius=gt_radius,
|
||||
range_search_metric_function=range_search_metric_function,
|
||||
range_search_metric_function=range_search_metric_function,
|
||||
gt_rsm=gt_rsm,
|
||||
)[4:],
|
||||
key_func=lambda parameters: index.get_range_search_name(
|
||||
search_parameters=parameters,
|
||||
query_vectors=self.query_vectors,
|
||||
radius=radius,
|
||||
) + metric_key,
|
||||
)
|
||||
+ metric_key,
|
||||
cost_metrics=["time"],
|
||||
perf_metrics=["range_score_max_recall"],
|
||||
results=results,
|
||||
|
@ -446,11 +467,12 @@ class Benchmark:
|
|||
)
|
||||
|
||||
def build_index_wrapper(self, index_desc: IndexDescriptor):
|
||||
if hasattr(index_desc, 'index'):
|
||||
if hasattr(index_desc, "index"):
|
||||
return
|
||||
if index_desc.factory is not None:
|
||||
training_vectors = copy(self.training_vectors)
|
||||
training_vectors.num_vectors = index_desc.training_size
|
||||
if index_desc.training_size is not None:
|
||||
training_vectors.num_vectors = index_desc.training_size
|
||||
index = IndexFromFactory(
|
||||
num_threads=self.num_threads,
|
||||
d=self.d,
|
||||
|
@ -481,15 +503,24 @@ class Benchmark:
|
|||
training_vectors=self.training_vectors,
|
||||
database_vectors=self.database_vectors,
|
||||
query_vectors=self.query_vectors,
|
||||
index_descs = [self.get_index_desc("Flat"), index_desc],
|
||||
index_descs=[self.get_index_desc("Flat"), index_desc],
|
||||
range_ref_index_desc=self.range_ref_index_desc,
|
||||
k=self.k,
|
||||
distance_metric=self.distance_metric,
|
||||
)
|
||||
benchmark.set_io(self.io)
|
||||
benchmark.set_io(self.io.clone())
|
||||
return benchmark
|
||||
|
||||
def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescriptor, train, reconstruct, knn, range):
|
||||
def benchmark_one(
|
||||
self,
|
||||
dry_run,
|
||||
results: Dict[str, Any],
|
||||
index_desc: IndexDescriptor,
|
||||
train,
|
||||
reconstruct,
|
||||
knn,
|
||||
range,
|
||||
):
|
||||
faiss.omp_set_num_threads(self.num_threads)
|
||||
if not dry_run:
|
||||
self.knn_ground_truth()
|
||||
|
@ -531,9 +562,12 @@ class Benchmark:
|
|||
)
|
||||
assert requires is None
|
||||
|
||||
if self.range_ref_index_desc is None or not index_desc.index.supports_range_search():
|
||||
if (
|
||||
self.range_ref_index_desc is None
|
||||
or not index_desc.index.supports_range_search()
|
||||
):
|
||||
return results, None
|
||||
|
||||
|
||||
ref_index_desc = self.get_index_desc(self.range_ref_index_desc)
|
||||
if ref_index_desc is None:
|
||||
raise ValueError(
|
||||
|
@ -550,7 +584,9 @@ class Benchmark:
|
|||
coefficients,
|
||||
coefficients_training_data,
|
||||
) = self.range_search_reference(
|
||||
ref_index_desc.index, ref_index_desc.search_params, range_metric
|
||||
ref_index_desc.index,
|
||||
ref_index_desc.search_params,
|
||||
range_metric,
|
||||
)
|
||||
gt_rsm = self.range_ground_truth(
|
||||
gt_radius, range_search_metric_function
|
||||
|
@ -583,7 +619,15 @@ class Benchmark:
|
|||
|
||||
return results, None
|
||||
|
||||
def benchmark(self, result_file=None, local=False, train=False, reconstruct=False, knn=False, range=False):
|
||||
def benchmark(
|
||||
self,
|
||||
result_file=None,
|
||||
local=False,
|
||||
train=False,
|
||||
reconstruct=False,
|
||||
knn=False,
|
||||
range=False,
|
||||
):
|
||||
logger.info("begin evaluate")
|
||||
|
||||
faiss.omp_set_num_threads(self.num_threads)
|
||||
|
@ -656,20 +700,34 @@ class Benchmark:
|
|||
|
||||
if current_todo:
|
||||
results_one = {"indices": {}, "experiments": {}}
|
||||
params = [(self.clone_one(index_desc), results_one, index_desc, train, reconstruct, knn, range) for index_desc in current_todo]
|
||||
for result in self.io.launch_jobs(run_benchmark_one, params, local=local):
|
||||
params = [
|
||||
(
|
||||
index_desc,
|
||||
self.clone_one(index_desc),
|
||||
results_one,
|
||||
train,
|
||||
reconstruct,
|
||||
knn,
|
||||
range,
|
||||
)
|
||||
for index_desc in current_todo
|
||||
]
|
||||
for result in self.io.launch_jobs(
|
||||
run_benchmark_one, params, local=local
|
||||
):
|
||||
dict_merge(results, result)
|
||||
|
||||
todo = next_todo
|
||||
todo = next_todo
|
||||
|
||||
if result_file is not None:
|
||||
self.io.write_json(results, result_file, overwrite=True)
|
||||
logger.info("end evaluate")
|
||||
return results
|
||||
|
||||
|
||||
def run_benchmark_one(params):
|
||||
logger.info(params)
|
||||
benchmark, results, index_desc, train, reconstruct, knn, range = params
|
||||
index_desc, benchmark, results, train, reconstruct, knn, range = params
|
||||
results, requires = benchmark.benchmark_one(
|
||||
dry_run=False,
|
||||
results=results,
|
||||
|
|
|
@ -10,13 +10,13 @@ import logging
|
|||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
import submitit
|
||||
from typing import Any, List, Optional
|
||||
from zipfile import ZipFile
|
||||
|
||||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
|
||||
import numpy as np
|
||||
import submitit
|
||||
from faiss.contrib.datasets import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
dataset_from_name,
|
||||
)
|
||||
|
@ -47,6 +47,9 @@ def merge_rcq_itq(
|
|||
class BenchmarkIO:
|
||||
path: str
|
||||
|
||||
def clone(self):
|
||||
return BenchmarkIO(path=self.path)
|
||||
|
||||
def __post_init__(self):
|
||||
self.cached_ds = {}
|
||||
|
||||
|
@ -119,18 +122,27 @@ class BenchmarkIO:
|
|||
|
||||
def get_dataset(self, dataset):
|
||||
if dataset not in self.cached_ds:
|
||||
if dataset.namespace is not None and dataset.namespace[:4] == "std_":
|
||||
if (
|
||||
dataset.namespace is not None
|
||||
and dataset.namespace[:4] == "std_"
|
||||
):
|
||||
if dataset.tablename not in self.cached_ds:
|
||||
self.cached_ds[dataset.tablename] = dataset_from_name(
|
||||
dataset.tablename,
|
||||
)
|
||||
p = dataset.namespace[4]
|
||||
if p == "t":
|
||||
self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_train(dataset.num_vectors)
|
||||
self.cached_ds[dataset] = self.cached_ds[
|
||||
dataset.tablename
|
||||
].get_train(dataset.num_vectors)
|
||||
elif p == "d":
|
||||
self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_database()
|
||||
self.cached_ds[dataset] = self.cached_ds[
|
||||
dataset.tablename
|
||||
].get_database()
|
||||
elif p == "q":
|
||||
self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_queries()
|
||||
self.cached_ds[dataset] = self.cached_ds[
|
||||
dataset.tablename
|
||||
].get_queries()
|
||||
else:
|
||||
raise ValueError
|
||||
elif dataset.namespace == "syn":
|
||||
|
@ -233,8 +245,8 @@ class BenchmarkIO:
|
|||
if local:
|
||||
results = [func(p) for p in params]
|
||||
return results
|
||||
print(f'launching {len(params)} jobs')
|
||||
executor = submitit.AutoExecutor(folder='/checkpoint/gsz/jobs')
|
||||
logger.info(f"launching {len(params)} jobs")
|
||||
executor = submitit.AutoExecutor(folder="/checkpoint/gsz/jobs")
|
||||
executor.update_parameters(
|
||||
nodes=1,
|
||||
gpus_per_node=8,
|
||||
|
@ -248,9 +260,9 @@ class BenchmarkIO:
|
|||
slurm_constraint="bldg1",
|
||||
)
|
||||
jobs = executor.map_array(func, params)
|
||||
print(f'launched {len(jobs)} jobs')
|
||||
# for job, param in zip(jobs, params):
|
||||
# print(f"{job.job_id=} {param=}")
|
||||
logger.info(f"launched {len(jobs)} jobs")
|
||||
for job, param in zip(jobs, params):
|
||||
logger.info(f"{job.job_id=} {param[0]=}")
|
||||
results = [job.result() for job in jobs]
|
||||
print(f'received {len(results)} results')
|
||||
print(f"received {len(results)} results")
|
||||
return results
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
from .utils import timer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -101,7 +102,9 @@ class DatasetDescriptor:
|
|||
tablename=f"{self.get_filename()}kmeans_{k}.npy"
|
||||
)
|
||||
meta_filename = kmeans_vectors.tablename + ".json"
|
||||
if not io.file_exist(kmeans_vectors.tablename) or not io.file_exist(meta_filename):
|
||||
if not io.file_exist(kmeans_vectors.tablename) or not io.file_exist(
|
||||
meta_filename
|
||||
):
|
||||
if dry_run:
|
||||
return None, None, kmeans_vectors.tablename
|
||||
x = io.get_dataset(self)
|
||||
|
|
|
@ -4,19 +4,19 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from copy import copy
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Dict, List, Optional
|
||||
|
||||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
|
||||
import numpy as np
|
||||
|
||||
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
OperatingPointsWithRanges,
|
||||
knn_intersection_measure,
|
||||
OperatingPointsWithRanges,
|
||||
)
|
||||
from faiss.contrib.factory_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
reverse_index_factory,
|
||||
|
@ -27,7 +27,13 @@ from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_g
|
|||
)
|
||||
|
||||
from .descriptors import DatasetDescriptor
|
||||
from .utils import distance_ratio_measure, get_cpu_info, timer, refine_distances_knn, refine_distances_range
|
||||
from .utils import (
|
||||
distance_ratio_measure,
|
||||
get_cpu_info,
|
||||
refine_distances_knn,
|
||||
refine_distances_range,
|
||||
timer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -106,7 +112,9 @@ class IndexBase:
|
|||
icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
|
||||
if isinstance(index, faiss.IndexProductLocalSearchQuantizer):
|
||||
for i in range(index.plsq.nsplits):
|
||||
lsq = faiss.downcast_Quantizer(index.plsq.subquantizer(i))
|
||||
lsq = faiss.downcast_Quantizer(
|
||||
index.plsq.subquantizer(i)
|
||||
)
|
||||
if lsq.icm_encoder_factory is None:
|
||||
lsq.icm_encoder_factory = icm_encoder_factory
|
||||
else:
|
||||
|
@ -119,29 +127,39 @@ class IndexBase:
|
|||
obj = faiss.extract_index_ivf(index)
|
||||
elif name in ["use_beam_LUT", "max_beam_size"]:
|
||||
if isinstance(index, faiss.IndexProductResidualQuantizer):
|
||||
obj = [faiss.downcast_Quantizer(index.prq.subquantizer(i)) for i in range(index.prq.nsplits)]
|
||||
obj = [
|
||||
faiss.downcast_Quantizer(index.prq.subquantizer(i))
|
||||
for i in range(index.prq.nsplits)
|
||||
]
|
||||
else:
|
||||
obj = index.rq
|
||||
elif name == "encode_ils_iters":
|
||||
if isinstance(index, faiss.IndexProductLocalSearchQuantizer):
|
||||
obj = [faiss.downcast_Quantizer(index.plsq.subquantizer(i)) for i in range(index.plsq.nsplits)]
|
||||
obj = [
|
||||
faiss.downcast_Quantizer(index.plsq.subquantizer(i))
|
||||
for i in range(index.plsq.nsplits)
|
||||
]
|
||||
else:
|
||||
obj = index.lsq
|
||||
else:
|
||||
obj = index
|
||||
|
||||
|
||||
if not isinstance(obj, list):
|
||||
obj = [obj]
|
||||
for o in obj:
|
||||
test = getattr(o, name)
|
||||
if assert_same and not name == 'use_beam_LUT':
|
||||
if assert_same and not name == "use_beam_LUT":
|
||||
assert test == val
|
||||
else:
|
||||
setattr(o, name, val)
|
||||
|
||||
@staticmethod
|
||||
def filter_index_param_dict_list(param_dict_list):
|
||||
if param_dict_list is not None and param_dict_list[0] is not None and "k_factor" in param_dict_list[0]:
|
||||
if (
|
||||
param_dict_list is not None
|
||||
and param_dict_list[0] is not None
|
||||
and "k_factor" in param_dict_list[0]
|
||||
):
|
||||
filtered = copy(param_dict_list)
|
||||
del filtered[0]["k_factor"]
|
||||
return filtered
|
||||
|
@ -153,6 +171,7 @@ class IndexBase:
|
|||
return isinstance(model, faiss.IndexFlat)
|
||||
|
||||
def is_ivf(self):
|
||||
return False
|
||||
model = self.get_model()
|
||||
return faiss.try_extract_index_ivf(model) is not None
|
||||
|
||||
|
@ -243,7 +262,9 @@ class IndexBase:
|
|||
pretransform = None
|
||||
quantizer_query_vectors = query_vectors
|
||||
|
||||
quantizer, _, _ = self.get_quantizer(dry_run=False, pretransform=pretransform)
|
||||
quantizer, _, _ = self.get_quantizer(
|
||||
dry_run=False, pretransform=pretransform
|
||||
)
|
||||
QD, QI, _, QP, _ = quantizer.knn_search(
|
||||
dry_run=False,
|
||||
search_parameters=None,
|
||||
|
@ -300,7 +321,9 @@ class IndexBase:
|
|||
# Index2Layer doesn't support search
|
||||
xq = self.io.get_dataset(query_vectors)
|
||||
xb = index.reconstruct_n(0, index.ntotal)
|
||||
(D, I), t, _ = timer("knn_search 2layer", lambda: faiss.knn(xq, xb, k))
|
||||
(D, I), t, _ = timer(
|
||||
"knn_search 2layer", lambda: faiss.knn(xq, xb, k)
|
||||
)
|
||||
elif self.is_ivf() and not isinstance(index, faiss.IndexRefine):
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
nprobe = (
|
||||
|
@ -310,7 +333,7 @@ class IndexBase:
|
|||
else index_ivf.nprobe
|
||||
)
|
||||
xqt, QD, QI, QP = self.knn_search_quantizer(
|
||||
query_vectors=query_vectors,
|
||||
query_vectors=query_vectors,
|
||||
k=nprobe,
|
||||
)
|
||||
if index_ivf.parallel_mode != 2:
|
||||
|
@ -358,11 +381,19 @@ class IndexBase:
|
|||
"construction_params": self.get_construction_params(),
|
||||
"search_params": search_parameters,
|
||||
"knn_intersection": knn_intersection_measure(
|
||||
I, I_gt,
|
||||
) if I_gt is not None else None,
|
||||
I,
|
||||
I_gt,
|
||||
)
|
||||
if I_gt is not None
|
||||
else None,
|
||||
"distance_ratio": distance_ratio_measure(
|
||||
I, R, D_gt, self.metric_type,
|
||||
) if D_gt is not None else None,
|
||||
I,
|
||||
R,
|
||||
D_gt,
|
||||
self.metric_type,
|
||||
)
|
||||
if D_gt is not None
|
||||
else None,
|
||||
}
|
||||
logger.info("knn_search: end")
|
||||
return D, I, R, P, None
|
||||
|
@ -377,12 +408,14 @@ class IndexBase:
|
|||
):
|
||||
logger.info("reconstruct: begin")
|
||||
filename = (
|
||||
self.get_knn_search_name(parameters, query_vectors, k, reconstruct=True)
|
||||
self.get_knn_search_name(
|
||||
parameters, query_vectors, k, reconstruct=True
|
||||
)
|
||||
+ "zip"
|
||||
)
|
||||
if self.io.file_exist(filename):
|
||||
logger.info(f"Using cached results for {filename}")
|
||||
P, = self.io.read_file(filename, ["P"])
|
||||
(P,) = self.io.read_file(filename, ["P"])
|
||||
P["index"] = self.get_index_name()
|
||||
P["codec"] = self.get_codec_name()
|
||||
P["factory"] = self.get_model_name()
|
||||
|
@ -395,15 +428,21 @@ class IndexBase:
|
|||
codec_meta = self.fetch_meta()
|
||||
Index.set_index_param_dict(codec, parameters)
|
||||
xb = self.io.get_dataset(self.database_vectors)
|
||||
xb_encoded, encode_t, _ = timer("sa_encode", lambda: codec.sa_encode(xb))
|
||||
xb_encoded, encode_t, _ = timer(
|
||||
"sa_encode", lambda: codec.sa_encode(xb)
|
||||
)
|
||||
xq = self.io.get_dataset(query_vectors)
|
||||
if self.is_decode_supported():
|
||||
xb_decoded, decode_t, _ = timer("sa_decode", lambda: codec.sa_decode(xb_encoded))
|
||||
xb_decoded, decode_t, _ = timer(
|
||||
"sa_decode", lambda: codec.sa_decode(xb_encoded)
|
||||
)
|
||||
mse = np.square(xb_decoded - xb).sum(axis=1).mean().item()
|
||||
_, I = faiss.knn(xq, xb_decoded, k, metric=self.metric_type)
|
||||
asym_recall = knn_intersection_measure(I, I_gt)
|
||||
xq_decoded = codec.sa_decode(codec.sa_encode(xq))
|
||||
_, I = faiss.knn(xq_decoded, xb_decoded, k, metric=self.metric_type)
|
||||
_, I = faiss.knn(
|
||||
xq_decoded, xb_decoded, k, metric=self.metric_type
|
||||
)
|
||||
else:
|
||||
mse = None
|
||||
asym_recall = None
|
||||
|
@ -604,7 +643,7 @@ class Index(IndexBase):
|
|||
|
||||
if self.is_ivf() and not isinstance(index, faiss.IndexRefine):
|
||||
xbt, QD, QI, QP = self.knn_search_quantizer(
|
||||
query_vectors=self.database_vectors,
|
||||
query_vectors=self.database_vectors,
|
||||
k=1,
|
||||
)
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
|
@ -638,22 +677,21 @@ class Index(IndexBase):
|
|||
def get_construction_params(self):
|
||||
return self.construction_params
|
||||
|
||||
# def get_code_size(self):
|
||||
# def get_index_code_size(index):
|
||||
# index = faiss.downcast_index(index)
|
||||
# if isinstance(index, faiss.IndexPreTransform):
|
||||
# return get_index_code_size(index.index)
|
||||
# elif isinstance(index, faiss.IndexHNSWFlat):
|
||||
# return index.d * 4 # TODO
|
||||
# elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]:
|
||||
# return get_index_code_size(
|
||||
# index.base_index
|
||||
# ) + get_index_code_size(index.refine_index)
|
||||
# else:
|
||||
# return index.code_size
|
||||
def get_code_size(self, codec=None):
|
||||
def get_index_code_size(index):
|
||||
index = faiss.downcast_index(index)
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
return get_index_code_size(index.index)
|
||||
elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]:
|
||||
return get_index_code_size(
|
||||
index.base_index
|
||||
) + get_index_code_size(index.refine_index)
|
||||
else:
|
||||
return index.code_size if hasattr(index, "code_size") else 0
|
||||
|
||||
# codec = self.get_codec()
|
||||
# return get_index_code_size(codec)
|
||||
if codec is None:
|
||||
codec = self.get_codec()
|
||||
return get_index_code_size(codec)
|
||||
|
||||
def get_sa_code_size(self, codec=None):
|
||||
if codec is None:
|
||||
|
@ -680,32 +718,28 @@ class Index(IndexBase):
|
|||
if model_ivf is not None:
|
||||
add_range_or_val(
|
||||
"nprobe",
|
||||
# [
|
||||
[2**i for i in range(12) if 2**i <= model_ivf.nlist * 0.5],
|
||||
# [1, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28] + [
|
||||
# i
|
||||
# for i in range(32, 64, 8)
|
||||
# if i <= model_ivf.nlist * 0.1
|
||||
# ] + [
|
||||
# i
|
||||
# for i in range(64, 128, 16)
|
||||
# if i <= model_ivf.nlist * 0.1
|
||||
# ] + [
|
||||
# i
|
||||
# for i in range(128, 256, 32)
|
||||
# if i <= model_ivf.nlist * 0.1
|
||||
# ] + [
|
||||
# i
|
||||
# for i in range(256, 512, 64)
|
||||
# if i <= model_ivf.nlist * 0.1
|
||||
# ] + [
|
||||
# 2**i
|
||||
# for i in range(12)
|
||||
# if 2**i <= model_ivf.nlist * 0.5
|
||||
# for i in range(9, 12)
|
||||
# if 2**i <= model_ivf.nlist * 0.1
|
||||
# ],
|
||||
[1, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28] + [
|
||||
i
|
||||
for i in range(32, 64, 8)
|
||||
if i <= model_ivf.nlist * 0.1
|
||||
] + [
|
||||
i
|
||||
for i in range(64, 128, 16)
|
||||
if i <= model_ivf.nlist * 0.1
|
||||
] + [
|
||||
i
|
||||
for i in range(128, 256, 32)
|
||||
if i <= model_ivf.nlist * 0.1
|
||||
] + [
|
||||
i
|
||||
for i in range(256, 512, 64)
|
||||
if i <= model_ivf.nlist * 0.1
|
||||
] + [
|
||||
2**i
|
||||
for i in range(9, 12)
|
||||
if 2**i <= model_ivf.nlist * 0.1
|
||||
],
|
||||
)
|
||||
model = faiss.downcast_index(model)
|
||||
if isinstance(model, faiss.IndexRefine):
|
||||
|
@ -718,7 +752,9 @@ class Index(IndexBase):
|
|||
"efSearch",
|
||||
[2**i for i in range(3, 11)],
|
||||
)
|
||||
elif isinstance(model, faiss.IndexResidualQuantizer) or isinstance(model, faiss.IndexProductResidualQuantizer):
|
||||
elif isinstance(model, faiss.IndexResidualQuantizer) or isinstance(
|
||||
model, faiss.IndexProductResidualQuantizer
|
||||
):
|
||||
add_range_or_val(
|
||||
"max_beam_size",
|
||||
[1, 2, 4, 8, 16, 32],
|
||||
|
@ -727,7 +763,9 @@ class Index(IndexBase):
|
|||
"use_beam_LUT",
|
||||
[1],
|
||||
)
|
||||
elif isinstance(model, faiss.IndexLocalSearchQuantizer) or isinstance(model, faiss.IndexProductLocalSearchQuantizer):
|
||||
elif isinstance(model, faiss.IndexLocalSearchQuantizer) or isinstance(
|
||||
model, faiss.IndexProductLocalSearchQuantizer
|
||||
):
|
||||
add_range_or_val(
|
||||
"encode_ils_iters",
|
||||
[2, 4, 8, 16],
|
||||
|
@ -854,7 +892,9 @@ class IndexFromFactory(Index):
|
|||
def fetch_codec(self, dry_run=False):
|
||||
codec_filename = self.get_codec_name() + "codec"
|
||||
meta_filename = self.get_codec_name() + "json"
|
||||
if self.io.file_exist(codec_filename) and self.io.file_exist(meta_filename):
|
||||
if self.io.file_exist(codec_filename) and self.io.file_exist(
|
||||
meta_filename
|
||||
):
|
||||
codec = self.io.read_index(codec_filename)
|
||||
assert self.d == codec.d
|
||||
assert self.metric_type == codec.metric_type
|
||||
|
@ -874,6 +914,7 @@ class IndexFromFactory(Index):
|
|||
"training_size": self.training_vectors.num_vectors,
|
||||
"codec_size": codec_size,
|
||||
"sa_code_size": self.get_sa_code_size(codec),
|
||||
"code_size": self.get_code_size(codec),
|
||||
"cpu": get_cpu_info(),
|
||||
}
|
||||
self.io.write_json(meta, meta_filename, overwrite=True)
|
||||
|
@ -921,7 +962,9 @@ class IndexFromFactory(Index):
|
|||
training_vectors = self.training_vectors
|
||||
else:
|
||||
training_vectors = pretransform.transform(self.training_vectors)
|
||||
centroids, t, requires = training_vectors.k_means(self.io, model_ivf.nlist, dry_run)
|
||||
centroids, t, requires = training_vectors.k_means(
|
||||
self.io, model_ivf.nlist, dry_run
|
||||
)
|
||||
if requires is not None:
|
||||
return None, None, requires
|
||||
quantizer = IndexFromFactory(
|
||||
|
@ -944,11 +987,11 @@ class IndexFromFactory(Index):
|
|||
model = self.get_model()
|
||||
opaque = True
|
||||
t_aggregate = 0
|
||||
try:
|
||||
reverse_index_factory(model)
|
||||
opaque = False
|
||||
except NotImplementedError:
|
||||
opaque = True
|
||||
# try:
|
||||
# reverse_index_factory(model)
|
||||
# opaque = False
|
||||
# except NotImplementedError:
|
||||
# opaque = True
|
||||
if opaque:
|
||||
codec = model
|
||||
else:
|
||||
|
@ -958,7 +1001,9 @@ class IndexFromFactory(Index):
|
|||
if not isinstance(sub_index, faiss.IndexFlat):
|
||||
# replace the sub-index with Flat and fetch pre-trained
|
||||
pretransform = self.get_pretransform()
|
||||
codec, meta, report = pretransform.fetch_codec(dry_run=dry_run)
|
||||
codec, meta, report = pretransform.fetch_codec(
|
||||
dry_run=dry_run
|
||||
)
|
||||
if report is not None:
|
||||
return None, None, report
|
||||
t_aggregate += meta["training_time"]
|
||||
|
@ -978,7 +1023,9 @@ class IndexFromFactory(Index):
|
|||
training_vectors=transformed_training_vectors,
|
||||
)
|
||||
wrapper.set_io(self.io)
|
||||
codec.index, meta, report = wrapper.fetch_codec(dry_run=dry_run)
|
||||
codec.index, meta, report = wrapper.fetch_codec(
|
||||
dry_run=dry_run
|
||||
)
|
||||
if report is not None:
|
||||
return None, None, report
|
||||
t_aggregate += meta["training_time"]
|
||||
|
@ -1008,14 +1055,18 @@ class IndexFromFactory(Index):
|
|||
d=model.base_index.d,
|
||||
metric=model.base_index.metric_type,
|
||||
database_vectors=self.database_vectors,
|
||||
construction_params=IndexBase.filter_index_param_dict_list(self.construction_params),
|
||||
construction_params=IndexBase.filter_index_param_dict_list(
|
||||
self.construction_params
|
||||
),
|
||||
search_params=None,
|
||||
factory=reverse_index_factory(model.base_index),
|
||||
training_vectors=self.training_vectors,
|
||||
)
|
||||
wrapper.set_io(self.io)
|
||||
codec = faiss.clone_index(model)
|
||||
codec.base_index, meta, requires = wrapper.fetch_codec(dry_run=dry_run)
|
||||
codec.base_index, meta, requires = wrapper.fetch_codec(
|
||||
dry_run=dry_run
|
||||
)
|
||||
if requires is not None:
|
||||
return None, None, requires
|
||||
t_aggregate += meta["training_time"]
|
||||
|
|
|
@ -0,0 +1,333 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
|
||||
# from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
# OperatingPoints,
|
||||
# )
|
||||
|
||||
from .benchmark import Benchmark
|
||||
from .descriptors import DatasetDescriptor, IndexDescriptor
|
||||
from .utils import dict_merge, filter_results, ParetoMetric, ParetoMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Optimizer:
|
||||
distance_metric: str = "L2"
|
||||
num_threads: int = 32
|
||||
run_local: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.cached_benchmark = None
|
||||
if self.distance_metric == "IP":
|
||||
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
|
||||
elif self.distance_metric == "L2":
|
||||
self.distance_metric_type = faiss.METRIC_L2
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def set_io(self, benchmark_io):
|
||||
self.io = benchmark_io
|
||||
self.io.distance_metric = self.distance_metric
|
||||
self.io.distance_metric_type = self.distance_metric_type
|
||||
|
||||
def benchmark_and_filter_candidates(
|
||||
self,
|
||||
index_descs,
|
||||
training_vectors,
|
||||
database_vectors,
|
||||
query_vectors,
|
||||
result_file,
|
||||
include_flat,
|
||||
min_accuracy,
|
||||
pareto_metric,
|
||||
):
|
||||
benchmark = Benchmark(
|
||||
num_threads=self.num_threads,
|
||||
training_vectors=training_vectors,
|
||||
database_vectors=database_vectors,
|
||||
query_vectors=query_vectors,
|
||||
index_descs=index_descs,
|
||||
k=10,
|
||||
distance_metric=self.distance_metric,
|
||||
)
|
||||
benchmark.set_io(self.io)
|
||||
results = benchmark.benchmark(
|
||||
result_file=result_file, local=self.run_local, train=True, knn=True
|
||||
)
|
||||
assert results
|
||||
filtered = filter_results(
|
||||
results=results,
|
||||
evaluation="knn",
|
||||
accuracy_metric="knn_intersection",
|
||||
min_accuracy=min_accuracy,
|
||||
name_filter=None
|
||||
if include_flat
|
||||
else (lambda n: not n.startswith("Flat")),
|
||||
pareto_mode=ParetoMode.GLOBAL,
|
||||
pareto_metric=pareto_metric,
|
||||
)
|
||||
assert filtered
|
||||
index_descs = [
|
||||
IndexDescriptor(
|
||||
factory=v["factory"],
|
||||
construction_params=v["construction_params"],
|
||||
search_params=v["search_params"],
|
||||
)
|
||||
for _, _, _, _, v in filtered
|
||||
]
|
||||
return index_descs, filtered
|
||||
|
||||
def optimize_quantizer(
|
||||
self,
|
||||
training_vectors: DatasetDescriptor,
|
||||
query_vectors: DatasetDescriptor,
|
||||
nlists: List[int],
|
||||
min_accuracy: float,
|
||||
):
|
||||
quantizer_descs = {}
|
||||
for nlist in nlists:
|
||||
# cluster
|
||||
centroids, _, _ = training_vectors.k_means(
|
||||
self.io,
|
||||
nlist,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
descs = [IndexDescriptor(factory="Flat"),] + [
|
||||
IndexDescriptor(
|
||||
factory="HNSW32",
|
||||
construction_params=[{"efConstruction": 2**i}],
|
||||
)
|
||||
for i in range(6, 11)
|
||||
]
|
||||
|
||||
descs, _ = self.benchmark_and_filter_candidates(
|
||||
descs,
|
||||
training_vectors=centroids,
|
||||
database_vectors=centroids,
|
||||
query_vectors=query_vectors,
|
||||
result_file=f"result_{centroids.get_filename()}json",
|
||||
include_flat=True,
|
||||
min_accuracy=min_accuracy,
|
||||
pareto_metric=ParetoMetric.TIME,
|
||||
)
|
||||
quantizer_descs[nlist] = descs
|
||||
|
||||
return quantizer_descs
|
||||
|
||||
def optimize_ivf(
|
||||
self,
|
||||
result_file: str,
|
||||
training_vectors: DatasetDescriptor,
|
||||
database_vectors: DatasetDescriptor,
|
||||
query_vectors: DatasetDescriptor,
|
||||
quantizers: Dict[int, List[IndexDescriptor]],
|
||||
codecs: List[Tuple[str, str]],
|
||||
min_accuracy: float,
|
||||
):
|
||||
ivf_descs = []
|
||||
for nlist, quantizer_descs in quantizers.items():
|
||||
# build IVF index
|
||||
for quantizer_desc in quantizer_descs:
|
||||
for pretransform, fine_ivf in codecs:
|
||||
if pretransform is None:
|
||||
pretransform = ""
|
||||
else:
|
||||
pretransform = pretransform + ","
|
||||
if quantizer_desc.construction_params is None:
|
||||
construction_params = [
|
||||
None,
|
||||
quantizer_desc.search_params,
|
||||
]
|
||||
else:
|
||||
construction_params = [
|
||||
None
|
||||
] + quantizer_desc.construction_params
|
||||
if quantizer_desc.search_params is not None:
|
||||
dict_merge(
|
||||
construction_params[1],
|
||||
quantizer_desc.search_params,
|
||||
)
|
||||
ivf_descs.append(
|
||||
IndexDescriptor(
|
||||
factory=f"{pretransform}IVF{nlist}({quantizer_desc.factory}),{fine_ivf}",
|
||||
construction_params=construction_params,
|
||||
)
|
||||
)
|
||||
return self.benchmark_and_filter_candidates(
|
||||
ivf_descs,
|
||||
training_vectors,
|
||||
database_vectors,
|
||||
query_vectors,
|
||||
result_file,
|
||||
include_flat=False,
|
||||
min_accuracy=min_accuracy,
|
||||
pareto_metric=ParetoMetric.TIME_SPACE,
|
||||
)
|
||||
|
||||
# train an IVFFlat index
|
||||
# find the nprobe required for the given accuracy
|
||||
def ivf_flat_nprobe_required_for_accuracy(
|
||||
self,
|
||||
result_file: str,
|
||||
training_vectors: DatasetDescriptor,
|
||||
database_vectors: DatasetDescriptor,
|
||||
query_vectors: DatasetDescriptor,
|
||||
nlist,
|
||||
accuracy,
|
||||
):
|
||||
_, results = self.benchmark_and_filter_candidates(
|
||||
index_descs=[
|
||||
IndexDescriptor(factory=f"IVF{nlist}(Flat),Flat"),
|
||||
],
|
||||
training_vectors=training_vectors,
|
||||
database_vectors=database_vectors,
|
||||
query_vectors=query_vectors,
|
||||
result_file=result_file,
|
||||
include_flat=False,
|
||||
min_accuracy=accuracy,
|
||||
pareto_metric=ParetoMetric.TIME,
|
||||
)
|
||||
nprobe = nlist // 2
|
||||
for _, _, _, k, v in results:
|
||||
if (
|
||||
".knn" in k
|
||||
and "nprobe" in v["search_params"]
|
||||
and v["knn_intersection"] >= accuracy
|
||||
):
|
||||
nprobe = min(nprobe, v["search_params"]["nprobe"])
|
||||
return nprobe
|
||||
|
||||
# train candidate IVF codecs
|
||||
# benchmark them at the same nprobe
|
||||
# keep only the space _and_ time Pareto optimal
|
||||
def optimize_codec(
|
||||
self,
|
||||
result_file: str,
|
||||
d: int,
|
||||
training_vectors: DatasetDescriptor,
|
||||
database_vectors: DatasetDescriptor,
|
||||
query_vectors: DatasetDescriptor,
|
||||
nlist: int,
|
||||
nprobe: int,
|
||||
min_accuracy: float,
|
||||
):
|
||||
codecs = (
|
||||
[
|
||||
(None, "Flat"),
|
||||
(None, "SQfp16"),
|
||||
(None, "SQ8"),
|
||||
] + [
|
||||
(f"OPQ{M}_{M * dim}", f"PQ{M}x{b}")
|
||||
for M in [8, 12, 16, 32, 48, 64, 96, 128, 192, 256]
|
||||
if d % M == 0
|
||||
for dim in range(2, 18, 2)
|
||||
if M * dim <= d
|
||||
for b in range(4, 14, 2)
|
||||
if M * b < d * 8 # smaller than SQ8
|
||||
] + [
|
||||
(None, f"PQ{M}x{b}")
|
||||
for M in [8, 12, 16, 32, 48, 64, 96, 128, 192, 256]
|
||||
if d % M == 0
|
||||
for b in range(8, 14, 2)
|
||||
if M * b < d * 8 # smaller than SQ8
|
||||
]
|
||||
)
|
||||
factory = {}
|
||||
for opq, pq in codecs:
|
||||
factory[
|
||||
f"IVF{nlist},{pq}" if opq is None else f"{opq},IVF{nlist},{pq}"
|
||||
] = (
|
||||
opq,
|
||||
pq,
|
||||
)
|
||||
|
||||
_, filtered = self.benchmark_and_filter_candidates(
|
||||
index_descs=[
|
||||
IndexDescriptor(
|
||||
factory=f"IVF{nlist},{pq}"
|
||||
if opq is None
|
||||
else f"{opq},IVF{nlist},{pq}",
|
||||
search_params={
|
||||
"nprobe": nprobe,
|
||||
},
|
||||
)
|
||||
for opq, pq in codecs
|
||||
],
|
||||
training_vectors=training_vectors,
|
||||
database_vectors=database_vectors,
|
||||
query_vectors=query_vectors,
|
||||
result_file=result_file,
|
||||
include_flat=False,
|
||||
min_accuracy=min_accuracy,
|
||||
pareto_metric=ParetoMetric.TIME_SPACE,
|
||||
)
|
||||
results = [
|
||||
factory[r] for r in set(v["factory"] for _, _, _, k, v in filtered)
|
||||
]
|
||||
return results
|
||||
|
||||
def optimize(
|
||||
self,
|
||||
d: int,
|
||||
training_vectors: DatasetDescriptor,
|
||||
database_vectors_list: List[DatasetDescriptor],
|
||||
query_vectors: DatasetDescriptor,
|
||||
min_accuracy: float,
|
||||
):
|
||||
# train an IVFFlat index
|
||||
# find the nprobe required for near perfect accuracy
|
||||
nlist = 4096
|
||||
nprobe_at_95 = self.ivf_flat_nprobe_required_for_accuracy(
|
||||
result_file=f"result_ivf{nlist}_flat.json",
|
||||
training_vectors=training_vectors,
|
||||
database_vectors=database_vectors_list[0],
|
||||
query_vectors=query_vectors,
|
||||
nlist=nlist,
|
||||
accuracy=0.95,
|
||||
)
|
||||
|
||||
# train candidate IVF codecs
|
||||
# benchmark them at the same nprobe
|
||||
# keep only the space and time Pareto optima
|
||||
codecs = self.optimize_codec(
|
||||
result_file=f"result_ivf{nlist}_codec.json",
|
||||
d=d,
|
||||
training_vectors=training_vectors,
|
||||
database_vectors=database_vectors_list[0],
|
||||
query_vectors=query_vectors,
|
||||
nlist=nlist,
|
||||
nprobe=nprobe_at_95,
|
||||
min_accuracy=min_accuracy,
|
||||
)
|
||||
|
||||
# optimize coarse quantizers
|
||||
quantizers = self.optimize_quantizer(
|
||||
training_vectors=training_vectors,
|
||||
query_vectors=query_vectors,
|
||||
nlists=[4096, 8192, 16384, 32768],
|
||||
min_accuracy=0.7,
|
||||
)
|
||||
|
||||
# combine them with the codecs
|
||||
# test them at different scales
|
||||
for database_vectors in database_vectors_list:
|
||||
self.optimize_ivf(
|
||||
result_file=f"result_{database_vectors.get_filename()}json",
|
||||
training_vectors=training_vectors,
|
||||
database_vectors=database_vectors,
|
||||
query_vectors=query_vectors,
|
||||
quantizers=quantizers,
|
||||
codecs=codecs,
|
||||
min_accuracy=min_accuracy,
|
||||
)
|
|
@ -3,15 +3,22 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from time import perf_counter
|
||||
import logging
|
||||
from multiprocessing.pool import ThreadPool
|
||||
import numpy as np
|
||||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
import functools
|
||||
import logging
|
||||
from enum import Enum
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from time import perf_counter
|
||||
|
||||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
import numpy as np
|
||||
|
||||
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
OperatingPoints,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def timer(name, func, once=False) -> float:
|
||||
logger.info(f"Measuring {name}")
|
||||
t1 = perf_counter()
|
||||
|
@ -34,28 +41,41 @@ def timer(name, func, once=False) -> float:
|
|||
|
||||
|
||||
def refine_distances_knn(
|
||||
xq: np.ndarray, xb: np.ndarray, I: np.ndarray, metric,
|
||||
xq: np.ndarray,
|
||||
xb: np.ndarray,
|
||||
I: np.ndarray,
|
||||
metric,
|
||||
):
|
||||
""" Recompute distances between xq[i] and xb[I[i, :]] """
|
||||
"""Recompute distances between xq[i] and xb[I[i, :]]"""
|
||||
nq, k = I.shape
|
||||
xq = np.ascontiguousarray(xq, dtype='float32')
|
||||
xq = np.ascontiguousarray(xq, dtype="float32")
|
||||
nq2, d = xq.shape
|
||||
xb = np.ascontiguousarray(xb, dtype='float32')
|
||||
xb = np.ascontiguousarray(xb, dtype="float32")
|
||||
nb, d2 = xb.shape
|
||||
I = np.ascontiguousarray(I, dtype='int64')
|
||||
I = np.ascontiguousarray(I, dtype="int64")
|
||||
assert nq2 == nq
|
||||
assert d2 == d
|
||||
D = np.empty(I.shape, dtype='float32')
|
||||
D = np.empty(I.shape, dtype="float32")
|
||||
D[:] = np.inf
|
||||
if metric == faiss.METRIC_L2:
|
||||
faiss.fvec_L2sqr_by_idx(
|
||||
faiss.swig_ptr(D), faiss.swig_ptr(xq), faiss.swig_ptr(xb),
|
||||
faiss.swig_ptr(I), d, nq, k
|
||||
faiss.swig_ptr(D),
|
||||
faiss.swig_ptr(xq),
|
||||
faiss.swig_ptr(xb),
|
||||
faiss.swig_ptr(I),
|
||||
d,
|
||||
nq,
|
||||
k,
|
||||
)
|
||||
else:
|
||||
faiss.fvec_inner_products_by_idx(
|
||||
faiss.swig_ptr(D), faiss.swig_ptr(xq), faiss.swig_ptr(xb),
|
||||
faiss.swig_ptr(I), d, nq, k
|
||||
faiss.swig_ptr(D),
|
||||
faiss.swig_ptr(xq),
|
||||
faiss.swig_ptr(xb),
|
||||
faiss.swig_ptr(I),
|
||||
d,
|
||||
nq,
|
||||
k,
|
||||
)
|
||||
return D
|
||||
|
||||
|
@ -97,7 +117,10 @@ def distance_ratio_measure(I, R, D_GT, metric):
|
|||
|
||||
@functools.cache
|
||||
def get_cpu_info():
|
||||
return [l for l in open("/proc/cpuinfo", "r") if "model name" in l][0][13:].strip()
|
||||
return [l for l in open("/proc/cpuinfo", "r") if "model name" in l][0][
|
||||
13:
|
||||
].strip()
|
||||
|
||||
|
||||
def dict_merge(target, source):
|
||||
for k, v in source.items():
|
||||
|
@ -105,3 +128,121 @@ def dict_merge(target, source):
|
|||
dict_merge(target[k], v)
|
||||
else:
|
||||
target[k] = v
|
||||
|
||||
|
||||
class Cost:
|
||||
def __init__(self, values):
|
||||
self.values = values
|
||||
|
||||
def __le__(self, other):
|
||||
return all(
|
||||
v1 <= v2 for v1, v2 in zip(self.values, other.values, strict=True)
|
||||
)
|
||||
|
||||
def __lt__(self, other):
|
||||
return all(
|
||||
v1 < v2 for v1, v2 in zip(self.values, other.values, strict=True)
|
||||
)
|
||||
|
||||
|
||||
class ParetoMode(Enum):
|
||||
DISABLE = 1 # no Pareto filtering
|
||||
INDEX = 2 # index-local optima
|
||||
GLOBAL = 3 # global optima
|
||||
|
||||
|
||||
class ParetoMetric(Enum):
|
||||
TIME = 0 # time vs accuracy
|
||||
SPACE = 1 # space vs accuracy
|
||||
TIME_SPACE = 2 # (time, space) vs accuracy
|
||||
|
||||
|
||||
def range_search_recall_at_precision(experiment, precision):
|
||||
return round(
|
||||
max(
|
||||
r
|
||||
for r, p in zip(
|
||||
experiment["range_search_pr"]["recall"],
|
||||
experiment["range_search_pr"]["precision"],
|
||||
)
|
||||
if p > precision
|
||||
),
|
||||
6,
|
||||
)
|
||||
|
||||
|
||||
def filter_results(
|
||||
results,
|
||||
evaluation,
|
||||
accuracy_metric, # str or func
|
||||
time_metric=None, # func or None -> use default
|
||||
space_metric=None, # func or None -> use default
|
||||
min_accuracy=0,
|
||||
max_space=0,
|
||||
max_time=0,
|
||||
scaling_factor=1.0,
|
||||
name_filter=None, # func
|
||||
pareto_mode=ParetoMode.DISABLE,
|
||||
pareto_metric=ParetoMetric.TIME,
|
||||
):
|
||||
if isinstance(accuracy_metric, str):
|
||||
accuracy_key = accuracy_metric
|
||||
accuracy_metric = lambda v: v[accuracy_key]
|
||||
|
||||
if time_metric is None:
|
||||
time_metric = lambda v: v["time"] * scaling_factor + (
|
||||
v["quantizer"]["time"] if "quantizer" in v else 0
|
||||
)
|
||||
|
||||
if space_metric is None:
|
||||
space_metric = lambda v: results["indices"][v["codec"]]["code_size"]
|
||||
|
||||
fe = []
|
||||
ops = {}
|
||||
if pareto_mode == ParetoMode.GLOBAL:
|
||||
op = OperatingPoints()
|
||||
ops["global"] = op
|
||||
for k, v in results["experiments"].items():
|
||||
if f".{evaluation}" in k:
|
||||
accuracy = accuracy_metric(v)
|
||||
if min_accuracy > 0 and accuracy < min_accuracy:
|
||||
continue
|
||||
space = space_metric(v)
|
||||
if space is None:
|
||||
space = 0
|
||||
if max_space > 0 and space > max_space:
|
||||
continue
|
||||
time = time_metric(v)
|
||||
if max_time > 0 and time > max_time:
|
||||
continue
|
||||
idx_name = v["index"] + (
|
||||
"snap"
|
||||
if "search_params" in v and v["search_params"]["snap"] == 1
|
||||
else ""
|
||||
)
|
||||
if name_filter is not None and not name_filter(idx_name):
|
||||
continue
|
||||
experiment = (accuracy, space, time, k, v)
|
||||
if pareto_mode == ParetoMode.DISABLE:
|
||||
fe.append(experiment)
|
||||
continue
|
||||
if pareto_mode == ParetoMode.INDEX:
|
||||
if idx_name not in ops:
|
||||
ops[idx_name] = OperatingPoints()
|
||||
op = ops[idx_name]
|
||||
if pareto_metric == ParetoMetric.TIME:
|
||||
op.add_operating_point(experiment, accuracy, time)
|
||||
elif pareto_metric == ParetoMetric.SPACE:
|
||||
op.add_operating_point(experiment, accuracy, space)
|
||||
else:
|
||||
op.add_operating_point(
|
||||
experiment, accuracy, Cost([time, space])
|
||||
)
|
||||
|
||||
if ops:
|
||||
for op in ops.values():
|
||||
for v, _, _ in op.operating_points:
|
||||
fe.append(v)
|
||||
|
||||
fe.sort()
|
||||
return fe
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from bench_fw.benchmark_io import BenchmarkIO
|
||||
from bench_fw.descriptors import DatasetDescriptor
|
||||
from bench_fw.optimize import Optimizer
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def bigann(bio):
|
||||
optimizer = Optimizer(
|
||||
distance_metric="L2",
|
||||
num_threads=32,
|
||||
run_local=False,
|
||||
)
|
||||
optimizer.set_io(bio)
|
||||
query_vectors = DatasetDescriptor(namespace="std_q", tablename="bigann1M")
|
||||
xt = bio.get_dataset(query_vectors)
|
||||
optimizer.optimize(
|
||||
d=xt.shape[1],
|
||||
training_vectors=DatasetDescriptor(
|
||||
namespace="std_t",
|
||||
tablename="bigann1M",
|
||||
num_vectors=2_000_000,
|
||||
),
|
||||
database_vectors_list=[
|
||||
DatasetDescriptor(
|
||||
namespace="std_d",
|
||||
tablename="bigann1M",
|
||||
),
|
||||
DatasetDescriptor(namespace="std_d", tablename="bigann10M"),
|
||||
],
|
||||
query_vectors=query_vectors,
|
||||
min_accuracy=0.85,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("experiment")
|
||||
parser.add_argument("path")
|
||||
args = parser.parse_args()
|
||||
assert os.path.exists(args.path)
|
||||
path = os.path.join(args.path, args.experiment)
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
bio = BenchmarkIO(
|
||||
path=path,
|
||||
)
|
||||
if args.experiment == "bigann":
|
||||
bigann(bio)
|
Loading…
Reference in New Issue