add dispatcher for VectorDistance and ResultHandlers

Summary: Add dispatcher function to avoid repeating dispatching code for distance computation and result handlers.

Reviewed By: asadoughi

Differential Revision: D59318865

fbshipit-source-id: 59046ede02f71a0da3b8061289fc70306bf875cb
pull/3649/head
Matthijs Douze 2024-07-11 02:40:38 -07:00 committed by Facebook GitHub Bot
parent 444614b076
commit 261edde514
5 changed files with 301 additions and 232 deletions

View File

@ -13,8 +13,10 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissException.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/partitioning.h>
#include <algorithm>
#include <iostream>
namespace faiss {
@ -26,16 +28,21 @@ namespace faiss {
* - by instanciating a SingleResultHandler that tracks results for a single
* query
* - with begin_multiple/add_results/end_multiple calls where a whole block of
* resutls is submitted
* results is submitted
* All classes are templated on C which to define wheter the min or the max of
* results is to be kept.
* results is to be kept, and on sel, so that the codepaths for with / without
* selector can be separated at compile time.
*****************************************************************/
template <class C>
template <class C, bool use_sel = false>
struct BlockResultHandler {
size_t nq; // number of queries for which we search
const IDSelector* sel;
explicit BlockResultHandler(size_t nq) : nq(nq) {}
explicit BlockResultHandler(size_t nq, const IDSelector* sel = nullptr)
: nq(nq), sel(sel) {
assert(!use_sel || sel);
}
// currently handled query range
size_t i0 = 0, i1 = 0;
@ -53,13 +60,17 @@ struct BlockResultHandler {
virtual void end_multiple() {}
virtual ~BlockResultHandler() {}
bool is_in_selection(idx_t i) const {
return !use_sel || sel->is_member(i);
}
};
// handler for a single query
template <class C>
struct ResultHandler {
// if not better than threshold, then not necessary to call add_result
typename C::T threshold = 0;
typename C::T threshold = C::neutral();
// return whether threshold was updated
virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
@ -73,20 +84,26 @@ struct ResultHandler {
* some temporary data in memory.
*****************************************************************/
template <class C>
struct Top1BlockResultHandler : BlockResultHandler<C> {
template <class C, bool use_sel = false>
struct Top1BlockResultHandler : BlockResultHandler<C, use_sel> {
using T = typename C::T;
using TI = typename C::TI;
using BlockResultHandler<C>::i0;
using BlockResultHandler<C>::i1;
using BlockResultHandler<C, use_sel>::i0;
using BlockResultHandler<C, use_sel>::i1;
// contains exactly nq elements
T* dis_tab;
// contains exactly nq elements
TI* ids_tab;
Top1BlockResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
: BlockResultHandler<C>(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
Top1BlockResultHandler(
size_t nq,
T* dis_tab,
TI* ids_tab,
const IDSelector* sel = nullptr)
: BlockResultHandler<C, use_sel>(nq, sel),
dis_tab(dis_tab),
ids_tab(ids_tab) {}
struct SingleResultHandler : ResultHandler<C> {
Top1BlockResultHandler& hr;
@ -165,12 +182,12 @@ struct Top1BlockResultHandler : BlockResultHandler<C> {
* Heap based result handler
*****************************************************************/
template <class C>
struct HeapBlockResultHandler : BlockResultHandler<C> {
template <class C, bool use_sel = false>
struct HeapBlockResultHandler : BlockResultHandler<C, use_sel> {
using T = typename C::T;
using TI = typename C::TI;
using BlockResultHandler<C>::i0;
using BlockResultHandler<C>::i1;
using BlockResultHandler<C, use_sel>::i0;
using BlockResultHandler<C, use_sel>::i1;
T* heap_dis_tab;
TI* heap_ids_tab;
@ -181,8 +198,9 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
size_t nq,
T* heap_dis_tab,
TI* heap_ids_tab,
size_t k)
: BlockResultHandler<C>(nq),
size_t k,
const IDSelector* sel = nullptr)
: BlockResultHandler<C, use_sel>(nq, sel),
heap_dis_tab(heap_dis_tab),
heap_ids_tab(heap_ids_tab),
k(k) {}
@ -347,12 +365,12 @@ struct ReservoirTopN : ResultHandler<C> {
}
};
template <class C>
struct ReservoirBlockResultHandler : BlockResultHandler<C> {
template <class C, bool use_sel = false>
struct ReservoirBlockResultHandler : BlockResultHandler<C, use_sel> {
using T = typename C::T;
using TI = typename C::TI;
using BlockResultHandler<C>::i0;
using BlockResultHandler<C>::i1;
using BlockResultHandler<C, use_sel>::i0;
using BlockResultHandler<C, use_sel>::i1;
T* heap_dis_tab;
TI* heap_ids_tab;
@ -364,8 +382,9 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
size_t nq,
T* heap_dis_tab,
TI* heap_ids_tab,
size_t k)
: BlockResultHandler<C>(nq),
size_t k,
const IDSelector* sel = nullptr)
: BlockResultHandler<C, use_sel>(nq, sel),
heap_dis_tab(heap_dis_tab),
heap_ids_tab(heap_ids_tab),
k(k) {
@ -460,18 +479,23 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
* Result handler for range searches
*****************************************************************/
template <class C>
struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
template <class C, bool use_sel = false>
struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
using T = typename C::T;
using TI = typename C::TI;
using BlockResultHandler<C>::i0;
using BlockResultHandler<C>::i1;
using BlockResultHandler<C, use_sel>::i0;
using BlockResultHandler<C, use_sel>::i1;
RangeSearchResult* res;
T radius;
RangeSearchBlockResultHandler(RangeSearchResult* res, float radius)
: BlockResultHandler<C>(res->nq), res(res), radius(radius) {}
RangeSearchBlockResultHandler(
RangeSearchResult* res,
float radius,
const IDSelector* sel = nullptr)
: BlockResultHandler<C, use_sel>(res->nq, sel),
res(res),
radius(radius) {}
/******************************************************
* API for 1 result at a time (each SingleResultHandler is
@ -582,4 +606,81 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
}
};
/*****************************************************************
* Dispatcher function to choose the right knn result handler depending on k
*****************************************************************/
// declared in distances.cpp
FAISS_API extern int distance_compute_min_k_reservoir;
template <class Consumer, class... Types>
typename Consumer::T dispatch_knn_ResultHandler(
size_t nx,
float* vals,
int64_t* ids,
size_t k,
MetricType metric,
const IDSelector* sel,
Consumer& consumer,
Types... args) {
#define DISPATCH_C_SEL(C, use_sel) \
if (k == 1) { \
Top1BlockResultHandler<C, use_sel> res(nx, vals, ids, sel); \
return consumer.template f<>(res, args...); \
} else if (k < distance_compute_min_k_reservoir) { \
HeapBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
return consumer.template f<>(res, args...); \
} else { \
ReservoirBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
return consumer.template f<>(res, args...); \
}
if (is_similarity_metric(metric)) {
using C = CMin<float, int64_t>;
if (sel) {
DISPATCH_C_SEL(C, true);
} else {
DISPATCH_C_SEL(C, false);
}
} else {
using C = CMax<float, int64_t>;
if (sel) {
DISPATCH_C_SEL(C, true);
} else {
DISPATCH_C_SEL(C, false);
}
}
#undef DISPATCH_C_SEL
}
template <class Consumer, class... Types>
typename Consumer::T dispatch_range_ResultHandler(
RangeSearchResult* res,
float radius,
MetricType metric,
const IDSelector* sel,
Consumer& consumer,
Types... args) {
#define DISPATCH_C_SEL(C, use_sel) \
RangeSearchBlockResultHandler<C, use_sel> resb(res, radius, sel); \
return consumer.template f<>(resb, args...);
if (is_similarity_metric(metric)) {
using C = CMin<float, int64_t>;
if (sel) {
DISPATCH_C_SEL(C, true);
} else {
DISPATCH_C_SEL(C, false);
}
} else {
using C = CMax<float, int64_t>;
if (sel) {
DISPATCH_C_SEL(C, true);
} else {
DISPATCH_C_SEL(C, false);
}
}
#undef DISPATCH_C_SEL
}
} // namespace faiss

View File

@ -130,21 +130,18 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
namespace {
/* Find the nearest neighbors for nx queries in a set of ny vectors */
template <class BlockResultHandler, bool use_sel = false>
template <class BlockResultHandler>
void exhaustive_inner_product_seq(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
BlockResultHandler& res,
const IDSelector* sel = nullptr) {
BlockResultHandler& res) {
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
[[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads());
FAISS_ASSERT(use_sel == (sel != nullptr));
#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
@ -156,7 +153,7 @@ void exhaustive_inner_product_seq(
resi.begin(i);
for (size_t j = 0; j < ny; j++, y_j += d) {
if (use_sel && !sel->is_member(j)) {
if (!res.is_in_selection(j)) {
continue;
}
float ip = fvec_inner_product(x_i, y_j, d);
@ -167,21 +164,18 @@ void exhaustive_inner_product_seq(
}
}
template <class BlockResultHandler, bool use_sel = false>
template <class BlockResultHandler>
void exhaustive_L2sqr_seq(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
BlockResultHandler& res,
const IDSelector* sel = nullptr) {
BlockResultHandler& res) {
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
[[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads());
FAISS_ASSERT(use_sel == (sel != nullptr));
#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
@ -191,7 +185,7 @@ void exhaustive_L2sqr_seq(
const float* y_j = y;
resi.begin(i);
for (size_t j = 0; j < ny; j++, y_j += d) {
if (use_sel && !sel->is_member(j)) {
if (!res.is_in_selection(j)) {
continue;
}
float disij = fvec_L2sqr(x_i, y_j, d);
@ -326,6 +320,9 @@ void exhaustive_L2sqr_blas_default_impl(
float ip = *ip_line;
float dis = x_norms[i] + y_norms[j] - 2 * ip;
if (!res.is_in_selection(j)) {
dis = HUGE_VALF;
}
// negative values can occur for identical vectors
// due to roundoff errors
if (dis < 0)
@ -601,44 +598,40 @@ void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
#endif
}
template <class BlockResultHandler>
void knn_L2sqr_select(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
BlockResultHandler& res,
const float* y_norm2,
const IDSelector* sel) {
if (sel) {
exhaustive_L2sqr_seq<BlockResultHandler, true>(
x, y, d, nx, ny, res, sel);
} else if (nx < distance_compute_blas_threshold) {
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
} else {
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
struct Run_search_inner_product {
using T = void;
template <class BlockResultHandler>
void f(BlockResultHandler& res,
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny) {
if (res.sel || nx < distance_compute_blas_threshold) {
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
} else {
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
}
}
}
};
template <class BlockResultHandler>
void knn_inner_product_select(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
BlockResultHandler& res,
const IDSelector* sel) {
if (sel) {
exhaustive_inner_product_seq<BlockResultHandler, true>(
x, y, d, nx, ny, res, sel);
} else if (nx < distance_compute_blas_threshold) {
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
} else {
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
struct Run_search_L2sqr {
using T = void;
template <class BlockResultHandler>
void f(BlockResultHandler& res,
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
const float* y_norm2) {
if (res.sel || nx < distance_compute_blas_threshold) {
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
} else {
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
}
}
}
};
} // anonymous namespace
@ -675,16 +668,9 @@ void knn_inner_product(
return;
}
if (k == 1) {
Top1BlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids);
knn_inner_product_select(x, y, d, nx, ny, res, sel);
} else if (k < distance_compute_min_k_reservoir) {
HeapBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
knn_inner_product_select(x, y, d, nx, ny, res, sel);
} else {
ReservoirBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
knn_inner_product_select(x, y, d, nx, ny, res, sel);
}
Run_search_inner_product r;
dispatch_knn_ResultHandler(
nx, vals, ids, k, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny);
if (imin != 0) {
for (size_t i = 0; i < nx * k; i++) {
@ -730,16 +716,11 @@ void knn_L2sqr(
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
return;
}
if (k == 1) {
Top1BlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
} else if (k < distance_compute_min_k_reservoir) {
HeapBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
} else {
ReservoirBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
}
Run_search_L2sqr r;
dispatch_knn_ResultHandler(
nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
if (imin != 0) {
for (size_t i = 0; i < nx * k; i++) {
if (ids[i] >= 0) {
@ -766,6 +747,7 @@ void knn_L2sqr(
* Range search
***************************************************************************/
// TODO accept a y_norm2 as well
void range_search_L2sqr(
const float* x,
const float* y,
@ -775,15 +757,9 @@ void range_search_L2sqr(
float radius,
RangeSearchResult* res,
const IDSelector* sel) {
using RH = RangeSearchBlockResultHandler<CMax<float, int64_t>>;
RH resh(res, radius);
if (sel) {
exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
} else if (nx < distance_compute_blas_threshold) {
exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel);
} else {
exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
}
Run_search_L2sqr r;
dispatch_range_ResultHandler(
res, radius, METRIC_L2, sel, r, x, y, d, nx, ny, nullptr);
}
void range_search_inner_product(
@ -795,15 +771,9 @@ void range_search_inner_product(
float radius,
RangeSearchResult* res,
const IDSelector* sel) {
using RH = RangeSearchBlockResultHandler<CMin<float, int64_t>>;
RH resh(res, radius);
if (sel) {
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
} else if (nx < distance_compute_blas_threshold) {
exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
} else {
exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
}
Run_search_inner_product r;
dispatch_range_ResultHandler(
res, radius, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny);
}
/***************************************************************************

View File

@ -162,4 +162,39 @@ inline float VectorDistance<METRIC_ABS_INNER_PRODUCT>::operator()(
return accu;
}
/***************************************************************************
* Dispatching function that takes a metric type and a consumer object
* the consumer object should contain a retun type T and a operation template
* function f() that is called to perform the operation. The first argument
* of the function is the VectorDistance object. The rest are passed in as is.
**************************************************************************/
template <class Consumer, class... Types>
typename Consumer::T dispatch_VectorDistance(
size_t d,
MetricType metric,
float metric_arg,
Consumer& consumer,
Types... args) {
switch (metric) {
#define DISPATCH_VD(mt) \
case mt: { \
VectorDistance<mt> vd = {d, metric_arg}; \
return consumer.template f<VectorDistance<mt>>(vd, args...); \
}
DISPATCH_VD(METRIC_INNER_PRODUCT);
DISPATCH_VD(METRIC_L2);
DISPATCH_VD(METRIC_L1);
DISPATCH_VD(METRIC_Linf);
DISPATCH_VD(METRIC_Lp);
DISPATCH_VD(METRIC_Canberra);
DISPATCH_VD(METRIC_BrayCurtis);
DISPATCH_VD(METRIC_JensenShannon);
DISPATCH_VD(METRIC_Jaccard);
DISPATCH_VD(METRIC_NaNEuclidean);
DISPATCH_VD(METRIC_ABS_INNER_PRODUCT);
}
#undef DISPATCH_VD
}
} // namespace faiss

View File

@ -26,72 +26,77 @@ namespace faiss {
namespace {
template <class VD>
void pairwise_extra_distances_template(
VD vd,
int64_t nq,
const float* xq,
int64_t nb,
const float* xb,
float* dis,
int64_t ldq,
int64_t ldb,
int64_t ldd) {
#pragma omp parallel for if (nq > 10)
for (int64_t i = 0; i < nq; i++) {
const float* xqi = xq + i * ldq;
const float* xbj = xb;
float* disi = dis + ldd * i;
struct Run_pairwise_extra_distances {
using T = void;
for (int64_t j = 0; j < nb; j++) {
disi[j] = vd(xqi, xbj);
xbj += ldb;
template <class VD>
void f(VD vd,
int64_t nq,
const float* xq,
int64_t nb,
const float* xb,
float* dis,
int64_t ldq,
int64_t ldb,
int64_t ldd) {
#pragma omp parallel for if (nq > 10)
for (int64_t i = 0; i < nq; i++) {
const float* xqi = xq + i * ldq;
const float* xbj = xb;
float* disi = dis + ldd * i;
for (int64_t j = 0; j < nb; j++) {
disi[j] = vd(xqi, xbj);
xbj += ldb;
}
}
}
}
};
template <class VD>
void knn_extra_metrics_template(
VD vd,
const float* x,
const float* y,
size_t nx,
size_t ny,
size_t k,
float* distances,
int64_t* labels) {
size_t d = vd.d;
using C = typename VD::C;
size_t check_period = InterruptCallback::get_period_hint(ny * d);
check_period *= omp_get_max_threads();
struct Run_knn_extra_metrics {
using T = void;
template <class VD>
void f(VD vd,
const float* x,
const float* y,
size_t nx,
size_t ny,
size_t k,
float* distances,
int64_t* labels) {
size_t d = vd.d;
using C = typename VD::C;
size_t check_period = InterruptCallback::get_period_hint(ny * d);
check_period *= omp_get_max_threads();
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
size_t i1 = std::min(i0 + check_period, nx);
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
size_t i1 = std::min(i0 + check_period, nx);
#pragma omp parallel for
for (int64_t i = i0; i < i1; i++) {
const float* x_i = x + i * d;
const float* y_j = y;
size_t j;
float* simi = distances + k * i;
int64_t* idxi = labels + k * i;
for (int64_t i = i0; i < i1; i++) {
const float* x_i = x + i * d;
const float* y_j = y;
size_t j;
float* simi = distances + k * i;
int64_t* idxi = labels + k * i;
// maxheap_heapify(k, simi, idxi);
heap_heapify<C>(k, simi, idxi);
for (j = 0; j < ny; j++) {
float disij = vd(x_i, y_j);
// maxheap_heapify(k, simi, idxi);
heap_heapify<C>(k, simi, idxi);
for (j = 0; j < ny; j++) {
float disij = vd(x_i, y_j);
if (C::cmp(simi[0], disij)) {
heap_replace_top<C>(k, simi, idxi, disij, j);
if (C::cmp(simi[0], disij)) {
heap_replace_top<C>(k, simi, idxi, disij, j);
}
y_j += d;
}
y_j += d;
// maxheap_reorder(k, simi, idxi);
heap_reorder<C>(k, simi, idxi);
}
// maxheap_reorder(k, simi, idxi);
heap_reorder<C>(k, simi, idxi);
InterruptCallback::check();
}
InterruptCallback::check();
}
}
};
template <class VD>
struct ExtraDistanceComputer : FlatCodesDistanceComputer {
@ -124,6 +129,19 @@ struct ExtraDistanceComputer : FlatCodesDistanceComputer {
}
};
struct Run_get_distance_computer {
using T = FlatCodesDistanceComputer*;
template <class VD>
FlatCodesDistanceComputer* f(
VD vd,
const float* xb,
size_t nb,
const float* q = nullptr) {
return new ExtraDistanceComputer<VD>(vd, xb, nb, q);
}
};
} // anonymous namespace
void pairwise_extra_distances(
@ -147,28 +165,9 @@ void pairwise_extra_distances(
if (ldd == -1)
ldd = nb;
switch (mt) {
#define HANDLE_VAR(kw) \
case METRIC_##kw: { \
VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
pairwise_extra_distances_template( \
vd, nq, xq, nb, xb, dis, ldq, ldb, ldd); \
break; \
}
HANDLE_VAR(L2);
HANDLE_VAR(L1);
HANDLE_VAR(Linf);
HANDLE_VAR(Canberra);
HANDLE_VAR(BrayCurtis);
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
HANDLE_VAR(ABS_INNER_PRODUCT);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
}
Run_pairwise_extra_distances run;
dispatch_VectorDistance(
d, mt, metric_arg, run, nq, xq, nb, xb, dis, ldq, ldb, ldd);
}
void knn_extra_metrics(
@ -182,27 +181,9 @@ void knn_extra_metrics(
size_t k,
float* distances,
int64_t* indexes) {
switch (mt) {
#define HANDLE_VAR(kw) \
case METRIC_##kw: { \
VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
knn_extra_metrics_template(vd, x, y, nx, ny, k, distances, indexes); \
break; \
}
HANDLE_VAR(L2);
HANDLE_VAR(L1);
HANDLE_VAR(Linf);
HANDLE_VAR(Canberra);
HANDLE_VAR(BrayCurtis);
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
HANDLE_VAR(ABS_INNER_PRODUCT);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
}
Run_knn_extra_metrics run;
dispatch_VectorDistance(
d, mt, metric_arg, run, x, y, nx, ny, k, distances, indexes);
}
FlatCodesDistanceComputer* get_extra_distance_computer(
@ -211,27 +192,8 @@ FlatCodesDistanceComputer* get_extra_distance_computer(
float metric_arg,
size_t nb,
const float* xb) {
switch (mt) {
#define HANDLE_VAR(kw) \
case METRIC_##kw: { \
VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
return new ExtraDistanceComputer<VectorDistance<METRIC_##kw>>( \
vd, xb, nb); \
}
HANDLE_VAR(L2);
HANDLE_VAR(L1);
HANDLE_VAR(Linf);
HANDLE_VAR(Canberra);
HANDLE_VAR(BrayCurtis);
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
HANDLE_VAR(ABS_INNER_PRODUCT);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
}
Run_get_distance_computer run;
return dispatch_VectorDistance(d, mt, metric_arg, run, xb, nb);
}
} // namespace faiss

View File

@ -55,7 +55,7 @@ SPECIALIZED_HC(64);
/***************************************************************************
* Dispatching function that takes a code size and a consumer object
* the consumer object should contain a retun type t and a operation template
* function f() that to be called to perform the operation.
* function f() that must be called to perform the operation.
**************************************************************************/
template <class Consumer, class... Types>
@ -76,6 +76,7 @@ typename Consumer::T dispatch_HammingComputer(
default:
return consumer.template f<HammingComputerDefault>(args...);
}
#undef DISPATCH_HC
}
} // namespace faiss