337 lines
8.1 KiB
C++
337 lines
8.1 KiB
C++
/**
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
// -*- c++ -*-
|
|
|
|
#include <faiss/utils/distances.h>
|
|
|
|
#include <cmath>
|
|
#include <omp.h>
|
|
|
|
|
|
#include <faiss/utils/utils.h>
|
|
#include <faiss/impl/FaissAssert.h>
|
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
|
|
namespace faiss {
|
|
|
|
/***************************************************************************
|
|
* Distance functions (other than L2 and IP)
|
|
***************************************************************************/
|
|
|
|
struct VectorDistanceL2 {
|
|
size_t d;
|
|
|
|
float operator () (const float *x, const float *y) const {
|
|
return fvec_L2sqr (x, y, d);
|
|
}
|
|
};
|
|
|
|
struct VectorDistanceL1 {
|
|
size_t d;
|
|
|
|
float operator () (const float *x, const float *y) const {
|
|
return fvec_L1 (x, y, d);
|
|
}
|
|
};
|
|
|
|
struct VectorDistanceLinf {
|
|
size_t d;
|
|
|
|
float operator () (const float *x, const float *y) const {
|
|
return fvec_Linf (x, y, d);
|
|
/*
|
|
float vmax = 0;
|
|
for (size_t i = 0; i < d; i++) {
|
|
float diff = fabs (x[i] - y[i]);
|
|
if (diff > vmax) vmax = diff;
|
|
}
|
|
return vmax;*/
|
|
}
|
|
};
|
|
|
|
struct VectorDistanceLp {
|
|
size_t d;
|
|
const float p;
|
|
|
|
float operator () (const float *x, const float *y) const {
|
|
float accu = 0;
|
|
for (size_t i = 0; i < d; i++) {
|
|
float diff = fabs (x[i] - y[i]);
|
|
accu += powf (diff, p);
|
|
}
|
|
return accu;
|
|
}
|
|
};
|
|
|
|
struct VectorDistanceCanberra {
|
|
size_t d;
|
|
|
|
float operator () (const float *x, const float *y) const {
|
|
float accu = 0;
|
|
for (size_t i = 0; i < d; i++) {
|
|
float xi = x[i], yi = y[i];
|
|
accu += fabs (xi - yi) / (fabs(xi) + fabs(yi));
|
|
}
|
|
return accu;
|
|
}
|
|
};
|
|
|
|
struct VectorDistanceBrayCurtis {
|
|
size_t d;
|
|
|
|
float operator () (const float *x, const float *y) const {
|
|
float accu_num = 0, accu_den = 0;
|
|
for (size_t i = 0; i < d; i++) {
|
|
float xi = x[i], yi = y[i];
|
|
accu_num += fabs (xi - yi);
|
|
accu_den += fabs (xi + yi);
|
|
}
|
|
return accu_num / accu_den;
|
|
}
|
|
};
|
|
|
|
struct VectorDistanceJensenShannon {
|
|
size_t d;
|
|
|
|
float operator () (const float *x, const float *y) const {
|
|
float accu = 0;
|
|
|
|
for (size_t i = 0; i < d; i++) {
|
|
float xi = x[i], yi = y[i];
|
|
float mi = 0.5 * (xi + yi);
|
|
float kl1 = - xi * log(mi / xi);
|
|
float kl2 = - yi * log(mi / yi);
|
|
accu += kl1 + kl2;
|
|
}
|
|
return 0.5 * accu;
|
|
}
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
template<class VD>
|
|
void pairwise_extra_distances_template (
|
|
VD vd,
|
|
int64_t nq, const float *xq,
|
|
int64_t nb, const float *xb,
|
|
float *dis,
|
|
int64_t ldq, int64_t ldb, int64_t ldd)
|
|
{
|
|
|
|
#pragma omp parallel for if(nq > 10)
|
|
for (int64_t i = 0; i < nq; i++) {
|
|
const float *xqi = xq + i * ldq;
|
|
const float *xbj = xb;
|
|
float *disi = dis + ldd * i;
|
|
|
|
for (int64_t j = 0; j < nb; j++) {
|
|
disi[j] = vd (xqi, xbj);
|
|
xbj += ldb;
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
template<class VD>
|
|
void knn_extra_metrics_template (
|
|
VD vd,
|
|
const float * x,
|
|
const float * y,
|
|
size_t nx, size_t ny,
|
|
float_maxheap_array_t * res)
|
|
{
|
|
size_t k = res->k;
|
|
size_t d = vd.d;
|
|
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
|
check_period *= omp_get_max_threads();
|
|
|
|
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
|
size_t i1 = std::min(i0 + check_period, nx);
|
|
|
|
#pragma omp parallel for
|
|
for (size_t i = i0; i < i1; i++) {
|
|
const float * x_i = x + i * d;
|
|
const float * y_j = y;
|
|
size_t j;
|
|
float * simi = res->get_val(i);
|
|
int64_t * idxi = res->get_ids (i);
|
|
|
|
maxheap_heapify (k, simi, idxi);
|
|
for (j = 0; j < ny; j++) {
|
|
float disij = vd (x_i, y_j);
|
|
|
|
if (disij < simi[0]) {
|
|
maxheap_pop (k, simi, idxi);
|
|
maxheap_push (k, simi, idxi, disij, j);
|
|
}
|
|
y_j += d;
|
|
}
|
|
maxheap_reorder (k, simi, idxi);
|
|
}
|
|
InterruptCallback::check ();
|
|
}
|
|
|
|
}
|
|
|
|
|
|
template<class VD>
|
|
struct ExtraDistanceComputer : DistanceComputer {
|
|
VD vd;
|
|
Index::idx_t nb;
|
|
const float *q;
|
|
const float *b;
|
|
|
|
float operator () (idx_t i) override {
|
|
return vd (q, b + i * vd.d);
|
|
}
|
|
|
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
return vd (b + j * vd.d, b + i * vd.d);
|
|
}
|
|
|
|
ExtraDistanceComputer(const VD & vd, const float *xb,
|
|
size_t nb, const float *q = nullptr)
|
|
: vd(vd), nb(nb), q(q), b(xb) {}
|
|
|
|
void set_query(const float *x) override {
|
|
q = x;
|
|
}
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} // anonymous namespace
|
|
|
|
void pairwise_extra_distances (
|
|
int64_t d,
|
|
int64_t nq, const float *xq,
|
|
int64_t nb, const float *xb,
|
|
MetricType mt, float metric_arg,
|
|
float *dis,
|
|
int64_t ldq, int64_t ldb, int64_t ldd)
|
|
{
|
|
if (nq == 0 || nb == 0) return;
|
|
if (ldq == -1) ldq = d;
|
|
if (ldb == -1) ldb = d;
|
|
if (ldd == -1) ldd = nb;
|
|
|
|
switch(mt) {
|
|
#define HANDLE_VAR(kw) \
|
|
case METRIC_ ## kw: { \
|
|
VectorDistance ## kw vd({(size_t)d}); \
|
|
pairwise_extra_distances_template (vd, nq, xq, nb, xb, \
|
|
dis, ldq, ldb, ldd); \
|
|
break; \
|
|
}
|
|
HANDLE_VAR(L2);
|
|
HANDLE_VAR(L1);
|
|
HANDLE_VAR(Linf);
|
|
HANDLE_VAR(Canberra);
|
|
HANDLE_VAR(BrayCurtis);
|
|
HANDLE_VAR(JensenShannon);
|
|
#undef HANDLE_VAR
|
|
case METRIC_Lp: {
|
|
VectorDistanceLp vd({(size_t)d, metric_arg});
|
|
pairwise_extra_distances_template (vd, nq, xq, nb, xb,
|
|
dis, ldq, ldb, ldd);
|
|
break;
|
|
}
|
|
default:
|
|
FAISS_THROW_MSG ("metric type not implemented");
|
|
}
|
|
|
|
}
|
|
|
|
void knn_extra_metrics (
|
|
const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
MetricType mt, float metric_arg,
|
|
float_maxheap_array_t * res)
|
|
{
|
|
|
|
switch(mt) {
|
|
#define HANDLE_VAR(kw) \
|
|
case METRIC_ ## kw: { \
|
|
VectorDistance ## kw vd({(size_t)d}); \
|
|
knn_extra_metrics_template (vd, x, y, nx, ny, res); \
|
|
break; \
|
|
}
|
|
HANDLE_VAR(L2);
|
|
HANDLE_VAR(L1);
|
|
HANDLE_VAR(Linf);
|
|
HANDLE_VAR(Canberra);
|
|
HANDLE_VAR(BrayCurtis);
|
|
HANDLE_VAR(JensenShannon);
|
|
#undef HANDLE_VAR
|
|
case METRIC_Lp: {
|
|
VectorDistanceLp vd({(size_t)d, metric_arg});
|
|
knn_extra_metrics_template (vd, x, y, nx, ny, res);
|
|
break;
|
|
}
|
|
default:
|
|
FAISS_THROW_MSG ("metric type not implemented");
|
|
}
|
|
|
|
}
|
|
|
|
DistanceComputer *get_extra_distance_computer (
|
|
size_t d,
|
|
MetricType mt, float metric_arg,
|
|
size_t nb, const float *xb)
|
|
{
|
|
|
|
switch(mt) {
|
|
#define HANDLE_VAR(kw) \
|
|
case METRIC_ ## kw: { \
|
|
VectorDistance ## kw vd({(size_t)d}); \
|
|
return new ExtraDistanceComputer<VectorDistance ## kw>(vd, xb, nb); \
|
|
}
|
|
HANDLE_VAR(L2);
|
|
HANDLE_VAR(L1);
|
|
HANDLE_VAR(Linf);
|
|
HANDLE_VAR(Canberra);
|
|
HANDLE_VAR(BrayCurtis);
|
|
HANDLE_VAR(JensenShannon);
|
|
#undef HANDLE_VAR
|
|
case METRIC_Lp: {
|
|
VectorDistanceLp vd({(size_t)d, metric_arg});
|
|
return new ExtraDistanceComputer<VectorDistanceLp> (vd, xb, nb);
|
|
break;
|
|
}
|
|
default:
|
|
FAISS_THROW_MSG ("metric type not implemented");
|
|
}
|
|
|
|
}
|
|
|
|
|
|
} // namespace faiss
|