faiss/demos/rocksdb_ivf/RocksDBInvertedLists.cpp

115 lines
3.4 KiB
C++

/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "RocksDBInvertedLists.h"
#include <faiss/impl/FaissAssert.h>
using namespace faiss;
namespace faiss_rocksdb {
RocksDBInvertedListsIterator::RocksDBInvertedListsIterator(
rocksdb::DB* db,
size_t list_no,
size_t code_size)
: InvertedListsIterator(),
it(db->NewIterator(rocksdb::ReadOptions())),
list_no(list_no),
code_size(code_size),
codes(code_size) {
it->Seek(rocksdb::Slice(
reinterpret_cast<const char*>(&list_no), sizeof(size_t)));
}
bool RocksDBInvertedListsIterator::is_available() const {
return it->Valid() &&
it->key().starts_with(rocksdb::Slice(
reinterpret_cast<const char*>(&list_no), sizeof(size_t)));
}
void RocksDBInvertedListsIterator::next() {
it->Next();
}
std::pair<idx_t, const uint8_t*> RocksDBInvertedListsIterator::
get_id_and_codes() {
idx_t id =
*reinterpret_cast<const idx_t*>(&it->key().data()[sizeof(size_t)]);
assert(code_size == it->value().size());
return {id, reinterpret_cast<const uint8_t*>(it->value().data())};
}
RocksDBInvertedLists::RocksDBInvertedLists(
const char* db_directory,
size_t nlist,
size_t code_size)
: InvertedLists(nlist, code_size) {
use_iterator = true;
rocksdb::Options options;
options.create_if_missing = true;
rocksdb::DB* db;
rocksdb::Status status = rocksdb::DB::Open(options, db_directory, &db);
db_ = std::unique_ptr<rocksdb::DB>(db);
assert(status.ok());
}
size_t RocksDBInvertedLists::list_size(size_t /*list_no*/) const {
FAISS_THROW_MSG("list_size is not supported");
}
const uint8_t* RocksDBInvertedLists::get_codes(size_t /*list_no*/) const {
FAISS_THROW_MSG("get_codes is not supported");
}
const idx_t* RocksDBInvertedLists::get_ids(size_t /*list_no*/) const {
FAISS_THROW_MSG("get_ids is not supported");
}
size_t RocksDBInvertedLists::add_entries(
size_t list_no,
size_t n_entry,
const idx_t* ids,
const uint8_t* code) {
rocksdb::WriteOptions wo;
std::vector<char> key(sizeof(size_t) + sizeof(idx_t));
memcpy(key.data(), &list_no, sizeof(size_t));
for (size_t i = 0; i < n_entry; i++) {
memcpy(key.data() + sizeof(size_t), ids + i, sizeof(idx_t));
rocksdb::Status status = db_->Put(
wo,
rocksdb::Slice(key.data(), key.size()),
rocksdb::Slice(
reinterpret_cast<const char*>(code + i * code_size),
code_size));
assert(status.ok());
}
return 0; // ignored
}
void RocksDBInvertedLists::update_entries(
size_t /*list_no*/,
size_t /*offset*/,
size_t /*n_entry*/,
const idx_t* /*ids*/,
const uint8_t* /*code*/) {
FAISS_THROW_MSG("update_entries is not supported");
}
void RocksDBInvertedLists::resize(size_t /*list_no*/, size_t /*new_size*/) {
FAISS_THROW_MSG("resize is not supported");
}
InvertedListsIterator* RocksDBInvertedLists::get_iterator(
size_t list_no,
void* inverted_list_context) const {
return new RocksDBInvertedListsIterator(db_.get(), list_no, code_size);
}
} // namespace faiss_rocksdb