Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFFlatScan.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #include "IVFFlatScan.cuh"
11 #include "../GpuResources.h"
12 #include "IVFUtils.cuh"
13 #include "../utils/ConversionOperators.cuh"
14 #include "../utils/DeviceDefs.cuh"
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/DeviceTensor.cuh"
17 #include "../utils/Float16.cuh"
18 #include "../utils/MathOperators.cuh"
19 #include "../utils/LoadStoreOperators.cuh"
20 #include "../utils/PtxUtils.cuh"
21 #include "../utils/Reductions.cuh"
22 #include "../utils/StaticUtils.h"
23 #include <thrust/host_vector.h>
24 
25 namespace faiss { namespace gpu {
26 
27 template <typename T>
28 inline __device__ typename Math<T>::ScalarType l2Distance(T a, T b) {
29  a = Math<T>::sub(a, b);
30  a = Math<T>::mul(a, a);
31  return Math<T>::reduceAdd(a);
32 }
33 
34 template <typename T>
35 inline __device__ typename Math<T>::ScalarType ipDistance(T a, T b) {
36  return Math<T>::reduceAdd(Math<T>::mul(a, b));
37 }
38 
39 // For list scanning, even if the input data is `half`, we perform all
40 // math in float32, because the code is memory b/w bound, and the
41 // added precision for accumulation is useful
42 
43 /// The class that we use to provide scan specializations
44 template <int Dims, bool L2, typename T>
45 struct IVFFlatScan {
46 };
47 
48 // Fallback implementation: works for any dimension size
49 template <bool L2, typename T>
50 struct IVFFlatScan<-1, L2, T> {
51  static __device__ void scan(float* query,
52  void* vecData,
53  int numVecs,
54  int dim,
55  float* distanceOut) {
56  extern __shared__ float smem[];
57  T* vecs = (T*) vecData;
58 
59  for (int vec = 0; vec < numVecs; ++vec) {
60  // Reduce in dist
61  float dist = 0.0f;
62 
63  for (int d = threadIdx.x; d < dim; d += blockDim.x) {
64  float vecVal = ConvertTo<float>::to(vecs[vec * dim + d]);
65  float queryVal = query[d];
66  float curDist;
67 
68  if (L2) {
69  curDist = l2Distance(queryVal, vecVal);
70  } else {
71  curDist = ipDistance(queryVal, vecVal);
72  }
73 
74  dist += curDist;
75  }
76 
77  // Reduce distance within block
78  dist = blockReduceAllSum<float, false, true>(dist, smem);
79 
80  if (threadIdx.x == 0) {
81  distanceOut[vec] = dist;
82  }
83  }
84  }
85 };
86 
87 // implementation: works for # dims == blockDim.x
88 template <bool L2, typename T>
89 struct IVFFlatScan<0, L2, T> {
90  static __device__ void scan(float* query,
91  void* vecData,
92  int numVecs,
93  int dim,
94  float* distanceOut) {
95  extern __shared__ float smem[];
96  T* vecs = (T*) vecData;
97 
98  float queryVal = query[threadIdx.x];
99 
100  constexpr int kUnroll = 4;
101  int limit = utils::roundDown(numVecs, kUnroll);
102 
103  for (int i = 0; i < limit; i += kUnroll) {
104  float vecVal[kUnroll];
105 
106 #pragma unroll
107  for (int j = 0; j < kUnroll; ++j) {
108  vecVal[j] = ConvertTo<float>::to(vecs[(i + j) * dim + threadIdx.x]);
109  }
110 
111 #pragma unroll
112  for (int j = 0; j < kUnroll; ++j) {
113  if (L2) {
114  vecVal[j] = l2Distance(queryVal, vecVal[j]);
115  } else {
116  vecVal[j] = ipDistance(queryVal, vecVal[j]);
117  }
118  }
119 
120  blockReduceAllSum<kUnroll, float, false, true>(vecVal, smem);
121 
122  if (threadIdx.x == 0) {
123 #pragma unroll
124  for (int j = 0; j < kUnroll; ++j) {
125  distanceOut[i + j] = vecVal[j];
126  }
127  }
128  }
129 
130  // Handle remainder
131  for (int i = limit; i < numVecs; ++i) {
132  float vecVal = ConvertTo<float>::to(vecs[i * dim + threadIdx.x]);
133 
134  if (L2) {
135  vecVal = l2Distance(queryVal, vecVal);
136  } else {
137  vecVal = ipDistance(queryVal, vecVal);
138  }
139 
140  vecVal = blockReduceAllSum<float, false, true>(vecVal, smem);
141 
142  if (threadIdx.x == 0) {
143  distanceOut[i] = vecVal;
144  }
145  }
146  }
147 };
148 
149 template <int Dims, bool L2, typename T>
150 __global__ void
151 ivfFlatScan(Tensor<float, 2, true> queries,
152  Tensor<int, 2, true> listIds,
153  void** allListData,
154  int* listLengths,
155  Tensor<int, 2, true> prefixSumOffsets,
156  Tensor<float, 1, true> distance) {
157  auto queryId = blockIdx.y;
158  auto probeId = blockIdx.x;
159 
160  // This is where we start writing out data
161  // We ensure that before the array (at offset -1), there is a 0 value
162  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
163 
164  auto listId = listIds[queryId][probeId];
165  // Safety guard in case NaNs in input cause no list ID to be generated
166  if (listId == -1) {
167  return;
168  }
169 
170  auto query = queries[queryId].data();
171  auto vecs = allListData[listId];
172  auto numVecs = listLengths[listId];
173  auto dim = queries.getSize(1);
174  auto distanceOut = distance[outBase].data();
175 
176  IVFFlatScan<Dims, L2, T>::scan(query, vecs, numVecs, dim, distanceOut);
177 }
178 
179 void
180 runIVFFlatScanTile(Tensor<float, 2, true>& queries,
181  Tensor<int, 2, true>& listIds,
182  thrust::device_vector<void*>& listData,
183  thrust::device_vector<void*>& listIndices,
184  IndicesOptions indicesOptions,
185  thrust::device_vector<int>& listLengths,
186  Tensor<char, 1, true>& thrustMem,
187  Tensor<int, 2, true>& prefixSumOffsets,
188  Tensor<float, 1, true>& allDistances,
189  Tensor<float, 3, true>& heapDistances,
190  Tensor<int, 3, true>& heapIndices,
191  int k,
192  bool l2Distance,
193  bool useFloat16,
194  Tensor<float, 2, true>& outDistances,
195  Tensor<long, 2, true>& outIndices,
196  cudaStream_t stream) {
197  // Calculate offset lengths, so we know where to write out
198  // intermediate results
199  runCalcListOffsets(listIds, listLengths, prefixSumOffsets, thrustMem, stream);
200 
201  // Calculate distances for vectors within our chunk of lists
202  constexpr int kMaxThreadsIVF = 512;
203 
204  // FIXME: if `half` and # dims is multiple of 2, halve the
205  // threadblock size
206 
207  int dim = queries.getSize(1);
208  int numThreads = std::min(dim, kMaxThreadsIVF);
209 
210  auto grid = dim3(listIds.getSize(1),
211  listIds.getSize(0));
212  auto block = dim3(numThreads);
213  // All exact dim kernels are unrolled by 4, hence the `4`
214  auto smem = sizeof(float) * utils::divUp(numThreads, kWarpSize) * 4;
215 
216 #define RUN_IVF_FLAT(DIMS, L2, T) \
217  do { \
218  ivfFlatScan<DIMS, L2, T> \
219  <<<grid, block, smem, stream>>>( \
220  queries, \
221  listIds, \
222  listData.data().get(), \
223  listLengths.data().get(), \
224  prefixSumOffsets, \
225  allDistances); \
226  } while (0)
227 
228 #ifdef FAISS_USE_FLOAT16
229 
230 #define HANDLE_DIM_CASE(DIMS) \
231  do { \
232  if (l2Distance) { \
233  if (useFloat16) { \
234  RUN_IVF_FLAT(DIMS, true, half); \
235  } else { \
236  RUN_IVF_FLAT(DIMS, true, float); \
237  } \
238  } else { \
239  if (useFloat16) { \
240  RUN_IVF_FLAT(DIMS, false, half); \
241  } else { \
242  RUN_IVF_FLAT(DIMS, false, float); \
243  } \
244  } \
245  } while (0)
246 #else
247 
248 #define HANDLE_DIM_CASE(DIMS) \
249  do { \
250  if (l2Distance) { \
251  if (useFloat16) { \
252  FAISS_ASSERT(false); \
253  } else { \
254  RUN_IVF_FLAT(DIMS, true, float); \
255  } \
256  } else { \
257  if (useFloat16) { \
258  FAISS_ASSERT(false); \
259  } else { \
260  RUN_IVF_FLAT(DIMS, false, float); \
261  } \
262  } \
263  } while (0)
264 
265 #endif // FAISS_USE_FLOAT16
266 
267  if (dim <= kMaxThreadsIVF) {
268  HANDLE_DIM_CASE(0);
269  } else {
270  HANDLE_DIM_CASE(-1);
271  }
272 
273  CUDA_TEST_ERROR();
274 
275 #undef HANDLE_DIM_CASE
276 #undef RUN_IVF_FLAT
277 
278  // k-select the output in chunks, to increase parallelism
279  runPass1SelectLists(prefixSumOffsets,
280  allDistances,
281  listIds.getSize(1),
282  k,
283  !l2Distance, // L2 distance chooses smallest
284  heapDistances,
285  heapIndices,
286  stream);
287 
288  // k-select final output
289  auto flatHeapDistances = heapDistances.downcastInner<2>();
290  auto flatHeapIndices = heapIndices.downcastInner<2>();
291 
292  runPass2SelectLists(flatHeapDistances,
293  flatHeapIndices,
294  listIndices,
295  indicesOptions,
296  prefixSumOffsets,
297  listIds,
298  k,
299  !l2Distance, // L2 distance chooses smallest
300  outDistances,
301  outIndices,
302  stream);
303 }
304 
305 void
306 runIVFFlatScan(Tensor<float, 2, true>& queries,
307  Tensor<int, 2, true>& listIds,
308  thrust::device_vector<void*>& listData,
309  thrust::device_vector<void*>& listIndices,
310  IndicesOptions indicesOptions,
311  thrust::device_vector<int>& listLengths,
312  int maxListLength,
313  int k,
314  bool l2Distance,
315  bool useFloat16,
316  // output
317  Tensor<float, 2, true>& outDistances,
318  // output
319  Tensor<long, 2, true>& outIndices,
320  GpuResources* res) {
321  constexpr int kMinQueryTileSize = 8;
322  constexpr int kMaxQueryTileSize = 128;
323  constexpr int kThrustMemSize = 16384;
324 
325  int nprobe = listIds.getSize(1);
326 
327  auto& mem = res->getMemoryManagerCurrentDevice();
328  auto stream = res->getDefaultStreamCurrentDevice();
329 
330  // Make a reservation for Thrust to do its dirty work (global memory
331  // cross-block reduction space); hopefully this is large enough.
332  DeviceTensor<char, 1, true> thrustMem1(
333  mem, {kThrustMemSize}, stream);
334  DeviceTensor<char, 1, true> thrustMem2(
335  mem, {kThrustMemSize}, stream);
336  DeviceTensor<char, 1, true>* thrustMem[2] =
337  {&thrustMem1, &thrustMem2};
338 
339  // How much temporary storage is available?
340  // If possible, we'd like to fit within the space available.
341  size_t sizeAvailable = mem.getSizeAvailable();
342 
343  // We run two passes of heap selection
344  // This is the size of the first-level heap passes
345  constexpr int kNProbeSplit = 8;
346  int pass2Chunks = std::min(nprobe, kNProbeSplit);
347 
348  size_t sizeForFirstSelectPass =
349  pass2Chunks * k * (sizeof(float) + sizeof(int));
350 
351  // How much temporary storage we need per each query
352  size_t sizePerQuery =
353  2 * // # streams
354  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
355  nprobe * maxListLength * sizeof(float) + // allDistances
356  sizeForFirstSelectPass);
357 
358  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
359 
360  if (queryTileSize < kMinQueryTileSize) {
361  queryTileSize = kMinQueryTileSize;
362  } else if (queryTileSize > kMaxQueryTileSize) {
363  queryTileSize = kMaxQueryTileSize;
364  }
365 
366  // FIXME: we should adjust queryTileSize to deal with this, since
367  // indexing is in int32
368  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
369  std::numeric_limits<int>::max());
370 
371  // Temporary memory buffers
372  // Make sure there is space prior to the start which will be 0, and
373  // will handle the boundary condition without branches
374  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
375  mem, {queryTileSize * nprobe + 1}, stream);
376  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
377  mem, {queryTileSize * nprobe + 1}, stream);
378 
379  DeviceTensor<int, 2, true> prefixSumOffsets1(
380  prefixSumOffsetSpace1[1].data(),
381  {queryTileSize, nprobe});
382  DeviceTensor<int, 2, true> prefixSumOffsets2(
383  prefixSumOffsetSpace2[1].data(),
384  {queryTileSize, nprobe});
385  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
386  {&prefixSumOffsets1, &prefixSumOffsets2};
387 
388  // Make sure the element before prefixSumOffsets is 0, since we
389  // depend upon simple, boundary-less indexing to get proper results
390  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
391  0,
392  sizeof(int),
393  stream));
394  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
395  0,
396  sizeof(int),
397  stream));
398 
399  DeviceTensor<float, 1, true> allDistances1(
400  mem, {queryTileSize * nprobe * maxListLength}, stream);
401  DeviceTensor<float, 1, true> allDistances2(
402  mem, {queryTileSize * nprobe * maxListLength}, stream);
403  DeviceTensor<float, 1, true>* allDistances[2] =
404  {&allDistances1, &allDistances2};
405 
406  DeviceTensor<float, 3, true> heapDistances1(
407  mem, {queryTileSize, pass2Chunks, k}, stream);
408  DeviceTensor<float, 3, true> heapDistances2(
409  mem, {queryTileSize, pass2Chunks, k}, stream);
410  DeviceTensor<float, 3, true>* heapDistances[2] =
411  {&heapDistances1, &heapDistances2};
412 
413  DeviceTensor<int, 3, true> heapIndices1(
414  mem, {queryTileSize, pass2Chunks, k}, stream);
415  DeviceTensor<int, 3, true> heapIndices2(
416  mem, {queryTileSize, pass2Chunks, k}, stream);
417  DeviceTensor<int, 3, true>* heapIndices[2] =
418  {&heapIndices1, &heapIndices2};
419 
420  auto streams = res->getAlternateStreamsCurrentDevice();
421  streamWait(streams, {stream});
422 
423  int curStream = 0;
424 
425  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
426  int numQueriesInTile =
427  std::min(queryTileSize, queries.getSize(0) - query);
428 
429  auto prefixSumOffsetsView =
430  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
431 
432  auto listIdsView =
433  listIds.narrowOutermost(query, numQueriesInTile);
434  auto queryView =
435  queries.narrowOutermost(query, numQueriesInTile);
436 
437  auto heapDistancesView =
438  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
439  auto heapIndicesView =
440  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
441 
442  auto outDistanceView =
443  outDistances.narrowOutermost(query, numQueriesInTile);
444  auto outIndicesView =
445  outIndices.narrowOutermost(query, numQueriesInTile);
446 
447  runIVFFlatScanTile(queryView,
448  listIdsView,
449  listData,
450  listIndices,
451  indicesOptions,
452  listLengths,
453  *thrustMem[curStream],
454  prefixSumOffsetsView,
455  *allDistances[curStream],
456  heapDistancesView,
457  heapIndicesView,
458  k,
459  l2Distance,
460  useFloat16,
461  outDistanceView,
462  outIndicesView,
463  streams[curStream]);
464 
465  curStream = (curStream + 1) % 2;
466  }
467 
468  streamWait({stream}, streams);
469 }
470 
471 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:223
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:175
Our tensor type.
Definition: Tensor.cuh:29
The class that we use to provide scan specializations.
Definition: IVFFlatScan.cu:45