97 lines
2.6 KiB
Python
97 lines
2.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
import faiss
|
|
import argparse
|
|
from multiprocessing.pool import ThreadPool
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('--inputs', nargs='*', required=True,
|
|
help='input indexes to merge')
|
|
parser.add_argument('--l0', type=int, default=0)
|
|
parser.add_argument('--l1', type=int, default=-1)
|
|
|
|
parser.add_argument('--nt', default=-1,
|
|
help='nb threads')
|
|
|
|
parser.add_argument('--output', required=True,
|
|
help='output index filename')
|
|
parser.add_argument('--outputIL',
|
|
help='output invfile filename')
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.nt != -1:
|
|
print('set nb of threads to', args.nt)
|
|
|
|
|
|
ils = faiss.InvertedListsPtrVector()
|
|
ils_dont_dealloc = []
|
|
|
|
pool = ThreadPool(20)
|
|
|
|
def load_index(fname):
|
|
print("loading", fname)
|
|
try:
|
|
index = faiss.read_index(fname, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
|
|
except RuntimeError as e:
|
|
print('could not load %s: %s' % (fname, e))
|
|
return fname, None
|
|
|
|
print(" %d entries" % index.ntotal)
|
|
return fname, index
|
|
|
|
index0 = None
|
|
|
|
for _, index in pool.imap(load_index, args.inputs):
|
|
if index is None:
|
|
continue
|
|
index_ivf = faiss.extract_index_ivf(index)
|
|
il = faiss.downcast_InvertedLists(index_ivf.invlists)
|
|
index_ivf.invlists = None
|
|
il.this.own()
|
|
ils_dont_dealloc.append(il)
|
|
if (args.l0, args.l1) != (0, -1):
|
|
print('restricting to lists %d:%d' % (args.l0, args.l1))
|
|
# il = faiss.SliceInvertedLists(il, args.l0, args.l1)
|
|
|
|
il.crop_invlists(args.l0, args.l1)
|
|
ils_dont_dealloc.append(il)
|
|
ils.push_back(il)
|
|
|
|
if index0 is None:
|
|
index0 = index
|
|
|
|
print("loaded %d invlists" % ils.size())
|
|
|
|
if not args.outputIL:
|
|
args.outputIL = args.output + '_invlists'
|
|
|
|
il0 = ils.at(0)
|
|
|
|
il = faiss.OnDiskInvertedLists(
|
|
il0.nlist, il0.code_size,
|
|
args.outputIL)
|
|
|
|
print("perform merge")
|
|
|
|
ntotal = il.merge_from(ils.data(), ils.size(), True)
|
|
|
|
print("swap into index0")
|
|
|
|
index0_ivf = faiss.extract_index_ivf(index0)
|
|
index0_ivf.nlist = il0.nlist
|
|
index0_ivf.ntotal = index0.ntotal = ntotal
|
|
index0_ivf.invlists = il
|
|
index0_ivf.own_invlists = False
|
|
|
|
print("write", args.output)
|
|
|
|
faiss.write_index(index0, args.output)
|