1
0
mirror of https://github.com/facebookresearch/faiss.git synced 2025-06-03 21:54:02 +08:00
faiss/benchs/bench_quantizer.py
Check Deng 48ae55348a Update codebooks with double type ()
Summary:
## Description

The process of updating the codebook in LSQ may be unstable if the data is not zero-centering. This diff fixed it by using `double` instead of `float` during codebook updating. This would not affect the performance since the update process is quite fast.

Users could switch back to `float` mode by setting `update_codebooks_with_double = False`

## Changes

1. Support `double` during codebook updating.
2. Add a unit test.
3. Add `__init__.py` under `contrib/` to avoid warnings.

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1975

Reviewed By: wickedfoo

Differential Revision: D29565632

Pulled By: mdouze

fbshipit-source-id: 932d7932ae9725c299cd83f87495542703ad6654
2021-07-07 03:29:49 -07:00

65 lines
1.4 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import faiss
import time
try:
from faiss.contrib.datasets_fb import DatasetSIFT1M
except ImportError:
from faiss.contrib.datasets import DatasetSIFT1M
def eval_codec(q, xb):
t0 = time.time()
codes = q.compute_codes(xb)
t1 = time.time()
decoded = q.decode(codes)
return ((xb - decoded) ** 2).sum() / xb.shape[0], t1 - t0
def eval_quantizer(q, xb, xt, name):
t0 = time.time()
q.train(xt)
t1 = time.time()
train_t = t1 - t0
err, encode_t = eval_codec(q, xb)
print(f'===== {name}:')
print(f'\tmean square error = {err}')
print(f'\ttraining time: {train_t} s')
print(f'\tencoding time: {encode_t} s')
todo = sys.argv[1:]
ds = DatasetSIFT1M()
xq = ds.get_queries()
xb = ds.get_database()
gt = ds.get_groundtruth()
xt = ds.get_train()
nb, d = xb.shape
nq, d = xq.shape
nt, d = xt.shape
M = 8
nbits = 8
if 'lsq' in todo:
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
lsq.verbose = True
eval_quantizer(lsq, xb, xt, 'lsq')
if 'pq' in todo:
pq = faiss.ProductQuantizer(d, M, nbits)
eval_quantizer(pq, xb, xt, 'pq')
if 'rq' in todo:
rq = faiss.ResidualQuantizer(d, M, nbits)
rq.train_type = faiss.ResidualQuantizer.Train_default
rq.verbose = True
eval_quantizer(rq, xb, xt, 'rq')