Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Distance.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "Distance.cuh"
13 #include "BroadcastSum.cuh"
14 #include "L2Norm.cuh"
15 #include "L2Select.cuh"
16 #include "../../FaissAssert.h"
17 #include "../GpuResources.h"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Limits.cuh"
20 #include "../utils/MatrixMult.cuh"
21 #include "../utils/BlockSelectKernel.cuh"
22 
23 #include <memory>
24 #include <thrust/fill.h>
25 #include <thrust/for_each.h>
26 #include <thrust/device_ptr.h>
27 #include <thrust/execution_policy.h>
28 
29 namespace faiss { namespace gpu {
30 
31 constexpr int kDefaultTileSize = 256;
32 
33 template <typename T>
34 void runL2Distance(GpuResources* resources,
35  Tensor<T, 2, true>& centroids,
36  Tensor<T, 2, true>* centroidsTransposed,
37  Tensor<T, 1, true>* centroidNorms,
38  Tensor<T, 2, true>& queries,
39  int k,
40  Tensor<T, 2, true>& outDistances,
41  Tensor<int, 2, true>& outIndices,
42  bool ignoreOutDistances = false,
43  int tileSize = -1) {
44  FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
45  FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
46  FAISS_ASSERT(outDistances.getSize(1) == k);
47  FAISS_ASSERT(outIndices.getSize(1) == k);
48 
49  auto& mem = resources->getMemoryManagerCurrentDevice();
50  auto defaultStream = resources->getDefaultStreamCurrentDevice();
51 
52  // If we're quering against a 0 sized set, just return empty results
53  if (centroids.numElements() == 0) {
54  thrust::fill(thrust::cuda::par.on(defaultStream),
55  outDistances.data(), outDistances.end(),
56  Limits<T>::getMax());
57 
58  thrust::fill(thrust::cuda::par.on(defaultStream),
59  outIndices.data(), outIndices.end(),
60  -1);
61 
62  return;
63  }
64 
65  // If ||c||^2 is not pre-computed, calculate it
66  DeviceTensor<T, 1, true> cNorms;
67  if (!centroidNorms) {
68  cNorms = std::move(DeviceTensor<T, 1, true>(
69  mem,
70  {centroids.getSize(0)}, defaultStream));
71  runL2Norm(centroids, cNorms, true, defaultStream);
72  centroidNorms = &cNorms;
73  }
74 
75  //
76  // Prepare norm vector ||q||^2; ||c||^2 is already pre-computed
77  //
78  int qNormSize[1] = {queries.getSize(0)};
79  DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
80 
81  // ||q||^2
82  runL2Norm(queries, queryNorms, true, defaultStream);
83 
84  //
85  // Handle the problem in row tiles, to avoid excessive temporary
86  // memory requests
87  //
88 
89  FAISS_ASSERT(k <= centroids.getSize(0));
90  FAISS_ASSERT(k <= 1024); // select limitation
91 
92  // To allocate all of (#queries, #centroids) is potentially too much
93  // memory. Limit our total size requested
94  size_t distanceRowSize = centroids.getSize(0) * sizeof(T);
95 
96  // FIXME: parameterize based on # of centroids and DeviceMemory
97  int defaultTileSize = sizeof(T) < 4 ? kDefaultTileSize * 2 : kDefaultTileSize;
98  tileSize = tileSize <= 0 ? defaultTileSize : tileSize;
99 
100  int maxQueriesPerIteration = std::min(tileSize, queries.getSize(0));
101 
102  // Temporary output memory space we'll use
103  DeviceTensor<T, 2, true> distanceBuf1(
104  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
105  DeviceTensor<T, 2, true> distanceBuf2(
106  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
107  DeviceTensor<T, 2, true>* distanceBufs[2] =
108  {&distanceBuf1, &distanceBuf2};
109 
110  auto streams = resources->getAlternateStreamsCurrentDevice();
111  streamWait(streams, {defaultStream});
112 
113  int curStream = 0;
114 
115  for (int i = 0; i < queries.getSize(0); i += maxQueriesPerIteration) {
116  int numQueriesForIteration = std::min(maxQueriesPerIteration,
117  queries.getSize(0) - i);
118 
119  auto distanceBufView =
120  distanceBufs[curStream]->narrowOutermost(0, numQueriesForIteration);
121  auto queryView =
122  queries.narrowOutermost(i, numQueriesForIteration);
123  auto outDistanceView =
124  outDistances.narrowOutermost(i, numQueriesForIteration);
125  auto outIndexView =
126  outIndices.narrowOutermost(i, numQueriesForIteration);
127  auto queryNormNiew =
128  queryNorms.narrowOutermost(i, numQueriesForIteration);
129 
130  // L2 distance is ||c||^2 - 2qc + ||q||^2
131 
132  // -2qc
133  // (query id x dim) x (centroid id, dim)' = (query id, centroid id)
134  runMatrixMult(distanceBufView, false,
135  queryView, false,
136  centroidsTransposed ? *centroidsTransposed : centroids,
137  centroidsTransposed ? false : true,
138  -2.0f, 0.0f,
139  resources->getBlasHandleCurrentDevice(),
140  streams[curStream]);
141 
142  // For L2 distance, we use this fused kernel that performs both
143  // adding ||c||^2 to -2qc and k-selection, so we only need two
144  // passes (one write by the gemm, one read here) over the huge
145  // region of output memory
146  runL2SelectMin(distanceBufView,
147  *centroidNorms,
148  outDistanceView,
149  outIndexView,
150  k,
151  streams[curStream]);
152 
153  if (!ignoreOutDistances) {
154  // expand (query id) to (query id, k) by duplicating along rows
155  // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)
156  runSumAlongRows(queryNormNiew, outDistanceView, streams[curStream]);
157  }
158 
159  curStream = (curStream + 1) % 2;
160  }
161 
162  // Have the desired ordering stream wait on the multi-stream
163  streamWait({defaultStream}, streams);
164 }
165 
166 template <typename T>
167 void runIPDistance(GpuResources* resources,
168  Tensor<T, 2, true>& centroids,
169  Tensor<T, 2, true>* centroidsTransposed,
170  Tensor<T, 2, true>& queries,
171  int k,
172  Tensor<T, 2, true>& outDistances,
173  Tensor<int, 2, true>& outIndices,
174  int tileSize = -1) {
175  FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
176  FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
177  FAISS_ASSERT(outDistances.getSize(1) == k);
178  FAISS_ASSERT(outIndices.getSize(1) == k);
179 
180  auto& mem = resources->getMemoryManagerCurrentDevice();
181  auto defaultStream = resources->getDefaultStreamCurrentDevice();
182 
183  // If we're quering against a 0 sized set, just return empty results
184  if (centroids.numElements() == 0) {
185  thrust::fill(thrust::cuda::par.on(defaultStream),
186  outDistances.data(), outDistances.end(),
187  Limits<T>::getMax());
188 
189  thrust::fill(thrust::cuda::par.on(defaultStream),
190  outIndices.data(), outIndices.end(),
191  -1);
192 
193  return;
194  }
195 
196  //
197  // Handle the problem in row tiles, to avoid excessive temporary
198  // memory requests
199  //
200 
201  FAISS_ASSERT(k <= centroids.getSize(0));
202  FAISS_ASSERT(k <= 1024); // select limitation
203 
204  // To allocate all of (#queries, #centroids) is potentially too much
205  // memory. Limit our total size requested
206  size_t distanceRowSize = centroids.getSize(0) * sizeof(T);
207 
208  // FIXME: parameterize based on # of centroids and DeviceMemory
209  int defaultTileSize = sizeof(T) < 4 ? kDefaultTileSize * 2 : kDefaultTileSize;
210  tileSize = tileSize <= 0 ? defaultTileSize : tileSize;
211 
212  int maxQueriesPerIteration = std::min(tileSize, queries.getSize(0));
213 
214  // Temporary output memory space we'll use
215  DeviceTensor<T, 2, true> distanceBuf1(
216  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
217  DeviceTensor<T, 2, true> distanceBuf2(
218  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
219  DeviceTensor<T, 2, true>* distanceBufs[2] =
220  {&distanceBuf1, &distanceBuf2};
221 
222  auto streams = resources->getAlternateStreamsCurrentDevice();
223  streamWait(streams, {defaultStream});
224 
225  int curStream = 0;
226 
227  for (int i = 0; i < queries.getSize(0); i += maxQueriesPerIteration) {
228  int numQueriesForIteration = std::min(maxQueriesPerIteration,
229  queries.getSize(0) - i);
230 
231  auto distanceBufView =
232  distanceBufs[curStream]->narrowOutermost(0, numQueriesForIteration);
233  auto queryView =
234  queries.narrowOutermost(i, numQueriesForIteration);
235  auto outDistanceView =
236  outDistances.narrowOutermost(i, numQueriesForIteration);
237  auto outIndexView =
238  outIndices.narrowOutermost(i, numQueriesForIteration);
239 
240  // (query id x dim) x (centroid id, dim)' = (query id, centroid id)
241  runMatrixMult(distanceBufView, false,
242  queryView, false,
243  centroidsTransposed ? *centroidsTransposed : centroids,
244  centroidsTransposed ? false : true,
245  1.0f, 0.0f,
246  resources->getBlasHandleCurrentDevice(),
247  streams[curStream]);
248 
249  // top-k of dot products
250  // (query id, top k centroids)
251  runBlockSelect(distanceBufView,
252  outDistanceView,
253  outIndexView,
254  true, k, streams[curStream]);
255 
256  curStream = (curStream + 1) % 2;
257  }
258 
259  streamWait({defaultStream}, streams);
260 }
261 
262 //
263 // Instantiations of the distance templates
264 //
265 
266 void
267 runIPDistance(GpuResources* resources,
268  Tensor<float, 2, true>& vectors,
269  Tensor<float, 2, true>* vectorsTransposed,
270  Tensor<float, 2, true>& queries,
271  int k,
272  Tensor<float, 2, true>& outDistances,
273  Tensor<int, 2, true>& outIndices,
274  int tileSize) {
275  runIPDistance<float>(resources,
276  vectors,
277  vectorsTransposed,
278  queries,
279  k,
280  outDistances,
281  outIndices,
282  tileSize);
283 }
284 
285 #ifdef FAISS_USE_FLOAT16
286 void
287 runIPDistance(GpuResources* resources,
288  Tensor<half, 2, true>& vectors,
289  Tensor<half, 2, true>* vectorsTransposed,
290  Tensor<half, 2, true>& queries,
291  int k,
292  Tensor<half, 2, true>& outDistances,
293  Tensor<int, 2, true>& outIndices,
294  int tileSize) {
295  runIPDistance<half>(resources,
296  vectors,
297  vectorsTransposed,
298  queries,
299  k,
300  outDistances,
301  outIndices,
302  tileSize);
303 }
304 #endif
305 
306 void
307 runL2Distance(GpuResources* resources,
308  Tensor<float, 2, true>& vectors,
309  Tensor<float, 2, true>* vectorsTransposed,
310  Tensor<float, 1, true>* vectorNorms,
311  Tensor<float, 2, true>& queries,
312  int k,
313  Tensor<float, 2, true>& outDistances,
314  Tensor<int, 2, true>& outIndices,
315  bool ignoreOutDistances,
316  int tileSize) {
317  runL2Distance<float>(resources,
318  vectors,
319  vectorsTransposed,
320  vectorNorms,
321  queries,
322  k,
323  outDistances,
324  outIndices,
325  ignoreOutDistances,
326  tileSize);
327 }
328 
329 #ifdef FAISS_USE_FLOAT16
330 void
331 runL2Distance(GpuResources* resources,
332  Tensor<half, 2, true>& vectors,
333  Tensor<half, 2, true>* vectorsTransposed,
334  Tensor<half, 1, true>* vectorNorms,
335  Tensor<half, 2, true>& queries,
336  int k,
337  Tensor<half, 2, true>& outDistances,
338  Tensor<int, 2, true>& outIndices,
339  bool ignoreOutDistances,
340  int tileSize) {
341  runL2Distance<half>(resources,
342  vectors,
343  vectorsTransposed,
344  vectorNorms,
345  queries,
346  k,
347  outDistances,
348  outIndices,
349  ignoreOutDistances,
350  tileSize);
351 }
352 #endif
353 
354 } } // namespace