faiss/demos/demo_qinco.py
Sergii Dymchenko 383b5d908c Use weights_only for load (#3796)
Summary:
`torch.load` without `weights_only` parameter is unsafe. Explicitly set `weights_only` to False only if you trust the data you load and full pickle functionality is needed, otherwise set `weights_only=True`.

If `weights_only=True` doesn't work for some cases, then explicit `weights_only=False` should be used.

Found with https://github.com/pytorch-labs/torchfix/

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3796

Reviewed By: asadoughi

Differential Revision: D61824340

Pulled By: kit1980

fbshipit-source-id: bc013d06d4f368f730ffee6898e75fd0b0ff1d40
2024-08-30 12:01:55 -07:00

78 lines
2.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This demonstrates how to reproduce the QINCo paper results using the Faiss
QINCo implementation. The code loads the reference model because training
is not implemented in Faiss.
Prepare the data with
cd /tmp
# get the reference qinco code
git clone https://github.com/facebookresearch/Qinco.git
# get the data
wget https://dl.fbaipublicfiles.com/QINCo/datasets/bigann/bigann1M.bvecs
# get the model
wget https://dl.fbaipublicfiles.com/QINCo/models/bigann_8x8_L2.pt
"""
import numpy as np
from faiss.contrib.vecs_io import bvecs_mmap
import sys
import time
import torch
import faiss
# make sure pickle deserialization will work
sys.path.append("/tmp/Qinco")
import model_qinco
with torch.no_grad():
qinco = torch.load("/tmp/bigann_8x8_L2.pt", weights_only=False)
qinco.eval()
# print(qinco)
if True:
torch.set_num_threads(1)
faiss.omp_set_num_threads(1)
x_base = bvecs_mmap("/tmp/bigann1M.bvecs")[:1000].astype('float32')
x_scaled = torch.from_numpy(x_base) / qinco.db_scale
t0 = time.time()
codes, _ = qinco.encode(x_scaled)
x_decoded_scaled = qinco.decode(codes)
print(f"Pytorch encode {time.time() - t0:.3f} s")
# multi-thread: 1.13s, single-thread: 7.744
x_decoded = x_decoded_scaled.numpy() * qinco.db_scale
err = ((x_decoded - x_base) ** 2).sum(1).mean()
print("MSE=", err) # = 14211.956, near the L=2 result in Fig 4 of the paper
qinco2 = faiss.QINCo(qinco)
t0 = time.time()
codes2 = qinco2.encode(faiss.Tensor2D(x_scaled))
x_decoded2 = qinco2.decode(codes2).numpy() * qinco.db_scale
print(f"Faiss encode {time.time() - t0:.3f} s")
# multi-thread: 3.2s, single thread: 7.019
# these tests don't work because there are outlier encodings
# np.testing.assert_array_equal(codes.numpy(), codes2.numpy())
# np.testing.assert_allclose(x_decoded, x_decoded2)
ndiff = (codes.numpy() != codes2.numpy()).sum() / codes.numel()
assert ndiff < 0.01
ndiff = (((x_decoded - x_decoded2) ** 2).sum(1) > 1e-5).sum()
assert ndiff / len(x_base) < 0.01
err = ((x_decoded2 - x_base) ** 2).sum(1).mean()
print("MSE=", err) # = 14213.551