add dispatcher for VectorDistance and ResultHandlers
Summary: Add dispatcher function to avoid repeating dispatching code for distance computation and result handlers. Reviewed By: asadoughi Differential Revision: D59318865 fbshipit-source-id: 59046ede02f71a0da3b8061289fc70306bf875cbpull/3649/head
parent
444614b076
commit
261edde514
|
@ -13,8 +13,10 @@
|
|||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/FaissException.h>
|
||||
#include <faiss/impl/IDSelector.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/utils/partitioning.h>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
namespace faiss {
|
||||
|
@ -26,16 +28,21 @@ namespace faiss {
|
|||
* - by instanciating a SingleResultHandler that tracks results for a single
|
||||
* query
|
||||
* - with begin_multiple/add_results/end_multiple calls where a whole block of
|
||||
* resutls is submitted
|
||||
* results is submitted
|
||||
* All classes are templated on C which to define wheter the min or the max of
|
||||
* results is to be kept.
|
||||
* results is to be kept, and on sel, so that the codepaths for with / without
|
||||
* selector can be separated at compile time.
|
||||
*****************************************************************/
|
||||
|
||||
template <class C>
|
||||
template <class C, bool use_sel = false>
|
||||
struct BlockResultHandler {
|
||||
size_t nq; // number of queries for which we search
|
||||
const IDSelector* sel;
|
||||
|
||||
explicit BlockResultHandler(size_t nq) : nq(nq) {}
|
||||
explicit BlockResultHandler(size_t nq, const IDSelector* sel = nullptr)
|
||||
: nq(nq), sel(sel) {
|
||||
assert(!use_sel || sel);
|
||||
}
|
||||
|
||||
// currently handled query range
|
||||
size_t i0 = 0, i1 = 0;
|
||||
|
@ -53,13 +60,17 @@ struct BlockResultHandler {
|
|||
virtual void end_multiple() {}
|
||||
|
||||
virtual ~BlockResultHandler() {}
|
||||
|
||||
bool is_in_selection(idx_t i) const {
|
||||
return !use_sel || sel->is_member(i);
|
||||
}
|
||||
};
|
||||
|
||||
// handler for a single query
|
||||
template <class C>
|
||||
struct ResultHandler {
|
||||
// if not better than threshold, then not necessary to call add_result
|
||||
typename C::T threshold = 0;
|
||||
typename C::T threshold = C::neutral();
|
||||
|
||||
// return whether threshold was updated
|
||||
virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
|
||||
|
@ -73,20 +84,26 @@ struct ResultHandler {
|
|||
* some temporary data in memory.
|
||||
*****************************************************************/
|
||||
|
||||
template <class C>
|
||||
struct Top1BlockResultHandler : BlockResultHandler<C> {
|
||||
template <class C, bool use_sel = false>
|
||||
struct Top1BlockResultHandler : BlockResultHandler<C, use_sel> {
|
||||
using T = typename C::T;
|
||||
using TI = typename C::TI;
|
||||
using BlockResultHandler<C>::i0;
|
||||
using BlockResultHandler<C>::i1;
|
||||
using BlockResultHandler<C, use_sel>::i0;
|
||||
using BlockResultHandler<C, use_sel>::i1;
|
||||
|
||||
// contains exactly nq elements
|
||||
T* dis_tab;
|
||||
// contains exactly nq elements
|
||||
TI* ids_tab;
|
||||
|
||||
Top1BlockResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
|
||||
: BlockResultHandler<C>(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
|
||||
Top1BlockResultHandler(
|
||||
size_t nq,
|
||||
T* dis_tab,
|
||||
TI* ids_tab,
|
||||
const IDSelector* sel = nullptr)
|
||||
: BlockResultHandler<C, use_sel>(nq, sel),
|
||||
dis_tab(dis_tab),
|
||||
ids_tab(ids_tab) {}
|
||||
|
||||
struct SingleResultHandler : ResultHandler<C> {
|
||||
Top1BlockResultHandler& hr;
|
||||
|
@ -165,12 +182,12 @@ struct Top1BlockResultHandler : BlockResultHandler<C> {
|
|||
* Heap based result handler
|
||||
*****************************************************************/
|
||||
|
||||
template <class C>
|
||||
struct HeapBlockResultHandler : BlockResultHandler<C> {
|
||||
template <class C, bool use_sel = false>
|
||||
struct HeapBlockResultHandler : BlockResultHandler<C, use_sel> {
|
||||
using T = typename C::T;
|
||||
using TI = typename C::TI;
|
||||
using BlockResultHandler<C>::i0;
|
||||
using BlockResultHandler<C>::i1;
|
||||
using BlockResultHandler<C, use_sel>::i0;
|
||||
using BlockResultHandler<C, use_sel>::i1;
|
||||
|
||||
T* heap_dis_tab;
|
||||
TI* heap_ids_tab;
|
||||
|
@ -181,8 +198,9 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
|
|||
size_t nq,
|
||||
T* heap_dis_tab,
|
||||
TI* heap_ids_tab,
|
||||
size_t k)
|
||||
: BlockResultHandler<C>(nq),
|
||||
size_t k,
|
||||
const IDSelector* sel = nullptr)
|
||||
: BlockResultHandler<C, use_sel>(nq, sel),
|
||||
heap_dis_tab(heap_dis_tab),
|
||||
heap_ids_tab(heap_ids_tab),
|
||||
k(k) {}
|
||||
|
@ -347,12 +365,12 @@ struct ReservoirTopN : ResultHandler<C> {
|
|||
}
|
||||
};
|
||||
|
||||
template <class C>
|
||||
struct ReservoirBlockResultHandler : BlockResultHandler<C> {
|
||||
template <class C, bool use_sel = false>
|
||||
struct ReservoirBlockResultHandler : BlockResultHandler<C, use_sel> {
|
||||
using T = typename C::T;
|
||||
using TI = typename C::TI;
|
||||
using BlockResultHandler<C>::i0;
|
||||
using BlockResultHandler<C>::i1;
|
||||
using BlockResultHandler<C, use_sel>::i0;
|
||||
using BlockResultHandler<C, use_sel>::i1;
|
||||
|
||||
T* heap_dis_tab;
|
||||
TI* heap_ids_tab;
|
||||
|
@ -364,8 +382,9 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
|
|||
size_t nq,
|
||||
T* heap_dis_tab,
|
||||
TI* heap_ids_tab,
|
||||
size_t k)
|
||||
: BlockResultHandler<C>(nq),
|
||||
size_t k,
|
||||
const IDSelector* sel = nullptr)
|
||||
: BlockResultHandler<C, use_sel>(nq, sel),
|
||||
heap_dis_tab(heap_dis_tab),
|
||||
heap_ids_tab(heap_ids_tab),
|
||||
k(k) {
|
||||
|
@ -460,18 +479,23 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
|
|||
* Result handler for range searches
|
||||
*****************************************************************/
|
||||
|
||||
template <class C>
|
||||
struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
|
||||
template <class C, bool use_sel = false>
|
||||
struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
|
||||
using T = typename C::T;
|
||||
using TI = typename C::TI;
|
||||
using BlockResultHandler<C>::i0;
|
||||
using BlockResultHandler<C>::i1;
|
||||
using BlockResultHandler<C, use_sel>::i0;
|
||||
using BlockResultHandler<C, use_sel>::i1;
|
||||
|
||||
RangeSearchResult* res;
|
||||
T radius;
|
||||
|
||||
RangeSearchBlockResultHandler(RangeSearchResult* res, float radius)
|
||||
: BlockResultHandler<C>(res->nq), res(res), radius(radius) {}
|
||||
RangeSearchBlockResultHandler(
|
||||
RangeSearchResult* res,
|
||||
float radius,
|
||||
const IDSelector* sel = nullptr)
|
||||
: BlockResultHandler<C, use_sel>(res->nq, sel),
|
||||
res(res),
|
||||
radius(radius) {}
|
||||
|
||||
/******************************************************
|
||||
* API for 1 result at a time (each SingleResultHandler is
|
||||
|
@ -582,4 +606,81 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
|
|||
}
|
||||
};
|
||||
|
||||
/*****************************************************************
|
||||
* Dispatcher function to choose the right knn result handler depending on k
|
||||
*****************************************************************/
|
||||
|
||||
// declared in distances.cpp
|
||||
FAISS_API extern int distance_compute_min_k_reservoir;
|
||||
|
||||
template <class Consumer, class... Types>
|
||||
typename Consumer::T dispatch_knn_ResultHandler(
|
||||
size_t nx,
|
||||
float* vals,
|
||||
int64_t* ids,
|
||||
size_t k,
|
||||
MetricType metric,
|
||||
const IDSelector* sel,
|
||||
Consumer& consumer,
|
||||
Types... args) {
|
||||
#define DISPATCH_C_SEL(C, use_sel) \
|
||||
if (k == 1) { \
|
||||
Top1BlockResultHandler<C, use_sel> res(nx, vals, ids, sel); \
|
||||
return consumer.template f<>(res, args...); \
|
||||
} else if (k < distance_compute_min_k_reservoir) { \
|
||||
HeapBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
|
||||
return consumer.template f<>(res, args...); \
|
||||
} else { \
|
||||
ReservoirBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
|
||||
return consumer.template f<>(res, args...); \
|
||||
}
|
||||
|
||||
if (is_similarity_metric(metric)) {
|
||||
using C = CMin<float, int64_t>;
|
||||
if (sel) {
|
||||
DISPATCH_C_SEL(C, true);
|
||||
} else {
|
||||
DISPATCH_C_SEL(C, false);
|
||||
}
|
||||
} else {
|
||||
using C = CMax<float, int64_t>;
|
||||
if (sel) {
|
||||
DISPATCH_C_SEL(C, true);
|
||||
} else {
|
||||
DISPATCH_C_SEL(C, false);
|
||||
}
|
||||
}
|
||||
#undef DISPATCH_C_SEL
|
||||
}
|
||||
|
||||
template <class Consumer, class... Types>
|
||||
typename Consumer::T dispatch_range_ResultHandler(
|
||||
RangeSearchResult* res,
|
||||
float radius,
|
||||
MetricType metric,
|
||||
const IDSelector* sel,
|
||||
Consumer& consumer,
|
||||
Types... args) {
|
||||
#define DISPATCH_C_SEL(C, use_sel) \
|
||||
RangeSearchBlockResultHandler<C, use_sel> resb(res, radius, sel); \
|
||||
return consumer.template f<>(resb, args...);
|
||||
|
||||
if (is_similarity_metric(metric)) {
|
||||
using C = CMin<float, int64_t>;
|
||||
if (sel) {
|
||||
DISPATCH_C_SEL(C, true);
|
||||
} else {
|
||||
DISPATCH_C_SEL(C, false);
|
||||
}
|
||||
} else {
|
||||
using C = CMax<float, int64_t>;
|
||||
if (sel) {
|
||||
DISPATCH_C_SEL(C, true);
|
||||
} else {
|
||||
DISPATCH_C_SEL(C, false);
|
||||
}
|
||||
}
|
||||
#undef DISPATCH_C_SEL
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -130,21 +130,18 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|||
namespace {
|
||||
|
||||
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
||||
template <class BlockResultHandler, bool use_sel = false>
|
||||
template <class BlockResultHandler>
|
||||
void exhaustive_inner_product_seq(
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t nx,
|
||||
size_t ny,
|
||||
BlockResultHandler& res,
|
||||
const IDSelector* sel = nullptr) {
|
||||
BlockResultHandler& res) {
|
||||
using SingleResultHandler =
|
||||
typename BlockResultHandler::SingleResultHandler;
|
||||
[[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads());
|
||||
|
||||
FAISS_ASSERT(use_sel == (sel != nullptr));
|
||||
|
||||
#pragma omp parallel num_threads(nt)
|
||||
{
|
||||
SingleResultHandler resi(res);
|
||||
|
@ -156,7 +153,7 @@ void exhaustive_inner_product_seq(
|
|||
resi.begin(i);
|
||||
|
||||
for (size_t j = 0; j < ny; j++, y_j += d) {
|
||||
if (use_sel && !sel->is_member(j)) {
|
||||
if (!res.is_in_selection(j)) {
|
||||
continue;
|
||||
}
|
||||
float ip = fvec_inner_product(x_i, y_j, d);
|
||||
|
@ -167,21 +164,18 @@ void exhaustive_inner_product_seq(
|
|||
}
|
||||
}
|
||||
|
||||
template <class BlockResultHandler, bool use_sel = false>
|
||||
template <class BlockResultHandler>
|
||||
void exhaustive_L2sqr_seq(
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t nx,
|
||||
size_t ny,
|
||||
BlockResultHandler& res,
|
||||
const IDSelector* sel = nullptr) {
|
||||
BlockResultHandler& res) {
|
||||
using SingleResultHandler =
|
||||
typename BlockResultHandler::SingleResultHandler;
|
||||
[[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads());
|
||||
|
||||
FAISS_ASSERT(use_sel == (sel != nullptr));
|
||||
|
||||
#pragma omp parallel num_threads(nt)
|
||||
{
|
||||
SingleResultHandler resi(res);
|
||||
|
@ -191,7 +185,7 @@ void exhaustive_L2sqr_seq(
|
|||
const float* y_j = y;
|
||||
resi.begin(i);
|
||||
for (size_t j = 0; j < ny; j++, y_j += d) {
|
||||
if (use_sel && !sel->is_member(j)) {
|
||||
if (!res.is_in_selection(j)) {
|
||||
continue;
|
||||
}
|
||||
float disij = fvec_L2sqr(x_i, y_j, d);
|
||||
|
@ -326,6 +320,9 @@ void exhaustive_L2sqr_blas_default_impl(
|
|||
float ip = *ip_line;
|
||||
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
||||
|
||||
if (!res.is_in_selection(j)) {
|
||||
dis = HUGE_VALF;
|
||||
}
|
||||
// negative values can occur for identical vectors
|
||||
// due to roundoff errors
|
||||
if (dis < 0)
|
||||
|
@ -601,44 +598,40 @@ void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|||
#endif
|
||||
}
|
||||
|
||||
template <class BlockResultHandler>
|
||||
void knn_L2sqr_select(
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t nx,
|
||||
size_t ny,
|
||||
BlockResultHandler& res,
|
||||
const float* y_norm2,
|
||||
const IDSelector* sel) {
|
||||
if (sel) {
|
||||
exhaustive_L2sqr_seq<BlockResultHandler, true>(
|
||||
x, y, d, nx, ny, res, sel);
|
||||
} else if (nx < distance_compute_blas_threshold) {
|
||||
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
||||
} else {
|
||||
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
||||
struct Run_search_inner_product {
|
||||
using T = void;
|
||||
template <class BlockResultHandler>
|
||||
void f(BlockResultHandler& res,
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t nx,
|
||||
size_t ny) {
|
||||
if (res.sel || nx < distance_compute_blas_threshold) {
|
||||
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
||||
} else {
|
||||
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class BlockResultHandler>
|
||||
void knn_inner_product_select(
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t nx,
|
||||
size_t ny,
|
||||
BlockResultHandler& res,
|
||||
const IDSelector* sel) {
|
||||
if (sel) {
|
||||
exhaustive_inner_product_seq<BlockResultHandler, true>(
|
||||
x, y, d, nx, ny, res, sel);
|
||||
} else if (nx < distance_compute_blas_threshold) {
|
||||
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
||||
} else {
|
||||
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
||||
struct Run_search_L2sqr {
|
||||
using T = void;
|
||||
template <class BlockResultHandler>
|
||||
void f(BlockResultHandler& res,
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t nx,
|
||||
size_t ny,
|
||||
const float* y_norm2) {
|
||||
if (res.sel || nx < distance_compute_blas_threshold) {
|
||||
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
||||
} else {
|
||||
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
@ -675,16 +668,9 @@ void knn_inner_product(
|
|||
return;
|
||||
}
|
||||
|
||||
if (k == 1) {
|
||||
Top1BlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids);
|
||||
knn_inner_product_select(x, y, d, nx, ny, res, sel);
|
||||
} else if (k < distance_compute_min_k_reservoir) {
|
||||
HeapBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
|
||||
knn_inner_product_select(x, y, d, nx, ny, res, sel);
|
||||
} else {
|
||||
ReservoirBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
|
||||
knn_inner_product_select(x, y, d, nx, ny, res, sel);
|
||||
}
|
||||
Run_search_inner_product r;
|
||||
dispatch_knn_ResultHandler(
|
||||
nx, vals, ids, k, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny);
|
||||
|
||||
if (imin != 0) {
|
||||
for (size_t i = 0; i < nx * k; i++) {
|
||||
|
@ -730,16 +716,11 @@ void knn_L2sqr(
|
|||
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
|
||||
return;
|
||||
}
|
||||
if (k == 1) {
|
||||
Top1BlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
|
||||
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
||||
} else if (k < distance_compute_min_k_reservoir) {
|
||||
HeapBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
||||
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
||||
} else {
|
||||
ReservoirBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
||||
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
||||
}
|
||||
|
||||
Run_search_L2sqr r;
|
||||
dispatch_knn_ResultHandler(
|
||||
nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
|
||||
|
||||
if (imin != 0) {
|
||||
for (size_t i = 0; i < nx * k; i++) {
|
||||
if (ids[i] >= 0) {
|
||||
|
@ -766,6 +747,7 @@ void knn_L2sqr(
|
|||
* Range search
|
||||
***************************************************************************/
|
||||
|
||||
// TODO accept a y_norm2 as well
|
||||
void range_search_L2sqr(
|
||||
const float* x,
|
||||
const float* y,
|
||||
|
@ -775,15 +757,9 @@ void range_search_L2sqr(
|
|||
float radius,
|
||||
RangeSearchResult* res,
|
||||
const IDSelector* sel) {
|
||||
using RH = RangeSearchBlockResultHandler<CMax<float, int64_t>>;
|
||||
RH resh(res, radius);
|
||||
if (sel) {
|
||||
exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
||||
} else if (nx < distance_compute_blas_threshold) {
|
||||
exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel);
|
||||
} else {
|
||||
exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
|
||||
}
|
||||
Run_search_L2sqr r;
|
||||
dispatch_range_ResultHandler(
|
||||
res, radius, METRIC_L2, sel, r, x, y, d, nx, ny, nullptr);
|
||||
}
|
||||
|
||||
void range_search_inner_product(
|
||||
|
@ -795,15 +771,9 @@ void range_search_inner_product(
|
|||
float radius,
|
||||
RangeSearchResult* res,
|
||||
const IDSelector* sel) {
|
||||
using RH = RangeSearchBlockResultHandler<CMin<float, int64_t>>;
|
||||
RH resh(res, radius);
|
||||
if (sel) {
|
||||
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
||||
} else if (nx < distance_compute_blas_threshold) {
|
||||
exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
|
||||
} else {
|
||||
exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
|
||||
}
|
||||
Run_search_inner_product r;
|
||||
dispatch_range_ResultHandler(
|
||||
res, radius, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny);
|
||||
}
|
||||
|
||||
/***************************************************************************
|
||||
|
|
|
@ -162,4 +162,39 @@ inline float VectorDistance<METRIC_ABS_INNER_PRODUCT>::operator()(
|
|||
return accu;
|
||||
}
|
||||
|
||||
/***************************************************************************
|
||||
* Dispatching function that takes a metric type and a consumer object
|
||||
* the consumer object should contain a retun type T and a operation template
|
||||
* function f() that is called to perform the operation. The first argument
|
||||
* of the function is the VectorDistance object. The rest are passed in as is.
|
||||
**************************************************************************/
|
||||
|
||||
template <class Consumer, class... Types>
|
||||
typename Consumer::T dispatch_VectorDistance(
|
||||
size_t d,
|
||||
MetricType metric,
|
||||
float metric_arg,
|
||||
Consumer& consumer,
|
||||
Types... args) {
|
||||
switch (metric) {
|
||||
#define DISPATCH_VD(mt) \
|
||||
case mt: { \
|
||||
VectorDistance<mt> vd = {d, metric_arg}; \
|
||||
return consumer.template f<VectorDistance<mt>>(vd, args...); \
|
||||
}
|
||||
DISPATCH_VD(METRIC_INNER_PRODUCT);
|
||||
DISPATCH_VD(METRIC_L2);
|
||||
DISPATCH_VD(METRIC_L1);
|
||||
DISPATCH_VD(METRIC_Linf);
|
||||
DISPATCH_VD(METRIC_Lp);
|
||||
DISPATCH_VD(METRIC_Canberra);
|
||||
DISPATCH_VD(METRIC_BrayCurtis);
|
||||
DISPATCH_VD(METRIC_JensenShannon);
|
||||
DISPATCH_VD(METRIC_Jaccard);
|
||||
DISPATCH_VD(METRIC_NaNEuclidean);
|
||||
DISPATCH_VD(METRIC_ABS_INNER_PRODUCT);
|
||||
}
|
||||
#undef DISPATCH_VD
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -26,72 +26,77 @@ namespace faiss {
|
|||
|
||||
namespace {
|
||||
|
||||
template <class VD>
|
||||
void pairwise_extra_distances_template(
|
||||
VD vd,
|
||||
int64_t nq,
|
||||
const float* xq,
|
||||
int64_t nb,
|
||||
const float* xb,
|
||||
float* dis,
|
||||
int64_t ldq,
|
||||
int64_t ldb,
|
||||
int64_t ldd) {
|
||||
#pragma omp parallel for if (nq > 10)
|
||||
for (int64_t i = 0; i < nq; i++) {
|
||||
const float* xqi = xq + i * ldq;
|
||||
const float* xbj = xb;
|
||||
float* disi = dis + ldd * i;
|
||||
struct Run_pairwise_extra_distances {
|
||||
using T = void;
|
||||
|
||||
for (int64_t j = 0; j < nb; j++) {
|
||||
disi[j] = vd(xqi, xbj);
|
||||
xbj += ldb;
|
||||
template <class VD>
|
||||
void f(VD vd,
|
||||
int64_t nq,
|
||||
const float* xq,
|
||||
int64_t nb,
|
||||
const float* xb,
|
||||
float* dis,
|
||||
int64_t ldq,
|
||||
int64_t ldb,
|
||||
int64_t ldd) {
|
||||
#pragma omp parallel for if (nq > 10)
|
||||
for (int64_t i = 0; i < nq; i++) {
|
||||
const float* xqi = xq + i * ldq;
|
||||
const float* xbj = xb;
|
||||
float* disi = dis + ldd * i;
|
||||
|
||||
for (int64_t j = 0; j < nb; j++) {
|
||||
disi[j] = vd(xqi, xbj);
|
||||
xbj += ldb;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class VD>
|
||||
void knn_extra_metrics_template(
|
||||
VD vd,
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t nx,
|
||||
size_t ny,
|
||||
size_t k,
|
||||
float* distances,
|
||||
int64_t* labels) {
|
||||
size_t d = vd.d;
|
||||
using C = typename VD::C;
|
||||
size_t check_period = InterruptCallback::get_period_hint(ny * d);
|
||||
check_period *= omp_get_max_threads();
|
||||
struct Run_knn_extra_metrics {
|
||||
using T = void;
|
||||
template <class VD>
|
||||
void f(VD vd,
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t nx,
|
||||
size_t ny,
|
||||
size_t k,
|
||||
float* distances,
|
||||
int64_t* labels) {
|
||||
size_t d = vd.d;
|
||||
using C = typename VD::C;
|
||||
size_t check_period = InterruptCallback::get_period_hint(ny * d);
|
||||
check_period *= omp_get_max_threads();
|
||||
|
||||
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
||||
size_t i1 = std::min(i0 + check_period, nx);
|
||||
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
||||
size_t i1 = std::min(i0 + check_period, nx);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = i0; i < i1; i++) {
|
||||
const float* x_i = x + i * d;
|
||||
const float* y_j = y;
|
||||
size_t j;
|
||||
float* simi = distances + k * i;
|
||||
int64_t* idxi = labels + k * i;
|
||||
for (int64_t i = i0; i < i1; i++) {
|
||||
const float* x_i = x + i * d;
|
||||
const float* y_j = y;
|
||||
size_t j;
|
||||
float* simi = distances + k * i;
|
||||
int64_t* idxi = labels + k * i;
|
||||
|
||||
// maxheap_heapify(k, simi, idxi);
|
||||
heap_heapify<C>(k, simi, idxi);
|
||||
for (j = 0; j < ny; j++) {
|
||||
float disij = vd(x_i, y_j);
|
||||
// maxheap_heapify(k, simi, idxi);
|
||||
heap_heapify<C>(k, simi, idxi);
|
||||
for (j = 0; j < ny; j++) {
|
||||
float disij = vd(x_i, y_j);
|
||||
|
||||
if (C::cmp(simi[0], disij)) {
|
||||
heap_replace_top<C>(k, simi, idxi, disij, j);
|
||||
if (C::cmp(simi[0], disij)) {
|
||||
heap_replace_top<C>(k, simi, idxi, disij, j);
|
||||
}
|
||||
y_j += d;
|
||||
}
|
||||
y_j += d;
|
||||
// maxheap_reorder(k, simi, idxi);
|
||||
heap_reorder<C>(k, simi, idxi);
|
||||
}
|
||||
// maxheap_reorder(k, simi, idxi);
|
||||
heap_reorder<C>(k, simi, idxi);
|
||||
InterruptCallback::check();
|
||||
}
|
||||
InterruptCallback::check();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class VD>
|
||||
struct ExtraDistanceComputer : FlatCodesDistanceComputer {
|
||||
|
@ -124,6 +129,19 @@ struct ExtraDistanceComputer : FlatCodesDistanceComputer {
|
|||
}
|
||||
};
|
||||
|
||||
struct Run_get_distance_computer {
|
||||
using T = FlatCodesDistanceComputer*;
|
||||
|
||||
template <class VD>
|
||||
FlatCodesDistanceComputer* f(
|
||||
VD vd,
|
||||
const float* xb,
|
||||
size_t nb,
|
||||
const float* q = nullptr) {
|
||||
return new ExtraDistanceComputer<VD>(vd, xb, nb, q);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void pairwise_extra_distances(
|
||||
|
@ -147,28 +165,9 @@ void pairwise_extra_distances(
|
|||
if (ldd == -1)
|
||||
ldd = nb;
|
||||
|
||||
switch (mt) {
|
||||
#define HANDLE_VAR(kw) \
|
||||
case METRIC_##kw: { \
|
||||
VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
|
||||
pairwise_extra_distances_template( \
|
||||
vd, nq, xq, nb, xb, dis, ldq, ldb, ldd); \
|
||||
break; \
|
||||
}
|
||||
HANDLE_VAR(L2);
|
||||
HANDLE_VAR(L1);
|
||||
HANDLE_VAR(Linf);
|
||||
HANDLE_VAR(Canberra);
|
||||
HANDLE_VAR(BrayCurtis);
|
||||
HANDLE_VAR(JensenShannon);
|
||||
HANDLE_VAR(Lp);
|
||||
HANDLE_VAR(Jaccard);
|
||||
HANDLE_VAR(NaNEuclidean);
|
||||
HANDLE_VAR(ABS_INNER_PRODUCT);
|
||||
#undef HANDLE_VAR
|
||||
default:
|
||||
FAISS_THROW_MSG("metric type not implemented");
|
||||
}
|
||||
Run_pairwise_extra_distances run;
|
||||
dispatch_VectorDistance(
|
||||
d, mt, metric_arg, run, nq, xq, nb, xb, dis, ldq, ldb, ldd);
|
||||
}
|
||||
|
||||
void knn_extra_metrics(
|
||||
|
@ -182,27 +181,9 @@ void knn_extra_metrics(
|
|||
size_t k,
|
||||
float* distances,
|
||||
int64_t* indexes) {
|
||||
switch (mt) {
|
||||
#define HANDLE_VAR(kw) \
|
||||
case METRIC_##kw: { \
|
||||
VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
|
||||
knn_extra_metrics_template(vd, x, y, nx, ny, k, distances, indexes); \
|
||||
break; \
|
||||
}
|
||||
HANDLE_VAR(L2);
|
||||
HANDLE_VAR(L1);
|
||||
HANDLE_VAR(Linf);
|
||||
HANDLE_VAR(Canberra);
|
||||
HANDLE_VAR(BrayCurtis);
|
||||
HANDLE_VAR(JensenShannon);
|
||||
HANDLE_VAR(Lp);
|
||||
HANDLE_VAR(Jaccard);
|
||||
HANDLE_VAR(NaNEuclidean);
|
||||
HANDLE_VAR(ABS_INNER_PRODUCT);
|
||||
#undef HANDLE_VAR
|
||||
default:
|
||||
FAISS_THROW_MSG("metric type not implemented");
|
||||
}
|
||||
Run_knn_extra_metrics run;
|
||||
dispatch_VectorDistance(
|
||||
d, mt, metric_arg, run, x, y, nx, ny, k, distances, indexes);
|
||||
}
|
||||
|
||||
FlatCodesDistanceComputer* get_extra_distance_computer(
|
||||
|
@ -211,27 +192,8 @@ FlatCodesDistanceComputer* get_extra_distance_computer(
|
|||
float metric_arg,
|
||||
size_t nb,
|
||||
const float* xb) {
|
||||
switch (mt) {
|
||||
#define HANDLE_VAR(kw) \
|
||||
case METRIC_##kw: { \
|
||||
VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
|
||||
return new ExtraDistanceComputer<VectorDistance<METRIC_##kw>>( \
|
||||
vd, xb, nb); \
|
||||
}
|
||||
HANDLE_VAR(L2);
|
||||
HANDLE_VAR(L1);
|
||||
HANDLE_VAR(Linf);
|
||||
HANDLE_VAR(Canberra);
|
||||
HANDLE_VAR(BrayCurtis);
|
||||
HANDLE_VAR(JensenShannon);
|
||||
HANDLE_VAR(Lp);
|
||||
HANDLE_VAR(Jaccard);
|
||||
HANDLE_VAR(NaNEuclidean);
|
||||
HANDLE_VAR(ABS_INNER_PRODUCT);
|
||||
#undef HANDLE_VAR
|
||||
default:
|
||||
FAISS_THROW_MSG("metric type not implemented");
|
||||
}
|
||||
Run_get_distance_computer run;
|
||||
return dispatch_VectorDistance(d, mt, metric_arg, run, xb, nb);
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -55,7 +55,7 @@ SPECIALIZED_HC(64);
|
|||
/***************************************************************************
|
||||
* Dispatching function that takes a code size and a consumer object
|
||||
* the consumer object should contain a retun type t and a operation template
|
||||
* function f() that to be called to perform the operation.
|
||||
* function f() that must be called to perform the operation.
|
||||
**************************************************************************/
|
||||
|
||||
template <class Consumer, class... Types>
|
||||
|
@ -76,6 +76,7 @@ typename Consumer::T dispatch_HammingComputer(
|
|||
default:
|
||||
return consumer.template f<HammingComputerDefault>(args...);
|
||||
}
|
||||
#undef DISPATCH_HC
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
Loading…
Reference in New Issue