faiss/benchs/bench_pq_tables.py
Matthijs Douze e1adde0d84 Faster brute force search (#1502)
Summary:
This diff streamlines the code that collects results for brute force distance computations for the L2 / IP and range search / knn search combinations.

It introduces a `ResultHandler` template class that abstracts what happens with the computed distances and ids. In addition to the heap result handler and the range search result handler, it introduces a reservoir result handler that improves the search speed for  large k (>=100).

Benchmark results (https://fb.quip.com/y0g1ACLEqJXx#OCaACA2Gm45) show that on small datasets (10k) search is 10-50% faster (improvements are larger for small k). There is room for improvement in the reservoir implementation, whose implementation is quite naive currently, but the diff is already useful in its current form.

Experiments on precomputed db vector norms for L2 distance computations were not very concluding performance-wise, so the implementation is removed from IndexFlatL2.

This diff also removes IndexL2BaseShift, which was never used.

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1502

Test Plan:
```
buck test //faiss/tests/:test_product_quantizer
buck test //faiss/tests/:test_index -- TestIndexFlat
```

Reviewed By: wickedfoo

Differential Revision: D24705464

Pulled By: mdouze

fbshipit-source-id: 270e10b19f3c89ed7b607ec30549aca0ac5027fe
2020-11-04 22:16:23 -08:00

79 lines
2.2 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
import os
import numpy as np
import faiss
os.system("grep -m1 'model name' < /proc/cpuinfo")
def format_tab(x):
return "\n".join("\t".join("%g" % xi for xi in row) for row in x)
def run_bench(d, dsub, nbit=8, metric=None):
M = d // dsub
pq = faiss.ProductQuantizer(d, M, nbit)
pq.train(faiss.randn((max(1000, pq.ksub * 50), d), 123))
sp = faiss.swig_ptr
times = []
nrun = 100
print(f"d={d} dsub={dsub} ksub={pq.ksub}", end="\t")
res = []
for nx in 1, 10, 100:
x = faiss.randn((nx, d), 555)
times = []
for run in range(nrun):
t0 = time.time()
new_tab = np.zeros((nx, M, pq.ksub), "float32")
if metric == faiss.METRIC_INNER_PRODUCT:
pq.compute_inner_prod_tables(nx, sp(x), sp(new_tab))
elif metric == faiss.METRIC_L2:
pq.compute_distance_tables(nx, sp(x), sp(new_tab))
else:
assert False
t1 = time.time()
if run >= nrun // 5: # the rest is considered warmup
times.append((t1 - t0))
times = np.array(times) * 1000
print(f"nx={nx}: {np.mean(times):.3f} ms (± {np.std(times):.4f})",
end="\t")
res.append(times.mean())
print()
return res
# for have_threads in True, False:
for have_threads in False, True:
if have_threads:
# good config for Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz
nthread = 32
else:
nthread = 1
faiss.omp_set_num_threads(nthread)
for metric in faiss.METRIC_INNER_PRODUCT, faiss.METRIC_L2:
print("============= nthread=", nthread, "metric=", metric)
allres = []
for dsub in 2, 4, 8:
for nbit in 4, 8:
for M in 8, 20:
res = run_bench(M * dsub, dsub, nbit, metric)
allres.append(res)
allres = np.array(allres)
print("formated result:")
print(format_tab(allres))