766 lines
21 KiB
C++
766 lines
21 KiB
C++
/**
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
// -*- c++ -*-
|
|
|
|
#include <faiss/utils/distances.h>
|
|
|
|
#include <cstdio>
|
|
#include <cassert>
|
|
#include <cstring>
|
|
#include <cmath>
|
|
|
|
#include <omp.h>
|
|
|
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
#include <faiss/impl/FaissAssert.h>
|
|
|
|
|
|
|
|
#ifndef FINTEGER
|
|
#define FINTEGER long
|
|
#endif
|
|
|
|
|
|
extern "C" {
|
|
|
|
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
|
|
|
int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
|
|
n, FINTEGER *k, const float *alpha, const float *a,
|
|
FINTEGER *lda, const float *b, FINTEGER *
|
|
ldb, float *beta, float *c, FINTEGER *ldc);
|
|
|
|
/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
|
|
|
|
int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda,
|
|
float *tau, float *work, FINTEGER *lwork, FINTEGER *info);
|
|
|
|
int sgemv_(const char *trans, FINTEGER *m, FINTEGER *n, float *alpha,
|
|
const float *a, FINTEGER *lda, const float *x, FINTEGER *incx,
|
|
float *beta, float *y, FINTEGER *incy);
|
|
|
|
}
|
|
|
|
|
|
namespace faiss {
|
|
|
|
|
|
|
|
/***************************************************************************
|
|
* Matrix/vector ops
|
|
***************************************************************************/
|
|
|
|
|
|
|
|
/* Compute the inner product between a vector x and
|
|
a set of ny vectors y.
|
|
These functions are not intended to replace BLAS matrix-matrix, as they
|
|
would be significantly less efficient in this case. */
|
|
void fvec_inner_products_ny (float * ip,
|
|
const float * x,
|
|
const float * y,
|
|
size_t d, size_t ny)
|
|
{
|
|
// Not sure which one is fastest
|
|
#if 0
|
|
{
|
|
FINTEGER di = d;
|
|
FINTEGER nyi = ny;
|
|
float one = 1.0, zero = 0.0;
|
|
FINTEGER onei = 1;
|
|
sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
|
|
}
|
|
#endif
|
|
for (size_t i = 0; i < ny; i++) {
|
|
ip[i] = fvec_inner_product (x, y, d);
|
|
y += d;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/* Compute the L2 norm of a set of nx vectors */
|
|
void fvec_norms_L2 (float * __restrict nr,
|
|
const float * __restrict x,
|
|
size_t d, size_t nx)
|
|
{
|
|
|
|
#pragma omp parallel for
|
|
for (size_t i = 0; i < nx; i++) {
|
|
nr[i] = sqrtf (fvec_norm_L2sqr (x + i * d, d));
|
|
}
|
|
}
|
|
|
|
void fvec_norms_L2sqr (float * __restrict nr,
|
|
const float * __restrict x,
|
|
size_t d, size_t nx)
|
|
{
|
|
#pragma omp parallel for
|
|
for (size_t i = 0; i < nx; i++)
|
|
nr[i] = fvec_norm_L2sqr (x + i * d, d);
|
|
}
|
|
|
|
|
|
|
|
void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
|
|
{
|
|
#pragma omp parallel for
|
|
for (size_t i = 0; i < nx; i++) {
|
|
float * __restrict xi = x + i * d;
|
|
|
|
float nr = fvec_norm_L2sqr (xi, d);
|
|
|
|
if (nr > 0) {
|
|
size_t j;
|
|
const float inv_nr = 1.0 / sqrtf (nr);
|
|
for (j = 0; j < d; j++)
|
|
xi[j] *= inv_nr;
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/***************************************************************************
|
|
* KNN functions
|
|
***************************************************************************/
|
|
|
|
|
|
|
|
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
static void knn_inner_product_sse (const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
float_minheap_array_t * res)
|
|
{
|
|
size_t k = res->k;
|
|
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
|
|
|
check_period *= omp_get_max_threads();
|
|
|
|
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
|
size_t i1 = std::min(i0 + check_period, nx);
|
|
|
|
#pragma omp parallel for
|
|
for (size_t i = i0; i < i1; i++) {
|
|
const float * x_i = x + i * d;
|
|
const float * y_j = y;
|
|
|
|
float * __restrict simi = res->get_val(i);
|
|
int64_t * __restrict idxi = res->get_ids (i);
|
|
|
|
minheap_heapify (k, simi, idxi);
|
|
|
|
for (size_t j = 0; j < ny; j++) {
|
|
float ip = fvec_inner_product (x_i, y_j, d);
|
|
|
|
if (ip > simi[0]) {
|
|
minheap_pop (k, simi, idxi);
|
|
minheap_push (k, simi, idxi, ip, j);
|
|
}
|
|
y_j += d;
|
|
}
|
|
minheap_reorder (k, simi, idxi);
|
|
}
|
|
InterruptCallback::check ();
|
|
}
|
|
|
|
}
|
|
|
|
static void knn_L2sqr_sse (
|
|
const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
float_maxheap_array_t * res)
|
|
{
|
|
size_t k = res->k;
|
|
|
|
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
|
check_period *= omp_get_max_threads();
|
|
|
|
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
|
size_t i1 = std::min(i0 + check_period, nx);
|
|
|
|
#pragma omp parallel for
|
|
for (size_t i = i0; i < i1; i++) {
|
|
const float * x_i = x + i * d;
|
|
const float * y_j = y;
|
|
size_t j;
|
|
float * simi = res->get_val(i);
|
|
int64_t * idxi = res->get_ids (i);
|
|
|
|
maxheap_heapify (k, simi, idxi);
|
|
for (j = 0; j < ny; j++) {
|
|
float disij = fvec_L2sqr (x_i, y_j, d);
|
|
|
|
if (disij < simi[0]) {
|
|
maxheap_pop (k, simi, idxi);
|
|
maxheap_push (k, simi, idxi, disij, j);
|
|
}
|
|
y_j += d;
|
|
}
|
|
maxheap_reorder (k, simi, idxi);
|
|
}
|
|
InterruptCallback::check ();
|
|
}
|
|
|
|
}
|
|
|
|
|
|
/** Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
static void knn_inner_product_blas (
|
|
const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
float_minheap_array_t * res)
|
|
{
|
|
res->heapify ();
|
|
|
|
// BLAS does not like empty matrices
|
|
if (nx == 0 || ny == 0) return;
|
|
|
|
/* block sizes */
|
|
const size_t bs_x = 4096, bs_y = 1024;
|
|
// const size_t bs_x = 16, bs_y = 16;
|
|
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
|
|
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
size_t i1 = i0 + bs_x;
|
|
if(i1 > nx) i1 = nx;
|
|
|
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
size_t j1 = j0 + bs_y;
|
|
if (j1 > ny) j1 = ny;
|
|
/* compute the actual dot products */
|
|
{
|
|
float one = 1, zero = 0;
|
|
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
|
y + j0 * d, &di,
|
|
x + i0 * d, &di, &zero,
|
|
ip_block.get(), &nyi);
|
|
}
|
|
|
|
/* collect maxima */
|
|
res->addn (j1 - j0, ip_block.get(), j0, i0, i1 - i0);
|
|
}
|
|
InterruptCallback::check ();
|
|
}
|
|
res->reorder ();
|
|
}
|
|
|
|
// distance correction is an operator that can be applied to transform
|
|
// the distances
|
|
template<class DistanceCorrection>
|
|
static void knn_L2sqr_blas (const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
float_maxheap_array_t * res,
|
|
const DistanceCorrection &corr)
|
|
{
|
|
res->heapify ();
|
|
|
|
// BLAS does not like empty matrices
|
|
if (nx == 0 || ny == 0) return;
|
|
|
|
size_t k = res->k;
|
|
|
|
/* block sizes */
|
|
const size_t bs_x = 4096, bs_y = 1024;
|
|
// const size_t bs_x = 16, bs_y = 16;
|
|
float *ip_block = new float[bs_x * bs_y];
|
|
float *x_norms = new float[nx];
|
|
float *y_norms = new float[ny];
|
|
ScopeDeleter<float> del1(ip_block), del3(x_norms), del2(y_norms);
|
|
|
|
fvec_norms_L2sqr (x_norms, x, d, nx);
|
|
fvec_norms_L2sqr (y_norms, y, d, ny);
|
|
|
|
|
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
size_t i1 = i0 + bs_x;
|
|
if(i1 > nx) i1 = nx;
|
|
|
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
size_t j1 = j0 + bs_y;
|
|
if (j1 > ny) j1 = ny;
|
|
/* compute the actual dot products */
|
|
{
|
|
float one = 1, zero = 0;
|
|
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
|
y + j0 * d, &di,
|
|
x + i0 * d, &di, &zero,
|
|
ip_block, &nyi);
|
|
}
|
|
|
|
/* collect minima */
|
|
#pragma omp parallel for
|
|
for (size_t i = i0; i < i1; i++) {
|
|
float * __restrict simi = res->get_val(i);
|
|
int64_t * __restrict idxi = res->get_ids (i);
|
|
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
|
|
|
|
for (size_t j = j0; j < j1; j++) {
|
|
float ip = *ip_line++;
|
|
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
|
|
|
// negative values can occur for identical vectors
|
|
// due to roundoff errors
|
|
if (dis < 0) dis = 0;
|
|
|
|
dis = corr (dis, i, j);
|
|
|
|
if (dis < simi[0]) {
|
|
maxheap_pop (k, simi, idxi);
|
|
maxheap_push (k, simi, idxi, dis, j);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
InterruptCallback::check ();
|
|
}
|
|
res->reorder ();
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/*******************************************************
|
|
* KNN driver functions
|
|
*******************************************************/
|
|
|
|
int distance_compute_blas_threshold = 20;
|
|
|
|
void knn_inner_product (const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
float_minheap_array_t * res)
|
|
{
|
|
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
|
knn_inner_product_sse (x, y, d, nx, ny, res);
|
|
} else {
|
|
knn_inner_product_blas (x, y, d, nx, ny, res);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
struct NopDistanceCorrection {
|
|
float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const {
|
|
return dis;
|
|
}
|
|
};
|
|
|
|
void knn_L2sqr (const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
float_maxheap_array_t * res)
|
|
{
|
|
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
|
knn_L2sqr_sse (x, y, d, nx, ny, res);
|
|
} else {
|
|
NopDistanceCorrection nop;
|
|
knn_L2sqr_blas (x, y, d, nx, ny, res, nop);
|
|
}
|
|
}
|
|
|
|
struct BaseShiftDistanceCorrection {
|
|
const float *base_shift;
|
|
float operator()(float dis, size_t /*qno*/, size_t bno) const {
|
|
return dis - base_shift[bno];
|
|
}
|
|
};
|
|
|
|
void knn_L2sqr_base_shift (
|
|
const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
float_maxheap_array_t * res,
|
|
const float *base_shift)
|
|
{
|
|
BaseShiftDistanceCorrection corr = {base_shift};
|
|
knn_L2sqr_blas (x, y, d, nx, ny, res, corr);
|
|
}
|
|
|
|
|
|
|
|
/***************************************************************************
|
|
* compute a subset of distances
|
|
***************************************************************************/
|
|
|
|
/* compute the inner product between x and a subset y of ny vectors,
|
|
whose indices are given by idy. */
|
|
void fvec_inner_products_by_idx (float * __restrict ip,
|
|
const float * x,
|
|
const float * y,
|
|
const int64_t * __restrict ids, /* for y vecs */
|
|
size_t d, size_t nx, size_t ny)
|
|
{
|
|
#pragma omp parallel for
|
|
for (size_t j = 0; j < nx; j++) {
|
|
const int64_t * __restrict idsj = ids + j * ny;
|
|
const float * xj = x + j * d;
|
|
float * __restrict ipj = ip + j * ny;
|
|
for (size_t i = 0; i < ny; i++) {
|
|
if (idsj[i] < 0)
|
|
continue;
|
|
ipj[i] = fvec_inner_product (xj, y + d * idsj[i], d);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
/* compute the inner product between x and a subset y of ny vectors,
|
|
whose indices are given by idy. */
|
|
void fvec_L2sqr_by_idx (float * __restrict dis,
|
|
const float * x,
|
|
const float * y,
|
|
const int64_t * __restrict ids, /* ids of y vecs */
|
|
size_t d, size_t nx, size_t ny)
|
|
{
|
|
#pragma omp parallel for
|
|
for (size_t j = 0; j < nx; j++) {
|
|
const int64_t * __restrict idsj = ids + j * ny;
|
|
const float * xj = x + j * d;
|
|
float * __restrict disj = dis + j * ny;
|
|
for (size_t i = 0; i < ny; i++) {
|
|
if (idsj[i] < 0)
|
|
continue;
|
|
disj[i] = fvec_L2sqr (xj, y + d * idsj[i], d);
|
|
}
|
|
}
|
|
}
|
|
|
|
void pairwise_indexed_L2sqr (
|
|
size_t d, size_t n,
|
|
const float * x, const int64_t *ix,
|
|
const float * y, const int64_t *iy,
|
|
float *dis)
|
|
{
|
|
#pragma omp parallel for
|
|
for (size_t j = 0; j < n; j++) {
|
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
dis[j] = fvec_L2sqr (x + d * ix[j], y + d * iy[j], d);
|
|
}
|
|
}
|
|
}
|
|
|
|
void pairwise_indexed_inner_product (
|
|
size_t d, size_t n,
|
|
const float * x, const int64_t *ix,
|
|
const float * y, const int64_t *iy,
|
|
float *dis)
|
|
{
|
|
#pragma omp parallel for
|
|
for (size_t j = 0; j < n; j++) {
|
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
dis[j] = fvec_inner_product (x + d * ix[j], y + d * iy[j], d);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
|
indexed by ids. May be useful for re-ranking a pre-selected vector list */
|
|
void knn_inner_products_by_idx (const float * x,
|
|
const float * y,
|
|
const int64_t * ids,
|
|
size_t d, size_t nx, size_t ny,
|
|
float_minheap_array_t * res)
|
|
{
|
|
size_t k = res->k;
|
|
|
|
#pragma omp parallel for
|
|
for (size_t i = 0; i < nx; i++) {
|
|
const float * x_ = x + i * d;
|
|
const int64_t * idsi = ids + i * ny;
|
|
size_t j;
|
|
float * __restrict simi = res->get_val(i);
|
|
int64_t * __restrict idxi = res->get_ids (i);
|
|
minheap_heapify (k, simi, idxi);
|
|
|
|
for (j = 0; j < ny; j++) {
|
|
if (idsi[j] < 0) break;
|
|
float ip = fvec_inner_product (x_, y + d * idsi[j], d);
|
|
|
|
if (ip > simi[0]) {
|
|
minheap_pop (k, simi, idxi);
|
|
minheap_push (k, simi, idxi, ip, idsi[j]);
|
|
}
|
|
}
|
|
minheap_reorder (k, simi, idxi);
|
|
}
|
|
|
|
}
|
|
|
|
void knn_L2sqr_by_idx (const float * x,
|
|
const float * y,
|
|
const int64_t * __restrict ids,
|
|
size_t d, size_t nx, size_t ny,
|
|
float_maxheap_array_t * res)
|
|
{
|
|
size_t k = res->k;
|
|
|
|
#pragma omp parallel for
|
|
for (size_t i = 0; i < nx; i++) {
|
|
const float * x_ = x + i * d;
|
|
const int64_t * __restrict idsi = ids + i * ny;
|
|
float * __restrict simi = res->get_val(i);
|
|
int64_t * __restrict idxi = res->get_ids (i);
|
|
maxheap_heapify (res->k, simi, idxi);
|
|
for (size_t j = 0; j < ny; j++) {
|
|
float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
|
|
|
|
if (disij < simi[0]) {
|
|
maxheap_pop (k, simi, idxi);
|
|
maxheap_push (k, simi, idxi, disij, idsi[j]);
|
|
}
|
|
}
|
|
maxheap_reorder (res->k, simi, idxi);
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/***************************************************************************
|
|
* Range search
|
|
***************************************************************************/
|
|
|
|
/** Find the nearest neighbors for nx queries in a set of ny vectors
|
|
* compute_l2 = compute pairwise squared L2 distance rather than inner prod
|
|
*/
|
|
template <bool compute_l2>
|
|
static void range_search_blas (
|
|
const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
float radius,
|
|
RangeSearchResult *result)
|
|
{
|
|
|
|
// BLAS does not like empty matrices
|
|
if (nx == 0 || ny == 0) return;
|
|
|
|
/* block sizes */
|
|
const size_t bs_x = 4096, bs_y = 1024;
|
|
// const size_t bs_x = 16, bs_y = 16;
|
|
float *ip_block = new float[bs_x * bs_y];
|
|
ScopeDeleter<float> del0(ip_block);
|
|
|
|
float *x_norms = nullptr, *y_norms = nullptr;
|
|
ScopeDeleter<float> del1, del2;
|
|
if (compute_l2) {
|
|
x_norms = new float[nx];
|
|
del1.set (x_norms);
|
|
fvec_norms_L2sqr (x_norms, x, d, nx);
|
|
|
|
y_norms = new float[ny];
|
|
del2.set (y_norms);
|
|
fvec_norms_L2sqr (y_norms, y, d, ny);
|
|
}
|
|
|
|
std::vector <RangeSearchPartialResult *> partial_results;
|
|
|
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
size_t j1 = j0 + bs_y;
|
|
if (j1 > ny) j1 = ny;
|
|
RangeSearchPartialResult * pres = new RangeSearchPartialResult (result);
|
|
partial_results.push_back (pres);
|
|
|
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
size_t i1 = i0 + bs_x;
|
|
if(i1 > nx) i1 = nx;
|
|
|
|
/* compute the actual dot products */
|
|
{
|
|
float one = 1, zero = 0;
|
|
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
|
y + j0 * d, &di,
|
|
x + i0 * d, &di, &zero,
|
|
ip_block, &nyi);
|
|
}
|
|
|
|
|
|
for (size_t i = i0; i < i1; i++) {
|
|
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
|
|
|
|
RangeQueryResult & qres = pres->new_result (i);
|
|
|
|
for (size_t j = j0; j < j1; j++) {
|
|
float ip = *ip_line++;
|
|
if (compute_l2) {
|
|
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
|
if (dis < radius) {
|
|
qres.add (dis, j);
|
|
}
|
|
} else {
|
|
if (ip > radius) {
|
|
qres.add (ip, j);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
InterruptCallback::check ();
|
|
}
|
|
|
|
RangeSearchPartialResult::merge (partial_results);
|
|
}
|
|
|
|
|
|
template <bool compute_l2>
|
|
static void range_search_sse (const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
float radius,
|
|
RangeSearchResult *res)
|
|
{
|
|
FAISS_THROW_IF_NOT (d % 4 == 0);
|
|
|
|
#pragma omp parallel
|
|
{
|
|
RangeSearchPartialResult pres (res);
|
|
|
|
#pragma omp for
|
|
for (size_t i = 0; i < nx; i++) {
|
|
const float * x_ = x + i * d;
|
|
const float * y_ = y;
|
|
size_t j;
|
|
|
|
RangeQueryResult & qres = pres.new_result (i);
|
|
|
|
for (j = 0; j < ny; j++) {
|
|
if (compute_l2) {
|
|
float disij = fvec_L2sqr (x_, y_, d);
|
|
if (disij < radius) {
|
|
qres.add (disij, j);
|
|
}
|
|
} else {
|
|
float ip = fvec_inner_product (x_, y_, d);
|
|
if (ip > radius) {
|
|
qres.add (ip, j);
|
|
}
|
|
}
|
|
y_ += d;
|
|
}
|
|
|
|
}
|
|
pres.finalize ();
|
|
}
|
|
|
|
// check just at the end because the use case is typically just
|
|
// when the nb of queries is low.
|
|
InterruptCallback::check();
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void range_search_L2sqr (
|
|
const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
float radius,
|
|
RangeSearchResult *res)
|
|
{
|
|
|
|
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
|
range_search_sse<true> (x, y, d, nx, ny, radius, res);
|
|
} else {
|
|
range_search_blas<true> (x, y, d, nx, ny, radius, res);
|
|
}
|
|
}
|
|
|
|
void range_search_inner_product (
|
|
const float * x,
|
|
const float * y,
|
|
size_t d, size_t nx, size_t ny,
|
|
float radius,
|
|
RangeSearchResult *res)
|
|
{
|
|
|
|
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
|
range_search_sse<false> (x, y, d, nx, ny, radius, res);
|
|
} else {
|
|
range_search_blas<false> (x, y, d, nx, ny, radius, res);
|
|
}
|
|
}
|
|
|
|
|
|
void pairwise_L2sqr (int64_t d,
|
|
int64_t nq, const float *xq,
|
|
int64_t nb, const float *xb,
|
|
float *dis,
|
|
int64_t ldq, int64_t ldb, int64_t ldd)
|
|
{
|
|
if (nq == 0 || nb == 0) return;
|
|
if (ldq == -1) ldq = d;
|
|
if (ldb == -1) ldb = d;
|
|
if (ldd == -1) ldd = nb;
|
|
|
|
// store in beginning of distance matrix to avoid malloc
|
|
float *b_norms = dis;
|
|
|
|
#pragma omp parallel for
|
|
for (int64_t i = 0; i < nb; i++)
|
|
b_norms [i] = fvec_norm_L2sqr (xb + i * ldb, d);
|
|
|
|
#pragma omp parallel for
|
|
for (int64_t i = 1; i < nq; i++) {
|
|
float q_norm = fvec_norm_L2sqr (xq + i * ldq, d);
|
|
for (int64_t j = 0; j < nb; j++)
|
|
dis[i * ldd + j] = q_norm + b_norms [j];
|
|
}
|
|
|
|
{
|
|
float q_norm = fvec_norm_L2sqr (xq, d);
|
|
for (int64_t j = 0; j < nb; j++)
|
|
dis[j] += q_norm;
|
|
}
|
|
|
|
{
|
|
FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
|
|
float one = 1.0, minus_2 = -2.0;
|
|
|
|
sgemm_ ("Transposed", "Not transposed",
|
|
&nbi, &nqi, &di,
|
|
&minus_2,
|
|
xb, &ldbi,
|
|
xq, &ldqi,
|
|
&one, dis, &lddi);
|
|
}
|
|
|
|
}
|
|
|
|
|
|
} // namespace faiss
|