building blocks for hybrid CPU / GPU search (#2638)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2638

This diff is a more streamlined way of searching IVF indexes with precomputed clusters.
This will be used for experiments with hybrid CPU / GPU search.

Reviewed By: algoriddle

Differential Revision: D41301032

fbshipit-source-id: a1d645fd0f2bf806454dfd04971edc0a6200d20d
pull/2676/head
Matthijs Douze 2023-01-12 13:34:44 -08:00 committed by Facebook GitHub Bot
parent 1eb4f42639
commit 8fc3775472
21 changed files with 412 additions and 189 deletions

View File

@ -70,7 +70,7 @@ if args.factory_string == "":
else:
factory_string = args.factory_string
print(f"instanciate {factory_string}")
print(f"instantiate {factory_string}")
index = faiss.index_factory(ds.d, factory_string)
if args.factory_string != "":

View File

@ -32,7 +32,9 @@ def add_preassigned(index_ivf, x, a, ids=None):
def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
"""
Perform a search in the IVF index, with predefined lists to search into
Perform a search in the IVF index, with predefined lists to search into.
Supports indexes with pretransforms (as opposed to the
IndexIVF.search_preassigned, that cannot be applied with pretransform).
"""
n, d = xq.shape
if isinstance(index_ivf, faiss.IndexBinaryIVF):
@ -51,14 +53,7 @@ def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
else:
assert coarse_dis.shape == (n, index_ivf.nprobe)
D = np.empty((n, k), dtype=dis_type)
I = np.empty((n, k), dtype='int64')
sp = faiss.swig_ptr
index_ivf.search_preassigned(
n, sp(xq), k,
sp(list_nos), sp(coarse_dis), sp(D), sp(I), False)
return D, I
return index_ivf.search_preassigned(xq, k, list_nos, coarse_dis)
def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):

View File

@ -1125,7 +1125,7 @@ void IndexIVF::replace_invlists(InvertedLists* il, bool own) {
void IndexIVF::copy_subset_to(
IndexIVF& other,
int subset_type,
InvertedLists::subset_type_t subset_type,
idx_t a1,
idx_t a2) const {
other.ntotal +=

View File

@ -326,7 +326,7 @@ struct IndexIVF : Index, Level1Quantizer {
*/
virtual void copy_subset_to(
IndexIVF& other,
int subset_type,
InvertedLists::subset_type_t subset_type,
idx_t a1,
idx_t a2) const;

View File

@ -33,92 +33,6 @@ void translate_labels(long n, idx_t* labels, long translation) {
}
}
/** Merge result tables from several shards. The per-shard results are assumed
* to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
* element heap because we want the best (ie. lowest for L2) result to be on
* top, not the worst.
*
* @param all_distances size nshard * n * k
* @param all_labels idem
* @param translations label translations to apply, size nshard
*/
template <class IndexClass, class C>
void merge_tables(
long n,
long k,
long nshard,
typename IndexClass::distance_t* distances,
idx_t* labels,
const std::vector<typename IndexClass::distance_t>& all_distances,
const std::vector<idx_t>& all_labels,
const std::vector<long>& translations) {
if (k == 0) {
return;
}
using distance_t = typename IndexClass::distance_t;
long stride = n * k;
#pragma omp parallel if (n * nshard * k > 100000)
{
std::vector<int> buf(2 * nshard);
// index in each shard's result list
int* pointer = buf.data();
// (shard_ids, heap_vals): heap that indexes
// shard -> current distance for this shard
int* shard_ids = pointer + nshard;
std::vector<distance_t> buf2(nshard);
distance_t* heap_vals = buf2.data();
#pragma omp for
for (long i = 0; i < n; i++) {
// the heap maps values to the shard where they are
// produced.
const distance_t* D_in = all_distances.data() + i * k;
const idx_t* I_in = all_labels.data() + i * k;
int heap_size = 0;
// push the first element of each shard (if not -1)
for (long s = 0; s < nshard; s++) {
pointer[s] = 0;
if (I_in[stride * s] >= 0) {
heap_push<C>(
++heap_size,
heap_vals,
shard_ids,
D_in[stride * s],
s);
}
}
distance_t* D = distances + i * k;
idx_t* I = labels + i * k;
int j;
for (j = 0; j < k && heap_size > 0; j++) {
// pop element from best shard
int s = shard_ids[0]; // top of heap
int& p = pointer[s];
D[j] = heap_vals[0];
I[j] = I_in[stride * s + p] + translations[s];
// pop from shard, advance pointer for this shard
heap_pop<C>(heap_size--, heap_vals, shard_ids);
p++;
if (p < k && I_in[stride * s + p] >= 0) {
heap_push<C>(
++heap_size,
heap_vals,
shard_ids,
D_in[stride * s + p],
s);
}
}
for (; j < k; j++) {
I[j] = -1;
D[j] = C::Crev::neutral();
}
}
}
}
} // anonymous namespace
template <typename IndexT>
@ -303,27 +217,6 @@ void IndexShardsTemplate<IndexT>::search(
std::vector<distance_t> all_distances(nshard * k * n);
std::vector<idx_t> all_labels(nshard * k * n);
auto fn = [n, k, x, &all_distances, &all_labels](
int no, const IndexT* index) {
if (index->verbose) {
printf("begin query shard %d on %" PRId64 " points\n", no, n);
}
index->search(
n,
x,
k,
all_distances.data() + no * k * n,
all_labels.data() + no * k * n);
if (index->verbose) {
printf("end query shard %d\n", no);
}
};
this->runOnIndex(fn);
std::vector<long> translations(nshard, 0);
// Because we just called runOnIndex above, it is safe to access the
@ -336,26 +229,47 @@ void IndexShardsTemplate<IndexT>::search(
}
}
auto fn = [n, k, x, &all_distances, &all_labels, &translations](
int no, const IndexT* index) {
if (index->verbose) {
printf("begin query shard %d on %" PRId64 " points\n", no, n);
}
index->search(
n,
x,
k,
all_distances.data() + no * k * n,
all_labels.data() + no * k * n);
translate_labels(
n * k, all_labels.data() + no * k * n, translations[no]);
if (index->verbose) {
printf("end query shard %d\n", no);
}
};
this->runOnIndex(fn);
if (this->metric_type == METRIC_L2) {
merge_tables<IndexT, CMin<distance_t, int>>(
merge_knn_results<idx_t, CMin<distance_t, int>>(
n,
k,
nshard,
all_distances.data(),
all_labels.data(),
distances,
labels,
all_distances,
all_labels,
translations);
labels);
} else {
merge_tables<IndexT, CMax<distance_t, int>>(
merge_knn_results<idx_t, CMax<distance_t, int>>(
n,
k,
nshard,
all_distances.data(),
all_labels.data(),
distances,
labels,
all_distances,
all_labels,
translations);
labels);
}
}

View File

@ -71,7 +71,7 @@ struct IndexShardsTemplate : public ThreadedIndex<IndexT> {
* Cases (successive_ids, xids):
* - true, non-NULL ERROR: it makes no sense to pass in ids and
* request them to be shifted
* - true, NULL OK, but should be called only once (calls add()
* - true, NULL OK: but should be called only once (calls add()
* on sub-indexes).
* - false, non-NULL OK: will call add_with_ids with passed in xids
* distributed evenly over shards

View File

@ -248,12 +248,14 @@ void ToGpuClonerMultiple::copy_ivf_shard(
if (verbose)
printf("IndexShards shard %ld indices %ld:%ld\n", i, i0, i1);
index_ivf->copy_subset_to(*idx2, 2, i0, i1);
index_ivf->copy_subset_to(
*idx2, InvertedLists::SUBSET_TYPE_ID_RANGE, i0, i1);
FAISS_ASSERT(idx2->ntotal == i1 - i0);
} else if (shard_type == 1) {
if (verbose)
printf("IndexShards shard %ld select modulo %ld = %ld\n", i, n, i);
index_ivf->copy_subset_to(*idx2, 1, n, i);
index_ivf->copy_subset_to(
*idx2, InvertedLists::SUBSET_TYPE_ID_MOD, n, i);
} else {
FAISS_THROW_FMT("shard_type %d not implemented", shard_type);
}

View File

@ -636,7 +636,7 @@ void ZnSphereCodecRec::decode(uint64_t code, float* c) const {
}
}
// if not use_rec, instanciate an arbitrary harmless znc_rec
// if not use_rec, instantiate an arbitrary harmless znc_rec
ZnSphereCodecAlt::ZnSphereCodecAlt(int dim, int r2)
: ZnSphereCodec(dim, r2),
use_rec((dim & (dim - 1)) == 0),

View File

@ -19,7 +19,7 @@
* otherwise register spilling becomes too large.
*
* The implementation of these functions is spread over 3 cpp files to reduce
* parallel compile times. Templates are instanciated explicitly.
* parallel compile times. Templates are instantiated explicitly.
*/
namespace faiss {

View File

@ -189,7 +189,7 @@ void accumulate(
DISPATCH(3);
DISPATCH(4);
}
FAISS_THROW_FMT("accumulate nq=%d not instanciated", nq);
FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq);
#undef DISPATCH
}
@ -263,7 +263,7 @@ void pq4_accumulate_loop_qbs(
DISPATCH(4);
#undef DISPATCH
default:
FAISS_THROW_FMT("accumulate nq=%d not instanciated", nq);
FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq);
}
i0 += nq;
LUT += nq * nsq * 16;

View File

@ -88,13 +88,13 @@ void InvertedLists::merge_from(InvertedLists* oivf, size_t add_id) {
size_t InvertedLists::copy_subset_to(
InvertedLists& oivf,
int subset_type,
subset_type_t subset_type,
idx_t a1,
idx_t a2) const {
FAISS_THROW_IF_NOT(nlist == oivf.nlist);
FAISS_THROW_IF_NOT(code_size == oivf.code_size);
FAISS_THROW_IF_NOT_FMT(
subset_type >= 0 && subset_type <= 3,
subset_type >= 0 && subset_type <= 4,
"subset type %d not implemented",
subset_type);
size_t accu_n = 0;
@ -111,7 +111,7 @@ size_t InvertedLists::copy_subset_to(
size_t n = list_size(list_no);
ScopedIds ids_in(this, list_no);
if (subset_type == 0) {
if (subset_type == SUBSET_TYPE_ID_RANGE) {
for (idx_t i = 0; i < n; i++) {
idx_t id = ids_in[i];
if (a1 <= id && id < a2) {
@ -122,7 +122,7 @@ size_t InvertedLists::copy_subset_to(
n_added++;
}
}
} else if (subset_type == 1) {
} else if (subset_type == SUBSET_TYPE_ID_MOD) {
for (idx_t i = 0; i < n; i++) {
idx_t id = ids_in[i];
if (id % a1 == a2) {
@ -133,7 +133,7 @@ size_t InvertedLists::copy_subset_to(
n_added++;
}
}
} else if (subset_type == 2) {
} else if (subset_type == SUBSET_TYPE_ELEMENT_RANGE) {
// see what is allocated to a1 and to a2
size_t next_accu_n = accu_n + n;
size_t next_accu_a1 = next_accu_n * a1 / ntotal;
@ -151,7 +151,7 @@ size_t InvertedLists::copy_subset_to(
n_added += i2 - i1;
accu_a1 = next_accu_a1;
accu_a2 = next_accu_a2;
} else if (subset_type == 3) {
} else if (subset_type == SUBSET_TYPE_INVLIST_FRACTION) {
size_t i1 = n * a2 / a1;
size_t i2 = n * (a2 + 1) / a1;
@ -163,6 +163,15 @@ size_t InvertedLists::copy_subset_to(
}
n_added += i2 - i1;
} else if (subset_type == SUBSET_TYPE_INVLIST) {
if (list_no >= a1 && list_no < a2) {
oivf.add_entries(
list_no,
n,
ScopedIds(this, list_no).get(),
ScopedCodes(this, list_no).get());
n_added += n;
}
}
accu_n += n;
}

View File

@ -111,20 +111,28 @@ struct InvertedLists {
/// move all entries from oivf (empty on output)
void merge_from(InvertedLists* oivf, size_t add_id);
// how to copy a subset of elements from the inverted lists
// This depends on two integers, a1 and a2.
enum subset_type_t : int {
// depends on IDs
SUBSET_TYPE_ID_RANGE = 0, // copies ids in [a1, a2)
SUBSET_TYPE_ID_MOD = 1, // copies ids if id % a1 == a2
// depends on order within invlists
SUBSET_TYPE_ELEMENT_RANGE =
2, // copies fractions of invlists so that a1 elements are left
// before and a2 after
SUBSET_TYPE_INVLIST_FRACTION =
3, // take fraction a2 out of a1 from each invlist, 0 <= a2 < a1
// copy only inverted lists a1:a2
SUBSET_TYPE_INVLIST = 4
};
/** copy a subset of the entries index to the other index
*
* if subset_type == 0: copies ids in [a1, a2)
* if subset_type == 1: copies ids if id % a1 == a2
* if subset_type == 2: copies inverted lists such that a1
* elements are left before and a2 elements are after
* (insensitive to ids)
* if subset_type == 3: take fraction a2 out of a1 from each invlist
* (does not depend on ids). 0 <= a2 < a1
* @return number of entries copied
*/
size_t copy_subset_to(
InvertedLists& other,
int subset_type,
subset_type_t subset_type,
idx_t a1,
idx_t a2) const;

View File

@ -21,7 +21,8 @@ from faiss.gpu_wrappers import *
from faiss.array_conversions import *
from faiss.extra_wrappers import kmin, kmax, pairwise_distances, rand, randint, \
lrand, randn, rand_smooth_vectors, eval_intersection, normalize_L2, \
ResultHeap, knn, Kmeans, checksum, matrix_bucket_sort_inplace, bucket_sort
ResultHeap, knn, Kmeans, checksum, matrix_bucket_sort_inplace, bucket_sort, \
merge_knn_results
__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR,

View File

@ -554,6 +554,70 @@ def handle_Index(the_class):
I = rev_swig_ptr(res.labels, nd).copy()
return lims, D, I
def replacement_search_preassigned(self, x, k, Iq, Dq, *, params=None, D=None, I=None):
"""Find the k nearest neighbors of the set of vectors x in an IVF index,
with precalculated coarse quantization assignment.
Parameters
----------
x : array_like
Query vectors, shape (n, d) where d is appropriate for the index.
`dtype` must be float32.
k : int
Number of nearest neighbors.
Dq : array_like, optional
Distance array to the centroids, size (n, nprobe)
Iq : array_like, optional
Nearest centroids, size (n, nprobe)
params : SearchParameters
Search parameters of the current search (overrides the class-level params)
D : array_like, optional
Distance array to store the result.
I : array_like, optional
Labels array to store the results.
Returns
-------
D : array_like
Distances of the nearest neighbors, shape (n, k). When not enough results are found
the label is set to +Inf or -Inf.
I : array_like
Labels of the nearest neighbors, shape (n, k).
When not enough results are found, the label is set to -1
"""
n, d = x.shape
x = np.ascontiguousarray(x, dtype='float32')
assert d == self.d
assert k > 0
if D is None:
D = np.empty((n, k), dtype=np.float32)
else:
assert D.shape == (n, k)
if I is None:
I = np.empty((n, k), dtype=np.int64)
else:
assert I.shape == (n, k)
Iq = np.ascontiguousarray(Iq, dtype='int64')
assert params is None, "params not supported"
assert Iq.shape == (n, self.nprobe)
if Dq is not None:
Dq = np.ascontiguousarray(Dq, dtype='float32')
assert Dq.shape == Iq.shape
self.search_preassigned_c(
n, swig_ptr(x),
k,
swig_ptr(Iq), swig_ptr(Dq),
swig_ptr(D), swig_ptr(I),
False
)
return D, I
def replacement_sa_encode(self, x, codes=None):
n, d = x.shape
assert d == self.d
@ -605,6 +669,8 @@ def handle_Index(the_class):
ignore_missing=True)
replace_method(the_class, 'search_and_reconstruct',
replacement_search_and_reconstruct, ignore_missing=True)
replace_method(the_class, 'search_preassigned',
replacement_search_preassigned, ignore_missing=True)
replace_method(the_class, 'sa_encode', replacement_sa_encode)
replace_method(the_class, 'sa_decode', replacement_sa_decode)
replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes,
@ -664,6 +730,31 @@ def handle_IndexBinary(the_class):
swig_ptr(labels))
return distances, labels
def replacement_search_preassigned(self, x, k, Iq, Dq):
n, d = x.shape
x = _check_dtype_uint8(x)
assert d * 8 == self.d
assert k > 0
D = np.empty((n, k), dtype=np.int32)
I = np.empty((n, k), dtype=np.int64)
Iq = np.ascontiguousarray(Iq, dtype='int64')
assert Iq.shape == (n, self.nprobe)
if Dq is not None:
Dq = np.ascontiguousarray(Dq, dtype='int32')
assert Dq.shape == Iq.shape
self.search_preassigned_c(
n, swig_ptr(x),
k,
swig_ptr(Iq), swig_ptr(Dq),
swig_ptr(D), swig_ptr(I),
False
)
return D, I
def replacement_range_search(self, x, thresh):
n, d = x.shape
x = _check_dtype_uint8(x)
@ -693,6 +784,8 @@ def handle_IndexBinary(the_class):
replace_method(the_class, 'range_search', replacement_range_search)
replace_method(the_class, 'reconstruct', replacement_reconstruct)
replace_method(the_class, 'remove_ids', replacement_remove_ids)
replace_method(the_class, 'search_preassigned',
replacement_search_preassigned, ignore_missing=True)
def handle_VectorTransform(the_class):

View File

@ -279,6 +279,23 @@ class ResultHeap:
self.heaps.reorder()
def merge_knn_results(Dall, Iall, keep_max=False):
"""
Merge a set of sorted knn-results obtained from different shards in a dataset
Dall and Iall are of size (nshard, nq, k) each D[i, j] should be sorted
returns D, I of size (nq, k) as the merged result set
"""
assert Iall.shape == Dall.shape
nshard, n, k = Dall.shape
Dnew = np.empty((n, k), dtype=Dall.dtype)
Inew = np.empty((n, k), dtype=Iall.dtype)
func = merge_knn_results_CMax if keep_max else merge_knn_results_CMin
func(
n, k, nshard,
swig_ptr(Dall), swig_ptr(Iall),
swig_ptr(Dnew), swig_ptr(Inew)
)
return Dnew, Inew
######################################################
# KNN function

View File

@ -938,7 +938,6 @@ REV_SWIG_PTR(uint64_t, NPY_UINT64);
%template(float_minheap_array_t) faiss::HeapArray<faiss::CMin<float, int64_t> >;
%template(int_minheap_array_t) faiss::HeapArray<faiss::CMin<int, int64_t> >;
%template(float_maxheap_array_t) faiss::HeapArray<faiss::CMax<float, int64_t> >;
%template(int_maxheap_array_t) faiss::HeapArray<faiss::CMax<int, int64_t> >;
@ -951,46 +950,55 @@ REV_SWIG_PTR(uint64_t, NPY_UINT64);
%template(AlignedTableUint16) faiss::AlignedTable<uint16_t>;
%template(AlignedTableFloat32) faiss::AlignedTable<float>;
// SWIG seems to have some trouble resolving function template types here, so
// declare explicitly
%define INSTANTIATE_uint16_partition_fuzzy(C, id_t)
%inline %{
// SWIG seems to have has some trouble resolving the template type here, so
// declare explicitly
uint16_t CMax_uint16_partition_fuzzy(
uint16_t *vals, int64_t *ids, size_t n,
uint16_t C ## _uint16_partition_fuzzy(
uint16_t *vals, id_t *ids, size_t n,
size_t q_min, size_t q_max, size_t * q_out)
{
return faiss::partition_fuzzy<faiss::CMax<unsigned short, int64_t> >(
vals, ids, n, q_min, q_max, q_out);
}
uint16_t CMin_uint16_partition_fuzzy(
uint16_t *vals, int64_t *ids, size_t n,
size_t q_min, size_t q_max, size_t * q_out)
{
return faiss::partition_fuzzy<faiss::CMin<unsigned short, int64_t> >(
vals, ids, n, q_min, q_max, q_out);
}
// and overload with the int32 version
uint16_t CMax_uint16_partition_fuzzy(
uint16_t *vals, int *ids, size_t n,
size_t q_min, size_t q_max, size_t * q_out)
{
return faiss::partition_fuzzy<faiss::CMax<unsigned short, int> >(
vals, ids, n, q_min, q_max, q_out);
}
uint16_t CMin_uint16_partition_fuzzy(
uint16_t *vals, int *ids, size_t n,
size_t q_min, size_t q_max, size_t * q_out)
{
return faiss::partition_fuzzy<faiss::CMin<unsigned short, int> >(
return faiss::partition_fuzzy<faiss::C<unsigned short, id_t> >(
vals, ids, n, q_min, q_max, q_out);
}
%}
%enddef
INSTANTIATE_uint16_partition_fuzzy(CMin, int64_t)
INSTANTIATE_uint16_partition_fuzzy(CMax, int64_t)
INSTANTIATE_uint16_partition_fuzzy(CMin, int)
INSTANTIATE_uint16_partition_fuzzy(CMax, int)
// Same for merge_knn_results
// same define as explicit instanciation in Heap.cpp
%define INSTANTIATE_merge_knn_results(C, distance_t)
%inline %{
void merge_knn_results_ ## C(
size_t n, size_t k, int nshard,
const distance_t *all_distances, const faiss::idx_t *all_labels,
distance_t *distances, faiss::idx_t *labels)
{
faiss::merge_knn_results<faiss::idx_t, faiss::C<distance_t, int>>(
n, k, nshard, all_distances, all_labels, distances, labels);
}
%}
%enddef
INSTANTIATE_merge_knn_results(CMin, float);
INSTANTIATE_merge_knn_results(CMax, float);
INSTANTIATE_merge_knn_results(CMin, int32_t);
INSTANTIATE_merge_knn_results(CMax, int32_t);
/*******************************************************************
* Expose a few basic functions
*******************************************************************/

View File

@ -139,4 +139,111 @@ template struct HeapArray<CMax<float, int64_t>>;
template struct HeapArray<CMin<int, int64_t>>;
template struct HeapArray<CMax<int, int64_t>>;
/**********************************************************
* merge knn search results
**********************************************************/
/** Merge result tables from several shards. The per-shard results are assumed
* to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
* element heap because we want the best (ie. lowest for L2) result to be on
* top, not the worst.
*
* @param all_distances size (nshard, n, k)
* @param all_labels size (nshard, n, k)
* @param distances output distances, size (n, k)
* @param labels output labels, size (n, k)
*/
template <class idx_t, class C>
void merge_knn_results(
size_t n,
size_t k,
typename C::TI nshard,
const typename C::T* all_distances,
const idx_t* all_labels,
typename C::T* distances,
idx_t* labels) {
using distance_t = typename C::T;
if (k == 0) {
return;
}
long stride = n * k;
#pragma omp parallel if (n * nshard * k > 100000)
{
std::vector<int> buf(2 * nshard);
// index in each shard's result list
int* pointer = buf.data();
// (shard_ids, heap_vals): heap that indexes
// shard -> current distance for this shard
int* shard_ids = pointer + nshard;
std::vector<distance_t> buf2(nshard);
distance_t* heap_vals = buf2.data();
#pragma omp for
for (long i = 0; i < n; i++) {
// the heap maps values to the shard where they are
// produced.
const distance_t* D_in = all_distances + i * k;
const idx_t* I_in = all_labels + i * k;
int heap_size = 0;
// push the first element of each shard (if not -1)
for (long s = 0; s < nshard; s++) {
pointer[s] = 0;
if (I_in[stride * s] >= 0) {
heap_push<C>(
++heap_size,
heap_vals,
shard_ids,
D_in[stride * s],
s);
}
}
distance_t* D = distances + i * k;
idx_t* I = labels + i * k;
int j;
for (j = 0; j < k && heap_size > 0; j++) {
// pop element from best shard
int s = shard_ids[0]; // top of heap
int& p = pointer[s];
D[j] = heap_vals[0];
I[j] = I_in[stride * s + p];
// pop from shard, advance pointer for this shard
heap_pop<C>(heap_size--, heap_vals, shard_ids);
p++;
if (p < k && I_in[stride * s + p] >= 0) {
heap_push<C>(
++heap_size,
heap_vals,
shard_ids,
D_in[stride * s + p],
s);
}
}
for (; j < k; j++) {
I[j] = -1;
D[j] = C::Crev::neutral();
}
}
}
}
// explicit instanciations
#define INSTANTIATE(C, distance_t) \
template void merge_knn_results<int64_t, C<distance_t, int>>( \
size_t, \
size_t, \
int, \
const distance_t*, \
const int64_t*, \
distance_t*, \
int64_t*);
INSTANTIATE(CMin, float);
INSTANTIATE(CMax, float);
INSTANTIATE(CMin, int32_t);
INSTANTIATE(CMax, int32_t);
} // namespace faiss

View File

@ -444,7 +444,7 @@ typedef HeapArray<CMin<int, int64_t>> int_minheap_array_t;
typedef HeapArray<CMax<float, int64_t>> float_maxheap_array_t;
typedef HeapArray<CMax<int, int64_t>> int_maxheap_array_t;
// The heap templates are instanciated explicitly in Heap.cpp
// The heap templates are instantiated explicitly in Heap.cpp
/*********************************************************************
* Indirect heaps: instead of having
@ -505,6 +505,27 @@ inline void indirect_heap_push(
bh_ids[i] = id;
}
/** Merge result tables from several shards. The per-shard results are assumed
* to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
* element heap because we want the best (ie. lowest for L2) result to be on
* top, not the worst. Also, it needs to hold an index of a shard id (ie.
* usually int32 is more than enough).
*
* @param all_distances size (nshard, n, k)
* @param all_labels size (nshard, n, k)
* @param distances output distances, size (n, k)
* @param labels output labels, size (n, k)
*/
template <class idx_t, class C>
void merge_knn_results(
size_t n,
size_t k,
typename C::TI nshard,
const typename C::T* all_distances,
const idx_t* all_labels,
typename C::T* distances,
idx_t* labels);
} // namespace faiss
#endif /* FAISS_Heap_h */

View File

@ -645,3 +645,45 @@ class TestBucketSort(unittest.TestCase):
def test_bucket_sort_inplace_parallel_fewbucket(self):
self.do_test_bucket_sort_inplace(4, nbucket=5)
class TestMergeKNNResults(unittest.TestCase):
def do_test(self, ismax, dtype):
rs = np.random.RandomState()
n, k, nshard = 10, 5, 3
all_ids = rs.randint(100000, size=(nshard, n, k)).astype('int64')
all_dis = rs.rand(nshard, n, k)
if dtype == 'int32':
all_dis = (all_dis * 1000000).astype("int32")
else:
all_dis = all_dis.astype(dtype)
for i in range(nshard):
for j in range(n):
all_dis[i, j].sort()
if ismax:
all_dis[i, j] = all_dis[i, j][::-1]
Dref = np.zeros((n, k), dtype=dtype)
Iref = np.zeros((n, k), dtype='int64')
for i in range(n):
dis = all_dis[:, i, :].ravel()
ids = all_ids[:, i, :].ravel()
o = dis.argsort()
if ismax:
o = o[::-1]
Dref[i] = dis[o[:k]]
Iref[i] = ids[o[:k]]
Dnew, Inew = faiss.merge_knn_results(all_dis, all_ids, keep_max=ismax)
np.testing.assert_array_equal(Dnew, Dref)
np.testing.assert_array_equal(Inew, Iref)
def test_min_float(self):
self.do_test(ismax=False, dtype='float32')
def test_max_int(self):
self.do_test(ismax=True, dtype='int32')
def test_max_float(self):
self.do_test(ismax=True, dtype='float32')

View File

@ -687,6 +687,7 @@ class TestSplitMerge(unittest.TestCase):
sub_indexes = [faiss.clone_index(index) for i in range(nsplit)]
index.add(xb)
Dref, Iref = index.search(xq, 10)
nlist = index.nlist
for i in range(nsplit):
if subset_type in (1, 3):
index.copy_subset_to(sub_indexes[i], subset_type, nsplit, i)
@ -694,6 +695,10 @@ class TestSplitMerge(unittest.TestCase):
j0 = index.ntotal * i // nsplit
j1 = index.ntotal * (i + 1) // nsplit
index.copy_subset_to(sub_indexes[i], subset_type, j0, j1)
elif subset_type == 4:
index.copy_subset_to(
sub_indexes[i], subset_type,
i * nlist // nsplit, (i + 1) * nlist // nsplit)
index_shards = faiss.IndexShards(False, False)
for i in range(nsplit):
@ -713,3 +718,6 @@ class TestSplitMerge(unittest.TestCase):
def test_Flat_subset_type_3(self):
self.do_test("IVF30,Flat", subset_type=3)
def test_Flat_subset_type_4(self):
self.do_test("IVF30,Flat", subset_type=4)

View File

@ -47,10 +47,8 @@ def search_single_scan(index, xq, k, bs=128):
sub_assign[skip_rows, skip_cols] = -1
index.search_preassigned(
nq, faiss.swig_ptr(xq), k,
faiss.swig_ptr(sub_assign), faiss.swig_ptr(coarse_dis),
faiss.swig_ptr(rh.D), faiss.swig_ptr(rh.I),
False, None
xq, k, sub_assign, coarse_dis,
D=rh.D, I=rh.I
)
rh.finalize()