From a3fbf2d61c7d4cb24e1c3f08ba08f2a575196c30 Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Fri, 4 Aug 2023 06:51:06 -0700 Subject: [PATCH] Better NaN handling (#2986) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2986 A NaN vector is a vector with at least one NaN (not-a-number) entry. After discussion in the Faiss team we decided that: - training should throw an exception on NaN vectors - added NaN vectors should be ignored (never returned) - searched NaN vectors should return only -1s This diff implements this for a few common index types + adds relevant tests. Reviewed By: algoriddle Differential Revision: D48031390 fbshipit-source-id: 99e7786582e91950e3a53c1d8bcffdd00b6afd24 --- faiss/impl/HNSW.cpp | 25 +++-- faiss/impl/ResultHandler.h | 7 +- faiss/impl/ScalarQuantizer.cpp | 5 + tests/test_error_reporting.py | 182 +++++++++++++++++++++++++++++++++ tests/test_index.py | 104 ++----------------- 5 files changed, 215 insertions(+), 108 deletions(-) create mode 100644 tests/test_error_reporting.py diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 9fc201ea3..bef961353 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -9,6 +9,7 @@ #include +#include #include #include @@ -542,12 +543,11 @@ int search_from_candidates( for (int i = 0; i < candidates.size(); i++) { idx_t v1 = candidates.ids[i]; float d = candidates.dis[i]; - FAISS_ASSERT(v1 >= 0); + assert(v1 >= 0); if (!sel || sel->is_member(v1)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, d, v1); - } else if (d < D[0]) { - faiss::maxheap_replace_top(nres, D, I, d, v1); + if (d < D[0]) { + faiss::maxheap_replace_top(k, D, I, d, v1); + nres++; } } vt.set(v1); @@ -612,10 +612,9 @@ int search_from_candidates( auto add_to_heap = [&](const size_t idx, const float dis) { if (!sel || sel->is_member(idx)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, dis, idx); - } else if (dis < D[0]) { - faiss::maxheap_replace_top(nres, D, I, dis, idx); + if (dis < D[0]) { + faiss::maxheap_replace_top(k, D, I, dis, idx); + nres++; } } candidates.push(idx, dis); @@ -668,7 +667,7 @@ int search_from_candidates( stats.n3 += ndis; } - return nres; + return std::min(nres, k); } std::priority_queue search_from_candidate_unbounded( @@ -816,6 +815,11 @@ HNSWStats HNSW::search( // greedy search on upper levels storage_idx_t nearest = entry_point; float d_nearest = qdis(nearest); + if (!std::isfinite(d_nearest)) { + // means either the query or the entry point are NaN: in + // both cases we can only return -1 as a result + return stats; + } for (int level = max_level; level >= 1; level--) { greedy_update_nearest(*this, qdis, level, nearest, d_nearest); @@ -826,7 +830,6 @@ HNSWStats HNSW::search( MinimaxHeap candidates(ef); candidates.push(nearest, d_nearest); - search_from_candidates( *this, qdis, k, I, D, candidates, vt, stats, 0, 0, params); } else { diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index d096fbcfa..945f68cf9 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -445,8 +445,8 @@ struct SingleBestResultHandler { /// begin results for query # i void begin(const size_t current_idx) { this->current_idx = current_idx; - min_dis = HUGE_VALF; - min_idx = 0; + min_dis = C::neutral(); + min_idx = -1; } /// add one result for query i @@ -472,7 +472,8 @@ struct SingleBestResultHandler { this->i1 = i1; for (size_t i = i0; i < i1; i++) { - this->dis_tab[i] = HUGE_VALF; + this->dis_tab[i] = C::neutral(); + this->ids_tab[i] = -1; } } diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index 680a3bc05..b1da370e6 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -1075,6 +1075,11 @@ void ScalarQuantizer::set_derived_sizes() { } void ScalarQuantizer::train(size_t n, const float* x) { + for (size_t i = 0; i < n * d; i++) { + FAISS_THROW_IF_NOT_MSG( + std::isfinite(x[i]), "training data contains NaN or Inf"); + } + int bit_per_dim = qtype == QT_4bit_uniform ? 4 : qtype == QT_4bit ? 4 : qtype == QT_6bit ? 6 diff --git a/tests/test_error_reporting.py b/tests/test_error_reporting.py new file mode 100644 index 000000000..04a023a36 --- /dev/null +++ b/tests/test_error_reporting.py @@ -0,0 +1,182 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""This script tests a few failure cases of Faiss and whether they are handled +properly.""" + +import numpy as np +import unittest +import faiss + +from common_faiss_tests import get_dataset_2 +from faiss.contrib.datasets import SyntheticDataset + + +class TestValidIndexParams(unittest.TestCase): + + def test_IndexIVFPQ(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) + + coarse_quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.train(xt) + index.add(xb) + + # invalid nprobe + index.nprobe = 0 + k = 10 + self.assertRaises(RuntimeError, index.search, xq, k) + + # invalid k + index.nprobe = 4 + k = -10 + self.assertRaises(AssertionError, index.search, xq, k) + + # valid params + index.nprobe = 4 + k = 10 + D, nns = index.search(xq, k) + + self.assertEqual(D.shape[0], nq) + self.assertEqual(D.shape[1], k) + + def test_IndexFlat(self): + d = 32 + nb = 1000 + nt = 0 + nq = 200 + + (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) + index = faiss.IndexFlat(d, faiss.METRIC_L2) + + index.add(xb) + + # invalid k + k = -5 + self.assertRaises(AssertionError, index.search, xq, k) + + # valid k + k = 5 + D, I = index.search(xq, k) + + self.assertEqual(D.shape[0], nq) + self.assertEqual(D.shape[1], k) + + +class TestReconsException(unittest.TestCase): + + def test_recons_exception(self): + + d = 64 # dimension + nb = 1000 + rs = np.random.RandomState(1234) + xb = rs.rand(nb, d).astype('float32') + nlist = 10 + quantizer = faiss.IndexFlatL2(d) # the other index + index = faiss.IndexIVFFlat(quantizer, d, nlist) + index.train(xb) + index.add(xb) + index.make_direct_map() + + index.reconstruct(9) + + self.assertRaises( + RuntimeError, + index.reconstruct, 100001 + ) + + def test_reconstuct_after_add(self): + index = faiss.index_factory(10, 'IVF5,SQfp16') + index.train(faiss.randn((100, 10), 123)) + index.add(faiss.randn((100, 10), 345)) + index.make_direct_map() + index.add(faiss.randn((100, 10), 678)) + + # should not raise an exception + index.reconstruct(5) + print(index.ntotal) + index.reconstruct(150) + + +class TestNaN(unittest.TestCase): + """ NaN values handling is transparent: they don't produce results + but should not crash. The tests below cover a few common index types. + """ + + def do_test_train(self, factory_string): + """ NaN and Inf should raise an exception at train time """ + ds = SyntheticDataset(32, 200, 20, 10) + index = faiss.index_factory(ds.d, factory_string) + # try to train with NaNs + xt = ds.get_train().copy() + xt[:, ::4] = np.nan + self.assertRaises(RuntimeError, index.train, xt) + + def test_train_IVFSQ(self): + self.do_test_train("IVF10,SQ8") + + def test_train_IVFPQ(self): + self.do_test_train("IVF10,PQ4np") + + def test_train_SQ(self): + self.do_test_train("SQ8") + + def do_test_add(self, factory_string): + """ stored NaNs should not be returned at search time """ + ds = SyntheticDataset(32, 200, 20, 10) + index = faiss.index_factory(ds.d, factory_string) + if not index.is_trained: + index.train(ds.get_train()) + xb = ds.get_database() + xb[12, 3] = np.nan + index.add(xb) + D, I = index.search(ds.get_queries(), 20) + self.assertTrue(np.where(I == 12)[0].size == 0) + + def test_add_Flat(self): + self.do_test_add("Flat") + + def test_add_HNSW(self): + self.do_test_add("HNSW32,Flat") + + def xx_test_add_SQ8(self): + # this is expected to fail because: + # in ASAN mode, the float NaN -> int conversion crashes + # in opt mode it works but there is no way to encode the NaN, + # so the value cannot be ignored. + self.do_test_add("SQ8") + + def test_add_IVFFlat(self): + self.do_test_add("IVF10,Flat") + + def do_test_search(self, factory_string): + """ NaN query vectors should return -1 """ + ds = SyntheticDataset(32, 200, 20, 10) + index = faiss.index_factory(ds.d, factory_string) + if not index.is_trained: + index.train(ds.get_train()) + index.add(ds.get_database()) + xq = ds.get_queries() + xq[7, 3] = np.nan + D, I = index.search(ds.get_queries(), 20) + self.assertTrue(np.all(I[7] == -1)) + + def test_search_Flat(self): + self.do_test_search("Flat") + + def test_search_HNSW(self): + self.do_test_search("HNSW32,Flat") + + def test_search_IVFFlat(self): + self.do_test_search("IVF10,Flat") + + def test_search_SQ(self): + self.do_test_search("SQ8") diff --git a/tests/test_index.py b/tests/test_index.py index 0e828e08c..e850f5aab 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -4,8 +4,6 @@ # LICENSE file in the root directory of this source tree. """this is a basic test script for simple indices work""" -from __future__ import absolute_import, division, print_function -# no unicode_literals because it messes up in py2 import numpy as np import unittest @@ -13,16 +11,17 @@ import faiss import tempfile import os import re -import warnings from common_faiss_tests import get_dataset, get_dataset_2 + class TestModuleInterface(unittest.TestCase): def test_version_attribute(self): assert hasattr(faiss, '__version__') assert re.match('^\\d+\\.\\d+\\.\\d+$', faiss.__version__) + class TestIndexFlat(unittest.TestCase): def do_test(self, nq, metric_type=faiss.METRIC_L2, k=10): @@ -109,6 +108,14 @@ class TestIndexFlat(unittest.TestCase): def test_with_blas_reservoir_ip(self): self.do_test(200, faiss.METRIC_INNER_PRODUCT, k=150) + def test_noblas_1res(self): + self.do_test(10, k=1) + + def test_with_blas_1res(self): + self.do_test(200, k=1) + + def test_with_blas_1res_ip(self): + self.do_test(200, faiss.METRIC_INNER_PRODUCT, k=1) class TestIndexFlatL2(unittest.TestCase): def test_indexflat_l2_sync_norms_1(self): @@ -1007,41 +1014,6 @@ class TestShardReplicas(unittest.TestCase): index.remove_replica(index1) self.assertEqual(index.ntotal, 0) -class TestReconsException(unittest.TestCase): - - def test_recons_exception(self): - - d = 64 # dimension - nb = 1000 - rs = np.random.RandomState(1234) - xb = rs.rand(nb, d).astype('float32') - nlist = 10 - quantizer = faiss.IndexFlatL2(d) # the other index - index = faiss.IndexIVFFlat(quantizer, d, nlist) - index.train(xb) - index.add(xb) - index.make_direct_map() - - index.reconstruct(9) - - self.assertRaises( - RuntimeError, - index.reconstruct, 100001 - ) - - def test_reconstuct_after_add(self): - index = faiss.index_factory(10, 'IVF5,SQfp16') - index.train(faiss.randn((100, 10), 123)) - index.add(faiss.randn((100, 10), 345)) - index.make_direct_map() - index.add(faiss.randn((100, 10), 678)) - - # should not raise an exception - index.reconstruct(5) - print(index.ntotal) - index.reconstruct(150) - - class TestReconsHash(unittest.TestCase): def do_test(self, index_key): @@ -1113,62 +1085,6 @@ class TestReconsHash(unittest.TestCase): self.do_test("IVF5,PQ4x4np") -class TestValidIndexParams(unittest.TestCase): - - def test_IndexIVFPQ(self): - d = 32 - nb = 1000 - nt = 1500 - nq = 200 - - (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) - - coarse_quantizer = faiss.IndexFlatL2(d) - index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8) - index.cp.min_points_per_centroid = 5 # quiet warning - index.train(xt) - index.add(xb) - - # invalid nprobe - index.nprobe = 0 - k = 10 - self.assertRaises(RuntimeError, index.search, xq, k) - - # invalid k - index.nprobe = 4 - k = -10 - self.assertRaises(AssertionError, index.search, xq, k) - - # valid params - index.nprobe = 4 - k = 10 - D, nns = index.search(xq, k) - - self.assertEqual(D.shape[0], nq) - self.assertEqual(D.shape[1], k) - - def test_IndexFlat(self): - d = 32 - nb = 1000 - nt = 0 - nq = 200 - - (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) - index = faiss.IndexFlat(d, faiss.METRIC_L2) - - index.add(xb) - - # invalid k - k = -5 - self.assertRaises(AssertionError, index.search, xq, k) - - # valid k - k = 5 - D, I = index.search(xq, k) - - self.assertEqual(D.shape[0], nq) - self.assertEqual(D.shape[1], k) - class TestLargeRangeSearch(unittest.TestCase):