diff --git a/faiss/IndexIVFFlat.cpp b/faiss/IndexIVFFlat.cpp index be9e07f69..e8834e864 100644 --- a/faiss/IndexIVFFlat.cpp +++ b/faiss/IndexIVFFlat.cpp @@ -9,6 +9,8 @@ #include +#include + #include #include @@ -47,19 +49,35 @@ void IndexIVFFlat::add_core( direct_map.check_can_add(xids); int64_t n_add = 0; - for (size_t i = 0; i < n; i++) { - idx_t id = xids ? xids[i] : ntotal + i; - idx_t list_no = coarse_idx[i]; - size_t offset; - if (list_no >= 0) { - const float* xi = x + i * d; - offset = invlists->add_entry(list_no, id, (const uint8_t*)xi); - n_add++; - } else { - offset = 0; +#pragma omp parallel reduction(+ : n_add) + { + int nt = omp_get_num_threads(); + int rank = omp_get_thread_num(); + + // each thread takes care of a subset of lists + for (size_t i = 0; i < n; i++) { + idx_t list_no = coarse_idx[i]; + + if (list_no % nt != rank) { + continue; + } + + idx_t id = xids ? xids[i] : ntotal + i; + size_t offset; + + if (list_no >= 0) { + const float* xi = x + i * d; + offset = invlists->add_entry(list_no, id, (const uint8_t*)xi); + n_add++; + } else { + offset = 0; + } + +#pragma omp critical + // executed by one thread at a time + direct_map.add_single_id(id, list_no, offset); } - direct_map.add_single_id(id, list_no, offset); } if (verbose) { @@ -249,38 +267,50 @@ void IndexIVFFlatDedup::add_with_ids( quantizer->assign(na, x, idx); int64_t n_add = 0, n_dup = 0; - // TODO make a omp loop with this - for (size_t i = 0; i < na; i++) { - idx_t id = xids ? xids[i] : ntotal + i; - int64_t list_no = idx[i]; - if (list_no < 0) { - continue; - } - const float* xi = x + i * d; +#pragma omp parallel reduction(+ : n_add, n_dup) + { + int nt = omp_get_num_threads(); + int rank = omp_get_thread_num(); - // search if there is already an entry with that id - InvertedLists::ScopedCodes codes(invlists, list_no); + // each thread takes care of a subset of lists + for (size_t i = 0; i < na; i++) { + int64_t list_no = idx[i]; - int64_t n = invlists->list_size(list_no); - int64_t offset = -1; - for (int64_t o = 0; o < n; o++) { - if (!memcmp(codes.get() + o * code_size, xi, code_size)) { - offset = o; - break; + if (list_no < 0 || list_no % nt != rank) { + continue; } - } - if (offset == -1) { // not found - invlists->add_entry(list_no, id, (const uint8_t*)xi); - } else { - // mark equivalence - idx_t id2 = invlists->get_single_id(list_no, offset); - std::pair pair(id2, id); - instances.insert(pair); - n_dup++; + idx_t id = xids ? xids[i] : ntotal + i; + const float* xi = x + i * d; + + // search if there is already an entry with that id + InvertedLists::ScopedCodes codes(invlists, list_no); + + int64_t n = invlists->list_size(list_no); + int64_t offset = -1; + for (int64_t o = 0; o < n; o++) { + if (!memcmp(codes.get() + o * code_size, xi, code_size)) { + offset = o; + break; + } + } + + if (offset == -1) { // not found + invlists->add_entry(list_no, id, (const uint8_t*)xi); + } else { + // mark equivalence + idx_t id2 = invlists->get_single_id(list_no, offset); + std::pair pair(id2, id); + +#pragma omp critical + // executed by one thread at a time + instances.insert(pair); + + n_dup++; + } + n_add++; } - n_add++; } if (verbose) { printf("IndexIVFFlat::add_with_ids: added %" PRId64 " / %" PRId64