# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import faiss import numpy as np import os from tqdm import tqdm, trange import sys import logging from faiss.contrib.ondisk import merge_ondisk from faiss.contrib.big_batch_search import big_batch_search from faiss.contrib.exhaustive_search import knn_ground_truth from faiss.contrib.evaluation import knn_intersection_measure from utils import ( get_intersection_cardinality_frequencies, margin, is_pretransform_index, ) from dataset import create_dataset_from_oivf_config logging.basicConfig( format=( "%(asctime)s.%(msecs)03d %(levelname)-8s %(threadName)-12s %(message)s" ), level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", force=True, ) EMBEDDINGS_BATCH_SIZE: int = 100_000 NUM_SUBSAMPLES: int = 100 SMALL_DATA_SAMPLE: int = 10000 class OfflineIVF: def __init__(self, cfg, args, nprobe, index_factory_str): self.input_d = cfg["d"] self.dt = cfg["datasets"][args.xb]["files"][0]["dtype"] assert self.input_d > 0 output_dir = cfg["output"] assert os.path.exists(output_dir) self.index_factory = index_factory_str assert self.index_factory is not None self.index_factory_fn = self.index_factory.replace(",", "_") self.index_template_file = ( f"{output_dir}/{args.xb}/{self.index_factory_fn}.empty.faissindex" ) logging.info(f"index template: {self.index_template_file}") if not args.xq: args.xq = args.xb self.by_residual = True if args.no_residuals: self.by_residual = False xb_output_dir = f"{output_dir}/{args.xb}" if not os.path.exists(xb_output_dir): os.makedirs(xb_output_dir) xq_output_dir = f"{output_dir}/{args.xq}" if not os.path.exists(xq_output_dir): os.makedirs(xq_output_dir) search_output_dir = f"{output_dir}/{args.xq}_in_{args.xb}" if not os.path.exists(search_output_dir): os.makedirs(search_output_dir) self.knn_dir = f"{search_output_dir}/knn" if not os.path.exists(self.knn_dir): os.makedirs(self.knn_dir) self.eval_dir = f"{search_output_dir}/eval" if not os.path.exists(self.eval_dir): os.makedirs(self.eval_dir) self.index = {} # to keep a reference to opened indices, self.ivls = {} # hstack inverted lists, self.index_shards = {} # and index shards self.index_shard_prefix = ( f"{xb_output_dir}/{self.index_factory_fn}.shard_" ) self.xq_index_shard_prefix = ( f"{xq_output_dir}/{self.index_factory_fn}.shard_" ) self.index_file = ( # TODO: added back temporarily for evaluate, handle name of non-sharded index file and remove. f"{xb_output_dir}/{self.index_factory_fn}.faissindex" ) self.xq_index_file = ( f"{xq_output_dir}/{self.index_factory_fn}.faissindex" ) self.training_sample = cfg["training_sample"] self.evaluation_sample = cfg["evaluation_sample"] self.xq_ds = create_dataset_from_oivf_config(cfg, args.xq) self.xb_ds = create_dataset_from_oivf_config(cfg, args.xb) file_descriptors = self.xq_ds.file_descriptors self.file_sizes = [fd.size for fd in file_descriptors] self.shard_size = cfg["index_shard_size"] # ~100GB self.nshards = self.xb_ds.size // self.shard_size if self.xb_ds.size % self.shard_size != 0: self.nshards += 1 self.xq_nshards = self.xq_ds.size // self.shard_size if self.xq_ds.size % self.shard_size != 0: self.xq_nshards += 1 self.nprobe = nprobe assert self.nprobe > 0, "Invalid nprobe parameter." if "deduper" in cfg: self.deduper = cfg["deduper"] self.deduper_codec_fn = [ f"{xb_output_dir}/deduper_codec_{codec.replace(',', '_')}" for codec in self.deduper ] self.deduper_idx_fn = [ f"{xb_output_dir}/deduper_idx_{codec.replace(',', '_')}" for codec in self.deduper ] else: self.deduper = None self.k = cfg["k"] assert self.k > 0, "Invalid number of neighbours parameter." self.knn_output_file_suffix = ( f"{self.index_factory_fn}_np{self.nprobe}.npy" ) fp = 32 if self.dt == "float16": fp = 16 self.xq_bs = cfg["query_batch_size"] if "metric" in cfg: self.metric = eval(f'faiss.{cfg["metric"]}') else: self.metric = faiss.METRIC_L2 if "evaluate_by_margin" in cfg: self.evaluate_by_margin = cfg["evaluate_by_margin"] else: self.evaluate_by_margin = False os.system("grep -m1 'model name' < /proc/cpuinfo") os.system("grep -E 'MemTotal|MemFree' /proc/meminfo") os.system("nvidia-smi") os.system("nvcc --version") self.knn_queries_memory_limit = 4 * 1024 * 1024 * 1024 # 4 GB self.knn_vectors_memory_limit = 8 * 1024 * 1024 * 1024 # 8 GB def input_stats(self): """ Trains the index using a subsample of the first chunk of data in the database and saves it in the template file (with no vectors added). """ xb_sample = self.xb_ds.get_first_n(self.training_sample, np.float32) logging.info(f"input shape: {xb_sample.shape}") logging.info("running MatrixStats on training sample...") logging.info(faiss.MatrixStats(xb_sample).comments) logging.info("done") def dedupe(self): logging.info(self.deduper) if self.deduper is None: logging.info("No deduper configured") return codecs = [] codesets = [] idxs = [] for factory, filename in zip(self.deduper, self.deduper_codec_fn): if os.path.exists(filename): logging.info(f"loading trained dedupe codec: {filename}") codec = faiss.read_index(filename) else: logging.info(f"training dedupe codec: {factory}") codec = faiss.index_factory(self.input_d, factory) xb_sample = np.unique( self.xb_ds.get_first_n(100_000, np.float32), axis=0 ) faiss.ParameterSpace().set_index_parameter(codec, "verbose", 1) codec.train(xb_sample) logging.info(f"writing trained dedupe codec: {filename}") faiss.write_index(codec, filename) codecs.append(codec) codesets.append(faiss.CodeSet(codec.sa_code_size())) idxs.append(np.empty((0,), dtype=np.uint32)) bs = 1_000_000 i = 0 for buffer in tqdm(self._iterate_transformed(self.xb_ds, 0, bs, np.float32)): for j in range(len(codecs)): codec, codeset, idx = codecs[j], codesets[j], idxs[j] uniq = codeset.insert(codec.sa_encode(buffer)) idxs[j] = np.append( idx, np.arange(i, i + buffer.shape[0], dtype=np.uint32)[uniq], ) i += buffer.shape[0] for idx, filename in zip(idxs, self.deduper_idx_fn): logging.info(f"writing {filename}, shape: {idx.shape}") np.save(filename, idx) logging.info("done") def train_index(self): """ Trains the index using a subsample of the first chunk of data in the database and saves it in the template file (with no vectors added). """ assert not os.path.exists(self.index_template_file), ( "The train command has been ran, the index template file already" " exists." ) xb_sample = np.unique( self.xb_ds.get_first_n(self.training_sample, np.float32), axis=0 ) logging.info(f"input shape: {xb_sample.shape}") index = faiss.index_factory( self.input_d, self.index_factory, self.metric ) index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) index_ivf.by_residual = True faiss.ParameterSpace().set_index_parameter(index, "verbose", 1) logging.info("running training...") index.train(xb_sample) logging.info(f"writing trained index {self.index_template_file}...") faiss.write_index(index, self.index_template_file) logging.info("done") def _iterate_transformed(self, ds, start, batch_size, dt): assert os.path.exists(self.index_template_file) index = faiss.read_index(self.index_template_file) if is_pretransform_index(index): vt = index.chain.at(0) # fetch pretransform for buffer in ds.iterate(start, batch_size, dt): yield vt.apply(buffer) else: for buffer in ds.iterate(start, batch_size, dt): yield buffer def index_shard(self): assert os.path.exists(self.index_template_file) index = faiss.read_index(self.index_template_file) index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) assert self.nprobe <= index_ivf.quantizer.ntotal, ( f"the number of vectors {index_ivf.quantizer.ntotal} is not enough" f" to retrieve {self.nprobe} neighbours, check." ) cpu_quantizer = index_ivf.quantizer gpu_quantizer = faiss.index_cpu_to_all_gpus(cpu_quantizer) for i in range(0, self.nshards): sfn = f"{self.index_shard_prefix}{i}" try: index.reset() index_ivf.quantizer = gpu_quantizer with open(sfn, "xb"): start = i * self.shard_size jj = 0 embeddings_batch_size = min( EMBEDDINGS_BATCH_SIZE, self.shard_size ) assert ( self.shard_size % embeddings_batch_size == 0 or EMBEDDINGS_BATCH_SIZE % embeddings_batch_size == 0 ), ( f"the shard size {self.shard_size} and embeddings" f" shard size {EMBEDDINGS_BATCH_SIZE} are not" " divisible" ) for xb_j in tqdm( self._iterate_transformed( self.xb_ds, start, embeddings_batch_size, np.float32, ), file=sys.stdout, ): if is_pretransform_index(index): assert xb_j.shape[1] == index.chain.at(0).d_out index_ivf.add_with_ids( xb_j, np.arange(start + jj, start + jj + xb_j.shape[0]), ) else: assert xb_j.shape[1] == index.d index.add_with_ids( xb_j, np.arange(start + jj, start + jj + xb_j.shape[0]), ) jj += xb_j.shape[0] logging.info(jj) assert ( jj <= self.shard_size ), f"jj {jj} and shard_zide {self.shard_size}" if jj == self.shard_size: break logging.info(f"writing {sfn}...") index_ivf.quantizer = cpu_quantizer faiss.write_index(index, sfn) except FileExistsError: logging.info(f"skipping shard: {i}") continue logging.info("done") def merge_index(self): ivf_file = f"{self.index_file}.ivfdata" assert os.path.exists(self.index_template_file) assert not os.path.exists( ivf_file ), f"file with embeddings data {ivf_file} not found, check." assert not os.path.exists(self.index_file) index = faiss.read_index(self.index_template_file) block_fnames = [ f"{self.index_shard_prefix}{i}" for i in range(self.nshards) ] for fn in block_fnames: assert os.path.exists(fn) logging.info(block_fnames) logging.info("merging...") merge_ondisk(index, block_fnames, ivf_file) logging.info("writing index...") faiss.write_index(index, self.index_file) logging.info("done") def _cached_search( self, sample, xq_ds, xb_ds, idx_file, vecs_file, I_file, D_file, index_file=None, nprobe=None, ): if not os.path.exists(I_file): assert not os.path.exists(I_file), f"file {I_file} does not exist " assert not os.path.exists(D_file), f"file {D_file} does not exist " xq = xq_ds.sample(sample, idx_file, vecs_file) if index_file: D, I = self._index_nonsharded_search(index_file, xq, nprobe) else: logging.info("ground truth computations") db_iterator = xb_ds.iterate(0, 100_000, np.float32) D, I = knn_ground_truth( xq, db_iterator, self.k, metric_type=self.metric ) assert np.amin(I) >= 0 np.save(I_file, I) np.save(D_file, D) else: assert os.path.exists(idx_file), f"file {idx_file} does not exist " assert os.path.exists( vecs_file ), f"file {vecs_file} does not exist " assert os.path.exists(I_file), f"file {I_file} does not exist " assert os.path.exists(D_file), f"file {D_file} does not exist " I = np.load(I_file) D = np.load(D_file) assert I.shape == (sample, self.k), f"{I_file} shape mismatch" assert D.shape == (sample, self.k), f"{D_file} shape mismatch" return (D, I) def _index_search(self, index_shard_prefix, xq, nprobe): assert nprobe is not None logging.info( f"open sharded index: {index_shard_prefix}, {self.nshards}" ) index = self._open_sharded_index(index_shard_prefix) index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) logging.info(f"setting nprobe to {nprobe}") index_ivf.nprobe = nprobe return index.search(xq, self.k) def _index_nonsharded_search(self, index_file, xq, nprobe): assert nprobe is not None logging.info(f"index {index_file}") assert os.path.exists(index_file), f"file {index_file} does not exist " index = faiss.read_index(index_file, faiss.IO_FLAG_ONDISK_SAME_DIR) logging.info(f"index size {index.ntotal} ") index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) logging.info(f"setting nprobe to {nprobe}") index_ivf.nprobe = nprobe return index.search(xq, self.k) def _refine_distances(self, xq_ds, idx, xb_ds, I): xq = xq_ds.get(idx).repeat(self.k, axis=0) xb = xb_ds.get(I.reshape(-1)) if self.metric == faiss.METRIC_INNER_PRODUCT: return (xq * xb).sum(axis=1).reshape(I.shape) elif self.metric == faiss.METRIC_L2: return ((xq - xb) ** 2).sum(axis=1).reshape(I.shape) else: raise ValueError(f"metric not supported {self.metric}") def evaluate(self): self._evaluate( self.index_factory_fn, self.index_file, self.xq_index_file, self.nprobe, ) def _evaluate(self, index_factory_fn, index_file, xq_index_file, nprobe): idx_a_file = f"{self.eval_dir}/idx_a.npy" idx_b_gt_file = f"{self.eval_dir}/idx_b_gt.npy" idx_b_ann_file = ( f"{self.eval_dir}/idx_b_ann_{index_factory_fn}_np{nprobe}.npy" ) vecs_a_file = f"{self.eval_dir}/vecs_a.npy" vecs_b_gt_file = f"{self.eval_dir}/vecs_b_gt.npy" vecs_b_ann_file = ( f"{self.eval_dir}/vecs_b_ann_{index_factory_fn}_np{nprobe}.npy" ) D_a_gt_file = f"{self.eval_dir}/D_a_gt.npy" D_a_ann_file = ( f"{self.eval_dir}/D_a_ann_{index_factory_fn}_np{nprobe}.npy" ) D_a_ann_refined_file = f"{self.eval_dir}/D_a_ann_refined_{index_factory_fn}_np{nprobe}.npy" D_b_gt_file = f"{self.eval_dir}/D_b_gt.npy" D_b_ann_file = ( f"{self.eval_dir}/D_b_ann_{index_factory_fn}_np{nprobe}.npy" ) D_b_ann_gt_file = ( f"{self.eval_dir}/D_b_ann_gt_{index_factory_fn}_np{nprobe}.npy" ) I_a_gt_file = f"{self.eval_dir}/I_a_gt.npy" I_a_ann_file = ( f"{self.eval_dir}/I_a_ann_{index_factory_fn}_np{nprobe}.npy" ) I_b_gt_file = f"{self.eval_dir}/I_b_gt.npy" I_b_ann_file = ( f"{self.eval_dir}/I_b_ann_{index_factory_fn}_np{nprobe}.npy" ) I_b_ann_gt_file = ( f"{self.eval_dir}/I_b_ann_gt_{index_factory_fn}_np{nprobe}.npy" ) margin_gt_file = f"{self.eval_dir}/margin_gt.npy" margin_refined_file = ( f"{self.eval_dir}/margin_refined_{index_factory_fn}_np{nprobe}.npy" ) margin_ann_file = ( f"{self.eval_dir}/margin_ann_{index_factory_fn}_np{nprobe}.npy" ) logging.info("exact search forward") # xq -> xb AKA a -> b D_a_gt, I_a_gt = self._cached_search( self.evaluation_sample, self.xq_ds, self.xb_ds, idx_a_file, vecs_a_file, I_a_gt_file, D_a_gt_file, ) idx_a = np.load(idx_a_file) logging.info("approximate search forward") D_a_ann, I_a_ann = self._cached_search( self.evaluation_sample, self.xq_ds, self.xb_ds, idx_a_file, vecs_a_file, I_a_ann_file, D_a_ann_file, index_file, nprobe, ) logging.info( "calculate refined distances on approximate search forward" ) if os.path.exists(D_a_ann_refined_file): D_a_ann_refined = np.load(D_a_ann_refined_file) assert D_a_ann.shape == D_a_ann_refined.shape else: D_a_ann_refined = self._refine_distances( self.xq_ds, idx_a, self.xb_ds, I_a_ann ) np.save(D_a_ann_refined_file, D_a_ann_refined) if self.evaluate_by_margin: k_extract = self.k margin_threshold = 1.05 logging.info( "exact search backward from the k_extract NN results of" " forward search" ) # xb -> xq AKA b -> a D_a_b_gt = D_a_gt[:, :k_extract].ravel() idx_b_gt = I_a_gt[:, :k_extract].ravel() assert len(idx_b_gt) == self.evaluation_sample * k_extract np.save(idx_b_gt_file, idx_b_gt) # exact search D_b_gt, _ = self._cached_search( len(idx_b_gt), self.xb_ds, self.xq_ds, idx_b_gt_file, vecs_b_gt_file, I_b_gt_file, D_b_gt_file, ) # xb and xq ^^^ are inverted logging.info("margin on exact search") margin_gt = margin( self.evaluation_sample, idx_a, idx_b_gt, D_a_b_gt, D_a_gt, D_b_gt, self.k, k_extract, margin_threshold, ) np.save(margin_gt_file, margin_gt) logging.info( "exact search backward from the k_extract NN results of" " approximate forward search" ) D_a_b_refined = D_a_ann_refined[:, :k_extract].ravel() idx_b_ann = I_a_ann[:, :k_extract].ravel() assert len(idx_b_ann) == self.evaluation_sample * k_extract np.save(idx_b_ann_file, idx_b_ann) # exact search D_b_ann_gt, _ = self._cached_search( len(idx_b_ann), self.xb_ds, self.xq_ds, idx_b_ann_file, vecs_b_ann_file, I_b_ann_gt_file, D_b_ann_gt_file, ) # xb and xq ^^^ are inverted logging.info("refined margin on approximate search") margin_refined = margin( self.evaluation_sample, idx_a, idx_b_ann, D_a_b_refined, D_a_gt, # not D_a_ann_refined(!) D_b_ann_gt, self.k, k_extract, margin_threshold, ) np.save(margin_refined_file, margin_refined) D_b_ann, I_b_ann = self._cached_search( len(idx_b_ann), self.xb_ds, self.xq_ds, idx_b_ann_file, vecs_b_ann_file, I_b_ann_file, D_b_ann_file, xq_index_file, nprobe, ) D_a_b_ann = D_a_ann[:, :k_extract].ravel() logging.info("approximate search margin") margin_ann = margin( self.evaluation_sample, idx_a, idx_b_ann, D_a_b_ann, D_a_ann, D_b_ann, self.k, k_extract, margin_threshold, ) np.save(margin_ann_file, margin_ann) logging.info("intersection") logging.info(I_a_gt) logging.info(I_a_ann) for i in range(1, self.k + 1): logging.info( f"{i}: {knn_intersection_measure(I_a_gt[:,:i], I_a_ann[:,:i])}" ) logging.info(f"mean of gt distances: {D_a_gt.mean()}") logging.info(f"mean of approx distances: {D_a_ann.mean()}") logging.info(f"mean of refined distances: {D_a_ann_refined.mean()}") logging.info("intersection cardinality frequencies") logging.info(get_intersection_cardinality_frequencies(I_a_ann, I_a_gt)) logging.info("done") pass def _knn_function(self, xq, xb, k, metric, thread_id=None): try: return faiss.knn_gpu( self.all_gpu_resources[thread_id], xq, xb, k, metric=metric, device=thread_id, vectorsMemoryLimit=self.knn_vectors_memory_limit, queriesMemoryLimit=self.knn_queries_memory_limit, ) except Exception: logging.info(f"knn_function failed: {xq.shape}, {xb.shape}") raise def _coarse_quantize(self, index_ivf, xq, nprobe): assert nprobe <= index_ivf.quantizer.ntotal quantizer = faiss.index_cpu_to_all_gpus(index_ivf.quantizer) bs = 100_000 nq = len(xq) q_assign = np.empty((nq, nprobe), dtype="int32") for i0 in trange(0, nq, bs): i1 = min(nq, i0 + bs) _, q_assign_i = quantizer.search(xq[i0:i1], nprobe) q_assign[i0:i1] = q_assign_i return q_assign def search(self): logging.info(f"search: {self.knn_dir}") slurm_job_id = os.environ.get("SLURM_JOB_ID") ngpu = faiss.get_num_gpus() logging.info(f"number of gpus: {ngpu}") self.all_gpu_resources = [ faiss.StandardGpuResources() for _ in range(ngpu) ] self._knn_function( np.zeros((10, 10), dtype=np.float16), np.zeros((10, 10), dtype=np.float16), self.k, metric=self.metric, thread_id=0, ) index = self._open_sharded_index() index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) logging.info(f"setting nprobe to {self.nprobe}") index_ivf.nprobe = self.nprobe # quantizer = faiss.index_cpu_to_all_gpus(index_ivf.quantizer) for i in range(0, self.xq_ds.size, self.xq_bs): Ifn = f"{self.knn_dir}/I{(i):010}_{self.knn_output_file_suffix}" Dfn = f"{self.knn_dir}/D_approx{(i):010}_{self.knn_output_file_suffix}" CPfn = f"{self.knn_dir}/CP{(i):010}_{self.knn_output_file_suffix}" if slurm_job_id: worker_record = ( self.knn_dir + f"/record_{(i):010}_{self.knn_output_file_suffix}.txt" ) if not os.path.exists(worker_record): logging.info( f"creating record file {worker_record} and saving job" f" id: {slurm_job_id}" ) with open(worker_record, "w") as h: h.write(slurm_job_id) else: old_slurm_id = open(worker_record, "r").read() logging.info( f"old job slurm id {old_slurm_id} and current job id:" f" {slurm_job_id}" ) if old_slurm_id == slurm_job_id: if os.path.getsize(Ifn) == 0: logging.info( f"cleaning up zero length files {Ifn} and" f" {Dfn}" ) os.remove(Ifn) os.remove(Dfn) try: if is_pretransform_index(index): d = index.chain.at(0).d_out else: d = self.input_d with open(Ifn, "xb") as f, open(Dfn, "xb") as g: xq_i = np.empty( shape=(self.xq_bs, d), dtype=np.float16 ) q_assign = np.empty( (self.xq_bs, self.nprobe), dtype=np.int32 ) j = 0 quantizer = faiss.index_cpu_to_all_gpus( index_ivf.quantizer ) for xq_i_j in tqdm( self._iterate_transformed( self.xq_ds, i, min(100_000, self.xq_bs), np.float16 ), file=sys.stdout, ): xq_i[j:j + xq_i_j.shape[0]] = xq_i_j ( _, q_assign[j:j + xq_i_j.shape[0]], ) = quantizer.search(xq_i_j, self.nprobe) j += xq_i_j.shape[0] assert j <= xq_i.shape[0] if j == xq_i.shape[0]: break xq_i = xq_i[:j] q_assign = q_assign[:j] assert q_assign.shape == (xq_i.shape[0], index_ivf.nprobe) del quantizer logging.info(f"computing: {Ifn}") logging.info(f"computing: {Dfn}") prefetch_threads = faiss.get_num_gpus() D_ann, I = big_batch_search( index_ivf, xq_i, self.k, verbose=10, method="knn_function", knn=self._knn_function, threaded=faiss.get_num_gpus() * 8, use_float16=True, prefetch_threads=prefetch_threads, computation_threads=faiss.get_num_gpus(), q_assign=q_assign, checkpoint=CPfn, checkpoint_freq=7200, # in seconds ) assert ( np.amin(I) >= 0 ), f"{I}, there exists negative indices, check" logging.info(f"saving: {Ifn}") np.save(f, I) logging.info(f"saving: {Dfn}") np.save(g, D_ann) if os.path.exists(CPfn): logging.info(f"removing: {CPfn}") os.remove(CPfn) except FileExistsError: logging.info(f"skipping {Ifn}, already exists") logging.info(f"skipping {Dfn}, already exists") continue def _open_index_shard(self, fn): if fn in self.index_shards: index_shard = self.index_shards[fn] else: logging.info(f"open index shard: {fn}") index_shard = faiss.read_index( fn, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY ) self.index_shards[fn] = index_shard return index_shard def _open_sharded_index(self, index_shard_prefix=None): if index_shard_prefix is None: index_shard_prefix = self.index_shard_prefix if index_shard_prefix in self.index: return self.index[index_shard_prefix] assert os.path.exists( self.index_template_file ), f"file {self.index_template_file} does not exist " logging.info(f"open index template: {self.index_template_file}") index = faiss.read_index(self.index_template_file) index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) ilv = faiss.InvertedListsPtrVector() for i in range(self.nshards): fn = f"{index_shard_prefix}{i}" assert os.path.exists(fn), f"file {fn} does not exist " logging.info(fn) index_shard = self._open_index_shard(fn) il = faiss.downcast_index( faiss.extract_index_ivf(index_shard) ).invlists ilv.push_back(il) hsil = faiss.HStackInvertedLists(ilv.size(), ilv.data()) index_ivf.replace_invlists(hsil, False) self.ivls[index_shard_prefix] = hsil self.index[index_shard_prefix] = index return index def index_shard_stats(self): for i in range(self.nshards): fn = f"{self.index_shard_prefix}{i}" assert os.path.exists(fn) index = faiss.read_index( fn, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY ) index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) il = index_ivf.invlists il.print_stats() def index_stats(self): index = self._open_sharded_index() index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) il = index_ivf.invlists list_sizes = [il.list_size(i) for i in range(il.nlist)] logging.info(np.max(list_sizes)) logging.info(np.mean(list_sizes)) logging.info(np.argmax(list_sizes)) logging.info("index_stats:") il.print_stats() def consistency_check(self): logging.info("consistency-check") logging.info("index template...") assert os.path.exists(self.index_template_file) index = faiss.read_index(self.index_template_file) offset = 0 # 2**24 assert self.shard_size > offset + SMALL_DATA_SAMPLE logging.info("index shards...") for i in range(self.nshards): r = i * self.shard_size + offset xb = next(self.xb_ds.iterate(r, SMALL_DATA_SAMPLE, np.float32)) fn = f"{self.index_shard_prefix}{i}" assert os.path.exists(fn), f"There is no index shard file {fn}" index = self._open_index_shard(fn) index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) index_ivf.nprobe = 1 _, I = index.search(xb, 100) for j in range(SMALL_DATA_SAMPLE): assert np.where(I[j] == j + r)[0].size > 0, ( f"I[j]: {I[j]}, j: {j}, i: {i}, shard_size:" f" {self.shard_size}" ) logging.info("merged index...") index = self._open_sharded_index() index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) index_ivf.nprobe = 1 for i in range(self.nshards): r = i * self.shard_size + offset xb = next(self.xb_ds.iterate(r, SMALL_DATA_SAMPLE, np.float32)) _, I = index.search(xb, 100) for j in range(SMALL_DATA_SAMPLE): assert np.where(I[j] == j + r)[0].size > 0, ( f"I[j]: {I[j]}, j: {j}, i: {i}, shard_size:" f" {self.shard_size}") logging.info("search results...") index_ivf.nprobe = self.nprobe for i in range(0, self.xq_ds.size, self.xq_bs): Ifn = f"{self.knn_dir}/I{i:010}_{self.index_factory_fn}_np{self.nprobe}.npy" assert os.path.exists(Ifn) assert os.path.getsize(Ifn) > 0, f"The file {Ifn} is empty." logging.info(Ifn) I = np.load(Ifn, mmap_mode="r") assert I.shape[1] == self.k assert I.shape[0] == min(self.xq_bs, self.xq_ds.size - i) assert np.all(I[:, 1] >= 0) Dfn = f"{self.knn_dir}/D_approx{i:010}_{self.index_factory_fn}_np{self.nprobe}.npy" assert os.path.exists(Dfn) assert os.path.getsize(Dfn) > 0, f"The file {Dfn} is empty." logging.info(Dfn) D = np.load(Dfn, mmap_mode="r") assert D.shape == I.shape xq = next(self.xq_ds.iterate(i, SMALL_DATA_SAMPLE, np.float32)) D_online, I_online = index.search(xq, self.k) assert ( np.where(I[:SMALL_DATA_SAMPLE] == I_online)[0].size / (self.k * SMALL_DATA_SAMPLE) > 0.95 ), ( "the ratio is" f" {np.where(I[:SMALL_DATA_SAMPLE] == I_online)[0].size / (self.k * SMALL_DATA_SAMPLE)}" ) assert np.allclose( D[:SMALL_DATA_SAMPLE].sum(axis=1), D_online.sum(axis=1), rtol=0.01, ), ( "the difference is" f" {D[:SMALL_DATA_SAMPLE].sum(axis=1), D_online.sum(axis=1)}" ) logging.info("done")