faiss/benchs/bench_gpu_1bn.py
Matthijs Douze 6d0bc58db6 Implementation of PQ4 search with SIMD instructions (#1542)
Summary:
IndexPQ and IndexIVFPQ implementations with AVX shuffle instructions.

The training and computing of the codes does not change wrt. the original PQ versions but the code layout is "packed" so that it can be used efficiently by the SIMD computation kernels.

The main changes are:

- new IndexPQFastScan and IndexIVFPQFastScan objects

- simdib.h for an abstraction above the AVX2 intrinsics

- BlockInvertedLists for invlists that are 32-byte aligned and where codes are not sequential

- pq4_fast_scan.h/.cpp:  for packing codes and look-up tables + optmized distance comptuation kernels

- simd_result_hander.h: SIMD version of result collection in heaps / reservoirs

Misc changes:

- added contrib.inspect_tools to access fields in C++ objects

- moved .h and .cpp code for inverted lists to an invlists/ subdirectory, and made a .h/.cpp for InvertedListsIOHook

- added a new inverted lists type with 32-byte aligned codes (for consumption by SIMD)

- moved Windows-specific intrinsics to platfrom_macros.h

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1542

Test Plan:
```
buck test mode/opt  -j 4  //faiss/tests/:test_fast_scan_ivf //faiss/tests/:test_fast_scan
buck test mode/opt  //faiss/manifold/...
```

Reviewed By: wickedfoo

Differential Revision: D25175439

Pulled By: mdouze

fbshipit-source-id: ad1a40c0df8c10f4b364bdec7172e43d71b56c34
2020-12-03 10:06:38 -08:00

748 lines
22 KiB
Python

#! /usr/bin/env python2
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import numpy as np
import time
import os
import sys
import faiss
import re
from multiprocessing.dummy import Pool as ThreadPool
from datasets import ivecs_read
####################################################################
# Parse command line
####################################################################
def usage():
print("""
Usage: bench_gpu_1bn.py dataset indextype [options]
dataset: set of vectors to operate on.
Supported: SIFT1M, SIFT2M, ..., SIFT1000M or Deep1B
indextype: any index type supported by index_factory that runs on GPU.
General options
-ngpu ngpu nb of GPUs to use (default = all)
-tempmem N use N bytes of temporary GPU memory
-nocache do not read or write intermediate files
-float16 use 16-bit floats on the GPU side
Add options
-abs N split adds in blocks of no more than N vectors
-max_add N copy sharded dataset to CPU each max_add additions
(to avoid memory overflows with geometric reallocations)
-altadd Alternative add function, where the index is not stored
on GPU during add. Slightly faster for big datasets on
slow GPUs
Search options
-R R: nb of replicas of the same dataset (the dataset
will be copied across ngpu/R, default R=1)
-noptables do not use precomputed tables in IVFPQ.
-qbs N split queries in blocks of no more than N vectors
-nnn N search N neighbors for each query
-nprobe 4,16,64 try this number of probes
-knngraph instead of the standard setup for the dataset,
compute a k-nn graph with nnn neighbors per element
-oI xx%d.npy output the search result indices to this numpy file,
%d will be replaced with the nprobe
-oD xx%d.npy output the search result distances to this file
""", file=sys.stderr)
sys.exit(1)
# default values
dbname = None
index_key = None
ngpu = faiss.get_num_gpus()
replicas = 1 # nb of replicas of sharded dataset
add_batch_size = 32768
query_batch_size = 16384
nprobes = [1 << l for l in range(9)]
knngraph = False
use_precomputed_tables = True
tempmem = -1 # if -1, use system default
max_add = -1
use_float16 = False
use_cache = True
nnn = 10
altadd = False
I_fname = None
D_fname = None
args = sys.argv[1:]
while args:
a = args.pop(0)
if a == '-h': usage()
elif a == '-ngpu': ngpu = int(args.pop(0))
elif a == '-R': replicas = int(args.pop(0))
elif a == '-noptables': use_precomputed_tables = False
elif a == '-abs': add_batch_size = int(args.pop(0))
elif a == '-qbs': query_batch_size = int(args.pop(0))
elif a == '-nnn': nnn = int(args.pop(0))
elif a == '-tempmem': tempmem = int(args.pop(0))
elif a == '-nocache': use_cache = False
elif a == '-knngraph': knngraph = True
elif a == '-altadd': altadd = True
elif a == '-float16': use_float16 = True
elif a == '-nprobe': nprobes = [int(x) for x in args.pop(0).split(',')]
elif a == '-max_add': max_add = int(args.pop(0))
elif not dbname: dbname = a
elif not index_key: index_key = a
else:
print("argument %s unknown" % a, file=sys.stderr)
sys.exit(1)
cacheroot = '/tmp/bench_gpu_1bn'
if not os.path.isdir(cacheroot):
print("%s does not exist, creating it" % cacheroot)
os.mkdir(cacheroot)
#################################################################
# Small Utility Functions
#################################################################
# we mem-map the biggest files to avoid having them in memory all at
# once
def mmap_fvecs(fname):
x = np.memmap(fname, dtype='int32', mode='r')
d = x[0]
return x.view('float32').reshape(-1, d + 1)[:, 1:]
def mmap_bvecs(fname):
x = np.memmap(fname, dtype='uint8', mode='r')
d = x[:4].view('int32')[0]
return x.reshape(-1, d + 4)[:, 4:]
def rate_limited_imap(f, l):
"""A threaded imap that does not produce elements faster than they
are consumed"""
pool = ThreadPool(1)
res = None
for i in l:
res_next = pool.apply_async(f, (i, ))
if res:
yield res.get()
res = res_next
yield res.get()
class IdentPreproc:
"""a pre-processor is either a faiss.VectorTransform or an IndentPreproc"""
def __init__(self, d):
self.d_in = self.d_out = d
def apply_py(self, x):
return x
def sanitize(x):
""" convert array to a c-contiguous float array """
return np.ascontiguousarray(x.astype('float32'))
def dataset_iterator(x, preproc, bs):
""" iterate over the lines of x in blocks of size bs"""
nb = x.shape[0]
block_ranges = [(i0, min(nb, i0 + bs))
for i0 in range(0, nb, bs)]
def prepare_block(i01):
i0, i1 = i01
xb = sanitize(x[i0:i1])
return i0, preproc.apply_py(xb)
return rate_limited_imap(prepare_block, block_ranges)
def eval_intersection_measure(gt_I, I):
""" measure intersection measure (used for knngraph)"""
inter = 0
rank = I.shape[1]
assert gt_I.shape[1] >= rank
for q in range(nq_gt):
inter += faiss.ranklist_intersection_size(
rank, faiss.swig_ptr(gt_I[q, :]),
rank, faiss.swig_ptr(I[q, :].astype('int64')))
return inter / float(rank * nq_gt)
#################################################################
# Prepare dataset
#################################################################
print("Preparing dataset", dbname)
if dbname.startswith('SIFT'):
# SIFT1M to SIFT1000M
dbsize = int(dbname[4:-1])
xb = mmap_bvecs('bigann/bigann_base.bvecs')
xq = mmap_bvecs('bigann/bigann_query.bvecs')
xt = mmap_bvecs('bigann/bigann_learn.bvecs')
# trim xb to correct size
xb = xb[:dbsize * 1000 * 1000]
gt_I = ivecs_read('bigann/gnd/idx_%dM.ivecs' % dbsize)
elif dbname == 'Deep1B':
xb = mmap_fvecs('deep1b/base.fvecs')
xq = mmap_fvecs('deep1b/deep1B_queries.fvecs')
xt = mmap_fvecs('deep1b/learn.fvecs')
# deep1B's train is is outrageously big
xt = xt[:10 * 1000 * 1000]
gt_I = ivecs_read('deep1b/deep1B_groundtruth.ivecs')
else:
print('unknown dataset', dbname, file=sys.stderr)
sys.exit(1)
if knngraph:
# convert to knn-graph dataset
xq = xb
xt = xb
# we compute the ground-truth on this number of queries for validation
nq_gt = 10000
gt_sl = 100
# ground truth will be computed below
gt_I = None
print("sizes: B %s Q %s T %s gt %s" % (
xb.shape, xq.shape, xt.shape,
gt_I.shape if gt_I is not None else None))
#################################################################
# Parse index_key and set cache files
#
# The index_key is a valid factory key that would work, but we
# decompose the training to do it faster
#################################################################
pat = re.compile('(OPQ[0-9]+(_[0-9]+)?,|PCAR[0-9]+,)?' +
'(IVF[0-9]+),' +
'(PQ[0-9]+|Flat)')
matchobject = pat.match(index_key)
assert matchobject, 'could not parse ' + index_key
mog = matchobject.groups()
preproc_str = mog[0]
ivf_str = mog[2]
pqflat_str = mog[3]
ncent = int(ivf_str[3:])
prefix = ''
if knngraph:
gt_cachefile = '%s/BK_gt_%s.npy' % (cacheroot, dbname)
prefix = 'BK_'
# files must be kept distinct because the training set is not the
# same for the knngraph
if preproc_str:
preproc_cachefile = '%s/%spreproc_%s_%s.vectrans' % (
cacheroot, prefix, dbname, preproc_str[:-1])
else:
preproc_cachefile = None
preproc_str = ''
cent_cachefile = '%s/%scent_%s_%s%s.npy' % (
cacheroot, prefix, dbname, preproc_str, ivf_str)
index_cachefile = '%s/%s%s_%s%s,%s.index' % (
cacheroot, prefix, dbname, preproc_str, ivf_str, pqflat_str)
if not use_cache:
preproc_cachefile = None
cent_cachefile = None
index_cachefile = None
print("cachefiles:")
print(preproc_cachefile)
print(cent_cachefile)
print(index_cachefile)
#################################################################
# Wake up GPUs
#################################################################
print("preparing resources for %d GPUs" % ngpu)
gpu_resources = []
for i in range(ngpu):
res = faiss.StandardGpuResources()
if tempmem >= 0:
res.setTempMemory(tempmem)
gpu_resources.append(res)
def make_vres_vdev(i0=0, i1=-1):
" return vectors of device ids and resources useful for gpu_multiple"
vres = faiss.GpuResourcesVector()
vdev = faiss.IntVector()
if i1 == -1:
i1 = ngpu
for i in range(i0, i1):
vdev.push_back(i)
vres.push_back(gpu_resources[i])
return vres, vdev
#################################################################
# Prepare ground truth (for the knngraph)
#################################################################
def compute_GT():
print("compute GT")
t0 = time.time()
gt_I = np.zeros((nq_gt, gt_sl), dtype='int64')
gt_D = np.zeros((nq_gt, gt_sl), dtype='float32')
heaps = faiss.float_maxheap_array_t()
heaps.k = gt_sl
heaps.nh = nq_gt
heaps.val = faiss.swig_ptr(gt_D)
heaps.ids = faiss.swig_ptr(gt_I)
heaps.heapify()
bs = 10 ** 5
n, d = xb.shape
xqs = sanitize(xq[:nq_gt])
db_gt = faiss.IndexFlatL2(d)
vres, vdev = make_vres_vdev()
db_gt_gpu = faiss.index_cpu_to_gpu_multiple(
vres, vdev, db_gt)
# compute ground-truth by blocks of bs, and add to heaps
for i0, xsl in dataset_iterator(xb, IdentPreproc(d), bs):
db_gt_gpu.add(xsl)
D, I = db_gt_gpu.search(xqs, gt_sl)
I += i0
heaps.addn_with_ids(
gt_sl, faiss.swig_ptr(D), faiss.swig_ptr(I), gt_sl)
db_gt_gpu.reset()
print("\r %d/%d, %.3f s" % (i0, n, time.time() - t0), end=' ')
print()
heaps.reorder()
print("GT time: %.3f s" % (time.time() - t0))
return gt_I
if knngraph:
if gt_cachefile and os.path.exists(gt_cachefile):
print("load GT", gt_cachefile)
gt_I = np.load(gt_cachefile)
else:
gt_I = compute_GT()
if gt_cachefile:
print("store GT", gt_cachefile)
np.save(gt_cachefile, gt_I)
#################################################################
# Prepare the vector transformation object (pure CPU)
#################################################################
def train_preprocessor():
print("train preproc", preproc_str)
d = xt.shape[1]
t0 = time.time()
if preproc_str.startswith('OPQ'):
fi = preproc_str[3:-1].split('_')
m = int(fi[0])
dout = int(fi[1]) if len(fi) == 2 else d
preproc = faiss.OPQMatrix(d, m, dout)
elif preproc_str.startswith('PCAR'):
dout = int(preproc_str[4:-1])
preproc = faiss.PCAMatrix(d, dout, 0, True)
else:
assert False
preproc.train(sanitize(xt[:1000000]))
print("preproc train done in %.3f s" % (time.time() - t0))
return preproc
def get_preprocessor():
if preproc_str:
if not preproc_cachefile or not os.path.exists(preproc_cachefile):
preproc = train_preprocessor()
if preproc_cachefile:
print("store", preproc_cachefile)
faiss.write_VectorTransform(preproc, preproc_cachefile)
else:
print("load", preproc_cachefile)
preproc = faiss.read_VectorTransform(preproc_cachefile)
else:
d = xb.shape[1]
preproc = IdentPreproc(d)
return preproc
#################################################################
# Prepare the coarse quantizer
#################################################################
def train_coarse_quantizer(x, k, preproc):
d = preproc.d_out
clus = faiss.Clustering(d, k)
clus.verbose = True
# clus.niter = 2
clus.max_points_per_centroid = 10000000
print("apply preproc on shape", x.shape, 'k=', k)
t0 = time.time()
x = preproc.apply_py(sanitize(x))
print(" preproc %.3f s output shape %s" % (
time.time() - t0, x.shape))
vres, vdev = make_vres_vdev()
index = faiss.index_cpu_to_gpu_multiple(
vres, vdev, faiss.IndexFlatL2(d))
clus.train(x, index)
centroids = faiss.vector_float_to_array(clus.centroids)
return centroids.reshape(k, d)
def prepare_coarse_quantizer(preproc):
if cent_cachefile and os.path.exists(cent_cachefile):
print("load centroids", cent_cachefile)
centroids = np.load(cent_cachefile)
else:
nt = max(1000000, 256 * ncent)
print("train coarse quantizer...")
t0 = time.time()
centroids = train_coarse_quantizer(xt[:nt], ncent, preproc)
print("Coarse train time: %.3f s" % (time.time() - t0))
if cent_cachefile:
print("store centroids", cent_cachefile)
np.save(cent_cachefile, centroids)
coarse_quantizer = faiss.IndexFlatL2(preproc.d_out)
coarse_quantizer.add(centroids)
return coarse_quantizer
#################################################################
# Make index and add elements to it
#################################################################
def prepare_trained_index(preproc):
coarse_quantizer = prepare_coarse_quantizer(preproc)
d = preproc.d_out
if pqflat_str == 'Flat':
print("making an IVFFlat index")
idx_model = faiss.IndexIVFFlat(coarse_quantizer, d, ncent,
faiss.METRIC_L2)
else:
m = int(pqflat_str[2:])
assert m < 56 or use_float16, "PQ%d will work only with -float16" % m
print("making an IVFPQ index, m = ", m)
idx_model = faiss.IndexIVFPQ(coarse_quantizer, d, ncent, m, 8)
coarse_quantizer.this.disown()
idx_model.own_fields = True
# finish training on CPU
t0 = time.time()
print("Training vector codes")
x = preproc.apply_py(sanitize(xt[:1000000]))
idx_model.train(x)
print(" done %.3f s" % (time.time() - t0))
return idx_model
def compute_populated_index(preproc):
"""Add elements to a sharded index. Return the index and if available
a sharded gpu_index that contains the same data. """
indexall = prepare_trained_index(preproc)
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = use_float16
co.useFloat16CoarseQuantizer = False
co.usePrecomputed = use_precomputed_tables
co.indicesOptions = faiss.INDICES_CPU
co.verbose = True
co.reserveVecs = max_add if max_add > 0 else xb.shape[0]
co.shard = True
assert co.shard_type in (0, 1, 2)
vres, vdev = make_vres_vdev()
gpu_index = faiss.index_cpu_to_gpu_multiple(
vres, vdev, indexall, co)
print("add...")
t0 = time.time()
nb = xb.shape[0]
for i0, xs in dataset_iterator(xb, preproc, add_batch_size):
i1 = i0 + xs.shape[0]
gpu_index.add_with_ids(xs, np.arange(i0, i1))
if max_add > 0 and gpu_index.ntotal > max_add:
print("Flush indexes to CPU")
for i in range(ngpu):
index_src_gpu = faiss.downcast_index(gpu_index.at(i))
index_src = faiss.index_gpu_to_cpu(index_src_gpu)
print(" index %d size %d" % (i, index_src.ntotal))
index_src.copy_subset_to(indexall, 0, 0, nb)
index_src_gpu.reset()
index_src_gpu.reserveMemory(max_add)
gpu_index.sync_with_shard_indexes()
print('\r%d/%d (%.3f s) ' % (
i0, nb, time.time() - t0), end=' ')
sys.stdout.flush()
print("Add time: %.3f s" % (time.time() - t0))
print("Aggregate indexes to CPU")
t0 = time.time()
if hasattr(gpu_index, 'at'):
# it is a sharded index
for i in range(ngpu):
index_src = faiss.index_gpu_to_cpu(gpu_index.at(i))
print(" index %d size %d" % (i, index_src.ntotal))
index_src.copy_subset_to(indexall, 0, 0, nb)
else:
# simple index
index_src = faiss.index_gpu_to_cpu(gpu_index)
index_src.copy_subset_to(indexall, 0, 0, nb)
print(" done in %.3f s" % (time.time() - t0))
if max_add > 0:
# it does not contain all the vectors
gpu_index = None
return gpu_index, indexall
def compute_populated_index_2(preproc):
indexall = prepare_trained_index(preproc)
# set up a 3-stage pipeline that does:
# - stage 1: load + preproc
# - stage 2: assign on GPU
# - stage 3: add to index
stage1 = dataset_iterator(xb, preproc, add_batch_size)
vres, vdev = make_vres_vdev()
coarse_quantizer_gpu = faiss.index_cpu_to_gpu_multiple(
vres, vdev, indexall.quantizer)
def quantize(args):
(i0, xs) = args
_, assign = coarse_quantizer_gpu.search(xs, 1)
return i0, xs, assign.ravel()
stage2 = rate_limited_imap(quantize, stage1)
print("add...")
t0 = time.time()
nb = xb.shape[0]
for i0, xs, assign in stage2:
i1 = i0 + xs.shape[0]
if indexall.__class__ == faiss.IndexIVFPQ:
indexall.add_core_o(i1 - i0, faiss.swig_ptr(xs),
None, None, faiss.swig_ptr(assign))
elif indexall.__class__ == faiss.IndexIVFFlat:
indexall.add_core(i1 - i0, faiss.swig_ptr(xs), None,
faiss.swig_ptr(assign))
else:
assert False
print('\r%d/%d (%.3f s) ' % (
i0, nb, time.time() - t0), end=' ')
sys.stdout.flush()
print("Add time: %.3f s" % (time.time() - t0))
return None, indexall
def get_populated_index(preproc):
if not index_cachefile or not os.path.exists(index_cachefile):
if not altadd:
gpu_index, indexall = compute_populated_index(preproc)
else:
gpu_index, indexall = compute_populated_index_2(preproc)
if index_cachefile:
print("store", index_cachefile)
faiss.write_index(indexall, index_cachefile)
else:
print("load", index_cachefile)
indexall = faiss.read_index(index_cachefile)
gpu_index = None
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = use_float16
co.useFloat16CoarseQuantizer = False
co.usePrecomputed = use_precomputed_tables
co.indicesOptions = 0
co.verbose = True
co.shard = True # the replicas will be made "manually"
t0 = time.time()
print("CPU index contains %d vectors, move to GPU" % indexall.ntotal)
if replicas == 1:
if not gpu_index:
print("copying loaded index to GPUs")
vres, vdev = make_vres_vdev()
index = faiss.index_cpu_to_gpu_multiple(
vres, vdev, indexall, co)
else:
index = gpu_index
else:
del gpu_index # We override the GPU index
print("Copy CPU index to %d sharded GPU indexes" % replicas)
index = faiss.IndexReplicas()
for i in range(replicas):
gpu0 = ngpu * i / replicas
gpu1 = ngpu * (i + 1) / replicas
vres, vdev = make_vres_vdev(gpu0, gpu1)
print(" dispatch to GPUs %d:%d" % (gpu0, gpu1))
index1 = faiss.index_cpu_to_gpu_multiple(
vres, vdev, indexall, co)
index1.this.disown()
index.addIndex(index1)
index.own_fields = True
del indexall
print("move to GPU done in %.3f s" % (time.time() - t0))
return index
#################################################################
# Perform search
#################################################################
def eval_dataset(index, preproc):
ps = faiss.GpuParameterSpace()
ps.initialize(index)
nq_gt = gt_I.shape[0]
print("search...")
sl = query_batch_size
nq = xq.shape[0]
for nprobe in nprobes:
ps.set_index_parameter(index, 'nprobe', nprobe)
t0 = time.time()
if sl == 0:
D, I = index.search(preproc.apply_py(sanitize(xq)), nnn)
else:
I = np.empty((nq, nnn), dtype='int32')
D = np.empty((nq, nnn), dtype='float32')
inter_res = ''
for i0, xs in dataset_iterator(xq, preproc, sl):
print('\r%d/%d (%.3f s%s) ' % (
i0, nq, time.time() - t0, inter_res), end=' ')
sys.stdout.flush()
i1 = i0 + xs.shape[0]
Di, Ii = index.search(xs, nnn)
I[i0:i1] = Ii
D[i0:i1] = Di
if knngraph and not inter_res and i1 >= nq_gt:
ires = eval_intersection_measure(
gt_I[:, :nnn], I[:nq_gt])
inter_res = ', %.4f' % ires
t1 = time.time()
if knngraph:
ires = eval_intersection_measure(gt_I[:, :nnn], I[:nq_gt])
print(" probe=%-3d: %.3f s rank-%d intersection results: %.4f" % (
nprobe, t1 - t0, nnn, ires))
else:
print(" probe=%-3d: %.3f s" % (nprobe, t1 - t0), end=' ')
gtc = gt_I[:, :1]
nq = xq.shape[0]
for rank in 1, 10, 100:
if rank > nnn: continue
nok = (I[:, :rank] == gtc).sum()
print("1-R@%d: %.4f" % (rank, nok / float(nq)), end=' ')
print()
if I_fname:
I_fname_i = I_fname % I
print("storing", I_fname_i)
np.save(I, I_fname_i)
if D_fname:
D_fname_i = I_fname % I
print("storing", D_fname_i)
np.save(D, D_fname_i)
#################################################################
# Driver
#################################################################
preproc = get_preprocessor()
index = get_populated_index(preproc)
eval_dataset(index, preproc)
# make sure index is deleted before the resources
del index