mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2439 Index wrapper that performs rowwise normalization to [0,1], preserving the coefficients. This is a vector codec index only. Basically, this index performs a rowwise scaling to [0,1] of every row in an input dataset before calling subindex::train() and subindex::sa_encode(). sa_encode() call stores the scaling coefficients (scaler and minv) in the very beginning of every output code. The format: [scaler][minv][subindex::sa_encode() output] The de-scaling in sa_decode() is done using: output_rescaled = scaler * output + minv An additional ::train_inplace() function is provided in order to do an inplace scaling before calling subindex::train() and, thus, avoiding the cloning of the input dataset, but modifying the input dataset because of the scaling and the scaling back. Derived classes provide different data types for scaling coefficients. Currently, versions with fp16 and fp32 scaling coefficients are available. * fp16 version adds 4 extra bytes per encoded vector * fp32 version adds 8 extra bytes per encoded vector Reviewed By: mdouze Differential Revision: D38581012 fbshipit-source-id: d739878f1db62ac5ab9e0db3f84aeb2b70a1b6c0
57 lines
1.5 KiB
Python
57 lines
1.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
|
|
import faiss
|
|
import unittest
|
|
|
|
from common_faiss_tests import get_dataset_2
|
|
|
|
|
|
class TestIndexRowwiseMinmax(unittest.TestCase):
|
|
def compare_train_vs_train_inplace(self, factory_key):
|
|
d = 96
|
|
nb = 1000
|
|
nq = 0
|
|
nt = 2000
|
|
|
|
xt, x, _ = get_dataset_2(d, nt, nb, nq)
|
|
|
|
assert x.size > 0
|
|
|
|
codec = faiss.index_factory(d, factory_key)
|
|
|
|
# use the regular .train()
|
|
codec.train(xt)
|
|
codes_train = codec.sa_encode(x)
|
|
|
|
decoded = codec.sa_decode(codes_train)
|
|
|
|
# use .train_inplace()
|
|
xt_cloned = np.copy(xt)
|
|
codec.train_inplace(xt_cloned)
|
|
codes_train_inplace = codec.sa_encode(x)
|
|
|
|
# compare .train and .train_inplace codes
|
|
n_diff = (codes_train != codes_train_inplace).sum()
|
|
self.assertEqual(n_diff, 0)
|
|
|
|
# make sure that the array used for .train_inplace got affected
|
|
n_diff_xt = (xt_cloned != xt).sum()
|
|
self.assertNotEqual(n_diff_xt, 0)
|
|
|
|
# make sure that the reconstruction error is not crazy
|
|
reconstruction_err = ((x - decoded) ** 2).sum()
|
|
print(reconstruction_err)
|
|
|
|
self.assertLess(reconstruction_err, 0.6)
|
|
|
|
def test_fp32(self) -> None:
|
|
self.compare_train_vs_train_inplace("MinMax,SQ8")
|
|
|
|
def test_fp16(self) -> None:
|
|
self.compare_train_vs_train_inplace("MinMaxFP16,SQ8")
|