# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import os from collections import OrderedDict from copy import copy from dataclasses import dataclass from typing import ClassVar, Dict, List, Optional import faiss # @manual=//faiss/python:pyfaiss_gpu import numpy as np from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu knn_intersection_measure, OperatingPointsWithRanges, ) from faiss.contrib.factory_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu reverse_index_factory, ) from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu add_preassigned, replace_ivf_quantizer, ) from .descriptors import DatasetDescriptor from .utils import ( distance_ratio_measure, get_cpu_info, refine_distances_knn, refine_distances_range, timer, ) logger = logging.getLogger(__name__) # The classes below are wrappers around Faiss indices, with different # implementations for the case when we start with an already trained # index (IndexFromCodec) vs factory strings (IndexFromFactory). # In both cases the classes have operations for adding to an index # and searching it, and outputs are cached on disk. # IndexFromFactory also decomposes the index (pretransform and quantizer) # and trains sub-indices independently. class IndexBase: def set_io(self, benchmark_io): self.io = benchmark_io @staticmethod def param_dict_list_to_name(param_dict_list): if not param_dict_list: return "" l = 0 n = "" for param_dict in param_dict_list: n += IndexBase.param_dict_to_name(param_dict, f"cp{l}") l += 1 return n @staticmethod def param_dict_to_name(param_dict, prefix="sp"): if not param_dict: return "" n = prefix for name, val in param_dict.items(): if name == "snap": continue if name == "lsq_gpu" and val == 0: continue if name == "use_beam_LUT" and val == 0: continue n += f"_{name}_{val}" if n == prefix: return "" n += "." return n @staticmethod def set_index_param_dict_list(index, param_dict_list, assert_same=False): if not param_dict_list: return index = faiss.downcast_index(index) for param_dict in param_dict_list: assert index is not None IndexBase.set_index_param_dict(index, param_dict, assert_same) index = faiss.try_extract_index_ivf(index) if index is not None: index = index.quantizer @staticmethod def set_index_param_dict(index, param_dict, assert_same=False): if not param_dict: return for name, val in param_dict.items(): IndexBase.set_index_param(index, name, val, assert_same) @staticmethod def set_index_param(index, name, val, assert_same=False): index = faiss.downcast_index(index) val = int(val) if isinstance(index, faiss.IndexPreTransform): Index.set_index_param(index.index, name, val) return elif name == "snap": return elif name == "lsq_gpu": if val == 1: ngpus = faiss.get_num_gpus() icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus) if isinstance(index, faiss.IndexProductLocalSearchQuantizer): for i in range(index.plsq.nsplits): lsq = faiss.downcast_Quantizer( index.plsq.subquantizer(i) ) if lsq.icm_encoder_factory is None: lsq.icm_encoder_factory = icm_encoder_factory else: if index.lsq.icm_encoder_factory is None: index.lsq.icm_encoder_factory = icm_encoder_factory return elif name in ["efSearch", "efConstruction"]: obj = index.hnsw elif name in ["nprobe", "parallel_mode"]: obj = faiss.extract_index_ivf(index) elif name in ["use_beam_LUT", "max_beam_size"]: if isinstance(index, faiss.IndexProductResidualQuantizer): obj = [ faiss.downcast_Quantizer(index.prq.subquantizer(i)) for i in range(index.prq.nsplits) ] else: obj = index.rq elif name == "encode_ils_iters": if isinstance(index, faiss.IndexProductLocalSearchQuantizer): obj = [ faiss.downcast_Quantizer(index.plsq.subquantizer(i)) for i in range(index.plsq.nsplits) ] else: obj = index.lsq else: obj = index if not isinstance(obj, list): obj = [obj] for o in obj: test = getattr(o, name) if assert_same and not name == "use_beam_LUT": assert test == val else: setattr(o, name, val) @staticmethod def filter_index_param_dict_list(param_dict_list): if ( param_dict_list is not None and param_dict_list[0] is not None and "k_factor" in param_dict_list[0] ): filtered = copy(param_dict_list) del filtered[0]["k_factor"] return filtered else: return param_dict_list def is_flat(self): model = faiss.downcast_index(self.get_model()) return isinstance(model, faiss.IndexFlat) def is_ivf(self): return False model = self.get_model() return faiss.try_extract_index_ivf(model) is not None def is_2layer(self): def is_2layer_(index): index = faiss.downcast_index(index) if isinstance(index, faiss.IndexPreTransform): return is_2layer_(index.index) return isinstance(index, faiss.Index2Layer) model = self.get_model() return is_2layer_(model) def is_decode_supported(self): model = self.get_model() if isinstance(model, faiss.IndexPreTransform): for i in range(model.chain.size()): vt = faiss.downcast_VectorTransform(model.chain.at(i)) if isinstance(vt, faiss.ITQTransform): return False return True def is_pretransform(self): codec = self.get_model() if isinstance(codec, faiss.IndexRefine): codec = faiss.downcast_index(codec.base_index) return isinstance(codec, faiss.IndexPreTransform) # index is a codec + database vectors # in other words: a trained Faiss index # that contains database vectors def get_index_name(self): raise NotImplementedError def get_index(self): raise NotImplementedError # codec is a trained model # in other words: a trained Faiss index # without any database vectors def get_codec_name(self): raise NotImplementedError def get_codec(self): raise NotImplementedError # model is an untrained Faiss index # it can be used for training (see codec) # or to inspect its structure def get_model_name(self): raise NotImplementedError def get_model(self): raise NotImplementedError def get_construction_params(self): raise NotImplementedError def transform(self, vectors): transformed_vectors = DatasetDescriptor( tablename=f"{vectors.get_filename()}{self.get_codec_name()}transform.npy" ) if not self.io.file_exist(transformed_vectors.tablename): codec = self.get_codec() assert isinstance(codec, faiss.IndexPreTransform) transform = faiss.downcast_VectorTransform(codec.chain.at(0)) x = self.io.get_dataset(vectors) xt = transform.apply(x) self.io.write_nparray(xt, transformed_vectors.tablename) return transformed_vectors def snap(self, vectors): transformed_vectors = DatasetDescriptor( tablename=f"{vectors.get_filename()}{self.get_codec_name()}snap.npy" ) if not self.io.file_exist(transformed_vectors.tablename): codec = self.get_codec() x = self.io.get_dataset(vectors) xt = codec.sa_decode(codec.sa_encode(x)) self.io.write_nparray(xt, transformed_vectors.tablename) return transformed_vectors def knn_search_quantizer(self, query_vectors, k): if self.is_pretransform(): pretransform = self.get_pretransform() quantizer_query_vectors = pretransform.transform(query_vectors) else: pretransform = None quantizer_query_vectors = query_vectors quantizer, _, _ = self.get_quantizer( dry_run=False, pretransform=pretransform ) QD, QI, _, QP, _ = quantizer.knn_search( dry_run=False, search_parameters=None, query_vectors=quantizer_query_vectors, k=k, ) xqt = self.io.get_dataset(quantizer_query_vectors) return xqt, QD, QI, QP def get_knn_search_name( self, search_parameters: Optional[Dict[str, int]], query_vectors: DatasetDescriptor, k: int, reconstruct: bool = False, ): name = self.get_index_name() name += Index.param_dict_to_name(search_parameters) name += query_vectors.get_filename("q") name += f"k_{k}." name += f"t_{self.num_threads}." if reconstruct: name += "rec." else: name += "knn." return name def knn_search( self, dry_run, search_parameters: Optional[Dict[str, int]], query_vectors: DatasetDescriptor, k: int, I_gt=None, D_gt=None, ): logger.info("knn_search: begin") if search_parameters is not None and search_parameters["snap"] == 1: query_vectors = self.snap(query_vectors) filename = ( self.get_knn_search_name(search_parameters, query_vectors, k) + "zip" ) if self.io.file_exist(filename): logger.info(f"Using cached results for {filename}") D, I, R, P = self.io.read_file(filename, ["D", "I", "R", "P"]) else: if dry_run: return None, None, None, None, filename index = self.get_index() Index.set_index_param_dict(index, search_parameters) if self.is_2layer(): # Index2Layer doesn't support search xq = self.io.get_dataset(query_vectors) xb = index.reconstruct_n(0, index.ntotal) (D, I), t, _ = timer( "knn_search 2layer", lambda: faiss.knn(xq, xb, k) ) elif self.is_ivf() and not isinstance(index, faiss.IndexRefine): index_ivf = faiss.extract_index_ivf(index) nprobe = ( search_parameters["nprobe"] if search_parameters is not None and "nprobe" in search_parameters else index_ivf.nprobe ) xqt, QD, QI, QP = self.knn_search_quantizer( query_vectors=query_vectors, k=nprobe, ) if index_ivf.parallel_mode != 2: logger.info("Setting IVF parallel mode") index_ivf.parallel_mode = 2 (D, I), t, repeat = timer( "knn_search_preassigned", lambda: index_ivf.search_preassigned(xqt, k, QI, QD), ) # Dref, Iref = index.search(xq, k) # np.testing.assert_array_equal(I, Iref) # np.testing.assert_allclose(D, Dref) else: xq = self.io.get_dataset(query_vectors) (D, I), t, _ = timer("knn_search", lambda: index.search(xq, k)) if self.is_flat() or not hasattr(self, "database_vectors"): # TODO R = D else: xq = self.io.get_dataset(query_vectors) xb = self.io.get_dataset(self.database_vectors) R = refine_distances_knn(xq, xb, I, self.metric_type) P = { "time": t, "k": k, } if self.is_ivf() and not isinstance(index, faiss.IndexRefine): stats = faiss.cvar.indexIVF_stats P |= { "quantizer": QP, "nq": int(stats.nq // repeat), "nlist": int(stats.nlist // repeat), "ndis": int(stats.ndis // repeat), "nheap_updates": int(stats.nheap_updates // repeat), "quantization_time": int( stats.quantization_time // repeat ), "search_time": int(stats.search_time // repeat), } self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P]) P |= { "index": self.get_index_name(), "codec": self.get_codec_name(), "factory": self.get_model_name(), "construction_params": self.get_construction_params(), "search_params": search_parameters, "knn_intersection": knn_intersection_measure( I, I_gt, ) if I_gt is not None else None, "distance_ratio": distance_ratio_measure( I, R, D_gt, self.metric_type, ) if D_gt is not None else None, } logger.info("knn_search: end") return D, I, R, P, None def reconstruct( self, dry_run, parameters: Optional[Dict[str, int]], query_vectors: DatasetDescriptor, k: int, I_gt, ): logger.info("reconstruct: begin") filename = ( self.get_knn_search_name( parameters, query_vectors, k, reconstruct=True ) + "zip" ) if self.io.file_exist(filename): logger.info(f"Using cached results for {filename}") (P,) = self.io.read_file(filename, ["P"]) P["index"] = self.get_index_name() P["codec"] = self.get_codec_name() P["factory"] = self.get_model_name() P["reconstruct_params"] = parameters P["construction_params"] = self.get_construction_params() else: if dry_run: return None, filename codec = self.get_codec() codec_meta = self.fetch_meta() Index.set_index_param_dict(codec, parameters) xb = self.io.get_dataset(self.database_vectors) xb_encoded, encode_t, _ = timer( "sa_encode", lambda: codec.sa_encode(xb) ) xq = self.io.get_dataset(query_vectors) if self.is_decode_supported(): xb_decoded, decode_t, _ = timer( "sa_decode", lambda: codec.sa_decode(xb_encoded) ) mse = np.square(xb_decoded - xb).sum(axis=1).mean().item() _, I = faiss.knn(xq, xb_decoded, k, metric=self.metric_type) asym_recall = knn_intersection_measure(I, I_gt) xq_decoded = codec.sa_decode(codec.sa_encode(xq)) _, I = faiss.knn( xq_decoded, xb_decoded, k, metric=self.metric_type ) else: mse = None asym_recall = None decode_t = None # assume hamming for sym xq_encoded = codec.sa_encode(xq) bin = faiss.IndexBinaryFlat(xq_encoded.shape[1] * 8) bin.add(xb_encoded) _, I = bin.search(xq_encoded, k) sym_recall = knn_intersection_measure(I, I_gt) P = { "encode_time": encode_t, "decode_time": decode_t, "mse": mse, "sym_recall": sym_recall, "asym_recall": asym_recall, "cpu": get_cpu_info(), "num_threads": self.num_threads, "index": self.get_index_name(), "codec": self.get_codec_name(), "factory": self.get_model_name(), "reconstruct_params": parameters, "construction_params": self.get_construction_params(), "codec_meta": codec_meta, } self.io.write_file(filename, ["P"], [P]) logger.info("reconstruct: end") return P, None def get_range_search_name( self, search_parameters: Optional[Dict[str, int]], query_vectors: DatasetDescriptor, radius: Optional[float] = None, ): name = self.get_index_name() name += Index.param_dict_to_name(search_parameters) name += query_vectors.get_filename("q") if radius is not None: name += f"r_{int(radius * 1000)}." else: name += "r_auto." return name def range_search( self, dry_run, search_parameters: Optional[Dict[str, int]], query_vectors: DatasetDescriptor, radius: Optional[float] = None, ): logger.info("range_search: begin") if search_parameters is not None and search_parameters["snap"] == 1: query_vectors = self.snap(query_vectors) filename = ( self.get_range_search_name( search_parameters, query_vectors, radius ) + "zip" ) if self.io.file_exist(filename): logger.info(f"Using cached results for {filename}") lims, D, I, R, P = self.io.read_file( filename, ["lims", "D", "I", "R", "P"] ) else: if dry_run: return None, None, None, None, None, filename xq = self.io.get_dataset(query_vectors) index = self.get_index() Index.set_index_param_dict(index, search_parameters) if self.is_ivf(): xqt, QD, QI, QP = self.knn_search_quantizer( query_vectors, search_parameters["nprobe"] ) index_ivf = faiss.extract_index_ivf(index) if index_ivf.parallel_mode != 2: logger.info("Setting IVF parallel mode") index_ivf.parallel_mode = 2 (lims, D, I), t, repeat = timer( "range_search_preassigned", lambda: index_ivf.range_search_preassigned( xqt, radius, QI, QD ), ) else: (lims, D, I), t, _ = timer( "range_search", lambda: index.range_search(xq, radius) ) if self.is_flat(): R = D else: xb = self.io.get_dataset(self.database_vectors) R = refine_distances_range( lims, D, I, xq, xb, self.metric_type ) P = { "time": t, "radius": radius, "count": len(I), } if self.is_ivf(): stats = faiss.cvar.indexIVF_stats P |= { "quantizer": QP, "nq": int(stats.nq // repeat), "nlist": int(stats.nlist // repeat), "ndis": int(stats.ndis // repeat), "nheap_updates": int(stats.nheap_updates // repeat), "quantization_time": int( stats.quantization_time // repeat ), "search_time": int(stats.search_time // repeat), } self.io.write_file( filename, ["lims", "D", "I", "R", "P"], [lims, D, I, R, P] ) P |= { "index": self.get_index_name(), "codec": self.get_codec_name(), "factory": self.get_model_name(), "construction_params": self.get_construction_params(), "search_params": search_parameters, } logger.info("range_seach: end") return lims, D, I, R, P, None # Common base for IndexFromCodec and IndexFromFactory, # but not for the sub-indices of codec-based indices # IndexFromQuantizer and IndexFromPreTransform, because # they share the configuration of their parent IndexFromCodec @dataclass class Index(IndexBase): num_threads: int d: int metric: str database_vectors: DatasetDescriptor construction_params: List[Dict[str, int]] search_params: Dict[str, int] cached_codec: ClassVar[OrderedDict[str, faiss.Index]] = OrderedDict() cached_index: ClassVar[OrderedDict[str, faiss.Index]] = OrderedDict() def __post_init__(self): if isinstance(self.metric, str): if self.metric == "IP": self.metric_type = faiss.METRIC_INNER_PRODUCT elif self.metric == "L2": self.metric_type = faiss.METRIC_L2 else: raise ValueError elif isinstance(self.metric, int): self.metric_type = self.metric if self.metric_type == faiss.METRIC_INNER_PRODUCT: self.metric = "IP" elif self.metric_type == faiss.METRIC_L2: self.metric = "L2" else: raise ValueError else: raise ValueError def supports_range_search(self): codec = self.get_codec() return not type(codec) in [ faiss.IndexHNSWFlat, faiss.IndexIVFFastScan, faiss.IndexRefine, faiss.IndexPQ, ] def fetch_codec(self): raise NotImplementedError def get_codec(self): codec_name = self.get_codec_name() if codec_name not in Index.cached_codec: Index.cached_codec[codec_name], _, _ = self.fetch_codec() if len(Index.cached_codec) > 1: Index.cached_codec.popitem(last=False) return Index.cached_codec[codec_name] def get_index_name(self): name = self.get_codec_name() assert self.database_vectors is not None name += self.database_vectors.get_filename("xb") return name def fetch_index(self): index = self.get_codec() index.reset() assert index.ntotal == 0 logger.info("Adding vectors to index") xb = self.io.get_dataset(self.database_vectors) if self.is_ivf() and not isinstance(index, faiss.IndexRefine): xbt, QD, QI, QP = self.knn_search_quantizer( query_vectors=self.database_vectors, k=1, ) index_ivf = faiss.extract_index_ivf(index) if index_ivf.parallel_mode != 2: logger.info("Setting IVF parallel mode") index_ivf.parallel_mode = 2 _, t, _ = timer( "add_preassigned", lambda: add_preassigned(index_ivf, xbt, QI.ravel()), once=True, ) else: _, t, _ = timer( "add", lambda: index.add(xb), once=True, ) assert index.ntotal == xb.shape[0] or index_ivf.ntotal == xb.shape[0] logger.info("Added vectors to index") return index, t def get_index(self): index_name = self.get_index_name() if index_name not in Index.cached_index: Index.cached_index[index_name], _ = self.fetch_index() if len(Index.cached_index) > 3: Index.cached_index.popitem(last=False) return Index.cached_index[index_name] def get_construction_params(self): return self.construction_params def get_code_size(self, codec=None): def get_index_code_size(index): index = faiss.downcast_index(index) if isinstance(index, faiss.IndexPreTransform): return get_index_code_size(index.index) elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]: return get_index_code_size( index.base_index ) + get_index_code_size(index.refine_index) else: return index.code_size if hasattr(index, "code_size") else 0 if codec is None: codec = self.get_codec() return get_index_code_size(codec) def get_sa_code_size(self, codec=None): if codec is None: codec = self.get_codec() try: return codec.sa_code_size() except: return None def get_operating_points(self): op = OperatingPointsWithRanges() def add_range_or_val(name, range): op.add_range( name, [self.search_params[name]] if self.search_params and name in self.search_params else range, ) add_range_or_val("snap", [0]) model = self.get_model() model_ivf = faiss.try_extract_index_ivf(model) if model_ivf is not None: add_range_or_val( "nprobe", [2**i for i in range(12) if 2**i <= model_ivf.nlist * 0.5], # [1, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28] + [ # i # for i in range(32, 64, 8) # if i <= model_ivf.nlist * 0.1 # ] + [ # i # for i in range(64, 128, 16) # if i <= model_ivf.nlist * 0.1 # ] + [ # i # for i in range(128, 256, 32) # if i <= model_ivf.nlist * 0.1 # ] + [ # i # for i in range(256, 512, 64) # if i <= model_ivf.nlist * 0.1 # ] + [ # 2**i # for i in range(9, 12) # if 2**i <= model_ivf.nlist * 0.1 # ], ) model = faiss.downcast_index(model) if isinstance(model, faiss.IndexRefine): add_range_or_val( "k_factor", [2**i for i in range(13)], ) elif isinstance(model, faiss.IndexHNSWFlat): add_range_or_val( "efSearch", [2**i for i in range(3, 11)], ) elif isinstance(model, faiss.IndexResidualQuantizer) or isinstance( model, faiss.IndexProductResidualQuantizer ): add_range_or_val( "max_beam_size", [1, 2, 4, 8, 16, 32], ) add_range_or_val( "use_beam_LUT", [1], ) elif isinstance(model, faiss.IndexLocalSearchQuantizer) or isinstance( model, faiss.IndexProductLocalSearchQuantizer ): add_range_or_val( "encode_ils_iters", [2, 4, 8, 16], ) add_range_or_val( "lsq_gpu", [1], ) return op # IndexFromCodec, IndexFromQuantizer and IndexFromPreTransform # are used to wrap pre-trained Faiss indices (codecs) @dataclass class IndexFromCodec(Index): path: str bucket: Optional[str] = None def get_quantizer(self): if not self.is_ivf(): raise ValueError("Not an IVF index") quantizer = IndexFromQuantizer(self) quantizer.set_io(self.io) return quantizer def get_pretransform(self): if not self.is_ivf(): raise ValueError("Not an IVF index") quantizer = IndexFromPreTransform(self) quantizer.set_io(self.io) return quantizer def get_model_name(self): return os.path.basename(self.path) def get_codec_name(self): assert self.path is not None name = os.path.basename(self.path) name += Index.param_dict_list_to_name(self.construction_params) return name def fetch_codec(self): codec = self.io.read_index( os.path.basename(self.path), self.bucket, os.path.dirname(self.path), ) assert self.d == codec.d assert self.metric_type == codec.metric_type Index.set_index_param_dict_list(codec, self.construction_params) return codec, None, None def get_model(self): return self.get_codec() class IndexFromQuantizer(IndexBase): ivf_index: Index def __init__(self, ivf_index: Index): self.ivf_index = ivf_index super().__init__() def get_model_name(self): return self.get_index_name() def get_codec_name(self): return self.get_index_name() def get_codec(self): return self.get_index() def get_index_name(self): ivf_codec_name = self.ivf_index.get_codec_name() return f"{ivf_codec_name}quantizer." def get_index(self): ivf_codec = faiss.extract_index_ivf(self.ivf_index.get_codec()) return ivf_codec.quantizer class IndexFromPreTransform(IndexBase): pre_transform_index: Index def __init__(self, pre_transform_index: Index): self.pre_transform_index = pre_transform_index super().__init__() def get_codec_name(self): pre_transform_codec_name = self.pre_transform_index.get_codec_name() return f"{pre_transform_codec_name}pretransform." def get_codec(self): return self.get_codec() # IndexFromFactory is for creating and training indices from scratch @dataclass class IndexFromFactory(Index): factory: str training_vectors: DatasetDescriptor def get_codec_name(self): assert self.factory is not None name = f"{self.factory.replace(',', '_')}." assert self.d is not None assert self.metric is not None name += f"d_{self.d}.{self.metric.upper()}." if self.factory != "Flat": assert self.training_vectors is not None name += self.training_vectors.get_filename("xt") name += Index.param_dict_list_to_name(self.construction_params) return name def fetch_meta(self, dry_run=False): meta_filename = self.get_codec_name() + "json" if self.io.file_exist(meta_filename): meta = self.io.read_json(meta_filename) report = None else: _, meta, report = self.fetch_codec(dry_run=dry_run) return meta, report def fetch_codec(self, dry_run=False): codec_filename = self.get_codec_name() + "codec" meta_filename = self.get_codec_name() + "json" if self.io.file_exist(codec_filename) and self.io.file_exist( meta_filename ): codec = self.io.read_index(codec_filename) assert self.d == codec.d assert self.metric_type == codec.metric_type meta = self.io.read_json(meta_filename) else: codec, training_time, requires = self.assemble(dry_run=dry_run) if requires is not None: assert dry_run if requires == "": return None, None, codec_filename else: return None, None, requires codec_size = self.io.write_index(codec, codec_filename) assert codec_size is not None meta = { "training_time": training_time, "training_size": self.training_vectors.num_vectors, "codec_size": codec_size, "sa_code_size": self.get_sa_code_size(codec), "code_size": self.get_code_size(codec), "cpu": get_cpu_info(), } self.io.write_json(meta, meta_filename, overwrite=True) Index.set_index_param_dict_list( codec, self.construction_params, assert_same=True ) return codec, meta, None def get_model_name(self): return self.factory def get_model(self): model = faiss.index_factory(self.d, self.factory, self.metric_type) Index.set_index_param_dict_list(model, self.construction_params) return model def get_pretransform(self): model = self.get_model() assert isinstance(model, faiss.IndexPreTransform) sub_index = faiss.downcast_index(model.index) if isinstance(sub_index, faiss.IndexFlat): return self # replace the sub-index with Flat model.index = faiss.IndexFlat(model.index.d, model.index.metric_type) pretransform = IndexFromFactory( num_threads=self.num_threads, d=model.d, metric=model.metric_type, database_vectors=self.database_vectors, construction_params=self.construction_params, search_params=None, factory=reverse_index_factory(model), training_vectors=self.training_vectors, ) pretransform.set_io(self.io) return pretransform def get_quantizer(self, dry_run, pretransform=None): model = self.get_model() model_ivf = faiss.extract_index_ivf(model) assert isinstance(model_ivf, faiss.IndexIVF) assert ord(model_ivf.quantizer_trains_alone) in [0, 2] if pretransform is None: training_vectors = self.training_vectors else: training_vectors = pretransform.transform(self.training_vectors) centroids, t, requires = training_vectors.k_means( self.io, model_ivf.nlist, dry_run ) if requires is not None: return None, None, requires quantizer = IndexFromFactory( num_threads=self.num_threads, d=model_ivf.quantizer.d, metric=model_ivf.quantizer.metric_type, database_vectors=centroids, construction_params=self.construction_params[1:] if self.construction_params is not None else None, search_params=None, factory=reverse_index_factory(model_ivf.quantizer), training_vectors=centroids, ) quantizer.set_io(self.io) return quantizer, t, None def assemble(self, dry_run): logger.info(f"assemble {self.factory}") model = self.get_model() opaque = True t_aggregate = 0 # try: # reverse_index_factory(model) # opaque = False # except NotImplementedError: # opaque = True if opaque: codec = model else: if isinstance(model, faiss.IndexPreTransform): logger.info(f"assemble: pretransform {self.factory}") sub_index = faiss.downcast_index(model.index) if not isinstance(sub_index, faiss.IndexFlat): # replace the sub-index with Flat and fetch pre-trained pretransform = self.get_pretransform() codec, meta, report = pretransform.fetch_codec( dry_run=dry_run ) if report is not None: return None, None, report t_aggregate += meta["training_time"] assert codec.is_trained transformed_training_vectors = pretransform.transform( self.training_vectors ) # replace the Flat index with the required sub-index wrapper = IndexFromFactory( num_threads=self.num_threads, d=sub_index.d, metric=sub_index.metric_type, database_vectors=None, construction_params=self.construction_params, search_params=None, factory=reverse_index_factory(sub_index), training_vectors=transformed_training_vectors, ) wrapper.set_io(self.io) codec.index, meta, report = wrapper.fetch_codec( dry_run=dry_run ) if report is not None: return None, None, report t_aggregate += meta["training_time"] assert codec.index.is_trained else: codec = model elif isinstance(model, faiss.IndexIVF): logger.info(f"assemble: ivf {self.factory}") # replace the quantizer quantizer, t, requires = self.get_quantizer(dry_run=dry_run) if requires is not None: return None, None, requires t_aggregate += t codec = faiss.clone_index(model) quantizer_index, t = quantizer.fetch_index() t_aggregate += t replace_ivf_quantizer(codec, quantizer_index) assert codec.quantizer.is_trained assert codec.nlist == codec.quantizer.ntotal elif isinstance(model, faiss.IndexRefine) or isinstance( model, faiss.IndexRefineFlat ): logger.info(f"assemble: refine {self.factory}") # replace base_index wrapper = IndexFromFactory( num_threads=self.num_threads, d=model.base_index.d, metric=model.base_index.metric_type, database_vectors=self.database_vectors, construction_params=IndexBase.filter_index_param_dict_list( self.construction_params ), search_params=None, factory=reverse_index_factory(model.base_index), training_vectors=self.training_vectors, ) wrapper.set_io(self.io) codec = faiss.clone_index(model) codec.base_index, meta, requires = wrapper.fetch_codec( dry_run=dry_run ) if requires is not None: return None, None, requires t_aggregate += meta["training_time"] assert codec.base_index.is_trained else: codec = model if self.factory != "Flat": if dry_run: return None, None, "" logger.info(f"assemble, train {self.factory}") xt = self.io.get_dataset(self.training_vectors) _, t, _ = timer("train", lambda: codec.train(xt), once=True) t_aggregate += t return codec, t_aggregate, None