Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFFlatScan.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "IVFFlatScan.cuh"
10 #include "../GpuResources.h"
11 #include "IVFUtils.cuh"
12 #include "../utils/ConversionOperators.cuh"
13 #include "../utils/DeviceDefs.cuh"
14 #include "../utils/DeviceUtils.h"
15 #include "../utils/DeviceTensor.cuh"
16 #include "../utils/Float16.cuh"
17 #include "../utils/MathOperators.cuh"
18 #include "../utils/LoadStoreOperators.cuh"
19 #include "../utils/PtxUtils.cuh"
20 #include "../utils/Reductions.cuh"
21 #include "../utils/StaticUtils.h"
22 #include <thrust/host_vector.h>
23 
24 namespace faiss { namespace gpu {
25 
26 template <typename T>
27 inline __device__ typename Math<T>::ScalarType l2Distance(T a, T b) {
28  a = Math<T>::sub(a, b);
29  a = Math<T>::mul(a, a);
30  return Math<T>::reduceAdd(a);
31 }
32 
33 template <typename T>
34 inline __device__ typename Math<T>::ScalarType ipDistance(T a, T b) {
35  return Math<T>::reduceAdd(Math<T>::mul(a, b));
36 }
37 
38 // For list scanning, even if the input data is `half`, we perform all
39 // math in float32, because the code is memory b/w bound, and the
40 // added precision for accumulation is useful
41 
42 /// The class that we use to provide scan specializations
43 template <int Dims, bool L2, typename T>
44 struct IVFFlatScan {
45 };
46 
47 // Fallback implementation: works for any dimension size
48 template <bool L2, typename T>
49 struct IVFFlatScan<-1, L2, T> {
50  static __device__ void scan(float* query,
51  void* vecData,
52  int numVecs,
53  int dim,
54  float* distanceOut) {
55  extern __shared__ float smem[];
56  T* vecs = (T*) vecData;
57 
58  for (int vec = 0; vec < numVecs; ++vec) {
59  // Reduce in dist
60  float dist = 0.0f;
61 
62  for (int d = threadIdx.x; d < dim; d += blockDim.x) {
63  float vecVal = ConvertTo<float>::to(vecs[vec * dim + d]);
64  float queryVal = query[d];
65  float curDist;
66 
67  if (L2) {
68  curDist = l2Distance(queryVal, vecVal);
69  } else {
70  curDist = ipDistance(queryVal, vecVal);
71  }
72 
73  dist += curDist;
74  }
75 
76  // Reduce distance within block
77  dist = blockReduceAllSum<float, false, true>(dist, smem);
78 
79  if (threadIdx.x == 0) {
80  distanceOut[vec] = dist;
81  }
82  }
83  }
84 };
85 
86 // implementation: works for # dims == blockDim.x
87 template <bool L2, typename T>
88 struct IVFFlatScan<0, L2, T> {
89  static __device__ void scan(float* query,
90  void* vecData,
91  int numVecs,
92  int dim,
93  float* distanceOut) {
94  extern __shared__ float smem[];
95  T* vecs = (T*) vecData;
96 
97  float queryVal = query[threadIdx.x];
98 
99  constexpr int kUnroll = 4;
100  int limit = utils::roundDown(numVecs, kUnroll);
101 
102  for (int i = 0; i < limit; i += kUnroll) {
103  float vecVal[kUnroll];
104 
105 #pragma unroll
106  for (int j = 0; j < kUnroll; ++j) {
107  vecVal[j] = ConvertTo<float>::to(vecs[(i + j) * dim + threadIdx.x]);
108  }
109 
110 #pragma unroll
111  for (int j = 0; j < kUnroll; ++j) {
112  if (L2) {
113  vecVal[j] = l2Distance(queryVal, vecVal[j]);
114  } else {
115  vecVal[j] = ipDistance(queryVal, vecVal[j]);
116  }
117  }
118 
119  blockReduceAllSum<kUnroll, float, false, true>(vecVal, smem);
120 
121  if (threadIdx.x == 0) {
122 #pragma unroll
123  for (int j = 0; j < kUnroll; ++j) {
124  distanceOut[i + j] = vecVal[j];
125  }
126  }
127  }
128 
129  // Handle remainder
130  for (int i = limit; i < numVecs; ++i) {
131  float vecVal = ConvertTo<float>::to(vecs[i * dim + threadIdx.x]);
132 
133  if (L2) {
134  vecVal = l2Distance(queryVal, vecVal);
135  } else {
136  vecVal = ipDistance(queryVal, vecVal);
137  }
138 
139  vecVal = blockReduceAllSum<float, false, true>(vecVal, smem);
140 
141  if (threadIdx.x == 0) {
142  distanceOut[i] = vecVal;
143  }
144  }
145  }
146 };
147 
148 template <int Dims, bool L2, typename T>
149 __global__ void
150 ivfFlatScan(Tensor<float, 2, true> queries,
151  Tensor<int, 2, true> listIds,
152  void** allListData,
153  int* listLengths,
154  Tensor<int, 2, true> prefixSumOffsets,
155  Tensor<float, 1, true> distance) {
156  auto queryId = blockIdx.y;
157  auto probeId = blockIdx.x;
158 
159  // This is where we start writing out data
160  // We ensure that before the array (at offset -1), there is a 0 value
161  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
162 
163  auto listId = listIds[queryId][probeId];
164  // Safety guard in case NaNs in input cause no list ID to be generated
165  if (listId == -1) {
166  return;
167  }
168 
169  auto query = queries[queryId].data();
170  auto vecs = allListData[listId];
171  auto numVecs = listLengths[listId];
172  auto dim = queries.getSize(1);
173  auto distanceOut = distance[outBase].data();
174 
175  IVFFlatScan<Dims, L2, T>::scan(query, vecs, numVecs, dim, distanceOut);
176 }
177 
178 void
179 runIVFFlatScanTile(Tensor<float, 2, true>& queries,
180  Tensor<int, 2, true>& listIds,
181  thrust::device_vector<void*>& listData,
182  thrust::device_vector<void*>& listIndices,
183  IndicesOptions indicesOptions,
184  thrust::device_vector<int>& listLengths,
185  Tensor<char, 1, true>& thrustMem,
186  Tensor<int, 2, true>& prefixSumOffsets,
187  Tensor<float, 1, true>& allDistances,
188  Tensor<float, 3, true>& heapDistances,
189  Tensor<int, 3, true>& heapIndices,
190  int k,
191  bool l2Distance,
192  bool useFloat16,
193  Tensor<float, 2, true>& outDistances,
194  Tensor<long, 2, true>& outIndices,
195  cudaStream_t stream) {
196  // Calculate offset lengths, so we know where to write out
197  // intermediate results
198  runCalcListOffsets(listIds, listLengths, prefixSumOffsets, thrustMem, stream);
199 
200  // Calculate distances for vectors within our chunk of lists
201  constexpr int kMaxThreadsIVF = 512;
202 
203  // FIXME: if `half` and # dims is multiple of 2, halve the
204  // threadblock size
205 
206  int dim = queries.getSize(1);
207  int numThreads = std::min(dim, kMaxThreadsIVF);
208 
209  auto grid = dim3(listIds.getSize(1),
210  listIds.getSize(0));
211  auto block = dim3(numThreads);
212  // All exact dim kernels are unrolled by 4, hence the `4`
213  auto smem = sizeof(float) * utils::divUp(numThreads, kWarpSize) * 4;
214 
215 #define RUN_IVF_FLAT(DIMS, L2, T) \
216  do { \
217  ivfFlatScan<DIMS, L2, T> \
218  <<<grid, block, smem, stream>>>( \
219  queries, \
220  listIds, \
221  listData.data().get(), \
222  listLengths.data().get(), \
223  prefixSumOffsets, \
224  allDistances); \
225  } while (0)
226 
227 #ifdef FAISS_USE_FLOAT16
228 
229 #define HANDLE_DIM_CASE(DIMS) \
230  do { \
231  if (l2Distance) { \
232  if (useFloat16) { \
233  RUN_IVF_FLAT(DIMS, true, half); \
234  } else { \
235  RUN_IVF_FLAT(DIMS, true, float); \
236  } \
237  } else { \
238  if (useFloat16) { \
239  RUN_IVF_FLAT(DIMS, false, half); \
240  } else { \
241  RUN_IVF_FLAT(DIMS, false, float); \
242  } \
243  } \
244  } while (0)
245 #else
246 
247 #define HANDLE_DIM_CASE(DIMS) \
248  do { \
249  if (l2Distance) { \
250  if (useFloat16) { \
251  FAISS_ASSERT(false); \
252  } else { \
253  RUN_IVF_FLAT(DIMS, true, float); \
254  } \
255  } else { \
256  if (useFloat16) { \
257  FAISS_ASSERT(false); \
258  } else { \
259  RUN_IVF_FLAT(DIMS, false, float); \
260  } \
261  } \
262  } while (0)
263 
264 #endif // FAISS_USE_FLOAT16
265 
266  if (dim <= kMaxThreadsIVF) {
267  HANDLE_DIM_CASE(0);
268  } else {
269  HANDLE_DIM_CASE(-1);
270  }
271 
272  CUDA_TEST_ERROR();
273 
274 #undef HANDLE_DIM_CASE
275 #undef RUN_IVF_FLAT
276 
277  // k-select the output in chunks, to increase parallelism
278  runPass1SelectLists(prefixSumOffsets,
279  allDistances,
280  listIds.getSize(1),
281  k,
282  !l2Distance, // L2 distance chooses smallest
283  heapDistances,
284  heapIndices,
285  stream);
286 
287  // k-select final output
288  auto flatHeapDistances = heapDistances.downcastInner<2>();
289  auto flatHeapIndices = heapIndices.downcastInner<2>();
290 
291  runPass2SelectLists(flatHeapDistances,
292  flatHeapIndices,
293  listIndices,
294  indicesOptions,
295  prefixSumOffsets,
296  listIds,
297  k,
298  !l2Distance, // L2 distance chooses smallest
299  outDistances,
300  outIndices,
301  stream);
302 }
303 
304 void
305 runIVFFlatScan(Tensor<float, 2, true>& queries,
306  Tensor<int, 2, true>& listIds,
307  thrust::device_vector<void*>& listData,
308  thrust::device_vector<void*>& listIndices,
309  IndicesOptions indicesOptions,
310  thrust::device_vector<int>& listLengths,
311  int maxListLength,
312  int k,
313  bool l2Distance,
314  bool useFloat16,
315  // output
316  Tensor<float, 2, true>& outDistances,
317  // output
318  Tensor<long, 2, true>& outIndices,
319  GpuResources* res) {
320  constexpr int kMinQueryTileSize = 8;
321  constexpr int kMaxQueryTileSize = 128;
322  constexpr int kThrustMemSize = 16384;
323 
324  int nprobe = listIds.getSize(1);
325 
326  auto& mem = res->getMemoryManagerCurrentDevice();
327  auto stream = res->getDefaultStreamCurrentDevice();
328 
329  // Make a reservation for Thrust to do its dirty work (global memory
330  // cross-block reduction space); hopefully this is large enough.
331  DeviceTensor<char, 1, true> thrustMem1(
332  mem, {kThrustMemSize}, stream);
333  DeviceTensor<char, 1, true> thrustMem2(
334  mem, {kThrustMemSize}, stream);
335  DeviceTensor<char, 1, true>* thrustMem[2] =
336  {&thrustMem1, &thrustMem2};
337 
338  // How much temporary storage is available?
339  // If possible, we'd like to fit within the space available.
340  size_t sizeAvailable = mem.getSizeAvailable();
341 
342  // We run two passes of heap selection
343  // This is the size of the first-level heap passes
344  constexpr int kNProbeSplit = 8;
345  int pass2Chunks = std::min(nprobe, kNProbeSplit);
346 
347  size_t sizeForFirstSelectPass =
348  pass2Chunks * k * (sizeof(float) + sizeof(int));
349 
350  // How much temporary storage we need per each query
351  size_t sizePerQuery =
352  2 * // # streams
353  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
354  nprobe * maxListLength * sizeof(float) + // allDistances
355  sizeForFirstSelectPass);
356 
357  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
358 
359  if (queryTileSize < kMinQueryTileSize) {
360  queryTileSize = kMinQueryTileSize;
361  } else if (queryTileSize > kMaxQueryTileSize) {
362  queryTileSize = kMaxQueryTileSize;
363  }
364 
365  // FIXME: we should adjust queryTileSize to deal with this, since
366  // indexing is in int32
367  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
368  std::numeric_limits<int>::max());
369 
370  // Temporary memory buffers
371  // Make sure there is space prior to the start which will be 0, and
372  // will handle the boundary condition without branches
373  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
374  mem, {queryTileSize * nprobe + 1}, stream);
375  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
376  mem, {queryTileSize * nprobe + 1}, stream);
377 
378  DeviceTensor<int, 2, true> prefixSumOffsets1(
379  prefixSumOffsetSpace1[1].data(),
380  {queryTileSize, nprobe});
381  DeviceTensor<int, 2, true> prefixSumOffsets2(
382  prefixSumOffsetSpace2[1].data(),
383  {queryTileSize, nprobe});
384  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
385  {&prefixSumOffsets1, &prefixSumOffsets2};
386 
387  // Make sure the element before prefixSumOffsets is 0, since we
388  // depend upon simple, boundary-less indexing to get proper results
389  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
390  0,
391  sizeof(int),
392  stream));
393  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
394  0,
395  sizeof(int),
396  stream));
397 
398  DeviceTensor<float, 1, true> allDistances1(
399  mem, {queryTileSize * nprobe * maxListLength}, stream);
400  DeviceTensor<float, 1, true> allDistances2(
401  mem, {queryTileSize * nprobe * maxListLength}, stream);
402  DeviceTensor<float, 1, true>* allDistances[2] =
403  {&allDistances1, &allDistances2};
404 
405  DeviceTensor<float, 3, true> heapDistances1(
406  mem, {queryTileSize, pass2Chunks, k}, stream);
407  DeviceTensor<float, 3, true> heapDistances2(
408  mem, {queryTileSize, pass2Chunks, k}, stream);
409  DeviceTensor<float, 3, true>* heapDistances[2] =
410  {&heapDistances1, &heapDistances2};
411 
412  DeviceTensor<int, 3, true> heapIndices1(
413  mem, {queryTileSize, pass2Chunks, k}, stream);
414  DeviceTensor<int, 3, true> heapIndices2(
415  mem, {queryTileSize, pass2Chunks, k}, stream);
416  DeviceTensor<int, 3, true>* heapIndices[2] =
417  {&heapIndices1, &heapIndices2};
418 
419  auto streams = res->getAlternateStreamsCurrentDevice();
420  streamWait(streams, {stream});
421 
422  int curStream = 0;
423 
424  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
425  int numQueriesInTile =
426  std::min(queryTileSize, queries.getSize(0) - query);
427 
428  auto prefixSumOffsetsView =
429  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
430 
431  auto listIdsView =
432  listIds.narrowOutermost(query, numQueriesInTile);
433  auto queryView =
434  queries.narrowOutermost(query, numQueriesInTile);
435 
436  auto heapDistancesView =
437  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
438  auto heapIndicesView =
439  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
440 
441  auto outDistanceView =
442  outDistances.narrowOutermost(query, numQueriesInTile);
443  auto outIndicesView =
444  outIndices.narrowOutermost(query, numQueriesInTile);
445 
446  runIVFFlatScanTile(queryView,
447  listIdsView,
448  listData,
449  listIndices,
450  indicesOptions,
451  listLengths,
452  *thrustMem[curStream],
453  prefixSumOffsetsView,
454  *allDistances[curStream],
455  heapDistancesView,
456  heapIndicesView,
457  k,
458  l2Distance,
459  useFloat16,
460  outDistanceView,
461  outIndicesView,
462  streams[curStream]);
463 
464  curStream = (curStream + 1) % 2;
465  }
466 
467  streamWait({stream}, streams);
468 }
469 
470 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:222
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:174
Our tensor type.
Definition: Tensor.cuh:28
The class that we use to provide scan specializations.
Definition: IVFFlatScan.cu:44