892 lines
34 KiB
Python
892 lines
34 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import faiss
|
|
import numpy as np
|
|
import os
|
|
from tqdm import tqdm, trange
|
|
import sys
|
|
import logging
|
|
from faiss.contrib.ondisk import merge_ondisk
|
|
from faiss.contrib.big_batch_search import big_batch_search
|
|
from faiss.contrib.exhaustive_search import knn_ground_truth
|
|
from faiss.contrib.evaluation import knn_intersection_measure
|
|
from utils import (
|
|
get_intersection_cardinality_frequencies,
|
|
margin,
|
|
is_pretransform_index,
|
|
)
|
|
from dataset import create_dataset_from_oivf_config
|
|
|
|
logging.basicConfig(
|
|
format=(
|
|
"%(asctime)s.%(msecs)03d %(levelname)-8s %(threadName)-12s %(message)s"
|
|
),
|
|
level=logging.INFO,
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
force=True,
|
|
)
|
|
|
|
EMBEDDINGS_BATCH_SIZE: int = 100_000
|
|
NUM_SUBSAMPLES: int = 100
|
|
SMALL_DATA_SAMPLE: int = 10000
|
|
|
|
|
|
class OfflineIVF:
|
|
def __init__(self, cfg, args, nprobe, index_factory_str):
|
|
self.input_d = cfg["d"]
|
|
self.dt = cfg["datasets"][args.xb]["files"][0]["dtype"]
|
|
assert self.input_d > 0
|
|
output_dir = cfg["output"]
|
|
assert os.path.exists(output_dir)
|
|
self.index_factory = index_factory_str
|
|
assert self.index_factory is not None
|
|
self.index_factory_fn = self.index_factory.replace(",", "_")
|
|
self.index_template_file = (
|
|
f"{output_dir}/{args.xb}/{self.index_factory_fn}.empty.faissindex"
|
|
)
|
|
logging.info(f"index template: {self.index_template_file}")
|
|
|
|
if not args.xq:
|
|
args.xq = args.xb
|
|
|
|
self.by_residual = True
|
|
if args.no_residuals:
|
|
self.by_residual = False
|
|
|
|
xb_output_dir = f"{output_dir}/{args.xb}"
|
|
if not os.path.exists(xb_output_dir):
|
|
os.makedirs(xb_output_dir)
|
|
xq_output_dir = f"{output_dir}/{args.xq}"
|
|
if not os.path.exists(xq_output_dir):
|
|
os.makedirs(xq_output_dir)
|
|
search_output_dir = f"{output_dir}/{args.xq}_in_{args.xb}"
|
|
if not os.path.exists(search_output_dir):
|
|
os.makedirs(search_output_dir)
|
|
self.knn_dir = f"{search_output_dir}/knn"
|
|
if not os.path.exists(self.knn_dir):
|
|
os.makedirs(self.knn_dir)
|
|
self.eval_dir = f"{search_output_dir}/eval"
|
|
if not os.path.exists(self.eval_dir):
|
|
os.makedirs(self.eval_dir)
|
|
self.index = {} # to keep a reference to opened indices,
|
|
self.ivls = {} # hstack inverted lists,
|
|
self.index_shards = {} # and index shards
|
|
self.index_shard_prefix = (
|
|
f"{xb_output_dir}/{self.index_factory_fn}.shard_"
|
|
)
|
|
self.xq_index_shard_prefix = (
|
|
f"{xq_output_dir}/{self.index_factory_fn}.shard_"
|
|
)
|
|
self.index_file = ( # TODO: added back temporarily for evaluate, handle name of non-sharded index file and remove.
|
|
f"{xb_output_dir}/{self.index_factory_fn}.faissindex"
|
|
)
|
|
self.xq_index_file = (
|
|
f"{xq_output_dir}/{self.index_factory_fn}.faissindex"
|
|
)
|
|
self.training_sample = cfg["training_sample"]
|
|
self.evaluation_sample = cfg["evaluation_sample"]
|
|
self.xq_ds = create_dataset_from_oivf_config(cfg, args.xq)
|
|
self.xb_ds = create_dataset_from_oivf_config(cfg, args.xb)
|
|
file_descriptors = self.xq_ds.file_descriptors
|
|
self.file_sizes = [fd.size for fd in file_descriptors]
|
|
self.shard_size = cfg["index_shard_size"] # ~100GB
|
|
self.nshards = self.xb_ds.size // self.shard_size
|
|
if self.xb_ds.size % self.shard_size != 0:
|
|
self.nshards += 1
|
|
self.xq_nshards = self.xq_ds.size // self.shard_size
|
|
if self.xq_ds.size % self.shard_size != 0:
|
|
self.xq_nshards += 1
|
|
self.nprobe = nprobe
|
|
assert self.nprobe > 0, "Invalid nprobe parameter."
|
|
if "deduper" in cfg:
|
|
self.deduper = cfg["deduper"]
|
|
self.deduper_codec_fn = [
|
|
f"{xb_output_dir}/deduper_codec_{codec.replace(',', '_')}"
|
|
for codec in self.deduper
|
|
]
|
|
self.deduper_idx_fn = [
|
|
f"{xb_output_dir}/deduper_idx_{codec.replace(',', '_')}"
|
|
for codec in self.deduper
|
|
]
|
|
else:
|
|
self.deduper = None
|
|
self.k = cfg["k"]
|
|
assert self.k > 0, "Invalid number of neighbours parameter."
|
|
self.knn_output_file_suffix = (
|
|
f"{self.index_factory_fn}_np{self.nprobe}.npy"
|
|
)
|
|
|
|
fp = 32
|
|
if self.dt == "float16":
|
|
fp = 16
|
|
|
|
self.xq_bs = cfg["query_batch_size"]
|
|
if "metric" in cfg:
|
|
self.metric = eval(f'faiss.{cfg["metric"]}')
|
|
else:
|
|
self.metric = faiss.METRIC_L2
|
|
|
|
if "evaluate_by_margin" in cfg:
|
|
self.evaluate_by_margin = cfg["evaluate_by_margin"]
|
|
else:
|
|
self.evaluate_by_margin = False
|
|
|
|
os.system("grep -m1 'model name' < /proc/cpuinfo")
|
|
os.system("grep -E 'MemTotal|MemFree' /proc/meminfo")
|
|
os.system("nvidia-smi")
|
|
os.system("nvcc --version")
|
|
|
|
self.knn_queries_memory_limit = 4 * 1024 * 1024 * 1024 # 4 GB
|
|
self.knn_vectors_memory_limit = 8 * 1024 * 1024 * 1024 # 8 GB
|
|
|
|
def input_stats(self):
|
|
"""
|
|
Trains the index using a subsample of the first chunk of data in the database and saves it in the template file (with no vectors added).
|
|
"""
|
|
xb_sample = self.xb_ds.get_first_n(self.training_sample, np.float32)
|
|
logging.info(f"input shape: {xb_sample.shape}")
|
|
logging.info("running MatrixStats on training sample...")
|
|
logging.info(faiss.MatrixStats(xb_sample).comments)
|
|
logging.info("done")
|
|
|
|
def dedupe(self):
|
|
logging.info(self.deduper)
|
|
if self.deduper is None:
|
|
logging.info("No deduper configured")
|
|
return
|
|
codecs = []
|
|
codesets = []
|
|
idxs = []
|
|
for factory, filename in zip(self.deduper, self.deduper_codec_fn):
|
|
if os.path.exists(filename):
|
|
logging.info(f"loading trained dedupe codec: {filename}")
|
|
codec = faiss.read_index(filename)
|
|
else:
|
|
logging.info(f"training dedupe codec: {factory}")
|
|
codec = faiss.index_factory(self.input_d, factory)
|
|
xb_sample = np.unique(
|
|
self.xb_ds.get_first_n(100_000, np.float32), axis=0
|
|
)
|
|
faiss.ParameterSpace().set_index_parameter(codec, "verbose", 1)
|
|
codec.train(xb_sample)
|
|
logging.info(f"writing trained dedupe codec: {filename}")
|
|
faiss.write_index(codec, filename)
|
|
codecs.append(codec)
|
|
codesets.append(faiss.CodeSet(codec.sa_code_size()))
|
|
idxs.append(np.empty((0,), dtype=np.uint32))
|
|
bs = 1_000_000
|
|
i = 0
|
|
for buffer in tqdm(self._iterate_transformed(self.xb_ds, 0, bs, np.float32)):
|
|
for j in range(len(codecs)):
|
|
codec, codeset, idx = codecs[j], codesets[j], idxs[j]
|
|
uniq = codeset.insert(codec.sa_encode(buffer))
|
|
idxs[j] = np.append(
|
|
idx,
|
|
np.arange(i, i + buffer.shape[0], dtype=np.uint32)[uniq],
|
|
)
|
|
i += buffer.shape[0]
|
|
for idx, filename in zip(idxs, self.deduper_idx_fn):
|
|
logging.info(f"writing {filename}, shape: {idx.shape}")
|
|
np.save(filename, idx)
|
|
logging.info("done")
|
|
|
|
def train_index(self):
|
|
"""
|
|
Trains the index using a subsample of the first chunk of data in the database and saves it in the template file (with no vectors added).
|
|
"""
|
|
assert not os.path.exists(self.index_template_file), (
|
|
"The train command has been ran, the index template file already"
|
|
" exists."
|
|
)
|
|
xb_sample = np.unique(
|
|
self.xb_ds.get_first_n(self.training_sample, np.float32), axis=0
|
|
)
|
|
logging.info(f"input shape: {xb_sample.shape}")
|
|
index = faiss.index_factory(
|
|
self.input_d, self.index_factory, self.metric
|
|
)
|
|
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
|
index_ivf.by_residual = True
|
|
faiss.ParameterSpace().set_index_parameter(index, "verbose", 1)
|
|
logging.info("running training...")
|
|
index.train(xb_sample)
|
|
logging.info(f"writing trained index {self.index_template_file}...")
|
|
faiss.write_index(index, self.index_template_file)
|
|
logging.info("done")
|
|
|
|
def _iterate_transformed(self, ds, start, batch_size, dt):
|
|
assert os.path.exists(self.index_template_file)
|
|
index = faiss.read_index(self.index_template_file)
|
|
if is_pretransform_index(index):
|
|
vt = index.chain.at(0) # fetch pretransform
|
|
for buffer in ds.iterate(start, batch_size, dt):
|
|
yield vt.apply(buffer)
|
|
else:
|
|
for buffer in ds.iterate(start, batch_size, dt):
|
|
yield buffer
|
|
|
|
def index_shard(self):
|
|
assert os.path.exists(self.index_template_file)
|
|
index = faiss.read_index(self.index_template_file)
|
|
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
|
assert self.nprobe <= index_ivf.quantizer.ntotal, (
|
|
f"the number of vectors {index_ivf.quantizer.ntotal} is not enough"
|
|
f" to retrieve {self.nprobe} neighbours, check."
|
|
)
|
|
cpu_quantizer = index_ivf.quantizer
|
|
gpu_quantizer = faiss.index_cpu_to_all_gpus(cpu_quantizer)
|
|
|
|
for i in range(0, self.nshards):
|
|
sfn = f"{self.index_shard_prefix}{i}"
|
|
try:
|
|
index.reset()
|
|
index_ivf.quantizer = gpu_quantizer
|
|
with open(sfn, "xb"):
|
|
start = i * self.shard_size
|
|
jj = 0
|
|
embeddings_batch_size = min(
|
|
EMBEDDINGS_BATCH_SIZE, self.shard_size
|
|
)
|
|
assert (
|
|
self.shard_size % embeddings_batch_size == 0
|
|
or EMBEDDINGS_BATCH_SIZE % embeddings_batch_size == 0
|
|
), (
|
|
f"the shard size {self.shard_size} and embeddings"
|
|
f" shard size {EMBEDDINGS_BATCH_SIZE} are not"
|
|
" divisible"
|
|
)
|
|
|
|
for xb_j in tqdm(
|
|
self._iterate_transformed(
|
|
self.xb_ds,
|
|
start,
|
|
embeddings_batch_size,
|
|
np.float32,
|
|
),
|
|
file=sys.stdout,
|
|
):
|
|
if is_pretransform_index(index):
|
|
assert xb_j.shape[1] == index.chain.at(0).d_out
|
|
index_ivf.add_with_ids(
|
|
xb_j,
|
|
np.arange(start + jj, start + jj + xb_j.shape[0]),
|
|
)
|
|
else:
|
|
assert xb_j.shape[1] == index.d
|
|
index.add_with_ids(
|
|
xb_j,
|
|
np.arange(start + jj, start + jj + xb_j.shape[0]),
|
|
)
|
|
jj += xb_j.shape[0]
|
|
logging.info(jj)
|
|
assert (
|
|
jj <= self.shard_size
|
|
), f"jj {jj} and shard_zide {self.shard_size}"
|
|
if jj == self.shard_size:
|
|
break
|
|
logging.info(f"writing {sfn}...")
|
|
index_ivf.quantizer = cpu_quantizer
|
|
faiss.write_index(index, sfn)
|
|
except FileExistsError:
|
|
logging.info(f"skipping shard: {i}")
|
|
continue
|
|
logging.info("done")
|
|
|
|
def merge_index(self):
|
|
ivf_file = f"{self.index_file}.ivfdata"
|
|
|
|
assert os.path.exists(self.index_template_file)
|
|
assert not os.path.exists(
|
|
ivf_file
|
|
), f"file with embeddings data {ivf_file} not found, check."
|
|
assert not os.path.exists(self.index_file)
|
|
index = faiss.read_index(self.index_template_file)
|
|
block_fnames = [
|
|
f"{self.index_shard_prefix}{i}" for i in range(self.nshards)
|
|
]
|
|
for fn in block_fnames:
|
|
assert os.path.exists(fn)
|
|
logging.info(block_fnames)
|
|
logging.info("merging...")
|
|
merge_ondisk(index, block_fnames, ivf_file)
|
|
logging.info("writing index...")
|
|
faiss.write_index(index, self.index_file)
|
|
logging.info("done")
|
|
|
|
def _cached_search(
|
|
self,
|
|
sample,
|
|
xq_ds,
|
|
xb_ds,
|
|
idx_file,
|
|
vecs_file,
|
|
I_file,
|
|
D_file,
|
|
index_file=None,
|
|
nprobe=None,
|
|
):
|
|
if not os.path.exists(I_file):
|
|
assert not os.path.exists(I_file), f"file {I_file} does not exist "
|
|
assert not os.path.exists(D_file), f"file {D_file} does not exist "
|
|
xq = xq_ds.sample(sample, idx_file, vecs_file)
|
|
|
|
if index_file:
|
|
D, I = self._index_nonsharded_search(index_file, xq, nprobe)
|
|
else:
|
|
logging.info("ground truth computations")
|
|
db_iterator = xb_ds.iterate(0, 100_000, np.float32)
|
|
D, I = knn_ground_truth(
|
|
xq, db_iterator, self.k, metric_type=self.metric
|
|
)
|
|
assert np.amin(I) >= 0
|
|
|
|
np.save(I_file, I)
|
|
np.save(D_file, D)
|
|
else:
|
|
assert os.path.exists(idx_file), f"file {idx_file} does not exist "
|
|
assert os.path.exists(
|
|
vecs_file
|
|
), f"file {vecs_file} does not exist "
|
|
assert os.path.exists(I_file), f"file {I_file} does not exist "
|
|
assert os.path.exists(D_file), f"file {D_file} does not exist "
|
|
I = np.load(I_file)
|
|
D = np.load(D_file)
|
|
assert I.shape == (sample, self.k), f"{I_file} shape mismatch"
|
|
assert D.shape == (sample, self.k), f"{D_file} shape mismatch"
|
|
return (D, I)
|
|
|
|
def _index_search(self, index_shard_prefix, xq, nprobe):
|
|
assert nprobe is not None
|
|
logging.info(
|
|
f"open sharded index: {index_shard_prefix}, {self.nshards}"
|
|
)
|
|
index = self._open_sharded_index(index_shard_prefix)
|
|
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
|
logging.info(f"setting nprobe to {nprobe}")
|
|
index_ivf.nprobe = nprobe
|
|
return index.search(xq, self.k)
|
|
|
|
def _index_nonsharded_search(self, index_file, xq, nprobe):
|
|
assert nprobe is not None
|
|
logging.info(f"index {index_file}")
|
|
assert os.path.exists(index_file), f"file {index_file} does not exist "
|
|
index = faiss.read_index(index_file, faiss.IO_FLAG_ONDISK_SAME_DIR)
|
|
logging.info(f"index size {index.ntotal} ")
|
|
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
|
logging.info(f"setting nprobe to {nprobe}")
|
|
index_ivf.nprobe = nprobe
|
|
return index.search(xq, self.k)
|
|
|
|
def _refine_distances(self, xq_ds, idx, xb_ds, I):
|
|
xq = xq_ds.get(idx).repeat(self.k, axis=0)
|
|
xb = xb_ds.get(I.reshape(-1))
|
|
if self.metric == faiss.METRIC_INNER_PRODUCT:
|
|
return (xq * xb).sum(axis=1).reshape(I.shape)
|
|
elif self.metric == faiss.METRIC_L2:
|
|
return ((xq - xb) ** 2).sum(axis=1).reshape(I.shape)
|
|
else:
|
|
raise ValueError(f"metric not supported {self.metric}")
|
|
|
|
def evaluate(self):
|
|
self._evaluate(
|
|
self.index_factory_fn,
|
|
self.index_file,
|
|
self.xq_index_file,
|
|
self.nprobe,
|
|
)
|
|
|
|
def _evaluate(self, index_factory_fn, index_file, xq_index_file, nprobe):
|
|
idx_a_file = f"{self.eval_dir}/idx_a.npy"
|
|
idx_b_gt_file = f"{self.eval_dir}/idx_b_gt.npy"
|
|
idx_b_ann_file = (
|
|
f"{self.eval_dir}/idx_b_ann_{index_factory_fn}_np{nprobe}.npy"
|
|
)
|
|
vecs_a_file = f"{self.eval_dir}/vecs_a.npy"
|
|
vecs_b_gt_file = f"{self.eval_dir}/vecs_b_gt.npy"
|
|
vecs_b_ann_file = (
|
|
f"{self.eval_dir}/vecs_b_ann_{index_factory_fn}_np{nprobe}.npy"
|
|
)
|
|
D_a_gt_file = f"{self.eval_dir}/D_a_gt.npy"
|
|
D_a_ann_file = (
|
|
f"{self.eval_dir}/D_a_ann_{index_factory_fn}_np{nprobe}.npy"
|
|
)
|
|
D_a_ann_refined_file = f"{self.eval_dir}/D_a_ann_refined_{index_factory_fn}_np{nprobe}.npy"
|
|
D_b_gt_file = f"{self.eval_dir}/D_b_gt.npy"
|
|
D_b_ann_file = (
|
|
f"{self.eval_dir}/D_b_ann_{index_factory_fn}_np{nprobe}.npy"
|
|
)
|
|
D_b_ann_gt_file = (
|
|
f"{self.eval_dir}/D_b_ann_gt_{index_factory_fn}_np{nprobe}.npy"
|
|
)
|
|
I_a_gt_file = f"{self.eval_dir}/I_a_gt.npy"
|
|
I_a_ann_file = (
|
|
f"{self.eval_dir}/I_a_ann_{index_factory_fn}_np{nprobe}.npy"
|
|
)
|
|
I_b_gt_file = f"{self.eval_dir}/I_b_gt.npy"
|
|
I_b_ann_file = (
|
|
f"{self.eval_dir}/I_b_ann_{index_factory_fn}_np{nprobe}.npy"
|
|
)
|
|
I_b_ann_gt_file = (
|
|
f"{self.eval_dir}/I_b_ann_gt_{index_factory_fn}_np{nprobe}.npy"
|
|
)
|
|
margin_gt_file = f"{self.eval_dir}/margin_gt.npy"
|
|
margin_refined_file = (
|
|
f"{self.eval_dir}/margin_refined_{index_factory_fn}_np{nprobe}.npy"
|
|
)
|
|
margin_ann_file = (
|
|
f"{self.eval_dir}/margin_ann_{index_factory_fn}_np{nprobe}.npy"
|
|
)
|
|
|
|
logging.info("exact search forward")
|
|
# xq -> xb AKA a -> b
|
|
D_a_gt, I_a_gt = self._cached_search(
|
|
self.evaluation_sample,
|
|
self.xq_ds,
|
|
self.xb_ds,
|
|
idx_a_file,
|
|
vecs_a_file,
|
|
I_a_gt_file,
|
|
D_a_gt_file,
|
|
)
|
|
idx_a = np.load(idx_a_file)
|
|
|
|
logging.info("approximate search forward")
|
|
D_a_ann, I_a_ann = self._cached_search(
|
|
self.evaluation_sample,
|
|
self.xq_ds,
|
|
self.xb_ds,
|
|
idx_a_file,
|
|
vecs_a_file,
|
|
I_a_ann_file,
|
|
D_a_ann_file,
|
|
index_file,
|
|
nprobe,
|
|
)
|
|
|
|
logging.info(
|
|
"calculate refined distances on approximate search forward"
|
|
)
|
|
if os.path.exists(D_a_ann_refined_file):
|
|
D_a_ann_refined = np.load(D_a_ann_refined_file)
|
|
assert D_a_ann.shape == D_a_ann_refined.shape
|
|
else:
|
|
D_a_ann_refined = self._refine_distances(
|
|
self.xq_ds, idx_a, self.xb_ds, I_a_ann
|
|
)
|
|
np.save(D_a_ann_refined_file, D_a_ann_refined)
|
|
|
|
if self.evaluate_by_margin:
|
|
k_extract = self.k
|
|
margin_threshold = 1.05
|
|
logging.info(
|
|
"exact search backward from the k_extract NN results of"
|
|
" forward search"
|
|
)
|
|
# xb -> xq AKA b -> a
|
|
D_a_b_gt = D_a_gt[:, :k_extract].ravel()
|
|
idx_b_gt = I_a_gt[:, :k_extract].ravel()
|
|
assert len(idx_b_gt) == self.evaluation_sample * k_extract
|
|
np.save(idx_b_gt_file, idx_b_gt)
|
|
# exact search
|
|
D_b_gt, _ = self._cached_search(
|
|
len(idx_b_gt),
|
|
self.xb_ds,
|
|
self.xq_ds,
|
|
idx_b_gt_file,
|
|
vecs_b_gt_file,
|
|
I_b_gt_file,
|
|
D_b_gt_file,
|
|
) # xb and xq ^^^ are inverted
|
|
|
|
logging.info("margin on exact search")
|
|
margin_gt = margin(
|
|
self.evaluation_sample,
|
|
idx_a,
|
|
idx_b_gt,
|
|
D_a_b_gt,
|
|
D_a_gt,
|
|
D_b_gt,
|
|
self.k,
|
|
k_extract,
|
|
margin_threshold,
|
|
)
|
|
np.save(margin_gt_file, margin_gt)
|
|
|
|
logging.info(
|
|
"exact search backward from the k_extract NN results of"
|
|
" approximate forward search"
|
|
)
|
|
D_a_b_refined = D_a_ann_refined[:, :k_extract].ravel()
|
|
idx_b_ann = I_a_ann[:, :k_extract].ravel()
|
|
assert len(idx_b_ann) == self.evaluation_sample * k_extract
|
|
np.save(idx_b_ann_file, idx_b_ann)
|
|
# exact search
|
|
D_b_ann_gt, _ = self._cached_search(
|
|
len(idx_b_ann),
|
|
self.xb_ds,
|
|
self.xq_ds,
|
|
idx_b_ann_file,
|
|
vecs_b_ann_file,
|
|
I_b_ann_gt_file,
|
|
D_b_ann_gt_file,
|
|
) # xb and xq ^^^ are inverted
|
|
|
|
logging.info("refined margin on approximate search")
|
|
margin_refined = margin(
|
|
self.evaluation_sample,
|
|
idx_a,
|
|
idx_b_ann,
|
|
D_a_b_refined,
|
|
D_a_gt, # not D_a_ann_refined(!)
|
|
D_b_ann_gt,
|
|
self.k,
|
|
k_extract,
|
|
margin_threshold,
|
|
)
|
|
np.save(margin_refined_file, margin_refined)
|
|
|
|
D_b_ann, I_b_ann = self._cached_search(
|
|
len(idx_b_ann),
|
|
self.xb_ds,
|
|
self.xq_ds,
|
|
idx_b_ann_file,
|
|
vecs_b_ann_file,
|
|
I_b_ann_file,
|
|
D_b_ann_file,
|
|
xq_index_file,
|
|
nprobe,
|
|
)
|
|
|
|
D_a_b_ann = D_a_ann[:, :k_extract].ravel()
|
|
|
|
logging.info("approximate search margin")
|
|
|
|
margin_ann = margin(
|
|
self.evaluation_sample,
|
|
idx_a,
|
|
idx_b_ann,
|
|
D_a_b_ann,
|
|
D_a_ann,
|
|
D_b_ann,
|
|
self.k,
|
|
k_extract,
|
|
margin_threshold,
|
|
)
|
|
np.save(margin_ann_file, margin_ann)
|
|
|
|
logging.info("intersection")
|
|
logging.info(I_a_gt)
|
|
logging.info(I_a_ann)
|
|
|
|
for i in range(1, self.k + 1):
|
|
logging.info(
|
|
f"{i}: {knn_intersection_measure(I_a_gt[:,:i], I_a_ann[:,:i])}"
|
|
)
|
|
|
|
logging.info(f"mean of gt distances: {D_a_gt.mean()}")
|
|
logging.info(f"mean of approx distances: {D_a_ann.mean()}")
|
|
logging.info(f"mean of refined distances: {D_a_ann_refined.mean()}")
|
|
|
|
logging.info("intersection cardinality frequencies")
|
|
logging.info(get_intersection_cardinality_frequencies(I_a_ann, I_a_gt))
|
|
|
|
logging.info("done")
|
|
pass
|
|
|
|
def _knn_function(self, xq, xb, k, metric, thread_id=None):
|
|
try:
|
|
return faiss.knn_gpu(
|
|
self.all_gpu_resources[thread_id],
|
|
xq,
|
|
xb,
|
|
k,
|
|
metric=metric,
|
|
device=thread_id,
|
|
vectorsMemoryLimit=self.knn_vectors_memory_limit,
|
|
queriesMemoryLimit=self.knn_queries_memory_limit,
|
|
)
|
|
except Exception:
|
|
logging.info(f"knn_function failed: {xq.shape}, {xb.shape}")
|
|
raise
|
|
|
|
def _coarse_quantize(self, index_ivf, xq, nprobe):
|
|
assert nprobe <= index_ivf.quantizer.ntotal
|
|
quantizer = faiss.index_cpu_to_all_gpus(index_ivf.quantizer)
|
|
bs = 100_000
|
|
nq = len(xq)
|
|
q_assign = np.empty((nq, nprobe), dtype="int32")
|
|
for i0 in trange(0, nq, bs):
|
|
i1 = min(nq, i0 + bs)
|
|
_, q_assign_i = quantizer.search(xq[i0:i1], nprobe)
|
|
q_assign[i0:i1] = q_assign_i
|
|
return q_assign
|
|
|
|
def search(self):
|
|
logging.info(f"search: {self.knn_dir}")
|
|
slurm_job_id = os.environ.get("SLURM_JOB_ID")
|
|
|
|
ngpu = faiss.get_num_gpus()
|
|
logging.info(f"number of gpus: {ngpu}")
|
|
self.all_gpu_resources = [
|
|
faiss.StandardGpuResources() for _ in range(ngpu)
|
|
]
|
|
self._knn_function(
|
|
np.zeros((10, 10), dtype=np.float16),
|
|
np.zeros((10, 10), dtype=np.float16),
|
|
self.k,
|
|
metric=self.metric,
|
|
thread_id=0,
|
|
)
|
|
|
|
index = self._open_sharded_index()
|
|
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
|
logging.info(f"setting nprobe to {self.nprobe}")
|
|
index_ivf.nprobe = self.nprobe
|
|
# quantizer = faiss.index_cpu_to_all_gpus(index_ivf.quantizer)
|
|
for i in range(0, self.xq_ds.size, self.xq_bs):
|
|
Ifn = f"{self.knn_dir}/I{(i):010}_{self.knn_output_file_suffix}"
|
|
Dfn = f"{self.knn_dir}/D_approx{(i):010}_{self.knn_output_file_suffix}"
|
|
CPfn = f"{self.knn_dir}/CP{(i):010}_{self.knn_output_file_suffix}"
|
|
|
|
if slurm_job_id:
|
|
worker_record = (
|
|
self.knn_dir
|
|
+ f"/record_{(i):010}_{self.knn_output_file_suffix}.txt"
|
|
)
|
|
if not os.path.exists(worker_record):
|
|
logging.info(
|
|
f"creating record file {worker_record} and saving job"
|
|
f" id: {slurm_job_id}"
|
|
)
|
|
with open(worker_record, "w") as h:
|
|
h.write(slurm_job_id)
|
|
else:
|
|
old_slurm_id = open(worker_record, "r").read()
|
|
logging.info(
|
|
f"old job slurm id {old_slurm_id} and current job id:"
|
|
f" {slurm_job_id}"
|
|
)
|
|
if old_slurm_id == slurm_job_id:
|
|
if os.path.getsize(Ifn) == 0:
|
|
logging.info(
|
|
f"cleaning up zero length files {Ifn} and"
|
|
f" {Dfn}"
|
|
)
|
|
os.remove(Ifn)
|
|
os.remove(Dfn)
|
|
|
|
try:
|
|
if is_pretransform_index(index):
|
|
d = index.chain.at(0).d_out
|
|
else:
|
|
d = self.input_d
|
|
with open(Ifn, "xb") as f, open(Dfn, "xb") as g:
|
|
xq_i = np.empty(
|
|
shape=(self.xq_bs, d), dtype=np.float16
|
|
)
|
|
q_assign = np.empty(
|
|
(self.xq_bs, self.nprobe), dtype=np.int32
|
|
)
|
|
j = 0
|
|
quantizer = faiss.index_cpu_to_all_gpus(
|
|
index_ivf.quantizer
|
|
)
|
|
for xq_i_j in tqdm(
|
|
self._iterate_transformed(
|
|
self.xq_ds, i, min(100_000, self.xq_bs), np.float16
|
|
),
|
|
file=sys.stdout,
|
|
):
|
|
xq_i[j:j + xq_i_j.shape[0]] = xq_i_j
|
|
(
|
|
_,
|
|
q_assign[j:j + xq_i_j.shape[0]],
|
|
) = quantizer.search(xq_i_j, self.nprobe)
|
|
j += xq_i_j.shape[0]
|
|
assert j <= xq_i.shape[0]
|
|
if j == xq_i.shape[0]:
|
|
break
|
|
xq_i = xq_i[:j]
|
|
q_assign = q_assign[:j]
|
|
|
|
assert q_assign.shape == (xq_i.shape[0], index_ivf.nprobe)
|
|
del quantizer
|
|
logging.info(f"computing: {Ifn}")
|
|
logging.info(f"computing: {Dfn}")
|
|
prefetch_threads = faiss.get_num_gpus()
|
|
D_ann, I = big_batch_search(
|
|
index_ivf,
|
|
xq_i,
|
|
self.k,
|
|
verbose=10,
|
|
method="knn_function",
|
|
knn=self._knn_function,
|
|
threaded=faiss.get_num_gpus() * 8,
|
|
use_float16=True,
|
|
prefetch_threads=prefetch_threads,
|
|
computation_threads=faiss.get_num_gpus(),
|
|
q_assign=q_assign,
|
|
checkpoint=CPfn,
|
|
checkpoint_freq=7200, # in seconds
|
|
)
|
|
assert (
|
|
np.amin(I) >= 0
|
|
), f"{I}, there exists negative indices, check"
|
|
logging.info(f"saving: {Ifn}")
|
|
np.save(f, I)
|
|
logging.info(f"saving: {Dfn}")
|
|
np.save(g, D_ann)
|
|
|
|
if os.path.exists(CPfn):
|
|
logging.info(f"removing: {CPfn}")
|
|
os.remove(CPfn)
|
|
|
|
except FileExistsError:
|
|
logging.info(f"skipping {Ifn}, already exists")
|
|
logging.info(f"skipping {Dfn}, already exists")
|
|
continue
|
|
|
|
def _open_index_shard(self, fn):
|
|
if fn in self.index_shards:
|
|
index_shard = self.index_shards[fn]
|
|
else:
|
|
logging.info(f"open index shard: {fn}")
|
|
index_shard = faiss.read_index(
|
|
fn, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
|
|
)
|
|
self.index_shards[fn] = index_shard
|
|
return index_shard
|
|
|
|
def _open_sharded_index(self, index_shard_prefix=None):
|
|
if index_shard_prefix is None:
|
|
index_shard_prefix = self.index_shard_prefix
|
|
if index_shard_prefix in self.index:
|
|
return self.index[index_shard_prefix]
|
|
assert os.path.exists(
|
|
self.index_template_file
|
|
), f"file {self.index_template_file} does not exist "
|
|
logging.info(f"open index template: {self.index_template_file}")
|
|
index = faiss.read_index(self.index_template_file)
|
|
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
|
ilv = faiss.InvertedListsPtrVector()
|
|
for i in range(self.nshards):
|
|
fn = f"{index_shard_prefix}{i}"
|
|
assert os.path.exists(fn), f"file {fn} does not exist "
|
|
logging.info(fn)
|
|
index_shard = self._open_index_shard(fn)
|
|
il = faiss.downcast_index(
|
|
faiss.extract_index_ivf(index_shard)
|
|
).invlists
|
|
ilv.push_back(il)
|
|
hsil = faiss.HStackInvertedLists(ilv.size(), ilv.data())
|
|
index_ivf.replace_invlists(hsil, False)
|
|
self.ivls[index_shard_prefix] = hsil
|
|
self.index[index_shard_prefix] = index
|
|
return index
|
|
|
|
def index_shard_stats(self):
|
|
for i in range(self.nshards):
|
|
fn = f"{self.index_shard_prefix}{i}"
|
|
assert os.path.exists(fn)
|
|
index = faiss.read_index(
|
|
fn, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
|
|
)
|
|
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
|
il = index_ivf.invlists
|
|
il.print_stats()
|
|
|
|
def index_stats(self):
|
|
index = self._open_sharded_index()
|
|
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
|
il = index_ivf.invlists
|
|
list_sizes = [il.list_size(i) for i in range(il.nlist)]
|
|
logging.info(np.max(list_sizes))
|
|
logging.info(np.mean(list_sizes))
|
|
logging.info(np.argmax(list_sizes))
|
|
logging.info("index_stats:")
|
|
il.print_stats()
|
|
|
|
def consistency_check(self):
|
|
logging.info("consistency-check")
|
|
|
|
logging.info("index template...")
|
|
|
|
assert os.path.exists(self.index_template_file)
|
|
index = faiss.read_index(self.index_template_file)
|
|
|
|
offset = 0 # 2**24
|
|
assert self.shard_size > offset + SMALL_DATA_SAMPLE
|
|
|
|
logging.info("index shards...")
|
|
for i in range(self.nshards):
|
|
r = i * self.shard_size + offset
|
|
xb = next(self.xb_ds.iterate(r, SMALL_DATA_SAMPLE, np.float32))
|
|
fn = f"{self.index_shard_prefix}{i}"
|
|
assert os.path.exists(fn), f"There is no index shard file {fn}"
|
|
index = self._open_index_shard(fn)
|
|
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
|
index_ivf.nprobe = 1
|
|
_, I = index.search(xb, 100)
|
|
for j in range(SMALL_DATA_SAMPLE):
|
|
assert np.where(I[j] == j + r)[0].size > 0, (
|
|
f"I[j]: {I[j]}, j: {j}, i: {i}, shard_size:"
|
|
f" {self.shard_size}"
|
|
)
|
|
|
|
logging.info("merged index...")
|
|
index = self._open_sharded_index()
|
|
index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index))
|
|
index_ivf.nprobe = 1
|
|
for i in range(self.nshards):
|
|
r = i * self.shard_size + offset
|
|
xb = next(self.xb_ds.iterate(r, SMALL_DATA_SAMPLE, np.float32))
|
|
_, I = index.search(xb, 100)
|
|
for j in range(SMALL_DATA_SAMPLE):
|
|
assert np.where(I[j] == j + r)[0].size > 0, (
|
|
f"I[j]: {I[j]}, j: {j}, i: {i}, shard_size:"
|
|
f" {self.shard_size}")
|
|
|
|
logging.info("search results...")
|
|
index_ivf.nprobe = self.nprobe
|
|
for i in range(0, self.xq_ds.size, self.xq_bs):
|
|
Ifn = f"{self.knn_dir}/I{i:010}_{self.index_factory_fn}_np{self.nprobe}.npy"
|
|
assert os.path.exists(Ifn)
|
|
assert os.path.getsize(Ifn) > 0, f"The file {Ifn} is empty."
|
|
logging.info(Ifn)
|
|
I = np.load(Ifn, mmap_mode="r")
|
|
|
|
assert I.shape[1] == self.k
|
|
assert I.shape[0] == min(self.xq_bs, self.xq_ds.size - i)
|
|
assert np.all(I[:, 1] >= 0)
|
|
|
|
Dfn = f"{self.knn_dir}/D_approx{i:010}_{self.index_factory_fn}_np{self.nprobe}.npy"
|
|
assert os.path.exists(Dfn)
|
|
assert os.path.getsize(Dfn) > 0, f"The file {Dfn} is empty."
|
|
logging.info(Dfn)
|
|
D = np.load(Dfn, mmap_mode="r")
|
|
assert D.shape == I.shape
|
|
|
|
xq = next(self.xq_ds.iterate(i, SMALL_DATA_SAMPLE, np.float32))
|
|
D_online, I_online = index.search(xq, self.k)
|
|
assert (
|
|
np.where(I[:SMALL_DATA_SAMPLE] == I_online)[0].size
|
|
/ (self.k * SMALL_DATA_SAMPLE)
|
|
> 0.95
|
|
), (
|
|
"the ratio is"
|
|
f" {np.where(I[:SMALL_DATA_SAMPLE] == I_online)[0].size / (self.k * SMALL_DATA_SAMPLE)}"
|
|
)
|
|
assert np.allclose(
|
|
D[:SMALL_DATA_SAMPLE].sum(axis=1),
|
|
D_online.sum(axis=1),
|
|
rtol=0.01,
|
|
), (
|
|
"the difference is"
|
|
f" {D[:SMALL_DATA_SAMPLE].sum(axis=1), D_online.sum(axis=1)}"
|
|
)
|
|
|
|
logging.info("done")
|