Better NaN handling (#2986)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2986 A NaN vector is a vector with at least one NaN (not-a-number) entry. After discussion in the Faiss team we decided that: - training should throw an exception on NaN vectors - added NaN vectors should be ignored (never returned) - searched NaN vectors should return only -1s This diff implements this for a few common index types + adds relevant tests. Reviewed By: algoriddle Differential Revision: D48031390 fbshipit-source-id: 99e7786582e91950e3a53c1d8bcffdd00b6afd24pull/2993/head
parent
a4ddb18605
commit
a3fbf2d61c
|
@ -9,6 +9,7 @@
|
|||
|
||||
#include <faiss/impl/HNSW.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
@ -542,12 +543,11 @@ int search_from_candidates(
|
|||
for (int i = 0; i < candidates.size(); i++) {
|
||||
idx_t v1 = candidates.ids[i];
|
||||
float d = candidates.dis[i];
|
||||
FAISS_ASSERT(v1 >= 0);
|
||||
assert(v1 >= 0);
|
||||
if (!sel || sel->is_member(v1)) {
|
||||
if (nres < k) {
|
||||
faiss::maxheap_push(++nres, D, I, d, v1);
|
||||
} else if (d < D[0]) {
|
||||
faiss::maxheap_replace_top(nres, D, I, d, v1);
|
||||
if (d < D[0]) {
|
||||
faiss::maxheap_replace_top(k, D, I, d, v1);
|
||||
nres++;
|
||||
}
|
||||
}
|
||||
vt.set(v1);
|
||||
|
@ -612,10 +612,9 @@ int search_from_candidates(
|
|||
|
||||
auto add_to_heap = [&](const size_t idx, const float dis) {
|
||||
if (!sel || sel->is_member(idx)) {
|
||||
if (nres < k) {
|
||||
faiss::maxheap_push(++nres, D, I, dis, idx);
|
||||
} else if (dis < D[0]) {
|
||||
faiss::maxheap_replace_top(nres, D, I, dis, idx);
|
||||
if (dis < D[0]) {
|
||||
faiss::maxheap_replace_top(k, D, I, dis, idx);
|
||||
nres++;
|
||||
}
|
||||
}
|
||||
candidates.push(idx, dis);
|
||||
|
@ -668,7 +667,7 @@ int search_from_candidates(
|
|||
stats.n3 += ndis;
|
||||
}
|
||||
|
||||
return nres;
|
||||
return std::min(nres, k);
|
||||
}
|
||||
|
||||
std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
||||
|
@ -816,6 +815,11 @@ HNSWStats HNSW::search(
|
|||
// greedy search on upper levels
|
||||
storage_idx_t nearest = entry_point;
|
||||
float d_nearest = qdis(nearest);
|
||||
if (!std::isfinite(d_nearest)) {
|
||||
// means either the query or the entry point are NaN: in
|
||||
// both cases we can only return -1 as a result
|
||||
return stats;
|
||||
}
|
||||
|
||||
for (int level = max_level; level >= 1; level--) {
|
||||
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
||||
|
@ -826,7 +830,6 @@ HNSWStats HNSW::search(
|
|||
MinimaxHeap candidates(ef);
|
||||
|
||||
candidates.push(nearest, d_nearest);
|
||||
|
||||
search_from_candidates(
|
||||
*this, qdis, k, I, D, candidates, vt, stats, 0, 0, params);
|
||||
} else {
|
||||
|
|
|
@ -445,8 +445,8 @@ struct SingleBestResultHandler {
|
|||
/// begin results for query # i
|
||||
void begin(const size_t current_idx) {
|
||||
this->current_idx = current_idx;
|
||||
min_dis = HUGE_VALF;
|
||||
min_idx = 0;
|
||||
min_dis = C::neutral();
|
||||
min_idx = -1;
|
||||
}
|
||||
|
||||
/// add one result for query i
|
||||
|
@ -472,7 +472,8 @@ struct SingleBestResultHandler {
|
|||
this->i1 = i1;
|
||||
|
||||
for (size_t i = i0; i < i1; i++) {
|
||||
this->dis_tab[i] = HUGE_VALF;
|
||||
this->dis_tab[i] = C::neutral();
|
||||
this->ids_tab[i] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1075,6 +1075,11 @@ void ScalarQuantizer::set_derived_sizes() {
|
|||
}
|
||||
|
||||
void ScalarQuantizer::train(size_t n, const float* x) {
|
||||
for (size_t i = 0; i < n * d; i++) {
|
||||
FAISS_THROW_IF_NOT_MSG(
|
||||
std::isfinite(x[i]), "training data contains NaN or Inf");
|
||||
}
|
||||
|
||||
int bit_per_dim = qtype == QT_4bit_uniform ? 4
|
||||
: qtype == QT_4bit ? 4
|
||||
: qtype == QT_6bit ? 6
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""This script tests a few failure cases of Faiss and whether they are handled
|
||||
properly."""
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import faiss
|
||||
|
||||
from common_faiss_tests import get_dataset_2
|
||||
from faiss.contrib.datasets import SyntheticDataset
|
||||
|
||||
|
||||
class TestValidIndexParams(unittest.TestCase):
|
||||
|
||||
def test_IndexIVFPQ(self):
|
||||
d = 32
|
||||
nb = 1000
|
||||
nt = 1500
|
||||
nq = 200
|
||||
|
||||
(xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
|
||||
|
||||
coarse_quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8)
|
||||
index.cp.min_points_per_centroid = 5 # quiet warning
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
|
||||
# invalid nprobe
|
||||
index.nprobe = 0
|
||||
k = 10
|
||||
self.assertRaises(RuntimeError, index.search, xq, k)
|
||||
|
||||
# invalid k
|
||||
index.nprobe = 4
|
||||
k = -10
|
||||
self.assertRaises(AssertionError, index.search, xq, k)
|
||||
|
||||
# valid params
|
||||
index.nprobe = 4
|
||||
k = 10
|
||||
D, nns = index.search(xq, k)
|
||||
|
||||
self.assertEqual(D.shape[0], nq)
|
||||
self.assertEqual(D.shape[1], k)
|
||||
|
||||
def test_IndexFlat(self):
|
||||
d = 32
|
||||
nb = 1000
|
||||
nt = 0
|
||||
nq = 200
|
||||
|
||||
(xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
|
||||
index = faiss.IndexFlat(d, faiss.METRIC_L2)
|
||||
|
||||
index.add(xb)
|
||||
|
||||
# invalid k
|
||||
k = -5
|
||||
self.assertRaises(AssertionError, index.search, xq, k)
|
||||
|
||||
# valid k
|
||||
k = 5
|
||||
D, I = index.search(xq, k)
|
||||
|
||||
self.assertEqual(D.shape[0], nq)
|
||||
self.assertEqual(D.shape[1], k)
|
||||
|
||||
|
||||
class TestReconsException(unittest.TestCase):
|
||||
|
||||
def test_recons_exception(self):
|
||||
|
||||
d = 64 # dimension
|
||||
nb = 1000
|
||||
rs = np.random.RandomState(1234)
|
||||
xb = rs.rand(nb, d).astype('float32')
|
||||
nlist = 10
|
||||
quantizer = faiss.IndexFlatL2(d) # the other index
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
index.train(xb)
|
||||
index.add(xb)
|
||||
index.make_direct_map()
|
||||
|
||||
index.reconstruct(9)
|
||||
|
||||
self.assertRaises(
|
||||
RuntimeError,
|
||||
index.reconstruct, 100001
|
||||
)
|
||||
|
||||
def test_reconstuct_after_add(self):
|
||||
index = faiss.index_factory(10, 'IVF5,SQfp16')
|
||||
index.train(faiss.randn((100, 10), 123))
|
||||
index.add(faiss.randn((100, 10), 345))
|
||||
index.make_direct_map()
|
||||
index.add(faiss.randn((100, 10), 678))
|
||||
|
||||
# should not raise an exception
|
||||
index.reconstruct(5)
|
||||
print(index.ntotal)
|
||||
index.reconstruct(150)
|
||||
|
||||
|
||||
class TestNaN(unittest.TestCase):
|
||||
""" NaN values handling is transparent: they don't produce results
|
||||
but should not crash. The tests below cover a few common index types.
|
||||
"""
|
||||
|
||||
def do_test_train(self, factory_string):
|
||||
""" NaN and Inf should raise an exception at train time """
|
||||
ds = SyntheticDataset(32, 200, 20, 10)
|
||||
index = faiss.index_factory(ds.d, factory_string)
|
||||
# try to train with NaNs
|
||||
xt = ds.get_train().copy()
|
||||
xt[:, ::4] = np.nan
|
||||
self.assertRaises(RuntimeError, index.train, xt)
|
||||
|
||||
def test_train_IVFSQ(self):
|
||||
self.do_test_train("IVF10,SQ8")
|
||||
|
||||
def test_train_IVFPQ(self):
|
||||
self.do_test_train("IVF10,PQ4np")
|
||||
|
||||
def test_train_SQ(self):
|
||||
self.do_test_train("SQ8")
|
||||
|
||||
def do_test_add(self, factory_string):
|
||||
""" stored NaNs should not be returned at search time """
|
||||
ds = SyntheticDataset(32, 200, 20, 10)
|
||||
index = faiss.index_factory(ds.d, factory_string)
|
||||
if not index.is_trained:
|
||||
index.train(ds.get_train())
|
||||
xb = ds.get_database()
|
||||
xb[12, 3] = np.nan
|
||||
index.add(xb)
|
||||
D, I = index.search(ds.get_queries(), 20)
|
||||
self.assertTrue(np.where(I == 12)[0].size == 0)
|
||||
|
||||
def test_add_Flat(self):
|
||||
self.do_test_add("Flat")
|
||||
|
||||
def test_add_HNSW(self):
|
||||
self.do_test_add("HNSW32,Flat")
|
||||
|
||||
def xx_test_add_SQ8(self):
|
||||
# this is expected to fail because:
|
||||
# in ASAN mode, the float NaN -> int conversion crashes
|
||||
# in opt mode it works but there is no way to encode the NaN,
|
||||
# so the value cannot be ignored.
|
||||
self.do_test_add("SQ8")
|
||||
|
||||
def test_add_IVFFlat(self):
|
||||
self.do_test_add("IVF10,Flat")
|
||||
|
||||
def do_test_search(self, factory_string):
|
||||
""" NaN query vectors should return -1 """
|
||||
ds = SyntheticDataset(32, 200, 20, 10)
|
||||
index = faiss.index_factory(ds.d, factory_string)
|
||||
if not index.is_trained:
|
||||
index.train(ds.get_train())
|
||||
index.add(ds.get_database())
|
||||
xq = ds.get_queries()
|
||||
xq[7, 3] = np.nan
|
||||
D, I = index.search(ds.get_queries(), 20)
|
||||
self.assertTrue(np.all(I[7] == -1))
|
||||
|
||||
def test_search_Flat(self):
|
||||
self.do_test_search("Flat")
|
||||
|
||||
def test_search_HNSW(self):
|
||||
self.do_test_search("HNSW32,Flat")
|
||||
|
||||
def test_search_IVFFlat(self):
|
||||
self.do_test_search("IVF10,Flat")
|
||||
|
||||
def test_search_SQ(self):
|
||||
self.do_test_search("SQ8")
|
|
@ -4,8 +4,6 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""this is a basic test script for simple indices work"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
# no unicode_literals because it messes up in py2
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
@ -13,16 +11,17 @@ import faiss
|
|||
import tempfile
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
|
||||
from common_faiss_tests import get_dataset, get_dataset_2
|
||||
|
||||
|
||||
class TestModuleInterface(unittest.TestCase):
|
||||
|
||||
def test_version_attribute(self):
|
||||
assert hasattr(faiss, '__version__')
|
||||
assert re.match('^\\d+\\.\\d+\\.\\d+$', faiss.__version__)
|
||||
|
||||
|
||||
class TestIndexFlat(unittest.TestCase):
|
||||
|
||||
def do_test(self, nq, metric_type=faiss.METRIC_L2, k=10):
|
||||
|
@ -109,6 +108,14 @@ class TestIndexFlat(unittest.TestCase):
|
|||
def test_with_blas_reservoir_ip(self):
|
||||
self.do_test(200, faiss.METRIC_INNER_PRODUCT, k=150)
|
||||
|
||||
def test_noblas_1res(self):
|
||||
self.do_test(10, k=1)
|
||||
|
||||
def test_with_blas_1res(self):
|
||||
self.do_test(200, k=1)
|
||||
|
||||
def test_with_blas_1res_ip(self):
|
||||
self.do_test(200, faiss.METRIC_INNER_PRODUCT, k=1)
|
||||
|
||||
class TestIndexFlatL2(unittest.TestCase):
|
||||
def test_indexflat_l2_sync_norms_1(self):
|
||||
|
@ -1007,41 +1014,6 @@ class TestShardReplicas(unittest.TestCase):
|
|||
index.remove_replica(index1)
|
||||
self.assertEqual(index.ntotal, 0)
|
||||
|
||||
class TestReconsException(unittest.TestCase):
|
||||
|
||||
def test_recons_exception(self):
|
||||
|
||||
d = 64 # dimension
|
||||
nb = 1000
|
||||
rs = np.random.RandomState(1234)
|
||||
xb = rs.rand(nb, d).astype('float32')
|
||||
nlist = 10
|
||||
quantizer = faiss.IndexFlatL2(d) # the other index
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
index.train(xb)
|
||||
index.add(xb)
|
||||
index.make_direct_map()
|
||||
|
||||
index.reconstruct(9)
|
||||
|
||||
self.assertRaises(
|
||||
RuntimeError,
|
||||
index.reconstruct, 100001
|
||||
)
|
||||
|
||||
def test_reconstuct_after_add(self):
|
||||
index = faiss.index_factory(10, 'IVF5,SQfp16')
|
||||
index.train(faiss.randn((100, 10), 123))
|
||||
index.add(faiss.randn((100, 10), 345))
|
||||
index.make_direct_map()
|
||||
index.add(faiss.randn((100, 10), 678))
|
||||
|
||||
# should not raise an exception
|
||||
index.reconstruct(5)
|
||||
print(index.ntotal)
|
||||
index.reconstruct(150)
|
||||
|
||||
|
||||
class TestReconsHash(unittest.TestCase):
|
||||
|
||||
def do_test(self, index_key):
|
||||
|
@ -1113,62 +1085,6 @@ class TestReconsHash(unittest.TestCase):
|
|||
self.do_test("IVF5,PQ4x4np")
|
||||
|
||||
|
||||
class TestValidIndexParams(unittest.TestCase):
|
||||
|
||||
def test_IndexIVFPQ(self):
|
||||
d = 32
|
||||
nb = 1000
|
||||
nt = 1500
|
||||
nq = 200
|
||||
|
||||
(xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
|
||||
|
||||
coarse_quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8)
|
||||
index.cp.min_points_per_centroid = 5 # quiet warning
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
|
||||
# invalid nprobe
|
||||
index.nprobe = 0
|
||||
k = 10
|
||||
self.assertRaises(RuntimeError, index.search, xq, k)
|
||||
|
||||
# invalid k
|
||||
index.nprobe = 4
|
||||
k = -10
|
||||
self.assertRaises(AssertionError, index.search, xq, k)
|
||||
|
||||
# valid params
|
||||
index.nprobe = 4
|
||||
k = 10
|
||||
D, nns = index.search(xq, k)
|
||||
|
||||
self.assertEqual(D.shape[0], nq)
|
||||
self.assertEqual(D.shape[1], k)
|
||||
|
||||
def test_IndexFlat(self):
|
||||
d = 32
|
||||
nb = 1000
|
||||
nt = 0
|
||||
nq = 200
|
||||
|
||||
(xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
|
||||
index = faiss.IndexFlat(d, faiss.METRIC_L2)
|
||||
|
||||
index.add(xb)
|
||||
|
||||
# invalid k
|
||||
k = -5
|
||||
self.assertRaises(AssertionError, index.search, xq, k)
|
||||
|
||||
# valid k
|
||||
k = 5
|
||||
D, I = index.search(xq, k)
|
||||
|
||||
self.assertEqual(D.shape[0], nq)
|
||||
self.assertEqual(D.shape[1], k)
|
||||
|
||||
|
||||
class TestLargeRangeSearch(unittest.TestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue