faiss/tests/test_rowwise_minmax.py
Alexandr Guzhva 1e4586a5a0 IndexRowwiseMinMax (#2439)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2439

Index wrapper that performs rowwise normalization to [0,1], preserving the coefficients. This is a vector codec index only.

Basically, this index performs a rowwise scaling to [0,1] of every row in an input dataset before calling subindex::train() and subindex::sa_encode(). sa_encode() call stores the scaling coefficients (scaler and minv) in the very beginning of every output code. The format:
    [scaler][minv][subindex::sa_encode() output]
The de-scaling in sa_decode() is done using:
    output_rescaled = scaler * output + minv

An additional ::train_inplace() function is provided in order to do an inplace scaling before calling subindex::train() and, thus, avoiding the cloning of the input dataset, but modifying the input dataset because of the scaling and the scaling back.

Derived classes provide different data types for scaling coefficients. Currently, versions with fp16 and fp32 scaling coefficients are available.
* fp16 version adds 4 extra bytes per encoded vector
* fp32 version adds 8 extra bytes per encoded vector

Reviewed By: mdouze

Differential Revision: D38581012

fbshipit-source-id: d739878f1db62ac5ab9e0db3f84aeb2b70a1b6c0
2022-09-05 06:59:41 -07:00

57 lines
1.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import faiss
import unittest
from common_faiss_tests import get_dataset_2
class TestIndexRowwiseMinmax(unittest.TestCase):
def compare_train_vs_train_inplace(self, factory_key):
d = 96
nb = 1000
nq = 0
nt = 2000
xt, x, _ = get_dataset_2(d, nt, nb, nq)
assert x.size > 0
codec = faiss.index_factory(d, factory_key)
# use the regular .train()
codec.train(xt)
codes_train = codec.sa_encode(x)
decoded = codec.sa_decode(codes_train)
# use .train_inplace()
xt_cloned = np.copy(xt)
codec.train_inplace(xt_cloned)
codes_train_inplace = codec.sa_encode(x)
# compare .train and .train_inplace codes
n_diff = (codes_train != codes_train_inplace).sum()
self.assertEqual(n_diff, 0)
# make sure that the array used for .train_inplace got affected
n_diff_xt = (xt_cloned != xt).sum()
self.assertNotEqual(n_diff_xt, 0)
# make sure that the reconstruction error is not crazy
reconstruction_err = ((x - decoded) ** 2).sum()
print(reconstruction_err)
self.assertLess(reconstruction_err, 0.6)
def test_fp32(self) -> None:
self.compare_train_vs_train_inplace("MinMax,SQ8")
def test_fp16(self) -> None:
self.compare_train_vs_train_inplace("MinMaxFP16,SQ8")