/** * Copyright (c) 2015-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD+Patents license found in the * LICENSE file in the root directory of this source tree. */ /* Copyright 2004-present Facebook. All Rights Reserved. Inverted list structure. */ #include "IndexIVFFlat.h" #include #include "utils.h" #include "FaissAssert.h" #include "IndexFlat.h" #include "AuxIndexStructures.h" namespace faiss { /***************************************** * IndexIVFFlat implementation ******************************************/ IndexIVFFlat::IndexIVFFlat (Index * quantizer, size_t d, size_t nlist, MetricType metric): IndexIVF (quantizer, d, nlist, sizeof(float) * d, metric) { code_size = sizeof(float) * d; } void IndexIVFFlat::add_with_ids (idx_t n, const float * x, const long *xids) { add_core (n, x, xids, nullptr); } void IndexIVFFlat::add_core (idx_t n, const float * x, const long *xids, const long *precomputed_idx) { FAISS_THROW_IF_NOT (is_trained); assert (invlists); FAISS_THROW_IF_NOT_MSG (!(maintain_direct_map && xids), "cannot have direct map and add with ids"); const long * idx; ScopeDeleter del; if (precomputed_idx) { idx = precomputed_idx; } else { long * idx0 = new long [n]; del.set (idx0); quantizer->assign (n, x, idx0); idx = idx0; } long n_add = 0; for (size_t i = 0; i < n; i++) { long id = xids ? xids[i] : ntotal + i; long list_no = idx [i]; if (list_no < 0) continue; const float *xi = x + i * d; size_t offset = invlists->add_entry ( list_no, id, (const uint8_t*) xi); if (maintain_direct_map) direct_map.push_back (list_no << 32 | offset); n_add++; } if (verbose) { printf("IndexIVFFlat::add_core: added %ld / %ld vectors\n", n_add, n); } ntotal += n_add; } namespace { void search_knn_inner_product (const IndexIVFFlat & ivf, size_t nx, const float * x, const long * keys, float_minheap_array_t * res, bool store_pairs) { const size_t k = res->k; size_t nlistv = 0, ndis = 0; size_t d = ivf.d; #pragma omp parallel for reduction(+: nlistv, ndis) for (size_t i = 0; i < nx; i++) { const float * xi = x + i * d; const long * keysi = keys + i * ivf.nprobe; float * __restrict simi = res->get_val (i); long * __restrict idxi = res->get_ids (i); minheap_heapify (k, simi, idxi); size_t nscan = 0; for (size_t ik = 0; ik < ivf.nprobe; ik++) { long key = keysi[ik]; /* select the list */ if (key < 0) { // not enough centroids for multiprobe continue; } FAISS_THROW_IF_NOT_FMT ( key < (long) ivf.nlist, "Invalid key=%ld at ik=%ld nlist=%ld\n", key, ik, ivf.nlist); nlistv++; size_t list_size = ivf.invlists->list_size(key); const float * list_vecs = (const float*)ivf.invlists->get_codes (key); const Index::idx_t * ids = store_pairs ? nullptr : ivf.invlists->get_ids (key); for (size_t j = 0; j < list_size; j++) { const float * yj = list_vecs + d * j; float ip = fvec_inner_product (xi, yj, d); if (ip > simi[0]) { minheap_pop (k, simi, idxi); long id = store_pairs ? (key << 32 | j) : ids[j]; minheap_push (k, simi, idxi, ip, id); } } nscan += list_size; if (ivf.max_codes && nscan >= ivf.max_codes) break; } ndis += nscan; minheap_reorder (k, simi, idxi); } indexIVF_stats.nq += nx; indexIVF_stats.nlist += nlistv; indexIVF_stats.ndis += ndis; } void search_knn_L2sqr (const IndexIVFFlat &ivf, size_t nx, const float * x, const long * keys, float_maxheap_array_t * res, bool store_pairs) { const size_t k = res->k; size_t nlistv = 0, ndis = 0; size_t d = ivf.d; #pragma omp parallel for reduction(+: nlistv, ndis) for (size_t i = 0; i < nx; i++) { const float * xi = x + i * d; const long * keysi = keys + i * ivf.nprobe; float * __restrict disi = res->get_val (i); long * __restrict idxi = res->get_ids (i); maxheap_heapify (k, disi, idxi); size_t nscan = 0; for (size_t ik = 0; ik < ivf.nprobe; ik++) { long key = keysi[ik]; /* select the list */ if (key < 0) { // not enough centroids for multiprobe continue; } FAISS_THROW_IF_NOT_FMT ( key < (long) ivf.nlist, "Invalid key=%ld at ik=%ld nlist=%ld\n", key, ik, ivf.nlist); nlistv++; size_t list_size = ivf.invlists->list_size(key); const float * list_vecs = (const float*)ivf.invlists->get_codes (key); const Index::idx_t * ids = store_pairs ? nullptr : ivf.invlists->get_ids (key); for (size_t j = 0; j < list_size; j++) { const float * yj = list_vecs + d * j; float disij = fvec_L2sqr (xi, yj, d); if (disij < disi[0]) { maxheap_pop (k, disi, idxi); long id = store_pairs ? (key << 32 | j) : ids[j]; maxheap_push (k, disi, idxi, disij, id); } } nscan += list_size; if (ivf.max_codes && nscan >= ivf.max_codes) break; } ndis += nscan; maxheap_reorder (k, disi, idxi); } indexIVF_stats.nq += nx; indexIVF_stats.nlist += nlistv; indexIVF_stats.ndis += ndis; } } // anonymous namespace void IndexIVFFlat::search_preassigned (idx_t n, const float *x, idx_t k, const idx_t *idx, const float * /* coarse_dis */, float *distances, idx_t *labels, bool store_pairs) const { if (metric_type == METRIC_INNER_PRODUCT) { float_minheap_array_t res = { size_t(n), size_t(k), labels, distances}; search_knn_inner_product (*this, n, x, idx, &res, store_pairs); } else if (metric_type == METRIC_L2) { float_maxheap_array_t res = { size_t(n), size_t(k), labels, distances}; search_knn_L2sqr (*this, n, x, idx, &res, store_pairs); } } void IndexIVFFlat::range_search (idx_t nx, const float *x, float radius, RangeSearchResult *result) const { idx_t * keys = new idx_t [nx * nprobe]; ScopeDeleter del (keys); quantizer->assign (nx, x, keys, nprobe); #pragma omp parallel { RangeSearchPartialResult pres(result); for (size_t i = 0; i < nx; i++) { const float * xi = x + i * d; const long * keysi = keys + i * nprobe; RangeSearchPartialResult::QueryResult & qres = pres.new_result (i); for (size_t ik = 0; ik < nprobe; ik++) { long key = keysi[ik]; /* select the list */ if (key < 0 || key >= (long) nlist) { fprintf (stderr, "Invalid key=%ld at ik=%ld nlist=%ld\n", key, ik, nlist); throw; } const size_t list_size = invlists->list_size(key); const float * list_vecs = (const float*)invlists->get_codes (key); const Index::idx_t * ids = invlists->get_ids (key); for (size_t j = 0; j < list_size; j++) { const float * yj = list_vecs + d * j; if (metric_type == METRIC_L2) { float disij = fvec_L2sqr (xi, yj, d); if (disij < radius) { qres.add (disij, ids[j]); } } else if (metric_type == METRIC_INNER_PRODUCT) { float disij = fvec_inner_product(xi, yj, d); if (disij > radius) { qres.add (disij, ids[j]); } } } } } pres.finalize (); } } void IndexIVFFlat::update_vectors (int n, idx_t *new_ids, const float *x) { FAISS_THROW_IF_NOT (maintain_direct_map); FAISS_THROW_IF_NOT (is_trained); std::vector assign (n); quantizer->assign (n, x, assign.data()); for (size_t i = 0; i < n; i++) { idx_t id = new_ids[i]; FAISS_THROW_IF_NOT_MSG (0 <= id && id < ntotal, "id to update out of range"); { // remove old one long dm = direct_map[id]; long ofs = dm & 0xffffffff; long il = dm >> 32; size_t l = invlists->list_size (il); if (ofs != l - 1) { // move l - 1 to ofs long id2 = invlists->get_single_id (il, l - 1); direct_map[id2] = (il << 32) | ofs; invlists->update_entry (il, ofs, id2, invlists->get_single_code (il, l - 1)); } invlists->resize (il, l - 1); } { // insert new one long il = assign[i]; size_t l = invlists->list_size (il); long dm = (il << 32) | l; direct_map[id] = dm; invlists->add_entry (il, id, (const uint8_t*)(x + i * d)); } } } void IndexIVFFlat::reconstruct_from_offset (long list_no, long offset, float* recons) const { memcpy (recons, invlists->get_single_code (list_no, offset), code_size); } } // namespace faiss