Fix performance regression in ResultHandler (#1840)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1840

This diff is related to

https://github.com/facebookresearch/faiss/issues/1762

The ResultHandler introduced for FlatL2 and FlatIP was not multithreaded. This diff attempts to fix that. To be verified if it is indeed faster.

Reviewed By: wickedfoo

Differential Revision: D27939173

fbshipit-source-id: c85f01a97d4249fe0c6bfb04396b68a7a9fe643d
This commit is contained in:
Matthijs Douze 2021-04-30 00:01:44 -07:00 committed by Facebook GitHub Bot
parent c3842ae5ff
commit 061b68b43a
2 changed files with 9 additions and 6 deletions

View File

@ -92,13 +92,14 @@ struct HeapResultHandler {
/// add results for query i0..i1 and j0..j1 /// add results for query i0..i1 and j0..j1
void add_results(size_t j0, size_t j1, const T* dis_tab) { void add_results(size_t j0, size_t j1, const T* dis_tab) {
// maybe parallel for #pragma omp parallel for
for (size_t i = i0; i < i1; i++) { for (int64_t i = i0; i < i1; i++) {
T* heap_dis = heap_dis_tab + i * k; T* heap_dis = heap_dis_tab + i * k;
TI* heap_ids = heap_ids_tab + i * k; TI* heap_ids = heap_ids_tab + i * k;
const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
T thresh = heap_dis[0]; T thresh = heap_dis[0];
for (size_t j = j0; j < j1; j++) { for (size_t j = j0; j < j1; j++) {
T dis = *dis_tab++; T dis = dis_tab_i[j];
if (C::cmp(thresh, dis)) { if (C::cmp(thresh, dis)) {
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j); heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
thresh = heap_dis[0]; thresh = heap_dis[0];
@ -281,10 +282,12 @@ struct ReservoirResultHandler {
/// add results for query i0..i1 and j0..j1 /// add results for query i0..i1 and j0..j1
void add_results(size_t j0, size_t j1, const T* dis_tab) { void add_results(size_t j0, size_t j1, const T* dis_tab) {
// maybe parallel for // maybe parallel for
for (size_t i = i0; i < i1; i++) { #pragma omp parallel for
for (int64_t i = i0; i < i1; i++) {
ReservoirTopN<C>& reservoir = reservoirs[i - i0]; ReservoirTopN<C>& reservoir = reservoirs[i - i0];
const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
for (size_t j = j0; j < j1; j++) { for (size_t j = j0; j < j1; j++) {
T dis = *dis_tab++; T dis = dis_tab_i[j];
reservoir.add(dis, j); reservoir.add(dis, j);
} }
} }

View File

@ -286,7 +286,7 @@ void exhaustive_L2sqr_blas(
ip_block.get(), ip_block.get(),
&nyi); &nyi);
} }
#pragma omp parallel for
for (int64_t i = i0; i < i1; i++) { for (int64_t i = i0; i < i1; i++) {
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0); float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);