faiss/benchs/bench_fw_codecs.py

147 lines
5.0 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import argparse
import os
from faiss.benchs.bench_fw.benchmark import Benchmark
from faiss.benchs.bench_fw.benchmark_io import BenchmarkIO
from faiss.benchs.bench_fw.descriptors import DatasetDescriptor, IndexDescriptorClassic
from faiss.benchs.bench_fw.index import IndexFromFactory
logging.basicConfig(level=logging.INFO)
def factory_factory(d):
return [
("SQ4", None, 256 * (2 ** 10), None),
("SQ8", None, 256 * (2 ** 10), None),
("SQfp16", None, 256 * (2 ** 10), None),
("ITQ64,LSH", None, 256 * (2 ** 10), None),
("Pad128,ITQ128,LSH", None, 256 * (2 ** 10), None),
("Pad256,ITQ256,LSH", None, 256 * (2 ** 10), None),
] + [
(f"OPQ32_128,Residual2x14,PQ32x{b}", None, 256 * (2 ** 14), None)
for b in range(8, 16, 2)
] + [
(f"PCAR{2 ** d_out},SQ{b}", None, 256 * (2 ** 10), None)
for d_out in range(6, 11)
if 2 ** d_out <= d
for b in [4, 8]
] + [
(f"OPQ{M}_{M * dim},PQ{M}x{b}", None, 256 * (2 ** b), None)
for M in [8, 12, 16, 32, 64, 128]
for dim in [2, 4, 6, 8, 12, 16]
if M * dim <= d
for b in range(8, 16, 2)
] + [
(f"RQ{cs // b}x{b}", [{"max_beam_size": 32}], 256 * (2 ** b), {"max_beam_size": bs, "use_beam_LUT": bl})
for cs in [64, 128, 256, 512]
for b in [6, 8, 10, 12]
for bs in [1, 2, 4, 8, 16, 32]
for bl in [0, 1]
if cs // b > 1
if cs // b < 65
if cs < d * 8 * 2
] + [
(f"LSQ{cs // b}x{b}", [{"encode_ils_iters": 16}], 256 * (2 ** b), {"encode_ils_iters": eii, "lsq_gpu": lg})
for cs in [64, 128, 256, 512]
for b in [6, 8, 10, 12]
for eii in [2, 4, 8, 16]
for lg in [0, 1]
if cs // b > 1
if cs // b < 65
if cs < d * 8 * 2
] + [
(f"PRQ{sub}x{cs // sub // b}x{b}", [{"max_beam_size": 32}], 256 * (2 ** b), {"max_beam_size": bs, "use_beam_LUT": bl})
for sub in [2, 3, 4, 8, 16, 32]
for cs in [64, 96, 128, 192, 256, 384, 512, 768, 1024, 2048]
for b in [6, 8, 10, 12]
for bs in [1, 2, 4, 8, 16, 32]
for bl in [0, 1]
if cs // sub // b > 1
if cs // sub // b < 65
if cs < d * 8 * 2
if d % sub == 0
] + [
(f"PLSQ{sub}x{cs // sub // b}x{b}", [{"encode_ils_iters": 16}], 256 * (2 ** b), {"encode_ils_iters": eii, "lsq_gpu": lg})
for sub in [2, 3, 4, 8, 16, 32]
for cs in [64, 128, 256, 512, 1024, 2048]
for b in [6, 8, 10, 12]
for eii in [2, 4, 8, 16]
for lg in [0, 1]
if cs // sub // b > 1
if cs // sub // b < 65
if cs < d * 8 * 2
if d % sub == 0
]
def run_local(rp):
bio, d, tablename, distance_metric = rp
if tablename == "contriever":
training_vectors=DatasetDescriptor(
tablename="training_set.npy"
)
database_vectors=DatasetDescriptor(
tablename="database1M.npy",
)
query_vectors=DatasetDescriptor(
tablename="queries.npy",
)
else:
training_vectors=DatasetDescriptor(
namespace="std_t", tablename=tablename,
)
database_vectors=DatasetDescriptor(
namespace="std_d", tablename=tablename,
)
query_vectors=DatasetDescriptor(
namespace="std_q", tablename=tablename,
)
benchmark = Benchmark(
num_threads=32,
training_vectors=training_vectors,
database_vectors=database_vectors,
query_vectors=query_vectors,
index_descs=[
IndexDescriptorClassic(
factory=factory,
construction_params=construction_params,
training_size=training_size,
search_params=search_params,
)
for factory, construction_params, training_size, search_params in factory_factory(d)
],
k=1,
distance_metric=distance_metric,
)
benchmark.set_io(bio)
benchmark.benchmark(result_file="result.json", train=True, reconstruct=False, knn=False, range=False)
def run(bio, d, tablename, distance_metric):
bio.launch_jobs(run_local, [(bio, d, tablename, distance_metric)], local=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('experiment')
parser.add_argument('path')
args = parser.parse_args()
assert os.path.exists(args.path)
path = os.path.join(args.path, args.experiment)
if not os.path.exists(path):
os.mkdir(path)
bio = BenchmarkIO(
path=path,
)
if args.experiment == "sift1M":
run(bio, 128, "sift1M", "L2")
elif args.experiment == "bigann":
run(bio, 128, "bigann1M", "L2")
elif args.experiment == "deep1b":
run(bio, 96, "deep1M", "L2")
elif args.experiment == "contriever":
run(bio, 768, "contriever", "IP")