99 lines
2.3 KiB
C++
99 lines
2.3 KiB
C++
/**
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
|
|
#include <iostream>
|
|
#include <vector>
|
|
#include <memory>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <faiss/impl/ProductQuantizer.h>
|
|
|
|
|
|
namespace {
|
|
|
|
const std::vector<uint64_t> random_vector(size_t s) {
|
|
std::vector<uint64_t> v(s, 0);
|
|
for (size_t i = 0; i < s; ++i) {
|
|
v[i] = rand();
|
|
}
|
|
|
|
return v;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
|
TEST(PQEncoderGeneric, encode) {
|
|
const int nsubcodes = 97;
|
|
const int minbits = 1;
|
|
const int maxbits = 24;
|
|
const std::vector<uint64_t> values = random_vector(nsubcodes);
|
|
|
|
for(int nbits = minbits; nbits <= maxbits; ++nbits) {
|
|
std::cerr << "nbits = " << nbits << std::endl;
|
|
|
|
const uint64_t mask = (1ull << nbits) - 1;
|
|
std::unique_ptr<uint8_t[]> codes(
|
|
new uint8_t[(nsubcodes * maxbits + 7) / 8]
|
|
);
|
|
|
|
// NOTE(hoss): Necessary scope to ensure trailing bits are flushed to mem.
|
|
{
|
|
faiss::PQEncoderGeneric encoder(codes.get(), nbits);
|
|
for (const auto& v : values) {
|
|
encoder.encode(v & mask);
|
|
}
|
|
}
|
|
|
|
faiss::PQDecoderGeneric decoder(codes.get(), nbits);
|
|
for (int i = 0; i < nsubcodes; ++i) {
|
|
uint64_t v = decoder.decode();
|
|
EXPECT_EQ(values[i] & mask, v);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
TEST(PQEncoder8, encode) {
|
|
const int nsubcodes = 100;
|
|
const std::vector<uint64_t> values = random_vector(nsubcodes);
|
|
const uint64_t mask = 0xFF;
|
|
std::unique_ptr<uint8_t[]> codes(new uint8_t[nsubcodes]);
|
|
|
|
faiss::PQEncoder8 encoder(codes.get(), 8);
|
|
for (const auto& v : values) {
|
|
encoder.encode(v & mask);
|
|
}
|
|
|
|
faiss::PQDecoder8 decoder(codes.get(), 8);
|
|
for (int i = 0; i < nsubcodes; ++i) {
|
|
uint64_t v = decoder.decode();
|
|
EXPECT_EQ(values[i] & mask, v);
|
|
}
|
|
}
|
|
|
|
|
|
TEST(PQEncoder16, encode) {
|
|
const int nsubcodes = 100;
|
|
const std::vector<uint64_t> values = random_vector(nsubcodes);
|
|
const uint64_t mask = 0xFFFF;
|
|
std::unique_ptr<uint8_t[]> codes(new uint8_t[2 * nsubcodes]);
|
|
|
|
faiss::PQEncoder16 encoder(codes.get(), 16);
|
|
for (const auto& v : values) {
|
|
encoder.encode(v & mask);
|
|
}
|
|
|
|
faiss::PQDecoder16 decoder(codes.get(), 16);
|
|
for (int i = 0; i < nsubcodes; ++i) {
|
|
uint64_t v = decoder.decode();
|
|
EXPECT_EQ(values[i] & mask, v);
|
|
}
|
|
}
|