509 lines
12 KiB
C++
509 lines
12 KiB
C++
/**
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
// -*- c++ -*-
|
|
|
|
#include <faiss/IndexFlat.h>
|
|
|
|
#include <cstring>
|
|
#include <faiss/utils/distances.h>
|
|
#include <faiss/utils/extra_distances.h>
|
|
#include <faiss/utils/utils.h>
|
|
#include <faiss/utils/Heap.h>
|
|
#include <faiss/impl/FaissAssert.h>
|
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
|
|
|
|
namespace faiss {
|
|
|
|
IndexFlat::IndexFlat (idx_t d, MetricType metric):
|
|
Index(d, metric)
|
|
{
|
|
}
|
|
|
|
|
|
|
|
void IndexFlat::add (idx_t n, const float *x) {
|
|
xb.insert(xb.end(), x, x + n * d);
|
|
ntotal += n;
|
|
}
|
|
|
|
|
|
void IndexFlat::reset() {
|
|
xb.clear();
|
|
ntotal = 0;
|
|
}
|
|
|
|
|
|
void IndexFlat::search (idx_t n, const float *x, idx_t k,
|
|
float *distances, idx_t *labels) const
|
|
{
|
|
// we see the distances and labels as heaps
|
|
|
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
float_minheap_array_t res = {
|
|
size_t(n), size_t(k), labels, distances};
|
|
knn_inner_product (x, xb.data(), d, n, ntotal, &res);
|
|
} else if (metric_type == METRIC_L2) {
|
|
float_maxheap_array_t res = {
|
|
size_t(n), size_t(k), labels, distances};
|
|
knn_L2sqr (x, xb.data(), d, n, ntotal, &res);
|
|
} else {
|
|
float_maxheap_array_t res = {
|
|
size_t(n), size_t(k), labels, distances};
|
|
knn_extra_metrics (x, xb.data(), d, n, ntotal,
|
|
metric_type, metric_arg,
|
|
&res);
|
|
}
|
|
}
|
|
|
|
void IndexFlat::range_search (idx_t n, const float *x, float radius,
|
|
RangeSearchResult *result) const
|
|
{
|
|
switch (metric_type) {
|
|
case METRIC_INNER_PRODUCT:
|
|
range_search_inner_product (x, xb.data(), d, n, ntotal,
|
|
radius, result);
|
|
break;
|
|
case METRIC_L2:
|
|
range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result);
|
|
break;
|
|
default:
|
|
FAISS_THROW_MSG("metric type not supported");
|
|
}
|
|
}
|
|
|
|
|
|
void IndexFlat::compute_distance_subset (
|
|
idx_t n,
|
|
const float *x,
|
|
idx_t k,
|
|
float *distances,
|
|
const idx_t *labels) const
|
|
{
|
|
switch (metric_type) {
|
|
case METRIC_INNER_PRODUCT:
|
|
fvec_inner_products_by_idx (
|
|
distances,
|
|
x, xb.data(), labels, d, n, k);
|
|
break;
|
|
case METRIC_L2:
|
|
fvec_L2sqr_by_idx (
|
|
distances,
|
|
x, xb.data(), labels, d, n, k);
|
|
break;
|
|
default:
|
|
FAISS_THROW_MSG("metric type not supported");
|
|
}
|
|
|
|
}
|
|
|
|
size_t IndexFlat::remove_ids (const IDSelector & sel)
|
|
{
|
|
idx_t j = 0;
|
|
for (idx_t i = 0; i < ntotal; i++) {
|
|
if (sel.is_member (i)) {
|
|
// should be removed
|
|
} else {
|
|
if (i > j) {
|
|
memmove (&xb[d * j], &xb[d * i], sizeof(xb[0]) * d);
|
|
}
|
|
j++;
|
|
}
|
|
}
|
|
size_t nremove = ntotal - j;
|
|
if (nremove > 0) {
|
|
ntotal = j;
|
|
xb.resize (ntotal * d);
|
|
}
|
|
return nremove;
|
|
}
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
struct FlatL2Dis : DistanceComputer {
|
|
size_t d;
|
|
Index::idx_t nb;
|
|
const float *q;
|
|
const float *b;
|
|
size_t ndis;
|
|
|
|
float operator () (idx_t i) override {
|
|
ndis++;
|
|
return fvec_L2sqr(q, b + i * d, d);
|
|
}
|
|
|
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
return fvec_L2sqr(b + j * d, b + i * d, d);
|
|
}
|
|
|
|
explicit FlatL2Dis(const IndexFlat& storage, const float *q = nullptr)
|
|
: d(storage.d),
|
|
nb(storage.ntotal),
|
|
q(q),
|
|
b(storage.xb.data()),
|
|
ndis(0) {}
|
|
|
|
void set_query(const float *x) override {
|
|
q = x;
|
|
}
|
|
};
|
|
|
|
struct FlatIPDis : DistanceComputer {
|
|
size_t d;
|
|
Index::idx_t nb;
|
|
const float *q;
|
|
const float *b;
|
|
size_t ndis;
|
|
|
|
float operator () (idx_t i) override {
|
|
ndis++;
|
|
return fvec_inner_product (q, b + i * d, d);
|
|
}
|
|
|
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
return fvec_inner_product (b + j * d, b + i * d, d);
|
|
}
|
|
|
|
explicit FlatIPDis(const IndexFlat& storage, const float *q = nullptr)
|
|
: d(storage.d),
|
|
nb(storage.ntotal),
|
|
q(q),
|
|
b(storage.xb.data()),
|
|
ndis(0) {}
|
|
|
|
void set_query(const float *x) override {
|
|
q = x;
|
|
}
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
DistanceComputer * IndexFlat::get_distance_computer() const {
|
|
if (metric_type == METRIC_L2) {
|
|
return new FlatL2Dis(*this);
|
|
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
|
return new FlatIPDis(*this);
|
|
} else {
|
|
return get_extra_distance_computer (d, metric_type, metric_arg,
|
|
ntotal, xb.data());
|
|
}
|
|
}
|
|
|
|
|
|
void IndexFlat::reconstruct (idx_t key, float * recons) const
|
|
{
|
|
memcpy (recons, &(xb[key * d]), sizeof(*recons) * d);
|
|
}
|
|
|
|
|
|
/* The standalone codec interface */
|
|
size_t IndexFlat::sa_code_size () const
|
|
{
|
|
return sizeof(float) * d;
|
|
}
|
|
|
|
void IndexFlat::sa_encode (idx_t n, const float *x, uint8_t *bytes) const
|
|
{
|
|
memcpy (bytes, x, sizeof(float) * d * n);
|
|
}
|
|
|
|
void IndexFlat::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
|
|
{
|
|
memcpy (x, bytes, sizeof(float) * d * n);
|
|
}
|
|
|
|
|
|
|
|
|
|
/***************************************************
|
|
* IndexFlatL2BaseShift
|
|
***************************************************/
|
|
|
|
IndexFlatL2BaseShift::IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift):
|
|
IndexFlatL2 (d), shift (nshift)
|
|
{
|
|
memcpy (this->shift.data(), shift, sizeof(float) * nshift);
|
|
}
|
|
|
|
void IndexFlatL2BaseShift::search (
|
|
idx_t n,
|
|
const float *x,
|
|
idx_t k,
|
|
float *distances,
|
|
idx_t *labels) const
|
|
{
|
|
FAISS_THROW_IF_NOT (shift.size() == ntotal);
|
|
|
|
float_maxheap_array_t res = {
|
|
size_t(n), size_t(k), labels, distances};
|
|
knn_L2sqr_base_shift (x, xb.data(), d, n, ntotal, &res, shift.data());
|
|
}
|
|
|
|
|
|
|
|
/***************************************************
|
|
* IndexRefineFlat
|
|
***************************************************/
|
|
|
|
IndexRefineFlat::IndexRefineFlat (Index *base_index):
|
|
Index (base_index->d, base_index->metric_type),
|
|
refine_index (base_index->d, base_index->metric_type),
|
|
base_index (base_index), own_fields (false),
|
|
k_factor (1)
|
|
{
|
|
is_trained = base_index->is_trained;
|
|
FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
|
|
"base_index should be empty in the beginning");
|
|
}
|
|
|
|
IndexRefineFlat::IndexRefineFlat () {
|
|
base_index = nullptr;
|
|
own_fields = false;
|
|
k_factor = 1;
|
|
}
|
|
|
|
|
|
void IndexRefineFlat::train (idx_t n, const float *x)
|
|
{
|
|
base_index->train (n, x);
|
|
is_trained = true;
|
|
}
|
|
|
|
void IndexRefineFlat::add (idx_t n, const float *x) {
|
|
FAISS_THROW_IF_NOT (is_trained);
|
|
base_index->add (n, x);
|
|
refine_index.add (n, x);
|
|
ntotal = refine_index.ntotal;
|
|
}
|
|
|
|
void IndexRefineFlat::reset ()
|
|
{
|
|
base_index->reset ();
|
|
refine_index.reset ();
|
|
ntotal = 0;
|
|
}
|
|
|
|
namespace {
|
|
typedef faiss::Index::idx_t idx_t;
|
|
|
|
template<class C>
|
|
static void reorder_2_heaps (
|
|
idx_t n,
|
|
idx_t k, idx_t *labels, float *distances,
|
|
idx_t k_base, const idx_t *base_labels, const float *base_distances)
|
|
{
|
|
#pragma omp parallel for
|
|
for (idx_t i = 0; i < n; i++) {
|
|
idx_t *idxo = labels + i * k;
|
|
float *diso = distances + i * k;
|
|
const idx_t *idxi = base_labels + i * k_base;
|
|
const float *disi = base_distances + i * k_base;
|
|
|
|
heap_heapify<C> (k, diso, idxo, disi, idxi, k);
|
|
if (k_base != k) { // add remaining elements
|
|
heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
|
|
}
|
|
heap_reorder<C> (k, diso, idxo);
|
|
}
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
void IndexRefineFlat::search (
|
|
idx_t n, const float *x, idx_t k,
|
|
float *distances, idx_t *labels) const
|
|
{
|
|
FAISS_THROW_IF_NOT (is_trained);
|
|
idx_t k_base = idx_t (k * k_factor);
|
|
idx_t * base_labels = labels;
|
|
float * base_distances = distances;
|
|
ScopeDeleter<idx_t> del1;
|
|
ScopeDeleter<float> del2;
|
|
|
|
|
|
if (k != k_base) {
|
|
base_labels = new idx_t [n * k_base];
|
|
del1.set (base_labels);
|
|
base_distances = new float [n * k_base];
|
|
del2.set (base_distances);
|
|
}
|
|
|
|
base_index->search (n, x, k_base, base_distances, base_labels);
|
|
|
|
for (int i = 0; i < n * k_base; i++)
|
|
assert (base_labels[i] >= -1 &&
|
|
base_labels[i] < ntotal);
|
|
|
|
// compute refined distances
|
|
refine_index.compute_distance_subset (
|
|
n, x, k_base, base_distances, base_labels);
|
|
|
|
// sort and store result
|
|
if (metric_type == METRIC_L2) {
|
|
typedef CMax <float, idx_t> C;
|
|
reorder_2_heaps<C> (
|
|
n, k, labels, distances,
|
|
k_base, base_labels, base_distances);
|
|
|
|
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
|
typedef CMin <float, idx_t> C;
|
|
reorder_2_heaps<C> (
|
|
n, k, labels, distances,
|
|
k_base, base_labels, base_distances);
|
|
} else {
|
|
FAISS_THROW_MSG("Metric type not supported");
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
IndexRefineFlat::~IndexRefineFlat ()
|
|
{
|
|
if (own_fields) delete base_index;
|
|
}
|
|
|
|
/***************************************************
|
|
* IndexFlat1D
|
|
***************************************************/
|
|
|
|
|
|
IndexFlat1D::IndexFlat1D (bool continuous_update):
|
|
IndexFlatL2 (1),
|
|
continuous_update (continuous_update)
|
|
{
|
|
}
|
|
|
|
/// if not continuous_update, call this between the last add and
|
|
/// the first search
|
|
void IndexFlat1D::update_permutation ()
|
|
{
|
|
perm.resize (ntotal);
|
|
if (ntotal < 1000000) {
|
|
fvec_argsort (ntotal, xb.data(), (size_t*)perm.data());
|
|
} else {
|
|
fvec_argsort_parallel (ntotal, xb.data(), (size_t*)perm.data());
|
|
}
|
|
}
|
|
|
|
void IndexFlat1D::add (idx_t n, const float *x)
|
|
{
|
|
IndexFlatL2::add (n, x);
|
|
if (continuous_update)
|
|
update_permutation();
|
|
}
|
|
|
|
void IndexFlat1D::reset()
|
|
{
|
|
IndexFlatL2::reset();
|
|
perm.clear();
|
|
}
|
|
|
|
void IndexFlat1D::search (
|
|
idx_t n,
|
|
const float *x,
|
|
idx_t k,
|
|
float *distances,
|
|
idx_t *labels) const
|
|
{
|
|
FAISS_THROW_IF_NOT_MSG (perm.size() == ntotal,
|
|
"Call update_permutation before search");
|
|
|
|
#pragma omp parallel for
|
|
for (idx_t i = 0; i < n; i++) {
|
|
|
|
float q = x[i]; // query
|
|
float *D = distances + i * k;
|
|
idx_t *I = labels + i * k;
|
|
|
|
// binary search
|
|
idx_t i0 = 0, i1 = ntotal;
|
|
idx_t wp = 0;
|
|
|
|
if (xb[perm[i0]] > q) {
|
|
i1 = 0;
|
|
goto finish_right;
|
|
}
|
|
|
|
if (xb[perm[i1 - 1]] <= q) {
|
|
i0 = i1 - 1;
|
|
goto finish_left;
|
|
}
|
|
|
|
while (i0 + 1 < i1) {
|
|
idx_t imed = (i0 + i1) / 2;
|
|
if (xb[perm[imed]] <= q) i0 = imed;
|
|
else i1 = imed;
|
|
}
|
|
|
|
// query is between xb[perm[i0]] and xb[perm[i1]]
|
|
// expand to nearest neighs
|
|
|
|
while (wp < k) {
|
|
float xleft = xb[perm[i0]];
|
|
float xright = xb[perm[i1]];
|
|
|
|
if (q - xleft < xright - q) {
|
|
D[wp] = q - xleft;
|
|
I[wp] = perm[i0];
|
|
i0--; wp++;
|
|
if (i0 < 0) { goto finish_right; }
|
|
} else {
|
|
D[wp] = xright - q;
|
|
I[wp] = perm[i1];
|
|
i1++; wp++;
|
|
if (i1 >= ntotal) { goto finish_left; }
|
|
}
|
|
}
|
|
goto done;
|
|
|
|
finish_right:
|
|
// grow to the right from i1
|
|
while (wp < k) {
|
|
if (i1 < ntotal) {
|
|
D[wp] = xb[perm[i1]] - q;
|
|
I[wp] = perm[i1];
|
|
i1++;
|
|
} else {
|
|
D[wp] = std::numeric_limits<float>::infinity();
|
|
I[wp] = -1;
|
|
}
|
|
wp++;
|
|
}
|
|
goto done;
|
|
|
|
finish_left:
|
|
// grow to the left from i0
|
|
while (wp < k) {
|
|
if (i0 >= 0) {
|
|
D[wp] = q - xb[perm[i0]];
|
|
I[wp] = perm[i0];
|
|
i0--;
|
|
} else {
|
|
D[wp] = std::numeric_limits<float>::infinity();
|
|
I[wp] = -1;
|
|
}
|
|
wp++;
|
|
}
|
|
done: ;
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace faiss
|