faiss/benchs/bench_for_interrupt.py

155 lines
4.0 KiB
Python

#! /usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import numpy as np
import faiss
import time
import os
import argparse
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('dataset options')
aa('--dim', type=int, default=64)
aa('--nb', type=int, default=int(1e6))
aa('--subset_len', type=int, default=int(1e5))
aa('--key', default='IVF1000,Flat')
aa('--nprobe', type=int, default=640)
aa('--no_intcallback', default=False, action='store_true')
aa('--twostage', default=False, action='store_true')
aa('--nt', type=int, default=-1)
args = parser.parse_args()
print("args:", args)
d = args.dim # dimension
nb = args.nb # database size
nq = 1000 # nb of queries
nt = 100000
subset_len = args.subset_len
np.random.seed(1234) # make reproducible
xb = np.random.random((nb, d)).astype('float32')
xq = np.random.random((nq, d)).astype('float32')
xt = np.random.random((nt, d)).astype('float32')
k = 100
if args.no_intcallback:
faiss.InterruptCallback.clear_instance()
if args.nt != -1:
faiss.omp_set_num_threads(args.nt)
nprobe = args.nprobe
key = args.key
#key = 'IVF1000,Flat'
# key = 'IVF1000,PQ64'
# key = 'IVF100_HNSW32,PQ64'
# faiss.omp_set_num_threads(1)
pf = 'dim%d_' % d
if d == 64:
pf = ''
basename = '/tmp/base%s%s.index' % (pf, key)
if os.path.exists(basename):
print('load', basename)
index_1 = faiss.read_index(basename)
else:
print('train + write', basename)
index_1 = faiss.index_factory(d, key)
index_1.train(xt)
faiss.write_index(index_1, basename)
print('add')
index_1.add(xb)
print('set nprobe=', nprobe)
faiss.ParameterSpace().set_index_parameter(index_1, 'nprobe', nprobe)
class ResultHeap:
""" Combine query results from a sliced dataset """
def __init__(self, nq, k):
" nq: number of query vectors, k: number of results per query "
self.I = np.zeros((nq, k), dtype='int64')
self.D = np.zeros((nq, k), dtype='float32')
self.nq, self.k = nq, k
heaps = faiss.float_maxheap_array_t()
heaps.k = k
heaps.nh = nq
heaps.val = faiss.swig_ptr(self.D)
heaps.ids = faiss.swig_ptr(self.I)
heaps.heapify()
self.heaps = heaps
def add_batch_result(self, D, I, i0):
assert D.shape == (self.nq, self.k)
assert I.shape == (self.nq, self.k)
I += i0
self.heaps.addn_with_ids(
self.k, faiss.swig_ptr(D),
faiss.swig_ptr(I), self.k)
def finalize(self):
self.heaps.reorder()
stats = faiss.cvar.indexIVF_stats
stats.reset()
print('index size', index_1.ntotal,
'imbalance', index_1.invlists.imbalance_factor())
start = time.time()
Dref, Iref = index_1.search(xq, k)
print('time of searching: %.3f s = %.3f + %.3f ms' % (
time.time() - start, stats.quantization_time, stats.search_time))
indexes = {}
if args.twostage:
for i in range(0, nb, subset_len):
index = faiss.read_index(basename)
faiss.ParameterSpace().set_index_parameter(index, 'nprobe', nprobe)
print("add %d:%d" %(i, i+subset_len))
index.add(xb[i:i + subset_len])
indexes[i] = index
rh = ResultHeap(nq, k)
sum_time = tq = ts = 0
for i in range(0, nb, subset_len):
if not args.twostage:
index = faiss.read_index(basename)
faiss.ParameterSpace().set_index_parameter(index, 'nprobe', nprobe)
print("add %d:%d" %(i, i+subset_len))
index.add(xb[i:i + subset_len])
else:
index = indexes[i]
stats.reset()
start = time.time()
Di, Ii = index.search(xq, k)
sum_time = sum_time + time.time() - start
tq += stats.quantization_time
ts += stats.search_time
rh.add_batch_result(Di, Ii, i)
print('time of searching separately: %.3f s = %.3f + %.3f ms' %
(sum_time, tq, ts))
rh.finalize()
print('diffs: %d / %d' % ((Iref != rh.I).sum(), Iref.size))