building blocks for hybrid CPU / GPU search (#2638)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2638 This diff is a more streamlined way of searching IVF indexes with precomputed clusters. This will be used for experiments with hybrid CPU / GPU search. Reviewed By: algoriddle Differential Revision: D41301032 fbshipit-source-id: a1d645fd0f2bf806454dfd04971edc0a6200d20dpull/2676/head
parent
1eb4f42639
commit
8fc3775472
|
@ -70,7 +70,7 @@ if args.factory_string == "":
|
||||||
else:
|
else:
|
||||||
factory_string = args.factory_string
|
factory_string = args.factory_string
|
||||||
|
|
||||||
print(f"instanciate {factory_string}")
|
print(f"instantiate {factory_string}")
|
||||||
index = faiss.index_factory(ds.d, factory_string)
|
index = faiss.index_factory(ds.d, factory_string)
|
||||||
|
|
||||||
if args.factory_string != "":
|
if args.factory_string != "":
|
||||||
|
|
|
@ -32,7 +32,9 @@ def add_preassigned(index_ivf, x, a, ids=None):
|
||||||
|
|
||||||
def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
|
def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
|
||||||
"""
|
"""
|
||||||
Perform a search in the IVF index, with predefined lists to search into
|
Perform a search in the IVF index, with predefined lists to search into.
|
||||||
|
Supports indexes with pretransforms (as opposed to the
|
||||||
|
IndexIVF.search_preassigned, that cannot be applied with pretransform).
|
||||||
"""
|
"""
|
||||||
n, d = xq.shape
|
n, d = xq.shape
|
||||||
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
||||||
|
@ -51,14 +53,7 @@ def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
|
||||||
else:
|
else:
|
||||||
assert coarse_dis.shape == (n, index_ivf.nprobe)
|
assert coarse_dis.shape == (n, index_ivf.nprobe)
|
||||||
|
|
||||||
D = np.empty((n, k), dtype=dis_type)
|
return index_ivf.search_preassigned(xq, k, list_nos, coarse_dis)
|
||||||
I = np.empty((n, k), dtype='int64')
|
|
||||||
|
|
||||||
sp = faiss.swig_ptr
|
|
||||||
index_ivf.search_preassigned(
|
|
||||||
n, sp(xq), k,
|
|
||||||
sp(list_nos), sp(coarse_dis), sp(D), sp(I), False)
|
|
||||||
return D, I
|
|
||||||
|
|
||||||
|
|
||||||
def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
|
def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
|
||||||
|
|
|
@ -1125,7 +1125,7 @@ void IndexIVF::replace_invlists(InvertedLists* il, bool own) {
|
||||||
|
|
||||||
void IndexIVF::copy_subset_to(
|
void IndexIVF::copy_subset_to(
|
||||||
IndexIVF& other,
|
IndexIVF& other,
|
||||||
int subset_type,
|
InvertedLists::subset_type_t subset_type,
|
||||||
idx_t a1,
|
idx_t a1,
|
||||||
idx_t a2) const {
|
idx_t a2) const {
|
||||||
other.ntotal +=
|
other.ntotal +=
|
||||||
|
|
|
@ -326,7 +326,7 @@ struct IndexIVF : Index, Level1Quantizer {
|
||||||
*/
|
*/
|
||||||
virtual void copy_subset_to(
|
virtual void copy_subset_to(
|
||||||
IndexIVF& other,
|
IndexIVF& other,
|
||||||
int subset_type,
|
InvertedLists::subset_type_t subset_type,
|
||||||
idx_t a1,
|
idx_t a1,
|
||||||
idx_t a2) const;
|
idx_t a2) const;
|
||||||
|
|
||||||
|
|
|
@ -33,92 +33,6 @@ void translate_labels(long n, idx_t* labels, long translation) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Merge result tables from several shards. The per-shard results are assumed
|
|
||||||
* to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
|
|
||||||
* element heap because we want the best (ie. lowest for L2) result to be on
|
|
||||||
* top, not the worst.
|
|
||||||
*
|
|
||||||
* @param all_distances size nshard * n * k
|
|
||||||
* @param all_labels idem
|
|
||||||
* @param translations label translations to apply, size nshard
|
|
||||||
*/
|
|
||||||
template <class IndexClass, class C>
|
|
||||||
void merge_tables(
|
|
||||||
long n,
|
|
||||||
long k,
|
|
||||||
long nshard,
|
|
||||||
typename IndexClass::distance_t* distances,
|
|
||||||
idx_t* labels,
|
|
||||||
const std::vector<typename IndexClass::distance_t>& all_distances,
|
|
||||||
const std::vector<idx_t>& all_labels,
|
|
||||||
const std::vector<long>& translations) {
|
|
||||||
if (k == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
using distance_t = typename IndexClass::distance_t;
|
|
||||||
long stride = n * k;
|
|
||||||
#pragma omp parallel if (n * nshard * k > 100000)
|
|
||||||
{
|
|
||||||
std::vector<int> buf(2 * nshard);
|
|
||||||
// index in each shard's result list
|
|
||||||
int* pointer = buf.data();
|
|
||||||
// (shard_ids, heap_vals): heap that indexes
|
|
||||||
// shard -> current distance for this shard
|
|
||||||
int* shard_ids = pointer + nshard;
|
|
||||||
std::vector<distance_t> buf2(nshard);
|
|
||||||
distance_t* heap_vals = buf2.data();
|
|
||||||
#pragma omp for
|
|
||||||
for (long i = 0; i < n; i++) {
|
|
||||||
// the heap maps values to the shard where they are
|
|
||||||
// produced.
|
|
||||||
const distance_t* D_in = all_distances.data() + i * k;
|
|
||||||
const idx_t* I_in = all_labels.data() + i * k;
|
|
||||||
int heap_size = 0;
|
|
||||||
|
|
||||||
// push the first element of each shard (if not -1)
|
|
||||||
for (long s = 0; s < nshard; s++) {
|
|
||||||
pointer[s] = 0;
|
|
||||||
if (I_in[stride * s] >= 0) {
|
|
||||||
heap_push<C>(
|
|
||||||
++heap_size,
|
|
||||||
heap_vals,
|
|
||||||
shard_ids,
|
|
||||||
D_in[stride * s],
|
|
||||||
s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
distance_t* D = distances + i * k;
|
|
||||||
idx_t* I = labels + i * k;
|
|
||||||
|
|
||||||
int j;
|
|
||||||
for (j = 0; j < k && heap_size > 0; j++) {
|
|
||||||
// pop element from best shard
|
|
||||||
int s = shard_ids[0]; // top of heap
|
|
||||||
int& p = pointer[s];
|
|
||||||
D[j] = heap_vals[0];
|
|
||||||
I[j] = I_in[stride * s + p] + translations[s];
|
|
||||||
|
|
||||||
// pop from shard, advance pointer for this shard
|
|
||||||
heap_pop<C>(heap_size--, heap_vals, shard_ids);
|
|
||||||
p++;
|
|
||||||
if (p < k && I_in[stride * s + p] >= 0) {
|
|
||||||
heap_push<C>(
|
|
||||||
++heap_size,
|
|
||||||
heap_vals,
|
|
||||||
shard_ids,
|
|
||||||
D_in[stride * s + p],
|
|
||||||
s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (; j < k; j++) {
|
|
||||||
I[j] = -1;
|
|
||||||
D[j] = C::Crev::neutral();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
template <typename IndexT>
|
template <typename IndexT>
|
||||||
|
@ -303,27 +217,6 @@ void IndexShardsTemplate<IndexT>::search(
|
||||||
|
|
||||||
std::vector<distance_t> all_distances(nshard * k * n);
|
std::vector<distance_t> all_distances(nshard * k * n);
|
||||||
std::vector<idx_t> all_labels(nshard * k * n);
|
std::vector<idx_t> all_labels(nshard * k * n);
|
||||||
|
|
||||||
auto fn = [n, k, x, &all_distances, &all_labels](
|
|
||||||
int no, const IndexT* index) {
|
|
||||||
if (index->verbose) {
|
|
||||||
printf("begin query shard %d on %" PRId64 " points\n", no, n);
|
|
||||||
}
|
|
||||||
|
|
||||||
index->search(
|
|
||||||
n,
|
|
||||||
x,
|
|
||||||
k,
|
|
||||||
all_distances.data() + no * k * n,
|
|
||||||
all_labels.data() + no * k * n);
|
|
||||||
|
|
||||||
if (index->verbose) {
|
|
||||||
printf("end query shard %d\n", no);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
this->runOnIndex(fn);
|
|
||||||
|
|
||||||
std::vector<long> translations(nshard, 0);
|
std::vector<long> translations(nshard, 0);
|
||||||
|
|
||||||
// Because we just called runOnIndex above, it is safe to access the
|
// Because we just called runOnIndex above, it is safe to access the
|
||||||
|
@ -336,26 +229,47 @@ void IndexShardsTemplate<IndexT>::search(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto fn = [n, k, x, &all_distances, &all_labels, &translations](
|
||||||
|
int no, const IndexT* index) {
|
||||||
|
if (index->verbose) {
|
||||||
|
printf("begin query shard %d on %" PRId64 " points\n", no, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
index->search(
|
||||||
|
n,
|
||||||
|
x,
|
||||||
|
k,
|
||||||
|
all_distances.data() + no * k * n,
|
||||||
|
all_labels.data() + no * k * n);
|
||||||
|
|
||||||
|
translate_labels(
|
||||||
|
n * k, all_labels.data() + no * k * n, translations[no]);
|
||||||
|
|
||||||
|
if (index->verbose) {
|
||||||
|
printf("end query shard %d\n", no);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this->runOnIndex(fn);
|
||||||
|
|
||||||
if (this->metric_type == METRIC_L2) {
|
if (this->metric_type == METRIC_L2) {
|
||||||
merge_tables<IndexT, CMin<distance_t, int>>(
|
merge_knn_results<idx_t, CMin<distance_t, int>>(
|
||||||
n,
|
n,
|
||||||
k,
|
k,
|
||||||
nshard,
|
nshard,
|
||||||
|
all_distances.data(),
|
||||||
|
all_labels.data(),
|
||||||
distances,
|
distances,
|
||||||
labels,
|
labels);
|
||||||
all_distances,
|
|
||||||
all_labels,
|
|
||||||
translations);
|
|
||||||
} else {
|
} else {
|
||||||
merge_tables<IndexT, CMax<distance_t, int>>(
|
merge_knn_results<idx_t, CMax<distance_t, int>>(
|
||||||
n,
|
n,
|
||||||
k,
|
k,
|
||||||
nshard,
|
nshard,
|
||||||
|
all_distances.data(),
|
||||||
|
all_labels.data(),
|
||||||
distances,
|
distances,
|
||||||
labels,
|
labels);
|
||||||
all_distances,
|
|
||||||
all_labels,
|
|
||||||
translations);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,7 @@ struct IndexShardsTemplate : public ThreadedIndex<IndexT> {
|
||||||
* Cases (successive_ids, xids):
|
* Cases (successive_ids, xids):
|
||||||
* - true, non-NULL ERROR: it makes no sense to pass in ids and
|
* - true, non-NULL ERROR: it makes no sense to pass in ids and
|
||||||
* request them to be shifted
|
* request them to be shifted
|
||||||
* - true, NULL OK, but should be called only once (calls add()
|
* - true, NULL OK: but should be called only once (calls add()
|
||||||
* on sub-indexes).
|
* on sub-indexes).
|
||||||
* - false, non-NULL OK: will call add_with_ids with passed in xids
|
* - false, non-NULL OK: will call add_with_ids with passed in xids
|
||||||
* distributed evenly over shards
|
* distributed evenly over shards
|
||||||
|
|
|
@ -248,12 +248,14 @@ void ToGpuClonerMultiple::copy_ivf_shard(
|
||||||
|
|
||||||
if (verbose)
|
if (verbose)
|
||||||
printf("IndexShards shard %ld indices %ld:%ld\n", i, i0, i1);
|
printf("IndexShards shard %ld indices %ld:%ld\n", i, i0, i1);
|
||||||
index_ivf->copy_subset_to(*idx2, 2, i0, i1);
|
index_ivf->copy_subset_to(
|
||||||
|
*idx2, InvertedLists::SUBSET_TYPE_ID_RANGE, i0, i1);
|
||||||
FAISS_ASSERT(idx2->ntotal == i1 - i0);
|
FAISS_ASSERT(idx2->ntotal == i1 - i0);
|
||||||
} else if (shard_type == 1) {
|
} else if (shard_type == 1) {
|
||||||
if (verbose)
|
if (verbose)
|
||||||
printf("IndexShards shard %ld select modulo %ld = %ld\n", i, n, i);
|
printf("IndexShards shard %ld select modulo %ld = %ld\n", i, n, i);
|
||||||
index_ivf->copy_subset_to(*idx2, 1, n, i);
|
index_ivf->copy_subset_to(
|
||||||
|
*idx2, InvertedLists::SUBSET_TYPE_ID_MOD, n, i);
|
||||||
} else {
|
} else {
|
||||||
FAISS_THROW_FMT("shard_type %d not implemented", shard_type);
|
FAISS_THROW_FMT("shard_type %d not implemented", shard_type);
|
||||||
}
|
}
|
||||||
|
|
|
@ -636,7 +636,7 @@ void ZnSphereCodecRec::decode(uint64_t code, float* c) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if not use_rec, instanciate an arbitrary harmless znc_rec
|
// if not use_rec, instantiate an arbitrary harmless znc_rec
|
||||||
ZnSphereCodecAlt::ZnSphereCodecAlt(int dim, int r2)
|
ZnSphereCodecAlt::ZnSphereCodecAlt(int dim, int r2)
|
||||||
: ZnSphereCodec(dim, r2),
|
: ZnSphereCodec(dim, r2),
|
||||||
use_rec((dim & (dim - 1)) == 0),
|
use_rec((dim & (dim - 1)) == 0),
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
* otherwise register spilling becomes too large.
|
* otherwise register spilling becomes too large.
|
||||||
*
|
*
|
||||||
* The implementation of these functions is spread over 3 cpp files to reduce
|
* The implementation of these functions is spread over 3 cpp files to reduce
|
||||||
* parallel compile times. Templates are instanciated explicitly.
|
* parallel compile times. Templates are instantiated explicitly.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
namespace faiss {
|
namespace faiss {
|
||||||
|
|
|
@ -189,7 +189,7 @@ void accumulate(
|
||||||
DISPATCH(3);
|
DISPATCH(3);
|
||||||
DISPATCH(4);
|
DISPATCH(4);
|
||||||
}
|
}
|
||||||
FAISS_THROW_FMT("accumulate nq=%d not instanciated", nq);
|
FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq);
|
||||||
|
|
||||||
#undef DISPATCH
|
#undef DISPATCH
|
||||||
}
|
}
|
||||||
|
@ -263,7 +263,7 @@ void pq4_accumulate_loop_qbs(
|
||||||
DISPATCH(4);
|
DISPATCH(4);
|
||||||
#undef DISPATCH
|
#undef DISPATCH
|
||||||
default:
|
default:
|
||||||
FAISS_THROW_FMT("accumulate nq=%d not instanciated", nq);
|
FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq);
|
||||||
}
|
}
|
||||||
i0 += nq;
|
i0 += nq;
|
||||||
LUT += nq * nsq * 16;
|
LUT += nq * nsq * 16;
|
||||||
|
|
|
@ -88,13 +88,13 @@ void InvertedLists::merge_from(InvertedLists* oivf, size_t add_id) {
|
||||||
|
|
||||||
size_t InvertedLists::copy_subset_to(
|
size_t InvertedLists::copy_subset_to(
|
||||||
InvertedLists& oivf,
|
InvertedLists& oivf,
|
||||||
int subset_type,
|
subset_type_t subset_type,
|
||||||
idx_t a1,
|
idx_t a1,
|
||||||
idx_t a2) const {
|
idx_t a2) const {
|
||||||
FAISS_THROW_IF_NOT(nlist == oivf.nlist);
|
FAISS_THROW_IF_NOT(nlist == oivf.nlist);
|
||||||
FAISS_THROW_IF_NOT(code_size == oivf.code_size);
|
FAISS_THROW_IF_NOT(code_size == oivf.code_size);
|
||||||
FAISS_THROW_IF_NOT_FMT(
|
FAISS_THROW_IF_NOT_FMT(
|
||||||
subset_type >= 0 && subset_type <= 3,
|
subset_type >= 0 && subset_type <= 4,
|
||||||
"subset type %d not implemented",
|
"subset type %d not implemented",
|
||||||
subset_type);
|
subset_type);
|
||||||
size_t accu_n = 0;
|
size_t accu_n = 0;
|
||||||
|
@ -111,7 +111,7 @@ size_t InvertedLists::copy_subset_to(
|
||||||
size_t n = list_size(list_no);
|
size_t n = list_size(list_no);
|
||||||
ScopedIds ids_in(this, list_no);
|
ScopedIds ids_in(this, list_no);
|
||||||
|
|
||||||
if (subset_type == 0) {
|
if (subset_type == SUBSET_TYPE_ID_RANGE) {
|
||||||
for (idx_t i = 0; i < n; i++) {
|
for (idx_t i = 0; i < n; i++) {
|
||||||
idx_t id = ids_in[i];
|
idx_t id = ids_in[i];
|
||||||
if (a1 <= id && id < a2) {
|
if (a1 <= id && id < a2) {
|
||||||
|
@ -122,7 +122,7 @@ size_t InvertedLists::copy_subset_to(
|
||||||
n_added++;
|
n_added++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (subset_type == 1) {
|
} else if (subset_type == SUBSET_TYPE_ID_MOD) {
|
||||||
for (idx_t i = 0; i < n; i++) {
|
for (idx_t i = 0; i < n; i++) {
|
||||||
idx_t id = ids_in[i];
|
idx_t id = ids_in[i];
|
||||||
if (id % a1 == a2) {
|
if (id % a1 == a2) {
|
||||||
|
@ -133,7 +133,7 @@ size_t InvertedLists::copy_subset_to(
|
||||||
n_added++;
|
n_added++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (subset_type == 2) {
|
} else if (subset_type == SUBSET_TYPE_ELEMENT_RANGE) {
|
||||||
// see what is allocated to a1 and to a2
|
// see what is allocated to a1 and to a2
|
||||||
size_t next_accu_n = accu_n + n;
|
size_t next_accu_n = accu_n + n;
|
||||||
size_t next_accu_a1 = next_accu_n * a1 / ntotal;
|
size_t next_accu_a1 = next_accu_n * a1 / ntotal;
|
||||||
|
@ -151,7 +151,7 @@ size_t InvertedLists::copy_subset_to(
|
||||||
n_added += i2 - i1;
|
n_added += i2 - i1;
|
||||||
accu_a1 = next_accu_a1;
|
accu_a1 = next_accu_a1;
|
||||||
accu_a2 = next_accu_a2;
|
accu_a2 = next_accu_a2;
|
||||||
} else if (subset_type == 3) {
|
} else if (subset_type == SUBSET_TYPE_INVLIST_FRACTION) {
|
||||||
size_t i1 = n * a2 / a1;
|
size_t i1 = n * a2 / a1;
|
||||||
size_t i2 = n * (a2 + 1) / a1;
|
size_t i2 = n * (a2 + 1) / a1;
|
||||||
|
|
||||||
|
@ -163,6 +163,15 @@ size_t InvertedLists::copy_subset_to(
|
||||||
}
|
}
|
||||||
|
|
||||||
n_added += i2 - i1;
|
n_added += i2 - i1;
|
||||||
|
} else if (subset_type == SUBSET_TYPE_INVLIST) {
|
||||||
|
if (list_no >= a1 && list_no < a2) {
|
||||||
|
oivf.add_entries(
|
||||||
|
list_no,
|
||||||
|
n,
|
||||||
|
ScopedIds(this, list_no).get(),
|
||||||
|
ScopedCodes(this, list_no).get());
|
||||||
|
n_added += n;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
accu_n += n;
|
accu_n += n;
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,20 +111,28 @@ struct InvertedLists {
|
||||||
/// move all entries from oivf (empty on output)
|
/// move all entries from oivf (empty on output)
|
||||||
void merge_from(InvertedLists* oivf, size_t add_id);
|
void merge_from(InvertedLists* oivf, size_t add_id);
|
||||||
|
|
||||||
|
// how to copy a subset of elements from the inverted lists
|
||||||
|
// This depends on two integers, a1 and a2.
|
||||||
|
enum subset_type_t : int {
|
||||||
|
// depends on IDs
|
||||||
|
SUBSET_TYPE_ID_RANGE = 0, // copies ids in [a1, a2)
|
||||||
|
SUBSET_TYPE_ID_MOD = 1, // copies ids if id % a1 == a2
|
||||||
|
// depends on order within invlists
|
||||||
|
SUBSET_TYPE_ELEMENT_RANGE =
|
||||||
|
2, // copies fractions of invlists so that a1 elements are left
|
||||||
|
// before and a2 after
|
||||||
|
SUBSET_TYPE_INVLIST_FRACTION =
|
||||||
|
3, // take fraction a2 out of a1 from each invlist, 0 <= a2 < a1
|
||||||
|
// copy only inverted lists a1:a2
|
||||||
|
SUBSET_TYPE_INVLIST = 4
|
||||||
|
};
|
||||||
|
|
||||||
/** copy a subset of the entries index to the other index
|
/** copy a subset of the entries index to the other index
|
||||||
*
|
|
||||||
* if subset_type == 0: copies ids in [a1, a2)
|
|
||||||
* if subset_type == 1: copies ids if id % a1 == a2
|
|
||||||
* if subset_type == 2: copies inverted lists such that a1
|
|
||||||
* elements are left before and a2 elements are after
|
|
||||||
* (insensitive to ids)
|
|
||||||
* if subset_type == 3: take fraction a2 out of a1 from each invlist
|
|
||||||
* (does not depend on ids). 0 <= a2 < a1
|
|
||||||
* @return number of entries copied
|
* @return number of entries copied
|
||||||
*/
|
*/
|
||||||
size_t copy_subset_to(
|
size_t copy_subset_to(
|
||||||
InvertedLists& other,
|
InvertedLists& other,
|
||||||
int subset_type,
|
subset_type_t subset_type,
|
||||||
idx_t a1,
|
idx_t a1,
|
||||||
idx_t a2) const;
|
idx_t a2) const;
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,8 @@ from faiss.gpu_wrappers import *
|
||||||
from faiss.array_conversions import *
|
from faiss.array_conversions import *
|
||||||
from faiss.extra_wrappers import kmin, kmax, pairwise_distances, rand, randint, \
|
from faiss.extra_wrappers import kmin, kmax, pairwise_distances, rand, randint, \
|
||||||
lrand, randn, rand_smooth_vectors, eval_intersection, normalize_L2, \
|
lrand, randn, rand_smooth_vectors, eval_intersection, normalize_L2, \
|
||||||
ResultHeap, knn, Kmeans, checksum, matrix_bucket_sort_inplace, bucket_sort
|
ResultHeap, knn, Kmeans, checksum, matrix_bucket_sort_inplace, bucket_sort, \
|
||||||
|
merge_knn_results
|
||||||
|
|
||||||
|
|
||||||
__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR,
|
__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR,
|
||||||
|
|
|
@ -554,6 +554,70 @@ def handle_Index(the_class):
|
||||||
I = rev_swig_ptr(res.labels, nd).copy()
|
I = rev_swig_ptr(res.labels, nd).copy()
|
||||||
return lims, D, I
|
return lims, D, I
|
||||||
|
|
||||||
|
def replacement_search_preassigned(self, x, k, Iq, Dq, *, params=None, D=None, I=None):
|
||||||
|
"""Find the k nearest neighbors of the set of vectors x in an IVF index,
|
||||||
|
with precalculated coarse quantization assignment.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : array_like
|
||||||
|
Query vectors, shape (n, d) where d is appropriate for the index.
|
||||||
|
`dtype` must be float32.
|
||||||
|
k : int
|
||||||
|
Number of nearest neighbors.
|
||||||
|
Dq : array_like, optional
|
||||||
|
Distance array to the centroids, size (n, nprobe)
|
||||||
|
Iq : array_like, optional
|
||||||
|
Nearest centroids, size (n, nprobe)
|
||||||
|
|
||||||
|
params : SearchParameters
|
||||||
|
Search parameters of the current search (overrides the class-level params)
|
||||||
|
D : array_like, optional
|
||||||
|
Distance array to store the result.
|
||||||
|
I : array_like, optional
|
||||||
|
Labels array to store the results.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
D : array_like
|
||||||
|
Distances of the nearest neighbors, shape (n, k). When not enough results are found
|
||||||
|
the label is set to +Inf or -Inf.
|
||||||
|
I : array_like
|
||||||
|
Labels of the nearest neighbors, shape (n, k).
|
||||||
|
When not enough results are found, the label is set to -1
|
||||||
|
"""
|
||||||
|
n, d = x.shape
|
||||||
|
x = np.ascontiguousarray(x, dtype='float32')
|
||||||
|
assert d == self.d
|
||||||
|
assert k > 0
|
||||||
|
|
||||||
|
if D is None:
|
||||||
|
D = np.empty((n, k), dtype=np.float32)
|
||||||
|
else:
|
||||||
|
assert D.shape == (n, k)
|
||||||
|
|
||||||
|
if I is None:
|
||||||
|
I = np.empty((n, k), dtype=np.int64)
|
||||||
|
else:
|
||||||
|
assert I.shape == (n, k)
|
||||||
|
|
||||||
|
Iq = np.ascontiguousarray(Iq, dtype='int64')
|
||||||
|
assert params is None, "params not supported"
|
||||||
|
assert Iq.shape == (n, self.nprobe)
|
||||||
|
|
||||||
|
if Dq is not None:
|
||||||
|
Dq = np.ascontiguousarray(Dq, dtype='float32')
|
||||||
|
assert Dq.shape == Iq.shape
|
||||||
|
|
||||||
|
self.search_preassigned_c(
|
||||||
|
n, swig_ptr(x),
|
||||||
|
k,
|
||||||
|
swig_ptr(Iq), swig_ptr(Dq),
|
||||||
|
swig_ptr(D), swig_ptr(I),
|
||||||
|
False
|
||||||
|
)
|
||||||
|
return D, I
|
||||||
|
|
||||||
def replacement_sa_encode(self, x, codes=None):
|
def replacement_sa_encode(self, x, codes=None):
|
||||||
n, d = x.shape
|
n, d = x.shape
|
||||||
assert d == self.d
|
assert d == self.d
|
||||||
|
@ -605,6 +669,8 @@ def handle_Index(the_class):
|
||||||
ignore_missing=True)
|
ignore_missing=True)
|
||||||
replace_method(the_class, 'search_and_reconstruct',
|
replace_method(the_class, 'search_and_reconstruct',
|
||||||
replacement_search_and_reconstruct, ignore_missing=True)
|
replacement_search_and_reconstruct, ignore_missing=True)
|
||||||
|
replace_method(the_class, 'search_preassigned',
|
||||||
|
replacement_search_preassigned, ignore_missing=True)
|
||||||
replace_method(the_class, 'sa_encode', replacement_sa_encode)
|
replace_method(the_class, 'sa_encode', replacement_sa_encode)
|
||||||
replace_method(the_class, 'sa_decode', replacement_sa_decode)
|
replace_method(the_class, 'sa_decode', replacement_sa_decode)
|
||||||
replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes,
|
replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes,
|
||||||
|
@ -664,6 +730,31 @@ def handle_IndexBinary(the_class):
|
||||||
swig_ptr(labels))
|
swig_ptr(labels))
|
||||||
return distances, labels
|
return distances, labels
|
||||||
|
|
||||||
|
def replacement_search_preassigned(self, x, k, Iq, Dq):
|
||||||
|
n, d = x.shape
|
||||||
|
x = _check_dtype_uint8(x)
|
||||||
|
assert d * 8 == self.d
|
||||||
|
assert k > 0
|
||||||
|
|
||||||
|
D = np.empty((n, k), dtype=np.int32)
|
||||||
|
I = np.empty((n, k), dtype=np.int64)
|
||||||
|
|
||||||
|
Iq = np.ascontiguousarray(Iq, dtype='int64')
|
||||||
|
assert Iq.shape == (n, self.nprobe)
|
||||||
|
|
||||||
|
if Dq is not None:
|
||||||
|
Dq = np.ascontiguousarray(Dq, dtype='int32')
|
||||||
|
assert Dq.shape == Iq.shape
|
||||||
|
|
||||||
|
self.search_preassigned_c(
|
||||||
|
n, swig_ptr(x),
|
||||||
|
k,
|
||||||
|
swig_ptr(Iq), swig_ptr(Dq),
|
||||||
|
swig_ptr(D), swig_ptr(I),
|
||||||
|
False
|
||||||
|
)
|
||||||
|
return D, I
|
||||||
|
|
||||||
def replacement_range_search(self, x, thresh):
|
def replacement_range_search(self, x, thresh):
|
||||||
n, d = x.shape
|
n, d = x.shape
|
||||||
x = _check_dtype_uint8(x)
|
x = _check_dtype_uint8(x)
|
||||||
|
@ -693,6 +784,8 @@ def handle_IndexBinary(the_class):
|
||||||
replace_method(the_class, 'range_search', replacement_range_search)
|
replace_method(the_class, 'range_search', replacement_range_search)
|
||||||
replace_method(the_class, 'reconstruct', replacement_reconstruct)
|
replace_method(the_class, 'reconstruct', replacement_reconstruct)
|
||||||
replace_method(the_class, 'remove_ids', replacement_remove_ids)
|
replace_method(the_class, 'remove_ids', replacement_remove_ids)
|
||||||
|
replace_method(the_class, 'search_preassigned',
|
||||||
|
replacement_search_preassigned, ignore_missing=True)
|
||||||
|
|
||||||
|
|
||||||
def handle_VectorTransform(the_class):
|
def handle_VectorTransform(the_class):
|
||||||
|
|
|
@ -279,6 +279,23 @@ class ResultHeap:
|
||||||
self.heaps.reorder()
|
self.heaps.reorder()
|
||||||
|
|
||||||
|
|
||||||
|
def merge_knn_results(Dall, Iall, keep_max=False):
|
||||||
|
"""
|
||||||
|
Merge a set of sorted knn-results obtained from different shards in a dataset
|
||||||
|
Dall and Iall are of size (nshard, nq, k) each D[i, j] should be sorted
|
||||||
|
returns D, I of size (nq, k) as the merged result set
|
||||||
|
"""
|
||||||
|
assert Iall.shape == Dall.shape
|
||||||
|
nshard, n, k = Dall.shape
|
||||||
|
Dnew = np.empty((n, k), dtype=Dall.dtype)
|
||||||
|
Inew = np.empty((n, k), dtype=Iall.dtype)
|
||||||
|
func = merge_knn_results_CMax if keep_max else merge_knn_results_CMin
|
||||||
|
func(
|
||||||
|
n, k, nshard,
|
||||||
|
swig_ptr(Dall), swig_ptr(Iall),
|
||||||
|
swig_ptr(Dnew), swig_ptr(Inew)
|
||||||
|
)
|
||||||
|
return Dnew, Inew
|
||||||
|
|
||||||
######################################################
|
######################################################
|
||||||
# KNN function
|
# KNN function
|
||||||
|
|
|
@ -938,7 +938,6 @@ REV_SWIG_PTR(uint64_t, NPY_UINT64);
|
||||||
|
|
||||||
%template(float_minheap_array_t) faiss::HeapArray<faiss::CMin<float, int64_t> >;
|
%template(float_minheap_array_t) faiss::HeapArray<faiss::CMin<float, int64_t> >;
|
||||||
%template(int_minheap_array_t) faiss::HeapArray<faiss::CMin<int, int64_t> >;
|
%template(int_minheap_array_t) faiss::HeapArray<faiss::CMin<int, int64_t> >;
|
||||||
|
|
||||||
%template(float_maxheap_array_t) faiss::HeapArray<faiss::CMax<float, int64_t> >;
|
%template(float_maxheap_array_t) faiss::HeapArray<faiss::CMax<float, int64_t> >;
|
||||||
%template(int_maxheap_array_t) faiss::HeapArray<faiss::CMax<int, int64_t> >;
|
%template(int_maxheap_array_t) faiss::HeapArray<faiss::CMax<int, int64_t> >;
|
||||||
|
|
||||||
|
@ -951,46 +950,55 @@ REV_SWIG_PTR(uint64_t, NPY_UINT64);
|
||||||
%template(AlignedTableUint16) faiss::AlignedTable<uint16_t>;
|
%template(AlignedTableUint16) faiss::AlignedTable<uint16_t>;
|
||||||
%template(AlignedTableFloat32) faiss::AlignedTable<float>;
|
%template(AlignedTableFloat32) faiss::AlignedTable<float>;
|
||||||
|
|
||||||
|
|
||||||
|
// SWIG seems to have some trouble resolving function template types here, so
|
||||||
|
// declare explicitly
|
||||||
|
|
||||||
|
%define INSTANTIATE_uint16_partition_fuzzy(C, id_t)
|
||||||
|
|
||||||
%inline %{
|
%inline %{
|
||||||
|
|
||||||
// SWIG seems to have has some trouble resolving the template type here, so
|
uint16_t C ## _uint16_partition_fuzzy(
|
||||||
// declare explicitly
|
uint16_t *vals, id_t *ids, size_t n,
|
||||||
uint16_t CMax_uint16_partition_fuzzy(
|
|
||||||
uint16_t *vals, int64_t *ids, size_t n,
|
|
||||||
size_t q_min, size_t q_max, size_t * q_out)
|
size_t q_min, size_t q_max, size_t * q_out)
|
||||||
{
|
{
|
||||||
return faiss::partition_fuzzy<faiss::CMax<unsigned short, int64_t> >(
|
return faiss::partition_fuzzy<faiss::C<unsigned short, id_t> >(
|
||||||
vals, ids, n, q_min, q_max, q_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t CMin_uint16_partition_fuzzy(
|
|
||||||
uint16_t *vals, int64_t *ids, size_t n,
|
|
||||||
size_t q_min, size_t q_max, size_t * q_out)
|
|
||||||
{
|
|
||||||
return faiss::partition_fuzzy<faiss::CMin<unsigned short, int64_t> >(
|
|
||||||
vals, ids, n, q_min, q_max, q_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
// and overload with the int32 version
|
|
||||||
|
|
||||||
uint16_t CMax_uint16_partition_fuzzy(
|
|
||||||
uint16_t *vals, int *ids, size_t n,
|
|
||||||
size_t q_min, size_t q_max, size_t * q_out)
|
|
||||||
{
|
|
||||||
return faiss::partition_fuzzy<faiss::CMax<unsigned short, int> >(
|
|
||||||
vals, ids, n, q_min, q_max, q_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t CMin_uint16_partition_fuzzy(
|
|
||||||
uint16_t *vals, int *ids, size_t n,
|
|
||||||
size_t q_min, size_t q_max, size_t * q_out)
|
|
||||||
{
|
|
||||||
return faiss::partition_fuzzy<faiss::CMin<unsigned short, int> >(
|
|
||||||
vals, ids, n, q_min, q_max, q_out);
|
vals, ids, n, q_min, q_max, q_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
%}
|
%}
|
||||||
|
|
||||||
|
%enddef
|
||||||
|
|
||||||
|
INSTANTIATE_uint16_partition_fuzzy(CMin, int64_t)
|
||||||
|
INSTANTIATE_uint16_partition_fuzzy(CMax, int64_t)
|
||||||
|
INSTANTIATE_uint16_partition_fuzzy(CMin, int)
|
||||||
|
INSTANTIATE_uint16_partition_fuzzy(CMax, int)
|
||||||
|
|
||||||
|
// Same for merge_knn_results
|
||||||
|
|
||||||
|
// same define as explicit instanciation in Heap.cpp
|
||||||
|
%define INSTANTIATE_merge_knn_results(C, distance_t)
|
||||||
|
|
||||||
|
%inline %{
|
||||||
|
void merge_knn_results_ ## C(
|
||||||
|
size_t n, size_t k, int nshard,
|
||||||
|
const distance_t *all_distances, const faiss::idx_t *all_labels,
|
||||||
|
distance_t *distances, faiss::idx_t *labels)
|
||||||
|
{
|
||||||
|
faiss::merge_knn_results<faiss::idx_t, faiss::C<distance_t, int>>(
|
||||||
|
n, k, nshard, all_distances, all_labels, distances, labels);
|
||||||
|
}
|
||||||
|
%}
|
||||||
|
|
||||||
|
%enddef
|
||||||
|
|
||||||
|
INSTANTIATE_merge_knn_results(CMin, float);
|
||||||
|
INSTANTIATE_merge_knn_results(CMax, float);
|
||||||
|
INSTANTIATE_merge_knn_results(CMin, int32_t);
|
||||||
|
INSTANTIATE_merge_knn_results(CMax, int32_t);
|
||||||
|
|
||||||
|
|
||||||
/*******************************************************************
|
/*******************************************************************
|
||||||
* Expose a few basic functions
|
* Expose a few basic functions
|
||||||
*******************************************************************/
|
*******************************************************************/
|
||||||
|
|
|
@ -139,4 +139,111 @@ template struct HeapArray<CMax<float, int64_t>>;
|
||||||
template struct HeapArray<CMin<int, int64_t>>;
|
template struct HeapArray<CMin<int, int64_t>>;
|
||||||
template struct HeapArray<CMax<int, int64_t>>;
|
template struct HeapArray<CMax<int, int64_t>>;
|
||||||
|
|
||||||
|
/**********************************************************
|
||||||
|
* merge knn search results
|
||||||
|
**********************************************************/
|
||||||
|
|
||||||
|
/** Merge result tables from several shards. The per-shard results are assumed
|
||||||
|
* to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
|
||||||
|
* element heap because we want the best (ie. lowest for L2) result to be on
|
||||||
|
* top, not the worst.
|
||||||
|
*
|
||||||
|
* @param all_distances size (nshard, n, k)
|
||||||
|
* @param all_labels size (nshard, n, k)
|
||||||
|
* @param distances output distances, size (n, k)
|
||||||
|
* @param labels output labels, size (n, k)
|
||||||
|
*/
|
||||||
|
template <class idx_t, class C>
|
||||||
|
void merge_knn_results(
|
||||||
|
size_t n,
|
||||||
|
size_t k,
|
||||||
|
typename C::TI nshard,
|
||||||
|
const typename C::T* all_distances,
|
||||||
|
const idx_t* all_labels,
|
||||||
|
typename C::T* distances,
|
||||||
|
idx_t* labels) {
|
||||||
|
using distance_t = typename C::T;
|
||||||
|
if (k == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
long stride = n * k;
|
||||||
|
#pragma omp parallel if (n * nshard * k > 100000)
|
||||||
|
{
|
||||||
|
std::vector<int> buf(2 * nshard);
|
||||||
|
// index in each shard's result list
|
||||||
|
int* pointer = buf.data();
|
||||||
|
// (shard_ids, heap_vals): heap that indexes
|
||||||
|
// shard -> current distance for this shard
|
||||||
|
int* shard_ids = pointer + nshard;
|
||||||
|
std::vector<distance_t> buf2(nshard);
|
||||||
|
distance_t* heap_vals = buf2.data();
|
||||||
|
#pragma omp for
|
||||||
|
for (long i = 0; i < n; i++) {
|
||||||
|
// the heap maps values to the shard where they are
|
||||||
|
// produced.
|
||||||
|
const distance_t* D_in = all_distances + i * k;
|
||||||
|
const idx_t* I_in = all_labels + i * k;
|
||||||
|
int heap_size = 0;
|
||||||
|
|
||||||
|
// push the first element of each shard (if not -1)
|
||||||
|
for (long s = 0; s < nshard; s++) {
|
||||||
|
pointer[s] = 0;
|
||||||
|
if (I_in[stride * s] >= 0) {
|
||||||
|
heap_push<C>(
|
||||||
|
++heap_size,
|
||||||
|
heap_vals,
|
||||||
|
shard_ids,
|
||||||
|
D_in[stride * s],
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
distance_t* D = distances + i * k;
|
||||||
|
idx_t* I = labels + i * k;
|
||||||
|
|
||||||
|
int j;
|
||||||
|
for (j = 0; j < k && heap_size > 0; j++) {
|
||||||
|
// pop element from best shard
|
||||||
|
int s = shard_ids[0]; // top of heap
|
||||||
|
int& p = pointer[s];
|
||||||
|
D[j] = heap_vals[0];
|
||||||
|
I[j] = I_in[stride * s + p];
|
||||||
|
|
||||||
|
// pop from shard, advance pointer for this shard
|
||||||
|
heap_pop<C>(heap_size--, heap_vals, shard_ids);
|
||||||
|
p++;
|
||||||
|
if (p < k && I_in[stride * s + p] >= 0) {
|
||||||
|
heap_push<C>(
|
||||||
|
++heap_size,
|
||||||
|
heap_vals,
|
||||||
|
shard_ids,
|
||||||
|
D_in[stride * s + p],
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; j < k; j++) {
|
||||||
|
I[j] = -1;
|
||||||
|
D[j] = C::Crev::neutral();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// explicit instanciations
|
||||||
|
#define INSTANTIATE(C, distance_t) \
|
||||||
|
template void merge_knn_results<int64_t, C<distance_t, int>>( \
|
||||||
|
size_t, \
|
||||||
|
size_t, \
|
||||||
|
int, \
|
||||||
|
const distance_t*, \
|
||||||
|
const int64_t*, \
|
||||||
|
distance_t*, \
|
||||||
|
int64_t*);
|
||||||
|
|
||||||
|
INSTANTIATE(CMin, float);
|
||||||
|
INSTANTIATE(CMax, float);
|
||||||
|
INSTANTIATE(CMin, int32_t);
|
||||||
|
INSTANTIATE(CMax, int32_t);
|
||||||
|
|
||||||
|
|
||||||
} // namespace faiss
|
} // namespace faiss
|
||||||
|
|
|
@ -444,7 +444,7 @@ typedef HeapArray<CMin<int, int64_t>> int_minheap_array_t;
|
||||||
typedef HeapArray<CMax<float, int64_t>> float_maxheap_array_t;
|
typedef HeapArray<CMax<float, int64_t>> float_maxheap_array_t;
|
||||||
typedef HeapArray<CMax<int, int64_t>> int_maxheap_array_t;
|
typedef HeapArray<CMax<int, int64_t>> int_maxheap_array_t;
|
||||||
|
|
||||||
// The heap templates are instanciated explicitly in Heap.cpp
|
// The heap templates are instantiated explicitly in Heap.cpp
|
||||||
|
|
||||||
/*********************************************************************
|
/*********************************************************************
|
||||||
* Indirect heaps: instead of having
|
* Indirect heaps: instead of having
|
||||||
|
@ -505,6 +505,27 @@ inline void indirect_heap_push(
|
||||||
bh_ids[i] = id;
|
bh_ids[i] = id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Merge result tables from several shards. The per-shard results are assumed
|
||||||
|
* to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
|
||||||
|
* element heap because we want the best (ie. lowest for L2) result to be on
|
||||||
|
* top, not the worst. Also, it needs to hold an index of a shard id (ie.
|
||||||
|
* usually int32 is more than enough).
|
||||||
|
*
|
||||||
|
* @param all_distances size (nshard, n, k)
|
||||||
|
* @param all_labels size (nshard, n, k)
|
||||||
|
* @param distances output distances, size (n, k)
|
||||||
|
* @param labels output labels, size (n, k)
|
||||||
|
*/
|
||||||
|
template <class idx_t, class C>
|
||||||
|
void merge_knn_results(
|
||||||
|
size_t n,
|
||||||
|
size_t k,
|
||||||
|
typename C::TI nshard,
|
||||||
|
const typename C::T* all_distances,
|
||||||
|
const idx_t* all_labels,
|
||||||
|
typename C::T* distances,
|
||||||
|
idx_t* labels);
|
||||||
|
|
||||||
} // namespace faiss
|
} // namespace faiss
|
||||||
|
|
||||||
#endif /* FAISS_Heap_h */
|
#endif /* FAISS_Heap_h */
|
||||||
|
|
|
@ -645,3 +645,45 @@ class TestBucketSort(unittest.TestCase):
|
||||||
|
|
||||||
def test_bucket_sort_inplace_parallel_fewbucket(self):
|
def test_bucket_sort_inplace_parallel_fewbucket(self):
|
||||||
self.do_test_bucket_sort_inplace(4, nbucket=5)
|
self.do_test_bucket_sort_inplace(4, nbucket=5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMergeKNNResults(unittest.TestCase):
|
||||||
|
|
||||||
|
def do_test(self, ismax, dtype):
|
||||||
|
rs = np.random.RandomState()
|
||||||
|
n, k, nshard = 10, 5, 3
|
||||||
|
all_ids = rs.randint(100000, size=(nshard, n, k)).astype('int64')
|
||||||
|
all_dis = rs.rand(nshard, n, k)
|
||||||
|
if dtype == 'int32':
|
||||||
|
all_dis = (all_dis * 1000000).astype("int32")
|
||||||
|
else:
|
||||||
|
all_dis = all_dis.astype(dtype)
|
||||||
|
for i in range(nshard):
|
||||||
|
for j in range(n):
|
||||||
|
all_dis[i, j].sort()
|
||||||
|
if ismax:
|
||||||
|
all_dis[i, j] = all_dis[i, j][::-1]
|
||||||
|
Dref = np.zeros((n, k), dtype=dtype)
|
||||||
|
Iref = np.zeros((n, k), dtype='int64')
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
dis = all_dis[:, i, :].ravel()
|
||||||
|
ids = all_ids[:, i, :].ravel()
|
||||||
|
o = dis.argsort()
|
||||||
|
if ismax:
|
||||||
|
o = o[::-1]
|
||||||
|
Dref[i] = dis[o[:k]]
|
||||||
|
Iref[i] = ids[o[:k]]
|
||||||
|
|
||||||
|
Dnew, Inew = faiss.merge_knn_results(all_dis, all_ids, keep_max=ismax)
|
||||||
|
np.testing.assert_array_equal(Dnew, Dref)
|
||||||
|
np.testing.assert_array_equal(Inew, Iref)
|
||||||
|
|
||||||
|
def test_min_float(self):
|
||||||
|
self.do_test(ismax=False, dtype='float32')
|
||||||
|
|
||||||
|
def test_max_int(self):
|
||||||
|
self.do_test(ismax=True, dtype='int32')
|
||||||
|
|
||||||
|
def test_max_float(self):
|
||||||
|
self.do_test(ismax=True, dtype='float32')
|
||||||
|
|
|
@ -687,6 +687,7 @@ class TestSplitMerge(unittest.TestCase):
|
||||||
sub_indexes = [faiss.clone_index(index) for i in range(nsplit)]
|
sub_indexes = [faiss.clone_index(index) for i in range(nsplit)]
|
||||||
index.add(xb)
|
index.add(xb)
|
||||||
Dref, Iref = index.search(xq, 10)
|
Dref, Iref = index.search(xq, 10)
|
||||||
|
nlist = index.nlist
|
||||||
for i in range(nsplit):
|
for i in range(nsplit):
|
||||||
if subset_type in (1, 3):
|
if subset_type in (1, 3):
|
||||||
index.copy_subset_to(sub_indexes[i], subset_type, nsplit, i)
|
index.copy_subset_to(sub_indexes[i], subset_type, nsplit, i)
|
||||||
|
@ -694,6 +695,10 @@ class TestSplitMerge(unittest.TestCase):
|
||||||
j0 = index.ntotal * i // nsplit
|
j0 = index.ntotal * i // nsplit
|
||||||
j1 = index.ntotal * (i + 1) // nsplit
|
j1 = index.ntotal * (i + 1) // nsplit
|
||||||
index.copy_subset_to(sub_indexes[i], subset_type, j0, j1)
|
index.copy_subset_to(sub_indexes[i], subset_type, j0, j1)
|
||||||
|
elif subset_type == 4:
|
||||||
|
index.copy_subset_to(
|
||||||
|
sub_indexes[i], subset_type,
|
||||||
|
i * nlist // nsplit, (i + 1) * nlist // nsplit)
|
||||||
|
|
||||||
index_shards = faiss.IndexShards(False, False)
|
index_shards = faiss.IndexShards(False, False)
|
||||||
for i in range(nsplit):
|
for i in range(nsplit):
|
||||||
|
@ -713,3 +718,6 @@ class TestSplitMerge(unittest.TestCase):
|
||||||
|
|
||||||
def test_Flat_subset_type_3(self):
|
def test_Flat_subset_type_3(self):
|
||||||
self.do_test("IVF30,Flat", subset_type=3)
|
self.do_test("IVF30,Flat", subset_type=3)
|
||||||
|
|
||||||
|
def test_Flat_subset_type_4(self):
|
||||||
|
self.do_test("IVF30,Flat", subset_type=4)
|
||||||
|
|
|
@ -47,10 +47,8 @@ def search_single_scan(index, xq, k, bs=128):
|
||||||
sub_assign[skip_rows, skip_cols] = -1
|
sub_assign[skip_rows, skip_cols] = -1
|
||||||
|
|
||||||
index.search_preassigned(
|
index.search_preassigned(
|
||||||
nq, faiss.swig_ptr(xq), k,
|
xq, k, sub_assign, coarse_dis,
|
||||||
faiss.swig_ptr(sub_assign), faiss.swig_ptr(coarse_dis),
|
D=rh.D, I=rh.I
|
||||||
faiss.swig_ptr(rh.D), faiss.swig_ptr(rh.I),
|
|
||||||
False, None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
rh.finalize()
|
rh.finalize()
|
||||||
|
|
Loading…
Reference in New Issue