faiss/demos/offline_ivf/create_sharded_ssnpp_files.py
Maria Lomeli 0fc8456e1d Offline IVF powered by faiss big batch search (#3202)
Summary:
This PR introduces the offline IVF (OIVF) framework which contains some tooling to run search using IVFPQ indexes (plus OPQ pretransforms) for large batches of queries using [big_batch_search](https://github.com/mlomeli1/faiss/blob/main/contrib/big_batch_search.py) and GPU faiss. See the [README](36226f5fe8/demos/offline_ivf/README.md) for details about using this framework.

This PR includes the following unit tests, which can be run with the unittest library as so:
````
~/faiss/demos/offline_ivf$ python3 -m unittest tests/test_iterate_input.py -k test_iterate_back
````
In test_offline_ivf:
````
test_consistency_check
test_train_index
test_index_shard_equal_file_sizes
test_index_shard_unequal_file_sizes
test_search
test_evaluate_without_margin
test_evaluate_without_margin_OPQ
````
In test_iterate_input:
````
test_iterate_input_file_larger_than_batch
test_get_vs_iterate
test_iterate_back

````

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3202

Reviewed By: algoriddle

Differential Revision: D52734222

Pulled By: mlomeli1

fbshipit-source-id: 61fd0084277c1b14bdae1189db8ae43340611e16
2024-01-16 05:05:15 -08:00

64 lines
2.0 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import argparse
import os
def xbin_mmap(fname, dtype, maxn=-1):
"""
Code from
https://github.com/harsha-simhadri/big-ann-benchmarks/blob/main/benchmark/dataset_io.py#L94
mmap the competition file format for a given type of items
"""
n, d = map(int, np.fromfile(fname, dtype="uint32", count=2))
assert os.stat(fname).st_size == 8 + n * d * np.dtype(dtype).itemsize
if maxn > 0:
n = min(n, maxn)
return np.memmap(fname, dtype=dtype, mode="r", offset=8, shape=(n, d))
def main(args: argparse.Namespace):
ssnpp_data = xbin_mmap(fname=args.filepath, dtype="uint8")
num_batches = ssnpp_data.shape[0] // args.data_batch
assert (
ssnpp_data.shape[0] % args.data_batch == 0
), "num of embeddings per file should divide total num of embeddings"
for i in range(num_batches):
xb_batch = ssnpp_data[
i * args.data_batch:(i + 1) * args.data_batch, :
]
filename = args.output_dir + f"/ssnpp_{(i):010}.npy"
np.save(filename, xb_batch)
print(f"File {filename} is saved!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_batch",
dest="data_batch",
type=int,
default=50000000,
help="Number of embeddings per file, should be a divisor of 1B",
)
parser.add_argument(
"--filepath",
dest="filepath",
type=str,
default="/datasets01/big-ann-challenge-data/FB_ssnpp/FB_ssnpp_database.u8bin",
help="path of 1B ssnpp database vectors' original file",
)
parser.add_argument(
"--filepath",
dest="output_dir",
type=str,
default="/checkpoint/marialomeli/ssnpp_data",
help="path to put sharded files",
)
args = parser.parse_args()
main(args)