10 #include "../../IndexBinaryFlat.h"
11 #include "../GpuIndexBinaryFlat.h"
12 #include "../StandardGpuResources.h"
13 #include "../utils/DeviceUtils.h"
14 #include "../test/TestUtils.h"
15 #include "../../utils.h"
16 #include <gtest/gtest.h>
20 void compareBinaryDist(
const std::vector<int>& cpuDist,
21 const std::vector<faiss::IndexBinary::idx_t>& cpuLabels,
22 const std::vector<int>& gpuDist,
23 const std::vector<faiss::IndexBinary::idx_t>& gpuLabels,
26 for (
int i = 0; i < numQuery; ++i) {
32 std::set<faiss::IndexBinary::idx_t> cpuLabelSet;
33 std::set<faiss::IndexBinary::idx_t> gpuLabelSet;
37 for (
int j = 0; j < k; ++j) {
41 curDist = cpuDist[idx];
44 if (curDist != cpuDist[idx]) {
46 EXPECT_LT(curDist, cpuDist[idx]);
49 EXPECT_EQ(cpuLabelSet, gpuLabelSet);
50 curDist = cpuDist[idx];
55 cpuLabelSet.insert(cpuLabels[idx]);
56 gpuLabelSet.insert(gpuLabels[idx]);
59 EXPECT_EQ(cpuDist[idx], gpuDist[idx]);
64 template <
int DimMultiple>
65 void testGpuIndexBinaryFlat() {
70 config.
device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
73 int dims = faiss::gpu::randVal(1, 20) * DimMultiple;
78 int numVecs = faiss::gpu::randVal(1, 20000);
79 int numQuery = faiss::gpu::randVal(1, 1000);
80 int k = faiss::gpu::randVal(1, 1024);
82 auto data = faiss::gpu::randBinaryVecs(numVecs, dims);
83 gpuIndex.add(numVecs, data.data());
84 cpuIndex.add(numVecs, data.data());
86 auto query = faiss::gpu::randBinaryVecs(numQuery, dims);
88 std::vector<int> cpuDist(numQuery * k);
89 std::vector<faiss::IndexBinary::idx_t> cpuLabels(numQuery * k);
91 cpuIndex.search(numQuery,
97 std::vector<int> gpuDist(numQuery * k);
98 std::vector<faiss::IndexBinary::idx_t> gpuLabels(numQuery * k);
100 gpuIndex.search(numQuery,
106 compareBinaryDist(cpuDist, cpuLabels,
111 TEST(TestGpuIndexBinaryFlat, Test8) {
112 for (
int tries = 0; tries < 4; ++tries) {
113 testGpuIndexBinaryFlat<8>();
117 TEST(TestGpuIndexBinaryFlat, Test32) {
118 for (
int tries = 0; tries < 4; ++tries) {
119 testGpuIndexBinaryFlat<32>();
123 int main(
int argc,
char** argv) {
124 testing::InitGoogleTest(&argc, argv);
127 faiss::gpu::setTestSeed(100);
129 return RUN_ALL_TESTS();
int device
GPU device on which the index is resident.