faiss/benchs/bench_fw/index.py

1086 lines
39 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from collections import OrderedDict
from copy import copy
from dataclasses import dataclass
from typing import ClassVar, Dict, List, Optional
import faiss # @manual=//faiss/python:pyfaiss_gpu
import numpy as np
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
knn_intersection_measure,
OperatingPointsWithRanges,
)
from faiss.contrib.factory_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
reverse_index_factory,
)
from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
add_preassigned,
replace_ivf_quantizer,
)
from .descriptors import DatasetDescriptor
from .utils import (
distance_ratio_measure,
get_cpu_info,
refine_distances_knn,
refine_distances_range,
timer,
)
logger = logging.getLogger(__name__)
# The classes below are wrappers around Faiss indices, with different
# implementations for the case when we start with an already trained
# index (IndexFromCodec) vs factory strings (IndexFromFactory).
# In both cases the classes have operations for adding to an index
# and searching it, and outputs are cached on disk.
# IndexFromFactory also decomposes the index (pretransform and quantizer)
# and trains sub-indices independently.
class IndexBase:
def set_io(self, benchmark_io):
self.io = benchmark_io
@staticmethod
def param_dict_list_to_name(param_dict_list):
if not param_dict_list:
return ""
l = 0
n = ""
for param_dict in param_dict_list:
n += IndexBase.param_dict_to_name(param_dict, f"cp{l}")
l += 1
return n
@staticmethod
def param_dict_to_name(param_dict, prefix="sp"):
if not param_dict:
return ""
n = prefix
for name, val in param_dict.items():
if name == "snap":
continue
if name == "lsq_gpu" and val == 0:
continue
if name == "use_beam_LUT" and val == 0:
continue
n += f"_{name}_{val}"
if n == prefix:
return ""
n += "."
return n
@staticmethod
def set_index_param_dict_list(index, param_dict_list, assert_same=False):
if not param_dict_list:
return
index = faiss.downcast_index(index)
for param_dict in param_dict_list:
assert index is not None
IndexBase.set_index_param_dict(index, param_dict, assert_same)
index = faiss.try_extract_index_ivf(index)
if index is not None:
index = index.quantizer
@staticmethod
def set_index_param_dict(index, param_dict, assert_same=False):
if not param_dict:
return
for name, val in param_dict.items():
IndexBase.set_index_param(index, name, val, assert_same)
@staticmethod
def set_index_param(index, name, val, assert_same=False):
index = faiss.downcast_index(index)
val = int(val)
if isinstance(index, faiss.IndexPreTransform):
Index.set_index_param(index.index, name, val)
return
elif name == "snap":
return
elif name == "lsq_gpu":
if val == 1:
ngpus = faiss.get_num_gpus()
icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
if isinstance(index, faiss.IndexProductLocalSearchQuantizer):
for i in range(index.plsq.nsplits):
lsq = faiss.downcast_Quantizer(
index.plsq.subquantizer(i)
)
if lsq.icm_encoder_factory is None:
lsq.icm_encoder_factory = icm_encoder_factory
else:
if index.lsq.icm_encoder_factory is None:
index.lsq.icm_encoder_factory = icm_encoder_factory
return
elif name in ["efSearch", "efConstruction"]:
obj = index.hnsw
elif name in ["nprobe", "parallel_mode"]:
obj = faiss.extract_index_ivf(index)
elif name in ["use_beam_LUT", "max_beam_size"]:
if isinstance(index, faiss.IndexProductResidualQuantizer):
obj = [
faiss.downcast_Quantizer(index.prq.subquantizer(i))
for i in range(index.prq.nsplits)
]
else:
obj = index.rq
elif name == "encode_ils_iters":
if isinstance(index, faiss.IndexProductLocalSearchQuantizer):
obj = [
faiss.downcast_Quantizer(index.plsq.subquantizer(i))
for i in range(index.plsq.nsplits)
]
else:
obj = index.lsq
else:
obj = index
if not isinstance(obj, list):
obj = [obj]
for o in obj:
test = getattr(o, name)
if assert_same and not name == "use_beam_LUT":
assert test == val
else:
setattr(o, name, val)
@staticmethod
def filter_index_param_dict_list(param_dict_list):
if (
param_dict_list is not None
and param_dict_list[0] is not None
and "k_factor" in param_dict_list[0]
):
filtered = copy(param_dict_list)
del filtered[0]["k_factor"]
return filtered
else:
return param_dict_list
def is_flat(self):
model = faiss.downcast_index(self.get_model())
return isinstance(model, faiss.IndexFlat)
def is_ivf(self):
return False
model = self.get_model()
return faiss.try_extract_index_ivf(model) is not None
def is_2layer(self):
def is_2layer_(index):
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexPreTransform):
return is_2layer_(index.index)
return isinstance(index, faiss.Index2Layer)
model = self.get_model()
return is_2layer_(model)
def is_decode_supported(self):
model = self.get_model()
if isinstance(model, faiss.IndexPreTransform):
for i in range(model.chain.size()):
vt = faiss.downcast_VectorTransform(model.chain.at(i))
if isinstance(vt, faiss.ITQTransform):
return False
return True
def is_pretransform(self):
codec = self.get_model()
if isinstance(codec, faiss.IndexRefine):
codec = faiss.downcast_index(codec.base_index)
return isinstance(codec, faiss.IndexPreTransform)
# index is a codec + database vectors
# in other words: a trained Faiss index
# that contains database vectors
def get_index_name(self):
raise NotImplementedError
def get_index(self):
raise NotImplementedError
# codec is a trained model
# in other words: a trained Faiss index
# without any database vectors
def get_codec_name(self):
raise NotImplementedError
def get_codec(self):
raise NotImplementedError
# model is an untrained Faiss index
# it can be used for training (see codec)
# or to inspect its structure
def get_model_name(self):
raise NotImplementedError
def get_model(self):
raise NotImplementedError
def get_construction_params(self):
raise NotImplementedError
def transform(self, vectors):
transformed_vectors = DatasetDescriptor(
tablename=f"{vectors.get_filename()}{self.get_codec_name()}transform.npy"
)
if not self.io.file_exist(transformed_vectors.tablename):
codec = self.get_codec()
assert isinstance(codec, faiss.IndexPreTransform)
transform = faiss.downcast_VectorTransform(codec.chain.at(0))
x = self.io.get_dataset(vectors)
xt = transform.apply(x)
self.io.write_nparray(xt, transformed_vectors.tablename)
return transformed_vectors
def snap(self, vectors):
transformed_vectors = DatasetDescriptor(
tablename=f"{vectors.get_filename()}{self.get_codec_name()}snap.npy"
)
if not self.io.file_exist(transformed_vectors.tablename):
codec = self.get_codec()
x = self.io.get_dataset(vectors)
xt = codec.sa_decode(codec.sa_encode(x))
self.io.write_nparray(xt, transformed_vectors.tablename)
return transformed_vectors
def knn_search_quantizer(self, query_vectors, k):
if self.is_pretransform():
pretransform = self.get_pretransform()
quantizer_query_vectors = pretransform.transform(query_vectors)
else:
pretransform = None
quantizer_query_vectors = query_vectors
quantizer, _, _ = self.get_quantizer(
dry_run=False, pretransform=pretransform
)
QD, QI, _, QP, _ = quantizer.knn_search(
dry_run=False,
search_parameters=None,
query_vectors=quantizer_query_vectors,
k=k,
)
xqt = self.io.get_dataset(quantizer_query_vectors)
return xqt, QD, QI, QP
def get_knn_search_name(
self,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
k: int,
reconstruct: bool = False,
):
name = self.get_index_name()
name += Index.param_dict_to_name(search_parameters)
name += query_vectors.get_filename("q")
name += f"k_{k}."
name += f"t_{self.num_threads}."
if reconstruct:
name += "rec."
else:
name += "knn."
return name
def knn_search(
self,
dry_run,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
k: int,
I_gt=None,
D_gt=None,
):
logger.info("knn_search: begin")
if search_parameters is not None and search_parameters["snap"] == 1:
query_vectors = self.snap(query_vectors)
filename = (
self.get_knn_search_name(search_parameters, query_vectors, k)
+ "zip"
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {filename}")
D, I, R, P = self.io.read_file(filename, ["D", "I", "R", "P"])
else:
if dry_run:
return None, None, None, None, filename
index = self.get_index()
Index.set_index_param_dict(index, search_parameters)
if self.is_2layer():
# Index2Layer doesn't support search
xq = self.io.get_dataset(query_vectors)
xb = index.reconstruct_n(0, index.ntotal)
(D, I), t, _ = timer(
"knn_search 2layer", lambda: faiss.knn(xq, xb, k)
)
elif self.is_ivf() and not isinstance(index, faiss.IndexRefine):
index_ivf = faiss.extract_index_ivf(index)
nprobe = (
search_parameters["nprobe"]
if search_parameters is not None
and "nprobe" in search_parameters
else index_ivf.nprobe
)
xqt, QD, QI, QP = self.knn_search_quantizer(
query_vectors=query_vectors,
k=nprobe,
)
if index_ivf.parallel_mode != 2:
logger.info("Setting IVF parallel mode")
index_ivf.parallel_mode = 2
(D, I), t, repeat = timer(
"knn_search_preassigned",
lambda: index_ivf.search_preassigned(xqt, k, QI, QD),
)
# Dref, Iref = index.search(xq, k)
# np.testing.assert_array_equal(I, Iref)
# np.testing.assert_allclose(D, Dref)
else:
xq = self.io.get_dataset(query_vectors)
(D, I), t, _ = timer("knn_search", lambda: index.search(xq, k))
if self.is_flat() or not hasattr(self, "database_vectors"): # TODO
R = D
else:
xq = self.io.get_dataset(query_vectors)
xb = self.io.get_dataset(self.database_vectors)
R = refine_distances_knn(xq, xb, I, self.metric_type)
P = {
"time": t,
"k": k,
}
if self.is_ivf() and not isinstance(index, faiss.IndexRefine):
stats = faiss.cvar.indexIVF_stats
P |= {
"quantizer": QP,
"nq": int(stats.nq // repeat),
"nlist": int(stats.nlist // repeat),
"ndis": int(stats.ndis // repeat),
"nheap_updates": int(stats.nheap_updates // repeat),
"quantization_time": int(
stats.quantization_time // repeat
),
"search_time": int(stats.search_time // repeat),
}
self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P])
P |= {
"index": self.get_index_name(),
"codec": self.get_codec_name(),
"factory": self.get_model_name(),
"construction_params": self.get_construction_params(),
"search_params": search_parameters,
"knn_intersection": knn_intersection_measure(
I,
I_gt,
)
if I_gt is not None
else None,
"distance_ratio": distance_ratio_measure(
I,
R,
D_gt,
self.metric_type,
)
if D_gt is not None
else None,
}
logger.info("knn_search: end")
return D, I, R, P, None
def reconstruct(
self,
dry_run,
parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
k: int,
I_gt,
):
logger.info("reconstruct: begin")
filename = (
self.get_knn_search_name(
parameters, query_vectors, k, reconstruct=True
)
+ "zip"
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {filename}")
(P,) = self.io.read_file(filename, ["P"])
P["index"] = self.get_index_name()
P["codec"] = self.get_codec_name()
P["factory"] = self.get_model_name()
P["reconstruct_params"] = parameters
P["construction_params"] = self.get_construction_params()
else:
if dry_run:
return None, filename
codec = self.get_codec()
codec_meta = self.fetch_meta()
Index.set_index_param_dict(codec, parameters)
xb = self.io.get_dataset(self.database_vectors)
xb_encoded, encode_t, _ = timer(
"sa_encode", lambda: codec.sa_encode(xb)
)
xq = self.io.get_dataset(query_vectors)
if self.is_decode_supported():
xb_decoded, decode_t, _ = timer(
"sa_decode", lambda: codec.sa_decode(xb_encoded)
)
mse = np.square(xb_decoded - xb).sum(axis=1).mean().item()
_, I = faiss.knn(xq, xb_decoded, k, metric=self.metric_type)
asym_recall = knn_intersection_measure(I, I_gt)
xq_decoded = codec.sa_decode(codec.sa_encode(xq))
_, I = faiss.knn(
xq_decoded, xb_decoded, k, metric=self.metric_type
)
else:
mse = None
asym_recall = None
decode_t = None
# assume hamming for sym
xq_encoded = codec.sa_encode(xq)
bin = faiss.IndexBinaryFlat(xq_encoded.shape[1] * 8)
bin.add(xb_encoded)
_, I = bin.search(xq_encoded, k)
sym_recall = knn_intersection_measure(I, I_gt)
P = {
"encode_time": encode_t,
"decode_time": decode_t,
"mse": mse,
"sym_recall": sym_recall,
"asym_recall": asym_recall,
"cpu": get_cpu_info(),
"num_threads": self.num_threads,
"index": self.get_index_name(),
"codec": self.get_codec_name(),
"factory": self.get_model_name(),
"reconstruct_params": parameters,
"construction_params": self.get_construction_params(),
"codec_meta": codec_meta,
}
self.io.write_file(filename, ["P"], [P])
logger.info("reconstruct: end")
return P, None
def get_range_search_name(
self,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
radius: Optional[float] = None,
):
name = self.get_index_name()
name += Index.param_dict_to_name(search_parameters)
name += query_vectors.get_filename("q")
if radius is not None:
name += f"r_{int(radius * 1000)}."
else:
name += "r_auto."
return name
def range_search(
self,
dry_run,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
radius: Optional[float] = None,
):
logger.info("range_search: begin")
if search_parameters is not None and search_parameters["snap"] == 1:
query_vectors = self.snap(query_vectors)
filename = (
self.get_range_search_name(
search_parameters, query_vectors, radius
)
+ "zip"
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {filename}")
lims, D, I, R, P = self.io.read_file(
filename, ["lims", "D", "I", "R", "P"]
)
else:
if dry_run:
return None, None, None, None, None, filename
xq = self.io.get_dataset(query_vectors)
index = self.get_index()
Index.set_index_param_dict(index, search_parameters)
if self.is_ivf():
xqt, QD, QI, QP = self.knn_search_quantizer(
query_vectors, search_parameters["nprobe"]
)
index_ivf = faiss.extract_index_ivf(index)
if index_ivf.parallel_mode != 2:
logger.info("Setting IVF parallel mode")
index_ivf.parallel_mode = 2
(lims, D, I), t, repeat = timer(
"range_search_preassigned",
lambda: index_ivf.range_search_preassigned(
xqt, radius, QI, QD
),
)
else:
(lims, D, I), t, _ = timer(
"range_search", lambda: index.range_search(xq, radius)
)
if self.is_flat():
R = D
else:
xb = self.io.get_dataset(self.database_vectors)
R = refine_distances_range(
lims, D, I, xq, xb, self.metric_type
)
P = {
"time": t,
"radius": radius,
"count": len(I),
}
if self.is_ivf():
stats = faiss.cvar.indexIVF_stats
P |= {
"quantizer": QP,
"nq": int(stats.nq // repeat),
"nlist": int(stats.nlist // repeat),
"ndis": int(stats.ndis // repeat),
"nheap_updates": int(stats.nheap_updates // repeat),
"quantization_time": int(
stats.quantization_time // repeat
),
"search_time": int(stats.search_time // repeat),
}
self.io.write_file(
filename, ["lims", "D", "I", "R", "P"], [lims, D, I, R, P]
)
P |= {
"index": self.get_index_name(),
"codec": self.get_codec_name(),
"factory": self.get_model_name(),
"construction_params": self.get_construction_params(),
"search_params": search_parameters,
}
logger.info("range_seach: end")
return lims, D, I, R, P, None
# Common base for IndexFromCodec and IndexFromFactory,
# but not for the sub-indices of codec-based indices
# IndexFromQuantizer and IndexFromPreTransform, because
# they share the configuration of their parent IndexFromCodec
@dataclass
class Index(IndexBase):
num_threads: int
d: int
metric: str
database_vectors: DatasetDescriptor
construction_params: List[Dict[str, int]]
search_params: Dict[str, int]
cached_codec: ClassVar[OrderedDict[str, faiss.Index]] = OrderedDict()
cached_index: ClassVar[OrderedDict[str, faiss.Index]] = OrderedDict()
def __post_init__(self):
if isinstance(self.metric, str):
if self.metric == "IP":
self.metric_type = faiss.METRIC_INNER_PRODUCT
elif self.metric == "L2":
self.metric_type = faiss.METRIC_L2
else:
raise ValueError
elif isinstance(self.metric, int):
self.metric_type = self.metric
if self.metric_type == faiss.METRIC_INNER_PRODUCT:
self.metric = "IP"
elif self.metric_type == faiss.METRIC_L2:
self.metric = "L2"
else:
raise ValueError
else:
raise ValueError
def supports_range_search(self):
codec = self.get_codec()
return not type(codec) in [
faiss.IndexHNSWFlat,
faiss.IndexIVFFastScan,
faiss.IndexRefine,
faiss.IndexPQ,
]
def fetch_codec(self):
raise NotImplementedError
def get_codec(self):
codec_name = self.get_codec_name()
if codec_name not in Index.cached_codec:
Index.cached_codec[codec_name], _, _ = self.fetch_codec()
if len(Index.cached_codec) > 1:
Index.cached_codec.popitem(last=False)
return Index.cached_codec[codec_name]
def get_index_name(self):
name = self.get_codec_name()
assert self.database_vectors is not None
name += self.database_vectors.get_filename("xb")
return name
def fetch_index(self):
index = self.get_codec()
index.reset()
assert index.ntotal == 0
logger.info("Adding vectors to index")
xb = self.io.get_dataset(self.database_vectors)
if self.is_ivf() and not isinstance(index, faiss.IndexRefine):
xbt, QD, QI, QP = self.knn_search_quantizer(
query_vectors=self.database_vectors,
k=1,
)
index_ivf = faiss.extract_index_ivf(index)
if index_ivf.parallel_mode != 2:
logger.info("Setting IVF parallel mode")
index_ivf.parallel_mode = 2
_, t, _ = timer(
"add_preassigned",
lambda: add_preassigned(index_ivf, xbt, QI.ravel()),
once=True,
)
else:
_, t, _ = timer(
"add",
lambda: index.add(xb),
once=True,
)
assert index.ntotal == xb.shape[0] or index_ivf.ntotal == xb.shape[0]
logger.info("Added vectors to index")
return index, t
def get_index(self):
index_name = self.get_index_name()
if index_name not in Index.cached_index:
Index.cached_index[index_name], _ = self.fetch_index()
if len(Index.cached_index) > 3:
Index.cached_index.popitem(last=False)
return Index.cached_index[index_name]
def get_construction_params(self):
return self.construction_params
def get_code_size(self, codec=None):
def get_index_code_size(index):
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexPreTransform):
return get_index_code_size(index.index)
elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]:
return get_index_code_size(
index.base_index
) + get_index_code_size(index.refine_index)
else:
return index.code_size if hasattr(index, "code_size") else 0
if codec is None:
codec = self.get_codec()
return get_index_code_size(codec)
def get_sa_code_size(self, codec=None):
if codec is None:
codec = self.get_codec()
try:
return codec.sa_code_size()
except:
return None
def get_operating_points(self):
op = OperatingPointsWithRanges()
def add_range_or_val(name, range):
op.add_range(
name,
[self.search_params[name]]
if self.search_params and name in self.search_params
else range,
)
add_range_or_val("snap", [0])
model = self.get_model()
model_ivf = faiss.try_extract_index_ivf(model)
if model_ivf is not None:
add_range_or_val(
"nprobe",
[2**i for i in range(12) if 2**i <= model_ivf.nlist * 0.5],
# [1, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28] + [
# i
# for i in range(32, 64, 8)
# if i <= model_ivf.nlist * 0.1
# ] + [
# i
# for i in range(64, 128, 16)
# if i <= model_ivf.nlist * 0.1
# ] + [
# i
# for i in range(128, 256, 32)
# if i <= model_ivf.nlist * 0.1
# ] + [
# i
# for i in range(256, 512, 64)
# if i <= model_ivf.nlist * 0.1
# ] + [
# 2**i
# for i in range(9, 12)
# if 2**i <= model_ivf.nlist * 0.1
# ],
)
model = faiss.downcast_index(model)
if isinstance(model, faiss.IndexRefine):
add_range_or_val(
"k_factor",
[2**i for i in range(13)],
)
elif isinstance(model, faiss.IndexHNSWFlat):
add_range_or_val(
"efSearch",
[2**i for i in range(3, 11)],
)
elif isinstance(model, faiss.IndexResidualQuantizer) or isinstance(
model, faiss.IndexProductResidualQuantizer
):
add_range_or_val(
"max_beam_size",
[1, 2, 4, 8, 16, 32],
)
add_range_or_val(
"use_beam_LUT",
[1],
)
elif isinstance(model, faiss.IndexLocalSearchQuantizer) or isinstance(
model, faiss.IndexProductLocalSearchQuantizer
):
add_range_or_val(
"encode_ils_iters",
[2, 4, 8, 16],
)
add_range_or_val(
"lsq_gpu",
[1],
)
return op
# IndexFromCodec, IndexFromQuantizer and IndexFromPreTransform
# are used to wrap pre-trained Faiss indices (codecs)
@dataclass
class IndexFromCodec(Index):
path: str
bucket: Optional[str] = None
def get_quantizer(self):
if not self.is_ivf():
raise ValueError("Not an IVF index")
quantizer = IndexFromQuantizer(self)
quantizer.set_io(self.io)
return quantizer
def get_pretransform(self):
if not self.is_ivf():
raise ValueError("Not an IVF index")
quantizer = IndexFromPreTransform(self)
quantizer.set_io(self.io)
return quantizer
def get_model_name(self):
return os.path.basename(self.path)
def get_codec_name(self):
assert self.path is not None
name = os.path.basename(self.path)
name += Index.param_dict_list_to_name(self.construction_params)
return name
def fetch_codec(self):
codec = self.io.read_index(
os.path.basename(self.path),
self.bucket,
os.path.dirname(self.path),
)
assert self.d == codec.d
assert self.metric_type == codec.metric_type
Index.set_index_param_dict_list(codec, self.construction_params)
return codec, None, None
def get_model(self):
return self.get_codec()
class IndexFromQuantizer(IndexBase):
ivf_index: Index
def __init__(self, ivf_index: Index):
self.ivf_index = ivf_index
super().__init__()
def get_model_name(self):
return self.get_index_name()
def get_codec_name(self):
return self.get_index_name()
def get_codec(self):
return self.get_index()
def get_index_name(self):
ivf_codec_name = self.ivf_index.get_codec_name()
return f"{ivf_codec_name}quantizer."
def get_index(self):
ivf_codec = faiss.extract_index_ivf(self.ivf_index.get_codec())
return ivf_codec.quantizer
class IndexFromPreTransform(IndexBase):
pre_transform_index: Index
def __init__(self, pre_transform_index: Index):
self.pre_transform_index = pre_transform_index
super().__init__()
def get_codec_name(self):
pre_transform_codec_name = self.pre_transform_index.get_codec_name()
return f"{pre_transform_codec_name}pretransform."
def get_codec(self):
return self.get_codec()
# IndexFromFactory is for creating and training indices from scratch
@dataclass
class IndexFromFactory(Index):
factory: str
training_vectors: DatasetDescriptor
def get_codec_name(self):
assert self.factory is not None
name = f"{self.factory.replace(',', '_')}."
assert self.d is not None
assert self.metric is not None
name += f"d_{self.d}.{self.metric.upper()}."
if self.factory != "Flat":
assert self.training_vectors is not None
name += self.training_vectors.get_filename("xt")
name += Index.param_dict_list_to_name(self.construction_params)
return name
def fetch_meta(self, dry_run=False):
meta_filename = self.get_codec_name() + "json"
if self.io.file_exist(meta_filename):
meta = self.io.read_json(meta_filename)
report = None
else:
_, meta, report = self.fetch_codec(dry_run=dry_run)
return meta, report
def fetch_codec(self, dry_run=False):
codec_filename = self.get_codec_name() + "codec"
meta_filename = self.get_codec_name() + "json"
if self.io.file_exist(codec_filename) and self.io.file_exist(
meta_filename
):
codec = self.io.read_index(codec_filename)
assert self.d == codec.d
assert self.metric_type == codec.metric_type
meta = self.io.read_json(meta_filename)
else:
codec, training_time, requires = self.assemble(dry_run=dry_run)
if requires is not None:
assert dry_run
if requires == "":
return None, None, codec_filename
else:
return None, None, requires
codec_size = self.io.write_index(codec, codec_filename)
assert codec_size is not None
meta = {
"training_time": training_time,
"training_size": self.training_vectors.num_vectors,
"codec_size": codec_size,
"sa_code_size": self.get_sa_code_size(codec),
"code_size": self.get_code_size(codec),
"cpu": get_cpu_info(),
}
self.io.write_json(meta, meta_filename, overwrite=True)
Index.set_index_param_dict_list(
codec, self.construction_params, assert_same=True
)
return codec, meta, None
def get_model_name(self):
return self.factory
def get_model(self):
model = faiss.index_factory(self.d, self.factory, self.metric_type)
Index.set_index_param_dict_list(model, self.construction_params)
return model
def get_pretransform(self):
model = self.get_model()
assert isinstance(model, faiss.IndexPreTransform)
sub_index = faiss.downcast_index(model.index)
if isinstance(sub_index, faiss.IndexFlat):
return self
# replace the sub-index with Flat
model.index = faiss.IndexFlat(model.index.d, model.index.metric_type)
pretransform = IndexFromFactory(
num_threads=self.num_threads,
d=model.d,
metric=model.metric_type,
database_vectors=self.database_vectors,
construction_params=self.construction_params,
search_params=None,
factory=reverse_index_factory(model),
training_vectors=self.training_vectors,
)
pretransform.set_io(self.io)
return pretransform
def get_quantizer(self, dry_run, pretransform=None):
model = self.get_model()
model_ivf = faiss.extract_index_ivf(model)
assert isinstance(model_ivf, faiss.IndexIVF)
assert ord(model_ivf.quantizer_trains_alone) in [0, 2]
if pretransform is None:
training_vectors = self.training_vectors
else:
training_vectors = pretransform.transform(self.training_vectors)
centroids, t, requires = training_vectors.k_means(
self.io, model_ivf.nlist, dry_run
)
if requires is not None:
return None, None, requires
quantizer = IndexFromFactory(
num_threads=self.num_threads,
d=model_ivf.quantizer.d,
metric=model_ivf.quantizer.metric_type,
database_vectors=centroids,
construction_params=self.construction_params[1:]
if self.construction_params is not None
else None,
search_params=None,
factory=reverse_index_factory(model_ivf.quantizer),
training_vectors=centroids,
)
quantizer.set_io(self.io)
return quantizer, t, None
def assemble(self, dry_run):
logger.info(f"assemble {self.factory}")
model = self.get_model()
opaque = True
t_aggregate = 0
# try:
# reverse_index_factory(model)
# opaque = False
# except NotImplementedError:
# opaque = True
if opaque:
codec = model
else:
if isinstance(model, faiss.IndexPreTransform):
logger.info(f"assemble: pretransform {self.factory}")
sub_index = faiss.downcast_index(model.index)
if not isinstance(sub_index, faiss.IndexFlat):
# replace the sub-index with Flat and fetch pre-trained
pretransform = self.get_pretransform()
codec, meta, report = pretransform.fetch_codec(
dry_run=dry_run
)
if report is not None:
return None, None, report
t_aggregate += meta["training_time"]
assert codec.is_trained
transformed_training_vectors = pretransform.transform(
self.training_vectors
)
# replace the Flat index with the required sub-index
wrapper = IndexFromFactory(
num_threads=self.num_threads,
d=sub_index.d,
metric=sub_index.metric_type,
database_vectors=None,
construction_params=self.construction_params,
search_params=None,
factory=reverse_index_factory(sub_index),
training_vectors=transformed_training_vectors,
)
wrapper.set_io(self.io)
codec.index, meta, report = wrapper.fetch_codec(
dry_run=dry_run
)
if report is not None:
return None, None, report
t_aggregate += meta["training_time"]
assert codec.index.is_trained
else:
codec = model
elif isinstance(model, faiss.IndexIVF):
logger.info(f"assemble: ivf {self.factory}")
# replace the quantizer
quantizer, t, requires = self.get_quantizer(dry_run=dry_run)
if requires is not None:
return None, None, requires
t_aggregate += t
codec = faiss.clone_index(model)
quantizer_index, t = quantizer.fetch_index()
t_aggregate += t
replace_ivf_quantizer(codec, quantizer_index)
assert codec.quantizer.is_trained
assert codec.nlist == codec.quantizer.ntotal
elif isinstance(model, faiss.IndexRefine) or isinstance(
model, faiss.IndexRefineFlat
):
logger.info(f"assemble: refine {self.factory}")
# replace base_index
wrapper = IndexFromFactory(
num_threads=self.num_threads,
d=model.base_index.d,
metric=model.base_index.metric_type,
database_vectors=self.database_vectors,
construction_params=IndexBase.filter_index_param_dict_list(
self.construction_params
),
search_params=None,
factory=reverse_index_factory(model.base_index),
training_vectors=self.training_vectors,
)
wrapper.set_io(self.io)
codec = faiss.clone_index(model)
codec.base_index, meta, requires = wrapper.fetch_codec(
dry_run=dry_run
)
if requires is not None:
return None, None, requires
t_aggregate += meta["training_time"]
assert codec.base_index.is_trained
else:
codec = model
if self.factory != "Flat":
if dry_run:
return None, None, ""
logger.info(f"assemble, train {self.factory}")
xt = self.io.get_dataset(self.training_vectors)
_, t, _ = timer("train", lambda: codec.train(xt), once=True)
t_aggregate += t
return codec, t_aggregate, None