Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
TestUtils.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "../test/TestUtils.h"
12 #include "../../utils.h"
13 #include <cmath>
14 #include <gtest/gtest.h>
15 #include <set>
16 #include <sstream>
17 #include <time.h>
18 #include <unordered_map>
19 
20 namespace faiss { namespace gpu {
21 
22 inline float relativeError(float a, float b) {
23  return std::abs(a - b) / (0.5f * (std::abs(a) + std::abs(b)));
24 }
25 
26 // This seed is also used for the faiss float_rand API; in a test it
27 // is all within a single thread, so it is ok
28 long s_seed = 1;
29 
30 void newTestSeed() {
31  struct timespec t;
32  clock_gettime(CLOCK_REALTIME, &t);
33 
34  setTestSeed(t.tv_nsec);
35 }
36 
37 void setTestSeed(long seed) {
38  printf("testing with random seed %ld\n", seed);
39 
40  srand48(seed);
41  s_seed = seed;
42 }
43 
44 int randVal(int a, int b) {
45  EXPECT_GE(a, 0);
46  EXPECT_LE(a, b);
47 
48  return a + (lrand48() % (b + 1 - a));
49 }
50 
51 bool randBool() {
52  return randSelect<bool>({true, false});
53 }
54 
55 std::vector<float> randVecs(size_t num, size_t dim) {
56  std::vector<float> v(num * dim);
57  static bool first = true;
58 
59  faiss::float_rand(v.data(), v.size(), s_seed);
60  // unfortunately we generate separate sets of vectors, and don't
61  // want the same values
62  ++s_seed;
63 
64  return v;
65 }
66 
67 void compareIndices(faiss::Index& refIndex,
68  faiss::Index& testIndex,
69  int numQuery, int dim, int k,
70  const std::string& configMsg,
71  float maxRelativeError,
72  float pctMaxDiff1,
73  float pctMaxDiffN) {
74  auto queries = faiss::gpu::randVecs(numQuery, dim);
75 
76  // Compare
77  std::vector<float> refDistance(numQuery * k, 0);
78  std::vector<faiss::Index::idx_t> refIndices(numQuery * k, -1);
79  refIndex.search(numQuery, queries.data(),
80  k, refDistance.data(), refIndices.data());
81 
82  std::vector<float> testDistance(numQuery * k, 0);
83  std::vector<faiss::Index::idx_t> testIndices(numQuery * k, -1);
84  testIndex.search(numQuery, queries.data(),
85  k, testDistance.data(), testIndices.data());
86 
87  faiss::gpu::compareLists(refDistance.data(),
88  refIndices.data(),
89  testDistance.data(),
90  testIndices.data(),
91  numQuery, k,
92  configMsg,
93  true, false, true,
94  maxRelativeError, pctMaxDiff1, pctMaxDiffN);
95 }
96 
97 template <typename T>
98 inline T lookup(const T* p, int i, int j, int dim1, int dim2) {
99  return p[i * dim2 + j];
100 }
101 
102 void compareLists(const float* refDist,
103  const faiss::Index::idx_t* refInd,
104  const float* testDist,
105  const faiss::Index::idx_t* testInd,
106  int dim1, int dim2,
107  const std::string& configMsg,
108  bool printBasicStats, bool printDiffs, bool assertOnErr,
109  float maxRelativeError,
110  float pctMaxDiff1,
111  float pctMaxDiffN) {
112 
113  float maxAbsErr = 0.0f;
114  for (int i = 0; i < dim1 * dim2; ++i) {
115  maxAbsErr = std::max(maxAbsErr, std::abs(refDist[i] - testDist[i]));
116  }
117  int numResults = dim1 * dim2;
118 
119  // query -> {index -> result position}
120  std::vector<std::unordered_map<faiss::Index::idx_t, int>> refIndexMap;
121 
122  for (int query = 0; query < dim1; ++query) {
123  std::unordered_map<faiss::Index::idx_t, int> indices;
124 
125  for (int result = 0; result < dim2; ++result) {
126  indices[lookup(refInd, query, result, dim1, dim2)] = result;
127  }
128 
129  refIndexMap.emplace_back(std::move(indices));
130  }
131 
132  // See how far off the indices are
133  // Keep track of the difference for each entry
134  std::vector<std::vector<int>> indexDiffs;
135 
136  int diff1 = 0; // index differs by 1
137  int diffN = 0; // index differs by >1
138  int diffInf = 0; // index not found in the other
139  int nonUniqueIndices = 0;
140 
141  double avgDiff = 0.0;
142  int maxDiff = 0;
143  float maxRelErr = 0.0f;
144 
145  for (int query = 0; query < dim1; ++query) {
146  std::vector<int> diffs;
147  std::set<faiss::Index::idx_t> uniqueIndices;
148 
149  auto& indices = refIndexMap[query];
150 
151  for (int result = 0; result < dim2; ++result) {
152  auto t = lookup(testInd, query, result, dim1, dim2);
153 
154  // All indices reported within a query should be unique; this is
155  // a serious error if is otherwise the case
156  bool uniqueIndex = uniqueIndices.count(t) == 0;
157  if (assertOnErr) {
158  EXPECT_TRUE(uniqueIndex) << configMsg
159  << " " << query
160  << " " << result
161  << " " << t;
162  }
163 
164  if (!uniqueIndex) {
165  ++nonUniqueIndices;
166  } else {
167  uniqueIndices.insert(t);
168  }
169 
170  auto it = indices.find(t);
171  if (it != indices.end()) {
172  int diff = std::abs(result - it->second);
173  diffs.push_back(diff);
174 
175  if (diff == 1) {
176  ++diff1;
177  maxDiff = std::max(diff, maxDiff);
178  } else if (diff > 1) {
179  ++diffN;
180  maxDiff = std::max(diff, maxDiff);
181  }
182 
183  avgDiff += (double) diff;
184  } else {
185  ++diffInf;
186  diffs.push_back(-1);
187  // don't count this for maxDiff
188  }
189 
190  auto refD = lookup(refDist, query, result, dim1, dim2);
191  auto testD = lookup(testDist, query, result, dim1, dim2);
192 
193  float relErr = relativeError(refD, testD);
194 
195  if (assertOnErr) {
196  EXPECT_LE(relErr, maxRelativeError) << configMsg
197  << " (" << query << ", " << result
198  << ") refD: " << refD
199  << " testD: " << testD;
200  }
201 
202  maxRelErr = std::max(maxRelErr, relErr);
203  }
204 
205  indexDiffs.emplace_back(std::move(diffs));
206  }
207 
208  if (assertOnErr) {
209  EXPECT_LE((float) (diff1 + diffN + diffInf),
210  (float) numResults * pctMaxDiff1) << configMsg;
211 
212  // Don't count diffInf because that could be diff1 as far as we
213  // know
214  EXPECT_LE((float) diffN, (float) numResults * pctMaxDiffN) << configMsg;
215  }
216 
217  avgDiff /= (double) numResults;
218 
219  if (printBasicStats) {
220  if (!configMsg.empty()) {
221  printf("Config\n"
222  "----------------------------\n"
223  "%s\n",
224  configMsg.c_str());
225  }
226 
227  printf("Result error and differences\n"
228  "----------------------------\n"
229  "max abs diff %.7f rel diff %.7f\n"
230  "idx diff avg: %.5g max: %d\n"
231  "idx diff of 1: %d (%.3f%% of queries)\n"
232  "idx diff of >1: %d (%.3f%% of queries)\n"
233  "idx diff not found: %d (%.3f%% of queries)"
234  " [typically a last element inversion]\n"
235  "non-unique indices: %d (a serious error if >0)\n",
236  maxAbsErr, maxRelErr,
237  avgDiff, maxDiff,
238  diff1, 100.0f * (float) diff1 / (float) numResults,
239  diffN, 100.0f * (float) diffN / (float) numResults,
240  diffInf, 100.0f * (float) diffInf / (float) numResults,
241  nonUniqueIndices);
242  }
243 
244  if (printDiffs) {
245  printf("differences:\n");
246  printf("==================\n");
247  for (int query = 0; query < dim1; ++query) {
248  for (int result = 0; result < dim2; ++result) {
249  long refI = lookup(refInd, query, result, dim1, dim2);
250  long testI = lookup(testInd, query, result, dim1, dim2);
251 
252  if (refI != testI) {
253  float refD = lookup(refDist, query, result, dim1, dim2);
254  float testD = lookup(testDist, query, result, dim1, dim2);
255 
256  float maxDist = std::max(refD, testD);
257  float delta = std::abs(refD - testD);
258 
259  float relErr = delta / maxDist;
260 
261  if (refD == testD) {
262  printf("(%d, %d [%d]) (ref %ld tst %ld dist ==)\n",
263  query, result,
264  indexDiffs[query][result],
265  refI, testI);
266  } else {
267  printf("(%d, %d [%d]) (ref %ld tst %ld abs %.8f "
268  "rel %.8f ref %a tst %a)\n",
269  query, result,
270  indexDiffs[query][result],
271  refI, testI, delta, relErr, refD, testD);
272  }
273  }
274  }
275  }
276  }
277 }
278 
279 } }
long idx_t
all indices are this type
Definition: Index.h:62
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0