faiss/demos/demo_qinco.py
Matthijs Douze dd72e4121d QINCo implementation in CPU Faiss (#3608)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3608

This is a straightforward implementation of QINCo in CPU Faiss, with encoding and decoding capabilities (not training).

For this, we translate a simplified version of some torch classes:

- tensors, restricted to 2D and int32 + float32

- Linear and Embedding layer

Then the QINCoStep and QINCo can just be defined as C++ objects that are copy-constructable.

There is some plumbing required in the wrapping layers to support the integration. Pytroch tensors are converted to numpy for getting / setting them in C++.

Reviewed By: asadoughi

Differential Revision: D59132952

fbshipit-source-id: eea4856507a5b7c5f219efcf8d19fe56944df088
2024-07-11 02:40:38 -07:00

78 lines
2.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This demonstrates how to reproduce the QINCo paper results using the Faiss
QINCo implementation. The code loads the reference model because training
is not implemented in Faiss.
Prepare the data with
cd /tmp
# get the reference qinco code
git clone https://github.com/facebookresearch/Qinco.git
# get the data
wget https://dl.fbaipublicfiles.com/QINCo/datasets/bigann/bigann1M.bvecs
# get the model
wget https://dl.fbaipublicfiles.com/QINCo/models/bigann_8x8_L2.pt
"""
import numpy as np
from faiss.contrib.vecs_io import bvecs_mmap
import sys
import time
import torch
import faiss
# make sure pickle deserialization will work
sys.path.append("/tmp/Qinco")
import model_qinco
with torch.no_grad():
qinco = torch.load("/tmp/bigann_8x8_L2.pt")
qinco.eval()
# print(qinco)
if True:
torch.set_num_threads(1)
faiss.omp_set_num_threads(1)
x_base = bvecs_mmap("/tmp/bigann1M.bvecs")[:1000].astype('float32')
x_scaled = torch.from_numpy(x_base) / qinco.db_scale
t0 = time.time()
codes, _ = qinco.encode(x_scaled)
x_decoded_scaled = qinco.decode(codes)
print(f"Pytorch encode {time.time() - t0:.3f} s")
# multi-thread: 1.13s, single-thread: 7.744
x_decoded = x_decoded_scaled.numpy() * qinco.db_scale
err = ((x_decoded - x_base) ** 2).sum(1).mean()
print("MSE=", err) # = 14211.956, near the L=2 result in Fig 4 of the paper
qinco2 = faiss.QINCo(qinco)
t0 = time.time()
codes2 = qinco2.encode(faiss.Tensor2D(x_scaled))
x_decoded2 = qinco2.decode(codes2).numpy() * qinco.db_scale
print(f"Faiss encode {time.time() - t0:.3f} s")
# multi-thread: 3.2s, single thread: 7.019
# these tests don't work because there are outlier encodings
# np.testing.assert_array_equal(codes.numpy(), codes2.numpy())
# np.testing.assert_allclose(x_decoded, x_decoded2)
ndiff = (codes.numpy() != codes2.numpy()).sum() / codes.numel()
assert ndiff < 0.01
ndiff = (((x_decoded - x_decoded2) ** 2).sum(1) > 1e-5).sum()
assert ndiff / len(x_base) < 0.01
err = ((x_decoded2 - x_base) ** 2).sum(1).mean()
print("MSE=", err) # = 14213.551