Support for parallelization in IVFFastScan over both queries and probes (#2380)

Summary:
For search request with few queries or single query, this PR adds the ability to run threads over both queries and different cluster of the IVF. For application where latency is important this can **dramatically reduce latency for single query requests**.

A new implementation (https://github.com/facebookresearch/faiss/issues/14) is added. The new implementation could be merged to the implementation 12 but for simplicity in this PR, I created a separate function.

Tests are added to cover the new implementation and new tests are added to specifically cover the case when a single query  is used.

In my benchmarks a very good reduction of latency is observed for single query requests.

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2380

Test Plan:
```
buck test //faiss/tests/:test_fast_scan_ivf -- implem14
buck test //faiss/tests/:test_fast_scan_ivf -- implem15
```

Reviewed By: alexanderguzhva

Differential Revision: D38074577

Pulled By: mdouze

fbshipit-source-id: e7a20b6ea2f9216e0a045764b5d7b7f550ea89fe
pull/2441/head
alemagnani 2022-08-31 05:37:53 -07:00 committed by Facebook GitHub Bot
parent dcbf33c525
commit 230a97f7cb
4 changed files with 480 additions and 35 deletions

View File

@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import faiss
import time
import os
import multiprocessing as mp
import numpy as np
import matplotlib.pyplot as plt
try:
from faiss.contrib.datasets_fb import \
DatasetSIFT1M, DatasetDeep1B, DatasetBigANN
except ImportError:
from faiss.contrib.datasets import \
DatasetSIFT1M, DatasetDeep1B, DatasetBigANN
# ds = DatasetDeep1B(10**6)
ds = DatasetBigANN(nb_M=50)
# ds = DatasetSIFT1M()
xq = ds.get_queries()
xb = ds.get_database()
gt = ds.get_groundtruth()
xt = ds.get_train()
nb, d = xb.shape
nq, d = xq.shape
nt, d = xt.shape
print('the dimension is {}, {}'.format(nb, d))
k = 64
def eval_recall(index, name, single_query=False):
t0 = time.time()
D, I = index.search(xq, k=k)
t = time.time() - t0
if single_query:
t0 = time.time()
for row in range(nq):
Ds, Is = index.search(xq[row:row + 1], k=k)
D[row, :] = Ds
I[row, :] = Is
t = time.time() - t0
speed = t * 1000 / nq
qps = 1000 / speed
corrects = (gt[:, :1] == I[:, :k]).sum()
recall = corrects / nq
print(
f'\tnprobe {index.nprobe:3d}, 1Recall@{k}: '
f'{recall:.6f}, speed: {speed:.6f} ms/query'
)
return recall, qps
def eval_and_plot(
name, rescale_norm=True, plot=True, single_query=False,
implem=None, num_threads=1):
index = faiss.index_factory(d, name)
index_path = f"indices/{name}.faissindex"
if os.path.exists(index_path):
index = faiss.read_index(index_path)
else:
faiss.omp_set_num_threads(mp.cpu_count())
index.train(xt)
index.add(xb)
faiss.write_index(index, index_path)
# search params
if hasattr(index, 'rescale_norm'):
index.rescale_norm = rescale_norm
name += f"(rescale_norm={rescale_norm})"
if implem is not None and hasattr(index, 'implem'):
index.implem = implem
name += f"(implem={implem})"
if single_query:
name += f"(single_query={single_query})"
if num_threads > 1:
name += f"(num_threads={num_threads})"
faiss.omp_set_num_threads(num_threads)
data = []
print(f"======{name}")
for nprobe in 1, 4, 8, 16, 32, 64, 128, 256:
index.nprobe = nprobe
recall, qps = eval_recall(index, name, single_query=single_query)
data.append((recall, qps))
if plot:
data = np.array(data)
plt.plot(data[:, 0], data[:, 1], label=name) # x - recall, y - qps
M, nlist = 64, 4096
# just for warmup...
# eval_and_plot(f"IVF{nlist},PQ{M}x4fs", plot=False)
# benchmark
plt.figure(figsize=(8, 6), dpi=80)
eval_and_plot(f"IVF{nlist},PQ{M}x4fs", num_threads=8)
eval_and_plot(f"IVF{nlist},PQ{M}x4fs", single_query=True, implem=0, num_threads=8)
eval_and_plot(f"IVF{nlist},PQ{M}x4fs", single_query=True, implem=14, num_threads=8)
eval_and_plot(f"IVF{nlist},PQ{M}x4fs", single_query=True, implem=15, num_threads=8)
plt.title("Indices on Bigann50M")
plt.xlabel("1Recall@{}".format(k))
plt.ylabel("QPS")
plt.legend(bbox_to_anchor=(1.02, 0.1), loc='upper left', borderaxespad=0)
plt.savefig("bench_ivf_fastscan.png", bbox_inches='tight')

View File

@ -10,6 +10,7 @@
#include <cassert>
#include <cinttypes>
#include <cstdio>
#include <set>
#include <omp.h>
@ -42,7 +43,9 @@ IndexIVFFastScan::IndexIVFFastScan(
size_t nlist,
size_t code_size,
MetricType metric)
: IndexIVF(quantizer, d, nlist, code_size, metric) {}
: IndexIVF(quantizer, d, nlist, code_size, metric) {
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
}
IndexIVFFastScan::IndexIVFFastScan() {
bbs = 0;
@ -352,7 +355,7 @@ void IndexIVFFastScan::search_dispatch_implem(
} else if (impl == 2) {
search_implem_2<C>(n, x, k, distances, labels, scaler);
} else if (impl >= 10 && impl <= 13) {
} else if (impl >= 10 && impl <= 15) {
size_t ndis = 0, nlist_visited = 0;
if (n < 2) {
@ -367,6 +370,8 @@ void IndexIVFFastScan::search_dispatch_implem(
&ndis,
&nlist_visited,
scaler);
} else if (impl == 14 || impl == 15) {
search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
} else {
search_implem_10<C>(
n,
@ -400,35 +405,40 @@ void IndexIVFFastScan::search_dispatch_implem(
// LUTs unlikely to be a limiting factor
nslice = omp_get_max_threads();
}
if (impl == 14 ||
impl == 15) { // this might require slicing if there are too
// many queries (for now we keep this simple)
search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
} else {
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
for (int slice = 0; slice < nslice; slice++) {
idx_t i0 = n * slice / nslice;
idx_t i1 = n * (slice + 1) / nslice;
float* dis_i = distances + i0 * k;
idx_t* lab_i = labels + i0 * k;
if (impl == 12 || impl == 13) {
search_implem_12<C>(
i1 - i0,
x + i0 * d,
k,
dis_i,
lab_i,
impl,
&ndis,
&nlist_visited,
scaler);
} else {
search_implem_10<C>(
i1 - i0,
x + i0 * d,
k,
dis_i,
lab_i,
impl,
&ndis,
&nlist_visited,
scaler);
for (int slice = 0; slice < nslice; slice++) {
idx_t i0 = n * slice / nslice;
idx_t i1 = n * (slice + 1) / nslice;
float* dis_i = distances + i0 * k;
idx_t* lab_i = labels + i0 * k;
if (impl == 12 || impl == 13) {
search_implem_12<C>(
i1 - i0,
x + i0 * d,
k,
dis_i,
lab_i,
impl,
&ndis,
&nlist_visited,
scaler);
} else {
search_implem_10<C>(
i1 - i0,
x + i0 * d,
k,
dis_i,
lab_i,
impl,
&ndis,
&nlist_visited,
scaler);
}
}
}
}
@ -922,6 +932,280 @@ void IndexIVFFastScan::search_implem_12(
*nlist_out = nlist;
}
template <class C, class Scaler>
void IndexIVFFastScan::search_implem_14(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
int impl,
const Scaler& scaler) const {
if (n == 0) { // does not work well with reservoir
return;
}
FAISS_THROW_IF_NOT(bbs == 32);
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
uint64_t ttg0 = get_cy();
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
uint64_t ttg1 = get_cy();
uint64_t coarse_search_tt = ttg1 - ttg0;
size_t dim12 = ksub * M2;
AlignedTable<uint8_t> dis_tables;
AlignedTable<uint16_t> biases;
std::unique_ptr<float[]> normalizers(new float[2 * n]);
compute_LUT_uint8(
n,
x,
coarse_ids.get(),
coarse_dis.get(),
dis_tables,
biases,
normalizers.get());
uint64_t ttg2 = get_cy();
uint64_t lut_compute_tt = ttg2 - ttg1;
struct QC {
int qno; // sequence number of the query
int list_no; // list to visit
int rank; // this is the rank'th result of the coarse quantizer
};
bool single_LUT = !lookup_table_is_3d();
std::vector<QC> qcs;
{
int ij = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < nprobe; j++) {
if (coarse_ids[ij] >= 0) {
qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
}
ij++;
}
}
std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
return a.list_no < b.list_no;
});
}
struct SE {
size_t start; // start in the QC vector
size_t end; // end in the QC vector
size_t list_size;
};
std::vector<SE> ses;
size_t i0_l = 0;
while (i0_l < qcs.size()) {
// find all queries that access this inverted list
int list_no = qcs[i0_l].list_no;
size_t i1 = i0_l + 1;
while (i1 < qcs.size() && i1 < i0_l + qbs2) {
if (qcs[i1].list_no != list_no) {
break;
}
i1++;
}
size_t list_size = invlists->list_size(list_no);
if (list_size == 0) {
i0_l = i1;
continue;
}
ses.push_back(SE{i0_l, i1, list_size});
i0_l = i1;
}
uint64_t ttg3 = get_cy();
uint64_t compute_clusters_tt = ttg3 - ttg2;
// function to handle the global heap
using HeapForIP = CMin<float, idx_t>;
using HeapForL2 = CMax<float, idx_t>;
auto init_result = [&](float* simi, idx_t* idxi) {
if (metric_type == METRIC_INNER_PRODUCT) {
heap_heapify<HeapForIP>(k, simi, idxi);
} else {
heap_heapify<HeapForL2>(k, simi, idxi);
}
};
auto add_local_results = [&](const float* local_dis,
const idx_t* local_idx,
float* simi,
idx_t* idxi) {
if (metric_type == METRIC_INNER_PRODUCT) {
heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
} else {
heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
}
};
auto reorder_result = [&](float* simi, idx_t* idxi) {
if (metric_type == METRIC_INNER_PRODUCT) {
heap_reorder<HeapForIP>(k, simi, idxi);
} else {
heap_reorder<HeapForL2>(k, simi, idxi);
}
};
uint64_t ttg4 = get_cy();
uint64_t fn_tt = ttg4 - ttg3;
size_t ndis = 0;
size_t nlist_visited = 0;
#pragma omp parallel reduction(+ : ndis, nlist_visited)
{
// storage for each thread
std::vector<idx_t> local_idx(k * n);
std::vector<float> local_dis(k * n);
// prepare the result handlers
std::unique_ptr<SIMDResultHandler<C, true>> handler;
AlignedTable<uint16_t> tmp_distances;
using HeapHC = HeapHandler<C, true>;
using ReservoirHC = ReservoirHandler<C, true>;
using SingleResultHC = SingleResultHandler<C, true>;
if (k == 1) {
handler.reset(new SingleResultHC(n, 0));
} else if (impl == 14) {
tmp_distances.resize(n * k);
handler.reset(
new HeapHC(n, tmp_distances.get(), local_idx.data(), k, 0));
} else if (impl == 15) {
handler.reset(new ReservoirHC(n, 0, k, 2 * k));
}
int qbs2 = this->qbs2 ? this->qbs2 : 11;
std::vector<uint16_t> tmp_bias;
if (biases.get()) {
tmp_bias.resize(qbs2);
handler->dbias = tmp_bias.data();
}
uint64_t ttg5 = get_cy();
uint64_t handler_tt = ttg5 - ttg4;
std::set<int> q_set;
uint64_t t_copy_pack = 0, t_scan = 0;
#pragma omp for schedule(dynamic)
for (idx_t cluster = 0; cluster < ses.size(); cluster++) {
uint64_t tt0 = get_cy();
size_t i0 = ses[cluster].start;
size_t i1 = ses[cluster].end;
size_t list_size = ses[cluster].list_size;
nlist_visited++;
int list_no = qcs[i0].list_no;
// re-organize LUTs and biases into the right order
int nc = i1 - i0;
std::vector<int> q_map(nc), lut_entries(nc);
AlignedTable<uint8_t> LUT(nc * dim12);
memset(LUT.get(), -1, nc * dim12);
int qbs = pq4_preferred_qbs(nc);
for (size_t i = i0; i < i1; i++) {
const QC& qc = qcs[i];
q_map[i - i0] = qc.qno;
q_set.insert(qc.qno);
int ij = qc.qno * nprobe + qc.rank;
lut_entries[i - i0] = single_LUT ? qc.qno : ij;
if (biases.get()) {
tmp_bias[i - i0] = biases[ij];
}
}
pq4_pack_LUT_qbs_q_map(
qbs, M2, dis_tables.get(), lut_entries.data(), LUT.get());
// access the inverted list
ndis += (i1 - i0) * list_size;
InvertedLists::ScopedCodes codes(invlists, list_no);
InvertedLists::ScopedIds ids(invlists, list_no);
// prepare the handler
handler->ntotal = list_size;
handler->q_map = q_map.data();
handler->id_map = ids.get();
uint64_t tt1 = get_cy();
#define DISPATCH(classHC) \
if (dynamic_cast<classHC*>(handler.get())) { \
auto* res = static_cast<classHC*>(handler.get()); \
pq4_accumulate_loop_qbs( \
qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
}
DISPATCH(HeapHC)
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
uint64_t tt2 = get_cy();
t_copy_pack += tt1 - tt0;
t_scan += tt2 - tt1;
}
// labels is in-place for HeapHC
handler->to_flat_arrays(
local_dis.data(),
local_idx.data(),
skip & 16 ? nullptr : normalizers.get());
#pragma omp single
{
// we init the results as a heap
for (idx_t i = 0; i < n; i++) {
init_result(distances + i * k, labels + i * k);
}
}
#pragma omp barrier
#pragma omp critical
{
// write to global heap #go over only the queries
for (std::set<int>::iterator it = q_set.begin(); it != q_set.end();
++it) {
add_local_results(
local_dis.data() + *it * k,
local_idx.data() + *it * k,
distances + *it * k,
labels + *it * k);
}
IVFFastScan_stats.t_copy_pack += t_copy_pack;
IVFFastScan_stats.t_scan += t_scan;
if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
for (int i = 0; i < 4; i++) {
IVFFastScan_stats.reservoir_times[i] += rh->times[i];
}
}
}
#pragma omp barrier
#pragma omp single
{
for (idx_t i = 0; i < n; i++) {
reorder_result(distances + i * k, labels + i * k);
}
}
}
indexIVF_stats.nq += n;
indexIVF_stats.ndis += ndis;
indexIVF_stats.nlist += nlist_visited;
}
void IndexIVFFastScan::reconstruct_from_offset(
int64_t list_no,
int64_t offset,

View File

@ -157,6 +157,17 @@ struct IndexIVFFastScan : IndexIVF {
size_t* nlist_out,
const Scaler& scaler) const;
// implem 14 is mukltithreaded internally across nprobes and queries
template <class C, class Scaler>
void search_implem_14(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
int impl,
const Scaler& scaler) const;
// reconstruct vectors from packed invlists
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
const override;

View File

@ -298,8 +298,8 @@ class TestIVFImplem12(unittest.TestCase):
IMPLEM = 12
def do_test(self, by_residual, metric=faiss.METRIC_L2, d=32):
ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
def do_test(self, by_residual, metric=faiss.METRIC_L2, d=32, nq=200):
ds = datasets.SyntheticDataset(d, 2000, 5000, nq)
index = faiss.index_factory(d, f"IVF32,PQ{d//2}x4np", metric)
# force coarse quantizer
@ -350,6 +350,26 @@ class TestIVFImplem12(unittest.TestCase):
def test_by_residual_odd_dim(self):
self.do_test(True, d=30)
# testin single query
def test_no_residual_single_query(self):
self.do_test(False, nq=1)
def test_by_residual_single_query(self):
self.do_test(True, nq=1)
def test_no_residual_ip_single_query(self):
self.do_test(False, metric=faiss.METRIC_INNER_PRODUCT, nq=1)
def test_by_residual_ip_single_query(self):
self.do_test(True, metric=faiss.METRIC_INNER_PRODUCT, nq=1)
def test_no_residual_odd_dim_single_query(self):
self.do_test(False, d=30, nq=1)
def test_by_residual_odd_dim_single_query(self):
self.do_test(True, d=30, nq=1)
class TestIVFImplem10(TestIVFImplem12):
IMPLEM = 10
@ -363,6 +383,14 @@ class TestIVFImplem13(TestIVFImplem12):
IMPLEM = 13
class TestIVFImplem14(TestIVFImplem12):
IMPLEM = 14
class TestIVFImplem15(TestIVFImplem12):
IMPLEM = 15
class TestAdd(unittest.TestCase):
def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32):
@ -536,7 +564,7 @@ class TestIVFAQFastScan(unittest.TestCase):
# generated programatically below
for metric in 'L2', 'IP':
for byr in True, False:
for implem in 0, 10, 11, 12, 13:
for implem in 0, 10, 11, 12, 13, 14, 15:
self.subtest_accuracy('RQ', 'rq', byr, implem, metric)
self.subtest_accuracy('LSQ', 'lsq', byr, implem, metric)
@ -579,7 +607,7 @@ class TestIVFAQFastScan(unittest.TestCase):
def xx_test_rescale_accuracy(self):
for byr in True, False:
for implem in 0, 10, 11, 12, 13:
for implem in 0, 10, 11, 12, 13, 14, 15:
self.subtest_accuracy('RQ', 'rq', byr, implem, 'L2')
self.subtest_accuracy('LSQ', 'lsq', byr, implem, 'L2')
@ -702,7 +730,7 @@ def add_TestIVFAQFastScan_subtest_rescale_accuracy(aq, st, by_residual, implem):
)
for byr in True, False:
for implem in 0, 10, 11, 12, 13:
for implem in 0, 10, 11, 12, 13, 14, 15:
for mt in 'L2', 'IP':
add_TestIVFAQFastScan_subtest_accuracy('RQ', 'rq', byr, implem, mt)
add_TestIVFAQFastScan_subtest_accuracy('LSQ', 'lsq', byr, implem, mt)