faiss/IVFlib.cpp

360 lines
11 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IVFlib.h>
#include <memory>
#include <faiss/IndexPreTransform.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/MetaIndexes.h>
namespace faiss { namespace ivflib {
void check_compatible_for_merge (const Index * index0,
const Index * index1)
{
const faiss::IndexPreTransform *pt0 =
dynamic_cast<const faiss::IndexPreTransform *>(index0);
if (pt0) {
const faiss::IndexPreTransform *pt1 =
dynamic_cast<const faiss::IndexPreTransform *>(index1);
FAISS_THROW_IF_NOT_MSG (pt1, "both indexes should be pretransforms");
FAISS_THROW_IF_NOT (pt0->chain.size() == pt1->chain.size());
for (int i = 0; i < pt0->chain.size(); i++) {
FAISS_THROW_IF_NOT (typeid(pt0->chain[i]) == typeid(pt1->chain[i]));
}
index0 = pt0->index;
index1 = pt1->index;
}
FAISS_THROW_IF_NOT (typeid(index0) == typeid(index1));
FAISS_THROW_IF_NOT (index0->d == index1->d &&
index0->metric_type == index1->metric_type);
const faiss::IndexIVF *ivf0 = dynamic_cast<const faiss::IndexIVF *>(index0);
if (ivf0) {
const faiss::IndexIVF *ivf1 =
dynamic_cast<const faiss::IndexIVF *>(index1);
FAISS_THROW_IF_NOT (ivf1);
ivf0->check_compatible_for_merge (*ivf1);
}
// TODO: check as thoroughfully for other index types
}
const IndexIVF * try_extract_index_ivf (const Index * index)
{
if (auto *pt =
dynamic_cast<const IndexPreTransform *>(index)) {
index = pt->index;
}
if (auto *idmap =
dynamic_cast<const IndexIDMap *>(index)) {
index = idmap->index;
}
if (auto *idmap =
dynamic_cast<const IndexIDMap2 *>(index)) {
index = idmap->index;
}
auto *ivf = dynamic_cast<const IndexIVF *>(index);
return ivf;
}
IndexIVF * try_extract_index_ivf (Index * index) {
return const_cast<IndexIVF*> (try_extract_index_ivf ((const Index*)(index)));
}
const IndexIVF * extract_index_ivf (const Index * index)
{
const IndexIVF *ivf = try_extract_index_ivf (index);
FAISS_THROW_IF_NOT (ivf);
return ivf;
}
IndexIVF * extract_index_ivf (Index * index) {
return const_cast<IndexIVF*> (extract_index_ivf ((const Index*)(index)));
}
void merge_into(faiss::Index *index0, faiss::Index *index1, bool shift_ids) {
check_compatible_for_merge (index0, index1);
IndexIVF * ivf0 = extract_index_ivf (index0);
IndexIVF * ivf1 = extract_index_ivf (index1);
ivf0->merge_from (*ivf1, shift_ids ? ivf0->ntotal : 0);
// useful for IndexPreTransform
index0->ntotal = ivf0->ntotal;
index1->ntotal = ivf1->ntotal;
}
void search_centroid(faiss::Index *index,
const float* x, int n,
idx_t* centroid_ids)
{
std::unique_ptr<float[]> del;
if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
x = index_pre->apply_chain(n, x);
del.reset((float*)x);
index = index_pre->index;
}
faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
assert(index_ivf);
index_ivf->quantizer->assign(n, x, centroid_ids);
}
void search_and_return_centroids(faiss::Index *index,
size_t n,
const float* xin,
long k,
float *distances,
idx_t* labels,
idx_t* query_centroid_ids,
idx_t* result_centroid_ids)
{
const float *x = xin;
std::unique_ptr<float []> del;
if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
x = index_pre->apply_chain(n, x);
del.reset((float*)x);
index = index_pre->index;
}
faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
assert(index_ivf);
size_t nprobe = index_ivf->nprobe;
std::vector<idx_t> cent_nos (n * nprobe);
std::vector<float> cent_dis (n * nprobe);
index_ivf->quantizer->search(
n, x, nprobe, cent_dis.data(), cent_nos.data());
if (query_centroid_ids) {
for (size_t i = 0; i < n; i++)
query_centroid_ids[i] = cent_nos[i * nprobe];
}
index_ivf->search_preassigned (n, x, k,
cent_nos.data(), cent_dis.data(),
distances, labels, true);
for (size_t i = 0; i < n * k; i++) {
idx_t label = labels[i];
if (label < 0) {
if (result_centroid_ids)
result_centroid_ids[i] = -1;
} else {
long list_no = lo_listno (label);
long list_index = lo_offset (label);
if (result_centroid_ids)
result_centroid_ids[i] = list_no;
labels[i] = index_ivf->invlists->get_single_id(list_no, list_index);
}
}
}
SlidingIndexWindow::SlidingIndexWindow (Index *index): index (index) {
n_slice = 0;
IndexIVF* index_ivf = const_cast<IndexIVF*>(extract_index_ivf (index));
ils = dynamic_cast<ArrayInvertedLists *> (index_ivf->invlists);
nlist = ils->nlist;
FAISS_THROW_IF_NOT_MSG (ils,
"only supports indexes with ArrayInvertedLists");
sizes.resize(nlist);
}
template<class T>
static void shift_and_add (std::vector<T> & dst,
size_t remove,
const std::vector<T> & src)
{
if (remove > 0)
memmove (dst.data(), dst.data() + remove,
(dst.size() - remove) * sizeof (T));
size_t insert_point = dst.size() - remove;
dst.resize (insert_point + src.size());
memcpy (dst.data() + insert_point, src.data (), src.size() * sizeof(T));
}
template<class T>
static void remove_from_begin (std::vector<T> & v,
size_t remove)
{
if (remove > 0)
v.erase (v.begin(), v.begin() + remove);
}
void SlidingIndexWindow::step(const Index *sub_index, bool remove_oldest) {
FAISS_THROW_IF_NOT_MSG (!remove_oldest || n_slice > 0,
"cannot remove slice: there is none");
const ArrayInvertedLists *ils2 = nullptr;
if(sub_index) {
check_compatible_for_merge (index, sub_index);
ils2 = dynamic_cast<const ArrayInvertedLists*>(
extract_index_ivf (sub_index)->invlists);
FAISS_THROW_IF_NOT_MSG (ils2, "supports only ArrayInvertedLists");
}
IndexIVF *index_ivf = extract_index_ivf (index);
if (remove_oldest && ils2) {
for (int i = 0; i < nlist; i++) {
std::vector<size_t> & sizesi = sizes[i];
size_t amount_to_remove = sizesi[0];
index_ivf->ntotal += ils2->ids[i].size() - amount_to_remove;
shift_and_add (ils->ids[i], amount_to_remove, ils2->ids[i]);
shift_and_add (ils->codes[i], amount_to_remove * ils->code_size,
ils2->codes[i]);
for (int j = 0; j + 1 < n_slice; j++) {
sizesi[j] = sizesi[j + 1] - amount_to_remove;
}
sizesi[n_slice - 1] = ils->ids[i].size();
}
} else if (ils2) {
for (int i = 0; i < nlist; i++) {
index_ivf->ntotal += ils2->ids[i].size();
shift_and_add (ils->ids[i], 0, ils2->ids[i]);
shift_and_add (ils->codes[i], 0, ils2->codes[i]);
sizes[i].push_back(ils->ids[i].size());
}
n_slice++;
} else if (remove_oldest) {
for (int i = 0; i < nlist; i++) {
size_t amount_to_remove = sizes[i][0];
index_ivf->ntotal -= amount_to_remove;
remove_from_begin (ils->ids[i], amount_to_remove);
remove_from_begin (ils->codes[i],
amount_to_remove * ils->code_size);
for (int j = 0; j + 1 < n_slice; j++) {
sizes[i][j] = sizes[i][j + 1] - amount_to_remove;
}
sizes[i].pop_back ();
}
n_slice--;
} else {
FAISS_THROW_MSG ("nothing to do???");
}
index->ntotal = index_ivf->ntotal;
}
// Get a subset of inverted lists [i0, i1). Works on IndexIVF's and
// IndexIVF's embedded in a IndexPreTransform
ArrayInvertedLists *
get_invlist_range (const Index *index, long i0, long i1)
{
const IndexIVF *ivf = extract_index_ivf (index);
FAISS_THROW_IF_NOT (0 <= i0 && i0 <= i1 && i1 <= ivf->nlist);
const InvertedLists *src = ivf->invlists;
ArrayInvertedLists * il = new ArrayInvertedLists(i1 - i0, src->code_size);
for (long i = i0; i < i1; i++) {
il->add_entries(i - i0, src->list_size(i),
InvertedLists::ScopedIds (src, i).get(),
InvertedLists::ScopedCodes (src, i).get());
}
return il;
}
void set_invlist_range (Index *index, long i0, long i1,
ArrayInvertedLists * src)
{
IndexIVF *ivf = extract_index_ivf (index);
FAISS_THROW_IF_NOT (0 <= i0 && i0 <= i1 && i1 <= ivf->nlist);
ArrayInvertedLists *dst = dynamic_cast<ArrayInvertedLists *>(ivf->invlists);
FAISS_THROW_IF_NOT_MSG (dst, "only ArrayInvertedLists supported");
FAISS_THROW_IF_NOT (src->nlist == i1 - i0 &&
dst->code_size == src->code_size);
size_t ntotal = index->ntotal;
for (long i = i0 ; i < i1; i++) {
ntotal -= dst->list_size (i);
ntotal += src->list_size (i - i0);
std::swap (src->codes[i - i0], dst->codes[i]);
std::swap (src->ids[i - i0], dst->ids[i]);
}
ivf->ntotal = index->ntotal = ntotal;
}
void search_with_parameters (const Index *index,
idx_t n, const float *x, idx_t k,
float *distances, idx_t *labels,
IVFSearchParameters *params,
size_t *nb_dis_ptr)
{
FAISS_THROW_IF_NOT (params);
const float *prev_x = x;
ScopeDeleter<float> del;
if (auto ip = dynamic_cast<const IndexPreTransform *> (index)) {
x = ip->apply_chain (n, x);
if (x != prev_x) {
del.set(x);
}
index = ip->index;
}
std::vector<idx_t> Iq(params->nprobe * n);
std::vector<float> Dq(params->nprobe * n);
const IndexIVF *index_ivf = dynamic_cast<const IndexIVF *>(index);
FAISS_THROW_IF_NOT (index_ivf);
index_ivf->quantizer->search(n, x, params->nprobe,
Dq.data(), Iq.data());
if (nb_dis_ptr) {
size_t nb_dis = 0;
const InvertedLists *il = index_ivf->invlists;
for (idx_t i = 0; i < n * params->nprobe; i++) {
if (Iq[i] >= 0) {
nb_dis += il->list_size(Iq[i]);
}
}
*nb_dis_ptr = nb_dis;
}
index_ivf->search_preassigned(n, x, k, Iq.data(), Dq.data(),
distances, labels,
false, params);
}
} } // namespace faiss::ivflib