1
0
mirror of https://github.com/facebookresearch/faiss.git synced 2025-06-03 21:54:02 +08:00
faiss/benchs/bench_quantizer.py
Chengqi Deng c087f87730 Add LocalSearchQuantizer ()
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1906

This PR implemented LSQ/LSQ++, a vector quantization technique described in the following two papers:

1. Revisiting additive quantization
2. LSQ++: Lower running time and higher recall in multi-codebook quantization

Here is a benchmark running on SIFT1M for 64 bits encoding:
```
===== lsq:
        mean square error = 17335.390208
        training time: 312.729779958725 s
        encoding time: 244.6277096271515 s
===== pq:
        mean square error = 23743.004672
        training time: 1.1610801219940186 s
        encoding time: 2.636141061782837 s
===== rq:
        mean square error = 20999.737344
        training time: 31.813055515289307 s
        encoding time: 307.51959800720215 s
```

Changes:

1. Add LocalSearchQuantizer object
2. Fix an out of memory bug in ResidualQuantizer
3. Add a benchmark for evaluating quantizers
4. Add tests for LocalSearchQuantizer

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1862

Test Plan:
```
buck test //faiss/tests/:test_lsq

buck run mode/opt //faiss/benchs/:bench_quantizer -- lsq pq rq
```

Reviewed By: beauby

Differential Revision: D28376369

Pulled By: mdouze

fbshipit-source-id: 2a394d38bf75b9de0a1c2cd6faddf7dd362a6fa8
2021-05-21 01:33:55 -07:00

65 lines
1.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import faiss
import time
try:
from faiss.contrib.datasets_fb import DatasetSIFT1M
except ImportError:
from faiss.contrib.datasets import DatasetSIFT1M
def eval_codec(q, xb):
t0 = time.time()
codes = q.compute_codes(xb)
t1 = time.time()
decoded = q.decode(codes)
return ((xb - decoded) ** 2).sum() / xb.shape[0], t1 - t0
def eval_quantizer(q, xb, xt, name):
t0 = time.time()
q.train(xt)
t1 = time.time()
train_t = t1 - t0
err, encode_t = eval_codec(q, xb)
print(f'===== {name}:')
print(f'\tmean square error = {err}')
print(f'\ttraining time: {train_t} s')
print(f'\tencoding time: {encode_t} s')
todo = sys.argv[1:]
ds = DatasetSIFT1M()
xq = ds.get_queries()
xb = ds.get_database()
gt = ds.get_groundtruth()
xt = ds.get_train()
nb, d = xb.shape
nq, d = xq.shape
nt, d = xt.shape
M = 8
nbits = 8
if 'lsq' in todo:
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
lsq.log_level = 2 # show detailed training progress
eval_quantizer(lsq, xb, xt, 'lsq')
if 'pq' in todo:
pq = faiss.ProductQuantizer(d, M, nbits)
eval_quantizer(pq, xb, xt, 'pq')
if 'rq' in todo:
rq = faiss.ResidualQuantizer(d, M, nbits)
rq.train_type = faiss.ResidualQuantizer.Train_default
rq.verbose = True
eval_quantizer(rq, xb, xt, 'rq')