9 #include "L2Select.cuh"
10 #include "../../FaissAssert.h"
12 #include "../utils/DeviceDefs.cuh"
13 #include "../utils/DeviceUtils.h"
14 #include "../utils/MathOperators.cuh"
15 #include "../utils/Pair.cuh"
16 #include "../utils/Reductions.cuh"
17 #include "../utils/Select.cuh"
18 #include "../utils/Tensor.cuh"
19 #include "../utils/StaticUtils.h"
21 namespace faiss {
namespace gpu {
24 template <
typename T,
int kRowsPerBlock,
int kBlockSize>
25 __global__
void l2SelectMin1(Tensor<T, 2, true> productDistances,
26 Tensor<T, 1, true> centroidDistances,
27 Tensor<T, 2, true> outDistances,
28 Tensor<int, 2, true> outIndices) {
30 Pair<T, int> threadMin[kRowsPerBlock];
31 __shared__ Pair<T, int> blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)];
33 T distance[kRowsPerBlock];
36 for (
int i = 0; i < kRowsPerBlock; ++i) {
37 threadMin[i].k = Limits<T>::getMax();
42 int rowStart = blockIdx.x * kRowsPerBlock;
45 bool endRow = (blockIdx.x == gridDim.x - 1);
48 if (productDistances.getSize(0) % kRowsPerBlock == 0) {
54 for (
int row = rowStart; row < productDistances.getSize(0); ++row) {
55 for (
int col = threadIdx.x; col < productDistances.getSize(1);
57 distance[0] = Math<T>::add(centroidDistances[col],
58 productDistances[row][col]);
60 if (Math<T>::lt(distance[0], threadMin[0].k)) {
61 threadMin[0].k = distance[0];
68 blockReduceAll<Pair<T, int>, Min<Pair<T, int> >,
false,
false>(
69 threadMin[0], Min<Pair<T, int> >(), blockMin);
71 if (threadIdx.x == 0) {
72 outDistances[row][0] = threadMin[0].k;
73 outIndices[row][0] = threadMin[0].v;
79 threadMin[0].k = Limits<T>::getMax();
83 for (
int col = threadIdx.x; col < productDistances.getSize(1);
85 T centroidDistance = centroidDistances[col];
88 for (
int row = 0; row < kRowsPerBlock; ++row) {
89 distance[row] = productDistances[rowStart + row][col];
93 for (
int row = 0; row < kRowsPerBlock; ++row) {
94 distance[row] = Math<T>::add(distance[row], centroidDistance);
98 for (
int row = 0; row < kRowsPerBlock; ++row) {
99 if (Math<T>::lt(distance[row], threadMin[row].k)) {
100 threadMin[row].k = distance[row];
101 threadMin[row].v = col;
107 blockReduceAll<kRowsPerBlock, Pair<T, int>, Min<Pair<T, int> >,
false,
false>(
108 threadMin, Min<Pair<T, int> >(), blockMin);
110 if (threadIdx.x == 0) {
112 for (
int row = 0; row < kRowsPerBlock; ++row) {
113 outDistances[rowStart + row][0] = threadMin[row].k;
114 outIndices[rowStart + row][0] = threadMin[row].v;
121 template <
typename T,
int NumWarpQ,
int NumThreadQ,
int ThreadsPerBlock>
122 __global__
void l2SelectMinK(Tensor<T, 2, true> productDistances,
123 Tensor<T, 1, true> centroidDistances,
124 Tensor<T, 2, true> outDistances,
125 Tensor<int, 2, true> outIndices,
128 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
130 __shared__ T smemK[kNumWarps * NumWarpQ];
131 __shared__
int smemV[kNumWarps * NumWarpQ];
133 BlockSelect<T, int, false, Comparator<T>,
134 NumWarpQ, NumThreadQ, ThreadsPerBlock>
135 heap(initK, -1, smemK, smemV, k);
137 int row = blockIdx.x;
140 int limit = utils::roundDown(productDistances.getSize(1), kWarpSize);
143 for (; i < limit; i += blockDim.x) {
144 T v = Math<T>::add(centroidDistances[i],
145 productDistances[row][i]);
149 if (i < productDistances.getSize(1)) {
150 T v = Math<T>::add(centroidDistances[i],
151 productDistances[row][i]);
152 heap.addThreadQ(v, i);
156 for (
int i = threadIdx.x; i < k; i += blockDim.x) {
157 outDistances[row][i] = smemK[i];
158 outIndices[row][i] = smemV[i];
162 template <
typename T>
163 void runL2SelectMin(Tensor<T, 2, true>& productDistances,
164 Tensor<T, 1, true>& centroidDistances,
165 Tensor<T, 2, true>& outDistances,
166 Tensor<int, 2, true>& outIndices,
168 cudaStream_t stream) {
169 FAISS_ASSERT(productDistances.getSize(0) == outDistances.getSize(0));
170 FAISS_ASSERT(productDistances.getSize(0) == outIndices.getSize(0));
171 FAISS_ASSERT(centroidDistances.getSize(0) == productDistances.getSize(1));
172 FAISS_ASSERT(outDistances.getSize(1) == k);
173 FAISS_ASSERT(outIndices.getSize(1) == k);
174 FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
177 constexpr
int kThreadsPerBlock = 256;
178 constexpr
int kRowsPerBlock = 8;
180 auto block = dim3(kThreadsPerBlock);
181 auto grid = dim3(utils::divUp(outDistances.getSize(0), kRowsPerBlock));
183 l2SelectMin1<T, kRowsPerBlock, kThreadsPerBlock>
184 <<<grid, block, 0, stream>>>(productDistances, centroidDistances,
185 outDistances, outIndices);
187 auto grid = dim3(outDistances.getSize(0));
189 #define RUN_L2_SELECT(BLOCK, NUM_WARP_Q, NUM_THREAD_Q) \
191 l2SelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, BLOCK> \
192 <<<grid, BLOCK, 0, stream>>>(productDistances, centroidDistances, \
193 outDistances, outIndices, \
194 k, Limits<T>::getMax()); \
199 RUN_L2_SELECT(128, 32, 2);
200 }
else if (k <= 64) {
201 RUN_L2_SELECT(128, 64, 3);
202 }
else if (k <= 128) {
203 RUN_L2_SELECT(128, 128, 3);
204 }
else if (k <= 256) {
205 RUN_L2_SELECT(128, 256, 4);
206 }
else if (k <= 512) {
207 RUN_L2_SELECT(128, 512, 8);
208 }
else if (k <= 1024) {
209 RUN_L2_SELECT(128, 1024, 8);
211 #if GPU_MAX_SELECTION_K >= 2048
212 }
else if (k <= 2048) {
214 RUN_L2_SELECT(64, 2048, 8);
225 void runL2SelectMin(Tensor<float, 2, true>& productDistances,
226 Tensor<float, 1, true>& centroidDistances,
227 Tensor<float, 2, true>& outDistances,
228 Tensor<int, 2, true>& outIndices,
230 cudaStream_t stream) {
231 runL2SelectMin<float>(productDistances,
239 #ifdef FAISS_USE_FLOAT16
240 void runL2SelectMin(Tensor<half, 2, true>& productDistances,
241 Tensor<half, 1, true>& centroidDistances,
242 Tensor<half, 2, true>& outDistances,
243 Tensor<int, 2, true>& outIndices,
245 cudaStream_t stream) {
246 runL2SelectMin<half>(productDistances,