Fix reconstruct bug when by_residual is false (#2298)

Summary:
When I reconstruct with by_residual turned off, the distance was greatly increased.
This is because the reconstruct_from_offset function did not check if the by_residual option was off.
I fix this bug with simple if statement.
(like this https://github.com/facebookresearch/faiss/blob/main/faiss/IndexIVFPQ.cpp#L365)

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2298

Reviewed By: alexanderguzhva

Differential Revision: D35746566

Pulled By: mdouze

fbshipit-source-id: 50f98c7cc97c7936507573fe41b65a79ecdbc4ca
pull/2291/head
spectaclehong 2022-04-20 01:35:21 -07:00 committed by Facebook GitHub Bot
parent 8ffed8c219
commit b13f47a4da
2 changed files with 36 additions and 6 deletions

View File

@ -251,13 +251,18 @@ void IndexIVFScalarQuantizer::reconstruct_from_offset(
int64_t list_no,
int64_t offset,
float* recons) const {
std::vector<float> centroid(d);
quantizer->reconstruct(list_no, centroid.data());
const uint8_t* code = invlists->get_single_code(list_no, offset);
sq.decode(code, recons, 1);
for (int i = 0; i < d; ++i) {
recons[i] += centroid[i];
if (by_residual) {
std::vector<float> centroid(d);
quantizer->reconstruct(list_no, centroid.data());
sq.decode(code, recons, 1);
for (int i = 0; i < d; ++i) {
recons[i] += centroid[i];
}
} else {
sq.decode(code, recons, 1);
}
}

View File

@ -325,6 +325,31 @@ class TestScalarQuantizer(unittest.TestCase):
# print(dis, D[i, j])
assert abs(D[i, j] - dis) / dis < 1e-5
def test_reconstruct(self):
self.do_reconstruct(True)
def test_reconstruct_no_residual(self):
self.do_reconstruct(False)
def do_reconstruct(self, by_residual):
d = 32
xt, xb, xq = get_dataset_2(d, 100, 5, 5)
index = faiss.index_factory(d, "IVF10,SQ8")
index.by_residual = by_residual
index.train(xt)
index.add(xb)
index.nprobe = 10
D, I = index.search(xq, 4)
xb2 = index.reconstruct_n(0, index.ntotal)
for i in range(5):
for j in range(4):
self.assertAlmostEqual(
((xq[i] - xb2[I[i, j]]) ** 2).sum(),
D[i, j],
places=4
)
class TestRandom(unittest.TestCase):