/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ // -*- c++ -*- #include #include #include #include #include namespace faiss { DirectMap::DirectMap(): type(NoMap) {} void DirectMap::set_type (Type new_type, const InvertedLists *invlists, size_t ntotal) { FAISS_THROW_IF_NOT (new_type == NoMap || new_type == Array || new_type == Hashtable); if (new_type == type) { // nothing to do return; } array.clear (); hashtable.clear (); type = new_type; if (new_type == NoMap) { return; } else if (new_type == Array) { array.resize (ntotal, -1); } else if (new_type == Hashtable) { hashtable.reserve (ntotal); } for (size_t key = 0; key < invlists->nlist; key++) { size_t list_size = invlists->list_size (key); InvertedLists::ScopedIds idlist (invlists, key); if (new_type == Array) { for (long ofs = 0; ofs < list_size; ofs++) { FAISS_THROW_IF_NOT_MSG ( 0 <= idlist [ofs] && idlist[ofs] < ntotal, "direct map supported only for seuquential ids"); array [idlist [ofs]] = lo_build(key, ofs); } } else if (new_type == Hashtable) { for (long ofs = 0; ofs < list_size; ofs++) { hashtable [idlist [ofs]] = lo_build(key, ofs); } } } } void DirectMap::clear() { array.clear (); hashtable.clear (); } DirectMap::idx_t DirectMap::get (idx_t key) const { if (type == Array) { FAISS_THROW_IF_NOT_MSG ( key >= 0 && key < array.size(), "invalid key" ); idx_t lo = array[key]; FAISS_THROW_IF_NOT_MSG(lo >= 0, "-1 entry in direct_map"); return lo; } else if (type == Hashtable) { auto res = hashtable.find (key); FAISS_THROW_IF_NOT_MSG (res != hashtable.end(), "key not found"); return res->second; } else { FAISS_THROW_MSG ("direct map not initialized"); } } void DirectMap::add_single_id (idx_t id, idx_t list_no, size_t offset) { if (type == NoMap) return; if (type == Array) { assert (id == array.size()); if (list_no >= 0) { array.push_back (lo_build (list_no, offset)); } else { array.push_back (-1); } } else if (type == Hashtable) { if (list_no >= 0) { hashtable[id] = lo_build (list_no, offset); } } } void DirectMap::check_can_add (const idx_t *ids) { if (type == Array && ids) { FAISS_THROW_MSG ("cannot have array direct map and add with ids"); } } /********************* DirectMapAdd implementation */ DirectMapAdd::DirectMapAdd (DirectMap &direct_map, size_t n, const idx_t *xids): direct_map(direct_map), type(direct_map.type), n(n), xids(xids) { if (type == DirectMap::Array) { FAISS_THROW_IF_NOT (xids == nullptr); ntotal = direct_map.array.size(); direct_map.array.resize (ntotal + n, -1); } else if (type == DirectMap::Hashtable) { // can't parallel update hashtable so use temp array all_ofs.resize (n, -1); } } void DirectMapAdd::add (size_t i, idx_t list_no, size_t ofs) { if (type == DirectMap::Array) { direct_map.array [ntotal + i] = lo_build (list_no, ofs); } else if (type == DirectMap::Hashtable) { all_ofs [i] = lo_build (list_no, ofs); } } DirectMapAdd::~DirectMapAdd () { if (type == DirectMap::Hashtable) { for (int i = 0; i < n; i++) { idx_t id = xids ? xids[i] : ntotal + i; direct_map.hashtable [id] = all_ofs [i]; } } } /********************************************************/ using ScopedCodes = InvertedLists::ScopedCodes; using ScopedIds = InvertedLists::ScopedIds; size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists *invlists) { size_t nlist = invlists->nlist; std::vector toremove(nlist); size_t nremove = 0; if (type == NoMap) { // exhaustive scan of IVF #pragma omp parallel for for (idx_t i = 0; i < nlist; i++) { idx_t l0 = invlists->list_size (i), l = l0, j = 0; ScopedIds idsi (invlists, i); while (j < l) { if (sel.is_member (idsi[j])) { l--; invlists->update_entry ( i, j, invlists->get_single_id (i, l), ScopedCodes (invlists, i, l).get() ); } else { j++; } } toremove[i] = l0 - l; } // this will not run well in parallel on ondisk because of // possible shrinks for (idx_t i = 0; i < nlist; i++) { if (toremove[i] > 0) { nremove += toremove[i]; invlists->resize(i, invlists->list_size(i) - toremove[i]); } } } else if (type == Hashtable) { const IDSelectorArray *sela = dynamic_cast(&sel); FAISS_THROW_IF_NOT_MSG ( sela, "remove with hashtable works only with IDSelectorArray" ); for (idx_t i = 0; i < sela->n; i++) { idx_t id = sela->ids[i]; auto res = hashtable.find (id); if (res != hashtable.end()) { size_t list_no = lo_listno (res->second); size_t offset = lo_offset (res->second); idx_t last = invlists->list_size (list_no) - 1; hashtable.erase (res); if (offset < last) { idx_t last_id = invlists->get_single_id (list_no, last); invlists->update_entry ( list_no, offset, last_id, ScopedCodes (invlists, list_no, last).get() ); // update hash entry for last element hashtable [last_id] = list_no << 32 | offset; } invlists->resize(list_no, last); nremove++; } } } else { FAISS_THROW_MSG("remove not supported with this direct_map format"); } return nremove; } void DirectMap::update_codes (InvertedLists *invlists, int n, const idx_t *ids, const idx_t *assign, const uint8_t *codes) { FAISS_THROW_IF_NOT (type == Array); size_t code_size = invlists->code_size; for (size_t i = 0; i < n; i++) { idx_t id = ids[i]; FAISS_THROW_IF_NOT_MSG (0 <= id && id < array.size(), "id to update out of range"); { // remove old one idx_t dm = array [id]; int64_t ofs = lo_offset (dm); int64_t il = lo_listno (dm); size_t l = invlists->list_size (il); if (ofs != l - 1) { // move l - 1 to ofs int64_t id2 = invlists->get_single_id (il, l - 1); array[id2] = lo_build (il, ofs); invlists->update_entry (il, ofs, id2, invlists->get_single_code (il, l - 1)); } invlists->resize (il, l - 1); } { // insert new one int64_t il = assign[i]; size_t l = invlists->list_size (il); idx_t dm = lo_build (il, l); array [id] = dm; invlists->add_entry (il, id, codes + i * code_size); } } } }