faiss/DirectMap.cpp

268 lines
7.7 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/DirectMap.h>
#include <cstdio>
#include <cassert>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/AuxIndexStructures.h>
namespace faiss {
DirectMap::DirectMap(): type(NoMap)
{}
void DirectMap::set_type (Type new_type, const InvertedLists *invlists, size_t ntotal) {
FAISS_THROW_IF_NOT (new_type == NoMap || new_type == Array ||
new_type == Hashtable);
if (new_type == type) {
// nothing to do
return;
}
array.clear ();
hashtable.clear ();
type = new_type;
if (new_type == NoMap) {
return;
} else if (new_type == Array) {
array.resize (ntotal, -1);
} else if (new_type == Hashtable) {
hashtable.reserve (ntotal);
}
for (size_t key = 0; key < invlists->nlist; key++) {
size_t list_size = invlists->list_size (key);
InvertedLists::ScopedIds idlist (invlists, key);
if (new_type == Array) {
for (long ofs = 0; ofs < list_size; ofs++) {
FAISS_THROW_IF_NOT_MSG (
0 <= idlist [ofs] && idlist[ofs] < ntotal,
"direct map supported only for seuquential ids");
array [idlist [ofs]] = lo_build(key, ofs);
}
} else if (new_type == Hashtable) {
for (long ofs = 0; ofs < list_size; ofs++) {
hashtable [idlist [ofs]] = lo_build(key, ofs);
}
}
}
}
void DirectMap::clear()
{
array.clear ();
hashtable.clear ();
}
DirectMap::idx_t DirectMap::get (idx_t key) const
{
if (type == Array) {
FAISS_THROW_IF_NOT_MSG (
key >= 0 && key < array.size(), "invalid key"
);
idx_t lo = array[key];
FAISS_THROW_IF_NOT_MSG(lo >= 0, "-1 entry in direct_map");
return lo;
} else if (type == Hashtable) {
auto res = hashtable.find (key);
FAISS_THROW_IF_NOT_MSG (res != hashtable.end(), "key not found");
return res->second;
} else {
FAISS_THROW_MSG ("direct map not initialized");
}
}
void DirectMap::add_single_id (idx_t id, idx_t list_no, size_t offset)
{
if (type == NoMap) return;
if (type == Array) {
assert (id == array.size());
if (list_no >= 0) {
array.push_back (lo_build (list_no, offset));
} else {
array.push_back (-1);
}
} else if (type == Hashtable) {
if (list_no >= 0) {
hashtable[id] = lo_build (list_no, offset);
}
}
}
void DirectMap::check_can_add (const idx_t *ids) {
if (type == Array && ids) {
FAISS_THROW_MSG ("cannot have array direct map and add with ids");
}
}
/********************* DirectMapAdd implementation */
DirectMapAdd::DirectMapAdd (DirectMap &direct_map, size_t n, const idx_t *xids):
direct_map(direct_map), type(direct_map.type), n(n), xids(xids)
{
if (type == DirectMap::Array) {
FAISS_THROW_IF_NOT (xids == nullptr);
ntotal = direct_map.array.size();
direct_map.array.resize (ntotal + n, -1);
} else if (type == DirectMap::Hashtable) {
// can't parallel update hashtable so use temp array
all_ofs.resize (n, -1);
}
}
void DirectMapAdd::add (size_t i, idx_t list_no, size_t ofs)
{
if (type == DirectMap::Array) {
direct_map.array [ntotal + i] = lo_build (list_no, ofs);
} else if (type == DirectMap::Hashtable) {
all_ofs [i] = lo_build (list_no, ofs);
}
}
DirectMapAdd::~DirectMapAdd ()
{
if (type == DirectMap::Hashtable) {
for (int i = 0; i < n; i++) {
idx_t id = xids ? xids[i] : ntotal + i;
direct_map.hashtable [id] = all_ofs [i];
}
}
}
/********************************************************/
using ScopedCodes = InvertedLists::ScopedCodes;
using ScopedIds = InvertedLists::ScopedIds;
size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists *invlists)
{
size_t nlist = invlists->nlist;
std::vector<idx_t> toremove(nlist);
size_t nremove = 0;
if (type == NoMap) {
// exhaustive scan of IVF
#pragma omp parallel for
for (idx_t i = 0; i < nlist; i++) {
idx_t l0 = invlists->list_size (i), l = l0, j = 0;
ScopedIds idsi (invlists, i);
while (j < l) {
if (sel.is_member (idsi[j])) {
l--;
invlists->update_entry (
i, j,
invlists->get_single_id (i, l),
ScopedCodes (invlists, i, l).get()
);
} else {
j++;
}
}
toremove[i] = l0 - l;
}
// this will not run well in parallel on ondisk because of
// possible shrinks
for (idx_t i = 0; i < nlist; i++) {
if (toremove[i] > 0) {
nremove += toremove[i];
invlists->resize(i, invlists->list_size(i) - toremove[i]);
}
}
} else if (type == Hashtable) {
const IDSelectorArray *sela =
dynamic_cast<const IDSelectorArray*>(&sel);
FAISS_THROW_IF_NOT_MSG (
sela,
"remove with hashtable works only with IDSelectorArray"
);
for (idx_t i = 0; i < sela->n; i++) {
idx_t id = sela->ids[i];
auto res = hashtable.find (id);
if (res != hashtable.end()) {
size_t list_no = lo_listno (res->second);
size_t offset = lo_offset (res->second);
idx_t last = invlists->list_size (list_no) - 1;
hashtable.erase (res);
if (offset < last) {
idx_t last_id = invlists->get_single_id (list_no, last);
invlists->update_entry (
list_no, offset,
last_id,
ScopedCodes (invlists, list_no, last).get()
);
// update hash entry for last element
hashtable [last_id] = list_no << 32 | offset;
}
invlists->resize(list_no, last);
nremove++;
}
}
} else {
FAISS_THROW_MSG("remove not supported with this direct_map format");
}
return nremove;
}
void DirectMap::update_codes (InvertedLists *invlists,
int n, const idx_t *ids,
const idx_t *assign,
const uint8_t *codes)
{
FAISS_THROW_IF_NOT (type == Array);
size_t code_size = invlists->code_size;
for (size_t i = 0; i < n; i++) {
idx_t id = ids[i];
FAISS_THROW_IF_NOT_MSG (0 <= id && id < array.size(),
"id to update out of range");
{ // remove old one
idx_t dm = array [id];
int64_t ofs = lo_offset (dm);
int64_t il = lo_listno (dm);
size_t l = invlists->list_size (il);
if (ofs != l - 1) { // move l - 1 to ofs
int64_t id2 = invlists->get_single_id (il, l - 1);
array[id2] = lo_build (il, ofs);
invlists->update_entry (il, ofs, id2,
invlists->get_single_code (il, l - 1));
}
invlists->resize (il, l - 1);
}
{ // insert new one
int64_t il = assign[i];
size_t l = invlists->list_size (il);
idx_t dm = lo_build (il, l);
array [id] = dm;
invlists->add_entry (il, id, codes + i * code_size);
}
}
}
}