10 #include "../../FaissAssert.h"
11 #include "../utils/ConversionOperators.cuh"
12 #include "../utils/DeviceDefs.cuh"
13 #include "../utils/DeviceUtils.h"
14 #include "../utils/Float16.cuh"
15 #include "../utils/MathOperators.cuh"
16 #include "../utils/PtxUtils.cuh"
17 #include "../utils/StaticUtils.h"
18 #include "../utils/Reductions.cuh"
20 namespace faiss {
namespace gpu {
32 template <
typename T,
typename TVec,
typename IndexType,
33 int RowTileSize,
bool NormLoop,
bool NormSquared>
35 l2NormRowMajor(Tensor<TVec, 2, true, IndexType> input,
36 Tensor<T, 1, true, IndexType> output) {
37 extern __shared__
char smemByte[];
38 T* smem = (T*) smemByte;
40 IndexType numWarps = utils::divUp(blockDim.x, kWarpSize);
41 IndexType laneId = getLaneId();
42 IndexType warpId = threadIdx.x / kWarpSize;
44 bool lastRowTile = (blockIdx.x == (gridDim.x - 1));
45 IndexType rowStart = RowTileSize * blockIdx.x;
46 T rowNorm[RowTileSize];
50 for (IndexType row = 0; row < input.getSize(0) - rowStart; ++row) {
52 rowNorm[0] = Math<T>::zero();
54 for (IndexType col = threadIdx.x;
55 col < input.getSize(1); col += blockDim.x) {
56 TVec val = input[rowStart + row][col];
57 val = Math<TVec>::mul(val, val);
58 rowNorm[0] = Math<T>::add(rowNorm[0], Math<TVec>::reduceAdd(val));
61 TVec val = input[rowStart + row][threadIdx.x];
62 val = Math<TVec>::mul(val, val);
66 rowNorm[0] = warpReduceAllSum(rowNorm[0]);
68 smem[row * numWarps + warpId] = rowNorm[0];
78 TVec tmp[RowTileSize];
81 for (
int row = 0; row < RowTileSize; ++row) {
82 rowNorm[row] = Math<T>::zero();
85 for (IndexType col = threadIdx.x;
86 col < input.getSize(1); col += blockDim.x) {
88 for (
int row = 0; row < RowTileSize; ++row) {
89 tmp[row] = input[rowStart + row][col];
93 for (
int row = 0; row < RowTileSize; ++row) {
94 tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
98 for (
int row = 0; row < RowTileSize; ++row) {
99 rowNorm[row] = Math<T>::add(rowNorm[row],
100 Math<TVec>::reduceAdd(tmp[row]));
104 TVec tmp[RowTileSize];
108 for (
int row = 0; row < RowTileSize; ++row) {
109 tmp[row] = input[rowStart + row][threadIdx.x];
113 for (
int row = 0; row < RowTileSize; ++row) {
114 tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
118 for (
int row = 0; row < RowTileSize; ++row) {
125 for (
int row = 0; row < RowTileSize; ++row) {
126 rowNorm[row] = warpReduceAllSum(rowNorm[row]);
131 for (
int row = 0; row < RowTileSize; ++row) {
132 smem[row * numWarps + warpId] = rowNorm[row];
142 for (
int row = 0; row < RowTileSize; ++row) {
143 rowNorm[row] = laneId < numWarps ?
144 smem[row * numWarps + laneId] : Math<T>::zero();
148 for (
int row = 0; row < RowTileSize; ++row) {
149 rowNorm[row] = warpReduceAllSum(rowNorm[row]);
155 for (
int row = 0; row < RowTileSize; ++row) {
156 int outCol = rowStart + row;
159 if (outCol < output.getSize(0)) {
161 NormSquared ? rowNorm[row] :
163 sqrtf(ConvertTo<float>::to(rowNorm[row])));
167 NormSquared ? rowNorm[row] :
169 sqrtf(ConvertTo<float>::to(rowNorm[row])));
180 template <
typename T,
typename IndexType,
bool NormSquared>
182 l2NormColMajor(Tensor<T, 2, true, IndexType> input,
183 Tensor<T, 1, true, IndexType> output) {
185 for (IndexType batch = blockIdx.x * blockDim.x + threadIdx.x;
186 batch < input.getSize(1);
187 batch += gridDim.x * blockDim.x) {
191 for (IndexType dim = 0; dim < input.getSize(0); ++dim) {
193 float v = ConvertTo<float>::to(input[dim][batch]);
201 output[batch] = ConvertTo<T>::to(sum);
205 template <
typename T,
typename TVec,
typename IndexType>
206 void runL2Norm(Tensor<T, 2, true, IndexType>& input,
208 Tensor<T, 1, true, IndexType>& output,
210 cudaStream_t stream) {
211 IndexType maxThreads = (IndexType) getMaxThreadsCurrentDevice();
212 constexpr
int rowTileSize = 8;
214 #define RUN_L2_ROW_MAJOR(TYPE_T, TYPE_TVEC, INPUT) \
218 l2NormRowMajor<TYPE_T, TYPE_TVEC, IndexType, rowTileSize, true, true> \
219 <<<grid, block, smem, stream>>>(INPUT, output); \
221 l2NormRowMajor<TYPE_T, TYPE_TVEC, IndexType, rowTileSize, true, false> \
222 <<<grid, block, smem, stream>>>(INPUT, output); \
226 l2NormRowMajor<TYPE_T, TYPE_TVEC, IndexType, rowTileSize, false, true> \
227 <<<grid, block, smem, stream>>>(INPUT, output); \
229 l2NormRowMajor<TYPE_T, TYPE_TVEC, IndexType, rowTileSize, false, false> \
230 <<<grid, block, smem, stream>>>(INPUT, output); \
240 if (input.template canCastResize<TVec>()) {
242 auto inputV = input.template castResize<TVec>();
244 auto dim = inputV.getSize(1);
245 bool normLoop = dim > maxThreads;
246 auto numThreads = min(dim, maxThreads);
248 auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize));
249 auto block = dim3(numThreads);
251 auto smem =
sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
253 RUN_L2_ROW_MAJOR(T, TVec, inputV);
257 auto dim = input.getSize(1);
258 bool normLoop = dim > maxThreads;
259 auto numThreads = min(dim, maxThreads);
261 auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize));
262 auto block = dim3(numThreads);
264 auto smem =
sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
266 RUN_L2_ROW_MAJOR(T, T, input);
280 std::min(utils::divUp(input.getSize(1), (IndexType) block),
284 l2NormColMajor<T, IndexType, true><<<grid, block, 0, stream>>>(
287 l2NormColMajor<T, IndexType, false><<<grid, block, 0, stream>>>(
297 void runL2Norm(Tensor<float, 2, true>& input,
299 Tensor<float, 1, true>& output,
301 cudaStream_t stream) {
302 if (input.canUseIndexType<
int>()) {
303 runL2Norm<float, float4, int>(
304 input, inputRowMajor, output, normSquared, stream);
306 auto inputCast = input.castIndexType<
long>();
307 auto outputCast = output.castIndexType<
long>();
309 runL2Norm<float, float4, long>(
310 inputCast, inputRowMajor, outputCast, normSquared, stream);
314 #ifdef FAISS_USE_FLOAT16
315 void runL2Norm(Tensor<half, 2, true>& input,
317 Tensor<half, 1, true>& output,
319 cudaStream_t stream) {
320 if (input.canUseIndexType<
int>()) {
321 runL2Norm<half, half2, int>(
322 input, inputRowMajor, output, normSquared, stream);
324 auto inputCast = input.castIndexType<
long>();
325 auto outputCast = output.castIndexType<
long>();
327 runL2Norm<half, half2, long>(
328 inputCast, inputRowMajor, outputCast, normSquared, stream);
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)