faiss/benchs/distributed_ondisk/make_index_vslice.py

118 lines
3.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import numpy as np
import faiss
import argparse
from multiprocessing.dummy import Pool as ThreadPool
def ivecs_mmap(fname):
a = np.memmap(fname, dtype='int32', mode='r')
d = a[0]
return a.reshape(-1, d + 1)[:, 1:]
def fvecs_mmap(fname):
return ivecs_mmap(fname).view('float32')
def produce_batches(args):
x = fvecs_mmap(args.input)
if args.i1 == -1:
args.i1 = len(x)
print("Iterating on vectors %d:%d from %s by batches of size %d" % (
args.i0, args.i1, args.input, args.bs))
for j0 in range(args.i0, args.i1, args.bs):
j1 = min(j0 + args.bs, args.i1)
yield np.arange(j0, j1), x[j0:j1]
def rate_limited_iter(l):
'a thread pre-processes the next element'
pool = ThreadPool(1)
res = None
def next_or_None():
try:
return next(l)
except StopIteration:
return None
while True:
res_next = pool.apply_async(next_or_None)
if res is not None:
res = res.get()
if res is None:
return
yield res
res = res_next
deep1bdir = "/datasets01_101/simsearch/041218/deep1b/"
workdir = "/checkpoint/matthijs/ondisk_distributed/"
def main():
parser = argparse.ArgumentParser(
description='make index for a subset of the data')
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('index type')
aa('--inputindex',
default=workdir + 'trained.faissindex',
help='empty input index to fill in')
aa('--nt', default=-1, type=int, help='nb of openmp threads to use')
group = parser.add_argument_group('db options')
aa('--input', default=deep1bdir + "base.fvecs")
aa('--bs', default=2**18, type=int,
help='batch size for db access')
aa('--i0', default=0, type=int, help='lower bound to index')
aa('--i1', default=-1, type=int, help='upper bound of vectors to index')
group = parser.add_argument_group('output')
aa('-o', default='/tmp/x', help='output index')
aa('--keepquantizer', default=False, action='store_true',
help='by default we remove the data from the quantizer to save space')
args = parser.parse_args()
print('args=', args)
print('start accessing data')
src = produce_batches(args)
print('loading index', args.inputindex)
index = faiss.read_index(args.inputindex)
if args.nt != -1:
faiss.omp_set_num_threads(args.nt)
t0 = time.time()
ntot = 0
for ids, x in rate_limited_iter(src):
print('add %d:%d (%.3f s)' % (ntot, ntot + ids.size, time.time() - t0))
index.add_with_ids(np.ascontiguousarray(x, dtype='float32'), ids)
ntot += ids.size
index_ivf = faiss.extract_index_ivf(index)
print('invlists stats: imbalance %.3f' % index_ivf.invlists.imbalance_factor())
index_ivf.invlists.print_stats()
if not args.keepquantizer:
print('resetting quantizer content')
index_ivf = faiss.extract_index_ivf(index)
index_ivf.quantizer.reset()
print('store output', args.o)
faiss.write_index(index, args.o)
if __name__ == '__main__':
main()