mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
Summary:
This PR introduces the offline IVF (OIVF) framework which contains some tooling to run search using IVFPQ indexes (plus OPQ pretransforms) for large batches of queries using [big_batch_search](https://github.com/mlomeli1/faiss/blob/main/contrib/big_batch_search.py) and GPU faiss. See the [README](36226f5fe8/demos/offline_ivf/README.md
) for details about using this framework.
This PR includes the following unit tests, which can be run with the unittest library as so:
````
~/faiss/demos/offline_ivf$ python3 -m unittest tests/test_iterate_input.py -k test_iterate_back
````
In test_offline_ivf:
````
test_consistency_check
test_train_index
test_index_shard_equal_file_sizes
test_index_shard_unequal_file_sizes
test_search
test_evaluate_without_margin
test_evaluate_without_margin_OPQ
````
In test_iterate_input:
````
test_iterate_input_file_larger_than_batch
test_get_vs_iterate
test_iterate_back
````
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3202
Reviewed By: algoriddle
Differential Revision: D52734222
Pulled By: mlomeli1
fbshipit-source-id: 61fd0084277c1b14bdae1189db8ae43340611e16
64 lines
2.0 KiB
Python
64 lines
2.0 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
import argparse
|
|
import os
|
|
|
|
|
|
def xbin_mmap(fname, dtype, maxn=-1):
|
|
"""
|
|
Code from
|
|
https://github.com/harsha-simhadri/big-ann-benchmarks/blob/main/benchmark/dataset_io.py#L94
|
|
mmap the competition file format for a given type of items
|
|
"""
|
|
n, d = map(int, np.fromfile(fname, dtype="uint32", count=2))
|
|
assert os.stat(fname).st_size == 8 + n * d * np.dtype(dtype).itemsize
|
|
if maxn > 0:
|
|
n = min(n, maxn)
|
|
return np.memmap(fname, dtype=dtype, mode="r", offset=8, shape=(n, d))
|
|
|
|
|
|
def main(args: argparse.Namespace):
|
|
ssnpp_data = xbin_mmap(fname=args.filepath, dtype="uint8")
|
|
num_batches = ssnpp_data.shape[0] // args.data_batch
|
|
assert (
|
|
ssnpp_data.shape[0] % args.data_batch == 0
|
|
), "num of embeddings per file should divide total num of embeddings"
|
|
for i in range(num_batches):
|
|
xb_batch = ssnpp_data[
|
|
i * args.data_batch:(i + 1) * args.data_batch, :
|
|
]
|
|
filename = args.output_dir + f"/ssnpp_{(i):010}.npy"
|
|
np.save(filename, xb_batch)
|
|
print(f"File {filename} is saved!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--data_batch",
|
|
dest="data_batch",
|
|
type=int,
|
|
default=50000000,
|
|
help="Number of embeddings per file, should be a divisor of 1B",
|
|
)
|
|
parser.add_argument(
|
|
"--filepath",
|
|
dest="filepath",
|
|
type=str,
|
|
default="/datasets01/big-ann-challenge-data/FB_ssnpp/FB_ssnpp_database.u8bin",
|
|
help="path of 1B ssnpp database vectors' original file",
|
|
)
|
|
parser.add_argument(
|
|
"--filepath",
|
|
dest="output_dir",
|
|
type=str,
|
|
default="/checkpoint/marialomeli/ssnpp_data",
|
|
help="path to put sharded files",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
main(args)
|