Fix reconstruct bug when by_residual is false (#2298)
Summary: When I reconstruct with by_residual turned off, the distance was greatly increased. This is because the reconstruct_from_offset function did not check if the by_residual option was off. I fix this bug with simple if statement. (like this https://github.com/facebookresearch/faiss/blob/main/faiss/IndexIVFPQ.cpp#L365) Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2298 Reviewed By: alexanderguzhva Differential Revision: D35746566 Pulled By: mdouze fbshipit-source-id: 50f98c7cc97c7936507573fe41b65a79ecdbc4capull/2291/head
parent
8ffed8c219
commit
b13f47a4da
|
@ -251,13 +251,18 @@ void IndexIVFScalarQuantizer::reconstruct_from_offset(
|
|||
int64_t list_no,
|
||||
int64_t offset,
|
||||
float* recons) const {
|
||||
std::vector<float> centroid(d);
|
||||
quantizer->reconstruct(list_no, centroid.data());
|
||||
|
||||
const uint8_t* code = invlists->get_single_code(list_no, offset);
|
||||
sq.decode(code, recons, 1);
|
||||
for (int i = 0; i < d; ++i) {
|
||||
recons[i] += centroid[i];
|
||||
|
||||
if (by_residual) {
|
||||
std::vector<float> centroid(d);
|
||||
quantizer->reconstruct(list_no, centroid.data());
|
||||
|
||||
sq.decode(code, recons, 1);
|
||||
for (int i = 0; i < d; ++i) {
|
||||
recons[i] += centroid[i];
|
||||
}
|
||||
} else {
|
||||
sq.decode(code, recons, 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -325,6 +325,31 @@ class TestScalarQuantizer(unittest.TestCase):
|
|||
# print(dis, D[i, j])
|
||||
assert abs(D[i, j] - dis) / dis < 1e-5
|
||||
|
||||
def test_reconstruct(self):
|
||||
self.do_reconstruct(True)
|
||||
|
||||
def test_reconstruct_no_residual(self):
|
||||
self.do_reconstruct(False)
|
||||
|
||||
def do_reconstruct(self, by_residual):
|
||||
d = 32
|
||||
xt, xb, xq = get_dataset_2(d, 100, 5, 5)
|
||||
|
||||
index = faiss.index_factory(d, "IVF10,SQ8")
|
||||
index.by_residual = by_residual
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
index.nprobe = 10
|
||||
D, I = index.search(xq, 4)
|
||||
xb2 = index.reconstruct_n(0, index.ntotal)
|
||||
for i in range(5):
|
||||
for j in range(4):
|
||||
self.assertAlmostEqual(
|
||||
((xq[i] - xb2[I[i, j]]) ** 2).sum(),
|
||||
D[i, j],
|
||||
places=4
|
||||
)
|
||||
|
||||
|
||||
class TestRandom(unittest.TestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue