Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BinaryDistance.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include "../utils/DeviceTensor.cuh"
9 #include "../utils/DeviceDefs.cuh"
10 #include "../utils/DeviceUtils.h"
11 #include "../utils/Select.cuh"
12 
13 namespace faiss { namespace gpu {
14 
15 // Number of warps that the kernel is instantiated with
16 constexpr int kWarps = 8;
17 constexpr int kLanes = kWarpSize;
18 
19 constexpr int kMaxDistance = std::numeric_limits<int>::max();
20 
21 // Performs a binary matrix multiplication, returning the lowest k results in
22 // `vecs` for each `query` in terms of Hamming distance (a fused kernel)
23 // Each warp calculates distance for a single query
24 template <int NumWarpQ,
25  int NumThreadQ,
26  typename BinaryType>
27 __launch_bounds__(kWarps * kLanes)
28 __global__ void binaryDistanceAnySize(const Tensor<BinaryType, 2, true> vecs,
29  const Tensor<BinaryType, 2, true> query,
30  Tensor<int, 2, true> outK,
31  Tensor<int, 2, true> outV,
32  int k) {
33  // A matrix tile (query, k)
34  __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict
35 
36  // B matrix tile (vec, k)
37  __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict
38 
39  WarpSelect<int, int, false, Comparator<int>,
40  NumWarpQ, NumThreadQ, kWarps * kLanes>
41  heap(kMaxDistance, -1, k);
42 
43  int warpId = threadIdx.y;
44  int laneId = threadIdx.x;
45 
46  // Each warp handles a single query
47  int warpQuery = blockIdx.x * kWarps + warpId;
48  bool queryInBounds = warpQuery < query.getSize(0);
49 
50  // Each warp loops through the entire chunk of vectors
51  for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) {
52  int threadDistance = 0;
53 
54  // Reduction dimension
55  for (int blockK = 0; blockK < vecs.getSize(1); blockK += kLanes) {
56  int laneK = blockK + laneId;
57  bool kInBounds = laneK < vecs.getSize(1);
58 
59  queryTile[warpId][laneId] = queryInBounds && kInBounds ?
60  query[warpQuery][laneK] : 0;
61 
62  // kWarps warps are responsible for loading 32 vecs
63 #pragma unroll
64  for (int i = 0; i < kLanes / kWarps; ++i) {
65  int warpVec = i * kWarps + warpId;
66  int vec = blockVec + warpVec;
67  bool vecInBounds = vec < vecs.getSize(0);
68 
69  vecTile[warpVec][laneId] = vecInBounds && kInBounds ?
70  vecs[vec][laneK] : 0;
71  }
72 
73  __syncthreads();
74 
75  // Compare distances
76 #pragma unroll
77  for (int i = 0; i < kLanes; ++i) {
78  threadDistance += __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
79  }
80 
81  __syncthreads();
82  }
83 
84  // Lanes within a warp are different vec results against the same query
85  // Only submit distances which represent real (query, vec) pairs
86  bool valInBounds = queryInBounds && (blockVec + laneId < vecs.getSize(0));
87  threadDistance = valInBounds ? threadDistance : kMaxDistance;
88  int id = valInBounds ? blockVec + laneId : -1;
89 
90  heap.add(threadDistance, id);
91  }
92 
93  heap.reduce();
94 
95  if (warpQuery < query.getSize(0)) {
96  heap.writeOut(outK[warpQuery].data(),
97  outV[warpQuery].data(),
98  k);
99  }
100 }
101 
102 // Version of the kernel that avoids a loop over the reduction dimension, and
103 // thus avoids reloading the query vectors
104 template <int NumWarpQ,
105  int NumThreadQ,
106  typename BinaryType,
107  int ReductionLimit = kLanes>
108 __global__ void
109 __launch_bounds__(kWarps * kLanes)
110 binaryDistanceLimitSize(const Tensor<BinaryType, 2, true> vecs,
111  const Tensor<BinaryType, 2, true> query,
112  Tensor<int, 2, true> outK,
113  Tensor<int, 2, true> outV,
114  int k) {
115  // A matrix tile (query, k)
116  __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict
117 
118  // B matrix tile (vec, k)
119  __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict
120 
121  WarpSelect<int, int, false, Comparator<int>,
122  NumWarpQ, NumThreadQ, kWarps * kLanes>
123  heap(kMaxDistance, -1, k);
124 
125  int warpId = threadIdx.y;
126  int laneId = threadIdx.x;
127 
128  // Each warp handles a single query
129  int laneK = laneId;
130  int warpQuery = blockIdx.x * kWarps + warpId;
131  bool kInBounds = laneK < vecs.getSize(1);
132  bool queryInBounds = warpQuery < query.getSize(0);
133 
134 
135  queryTile[warpId][laneId] = queryInBounds && kInBounds ?
136  query[warpQuery][laneK] : 0;
137 
138  // Each warp loops through the entire chunk of vectors
139  for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) {
140  int threadDistance = 0;
141 
142  // kWarps warps are responsible for loading 32 vecs
143 #pragma unroll
144  for (int i = 0; i < kLanes / kWarps; ++i) {
145  int warpVec = i * kWarps + warpId;
146  int vec = blockVec + warpVec;
147  bool vecInBounds = vec < vecs.getSize(0);
148 
149  vecTile[warpVec][laneId] = vecInBounds && kInBounds ?
150  vecs[vec][laneK] : 0;
151  }
152 
153  __syncthreads();
154 
155  // Compare distances
156 #pragma unroll
157  for (int i = 0; i < ReductionLimit; ++i) {
158  threadDistance += __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
159  }
160 
161  __syncthreads();
162 
163  // Lanes within a warp are different vec results against the same query
164  // Only submit distances which represent real (query, vec) pairs
165  bool valInBounds = queryInBounds && (blockVec + laneId < vecs.getSize(0));
166  threadDistance = valInBounds ? threadDistance : kMaxDistance;
167  int id = valInBounds ? blockVec + laneId : -1;
168 
169  heap.add(threadDistance, id);
170  }
171 
172  heap.reduce();
173 
174  if (warpQuery < query.getSize(0)) {
175  heap.writeOut(outK[warpQuery].data(),
176  outV[warpQuery].data(),
177  k);
178  }
179 }
180 
181 template <typename BinaryType>
182 void runBinaryDistanceAnySize(Tensor<BinaryType, 2, true>& vecs,
183  Tensor<BinaryType, 2, true>& query,
184  Tensor<int, 2, true>& outK,
185  Tensor<int, 2, true>& outV,
186  int k, cudaStream_t stream) {
187  dim3 grid(utils::divUp(query.getSize(0), kWarps));
188  dim3 block(kLanes, kWarps);
189 
190  if (k == 1) {
191  binaryDistanceAnySize<1, 1, BinaryType>
192  <<<grid, block, 0, stream>>>(
193  vecs, query, outK, outV, k);
194  } else if (k <= 32) {
195  binaryDistanceAnySize<32, 2, BinaryType>
196  <<<grid, block, 0, stream>>>(
197  vecs, query, outK, outV, k);
198  } else if (k <= 64) {
199  binaryDistanceAnySize<64, 3, BinaryType>
200  <<<grid, block, 0, stream>>>(
201  vecs, query, outK, outV, k);
202  } else if (k <= 128) {
203  binaryDistanceAnySize<128, 3, BinaryType>
204  <<<grid, block, 0, stream>>>(
205  vecs, query, outK, outV, k);
206  } else if (k <= 256) {
207  binaryDistanceAnySize<256, 4, BinaryType>
208  <<<grid, block, 0, stream>>>(
209  vecs, query, outK, outV, k);
210  } else if (k <= 512) {
211  binaryDistanceAnySize<512, 8, BinaryType>
212  <<<grid, block, 0, stream>>>(
213  vecs, query, outK, outV, k);
214  } else if (k <= 1024) {
215  binaryDistanceAnySize<1024, 8, BinaryType>
216  <<<grid, block, 0, stream>>>(
217  vecs, query, outK, outV, k);
218  }
219 #if GPU_MAX_SELECTION_K >= 2048
220  else if (k <= 2048) {
221  binaryDistanceAnySize<2048, 8, BinaryType>
222  <<<grid, block, 0, stream>>>(
223  vecs, query, outK, outV, k);
224  }
225 #endif
226 }
227 
228 template <typename BinaryType, int ReductionLimit>
229 void runBinaryDistanceLimitSize(Tensor<BinaryType, 2, true>& vecs,
230  Tensor<BinaryType, 2, true>& query,
231  Tensor<int, 2, true>& outK,
232  Tensor<int, 2, true>& outV,
233  int k, cudaStream_t stream) {
234  dim3 grid(utils::divUp(query.getSize(0), kWarps));
235  dim3 block(kLanes, kWarps);
236 
237  if (k == 1) {
238  binaryDistanceLimitSize<1, 1, BinaryType, ReductionLimit>
239  <<<grid, block, 0, stream>>>(
240  vecs, query, outK, outV, k);
241  } else if (k <= 32) {
242  binaryDistanceLimitSize<32, 2, BinaryType, ReductionLimit>
243  <<<grid, block, 0, stream>>>(
244  vecs, query, outK, outV, k);
245  } else if (k <= 64) {
246  binaryDistanceLimitSize<64, 3, BinaryType, ReductionLimit>
247  <<<grid, block, 0, stream>>>(
248  vecs, query, outK, outV, k);
249  } else if (k <= 128) {
250  binaryDistanceLimitSize<128, 3, BinaryType, ReductionLimit>
251  <<<grid, block, 0, stream>>>(
252  vecs, query, outK, outV, k);
253  } else if (k <= 256) {
254  binaryDistanceLimitSize<256, 4, BinaryType, ReductionLimit>
255  <<<grid, block, 0, stream>>>(
256  vecs, query, outK, outV, k);
257  } else if (k <= 512) {
258  binaryDistanceLimitSize<512, 8, BinaryType, ReductionLimit>
259  <<<grid, block, 0, stream>>>(
260  vecs, query, outK, outV, k);
261  } else if (k <= 1024) {
262  binaryDistanceLimitSize<1024, 8, BinaryType, ReductionLimit>
263  <<<grid, block, 0, stream>>>(
264  vecs, query, outK, outV, k);
265  }
266 #if GPU_MAX_SELECTION_K >= 2048
267  else if (k <= 2048) {
268  binaryDistanceLimitSize<2048, 8, BinaryType, ReductionLimit>
269  <<<grid, block, 0, stream>>>(
270  vecs, query, outK, outV, k);
271  }
272 #endif
273 }
274 
275 void runBinaryDistance(Tensor<unsigned char, 2, true>& vecs,
276  Tensor<unsigned char, 2, true>& query,
277  Tensor<int, 2, true>& outK,
278  Tensor<int, 2, true>& outV,
279  int k, cudaStream_t stream) {
280  FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
281  FAISS_ASSERT(vecs.getSize(1) == query.getSize(1));
282 
283  FAISS_ASSERT(outK.getSize(1) == k);
284  FAISS_ASSERT(outV.getSize(1) == k);
285 
286  // For the optimized uint32 kernel, we handle 32 * 8 = 256 max dims
287  constexpr int kReductionLimit32 = 8;
288 
289  // For the optimized uint8 kernel, we handle 8 * 16 = 128 max dims
290  constexpr int kReductionLimit8 = 16;
291 
292  // All other cases (large or small) go through the general kernel
293 
294  if (vecs.getSize(1) % sizeof(unsigned int) == 0 &&
295  (vecs.getSize(1) / sizeof(unsigned int)) <= kReductionLimit32) {
296  auto vecs32 = vecs.castResize<unsigned int>();
297  auto query32 = query.castResize<unsigned int>();
298 
299  // Optimize for vectors with dimensions a multiple of 32 that are less than
300  // 32 * kReductionLimit (256) dimensions in size
301  runBinaryDistanceLimitSize<unsigned int, kReductionLimit32>(
302  vecs32, query32, outK, outV, k, stream);
303 
304  } else if (vecs.getSize(1) <= kReductionLimit8) {
305  // Optimize for vectors with dimensions a multiple of 32 that are less than
306  // 32 * kReductionLimit (256) dimensions in size
307  runBinaryDistanceLimitSize<unsigned char, kReductionLimit8>(
308  vecs, query, outK, outV, k, stream);
309  } else {
310  // Arbitrary size kernel
311  runBinaryDistanceAnySize<unsigned char>(
312  vecs, query, outK, outV, k, stream);
313  }
314 }
315 
316 } } // namespace