13 #include "../../FaissAssert.h"
14 #include "../utils/ConversionOperators.cuh"
15 #include "../utils/DeviceDefs.cuh"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/Float16.cuh"
18 #include "../utils/MathOperators.cuh"
19 #include "../utils/PtxUtils.cuh"
20 #include "../utils/StaticUtils.h"
21 #include "../utils/Reductions.cuh"
23 namespace faiss {
namespace gpu {
35 template <
typename T,
typename TVec,
36 int RowTileSize,
bool NormLoop,
bool NormSquared>
37 __global__
void l2Norm(Tensor<TVec, 2, true> input,
38 Tensor<T, 1, true> output) {
39 extern __shared__
char smemByte[];
40 T* smem = (T*) smemByte;
42 int numWarps = utils::divUp(blockDim.x, kWarpSize);
43 int laneId = getLaneId();
44 int warpId = threadIdx.x / kWarpSize;
46 bool lastRowTile = (blockIdx.x == (gridDim.x - 1));
47 int rowStart = RowTileSize * blockIdx.x;
48 T rowNorm[RowTileSize];
52 for (
int row = 0; row < input.getSize(0) - rowStart; ++row) {
54 rowNorm[0] = Math<T>::zero();
56 for (
int col = threadIdx.x; col < input.getSize(1); col += blockDim.x) {
57 TVec val = input[rowStart + row][col];
58 val = Math<TVec>::mul(val, val);
59 rowNorm[0] = Math<T>::add(rowNorm[0], Math<TVec>::reduceAdd(val));
62 TVec val = input[rowStart + row][threadIdx.x];
63 val = Math<TVec>::mul(val, val);
67 rowNorm[0] = warpReduceAllSum(rowNorm[0]);
69 smem[row * numWarps + warpId] = rowNorm[0];
79 TVec tmp[RowTileSize];
82 for (
int row = 0; row < RowTileSize; ++row) {
83 rowNorm[row] = Math<T>::zero();
86 for (
int col = threadIdx.x; col < input.getSize(1); col += blockDim.x) {
88 for (
int row = 0; row < RowTileSize; ++row) {
89 tmp[row] = input[rowStart + row][col];
93 for (
int row = 0; row < RowTileSize; ++row) {
94 tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
98 for (
int row = 0; row < RowTileSize; ++row) {
99 rowNorm[row] = Math<T>::add(rowNorm[row],
100 Math<TVec>::reduceAdd(tmp[row]));
104 TVec tmp[RowTileSize];
108 for (
int row = 0; row < RowTileSize; ++row) {
109 tmp[row] = input[rowStart + row][threadIdx.x];
113 for (
int row = 0; row < RowTileSize; ++row) {
114 tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
118 for (
int row = 0; row < RowTileSize; ++row) {
125 for (
int row = 0; row < RowTileSize; ++row) {
126 rowNorm[row] = warpReduceAllSum(rowNorm[row]);
131 for (
int row = 0; row < RowTileSize; ++row) {
132 smem[row * numWarps + warpId] = rowNorm[row];
142 for (
int row = 0; row < RowTileSize; ++row) {
143 rowNorm[row] = laneId < numWarps ?
144 smem[row * numWarps + laneId] : Math<T>::zero();
148 for (
int row = 0; row < RowTileSize; ++row) {
149 rowNorm[row] = warpReduceAllSum(rowNorm[row]);
155 for (
int row = 0; row < RowTileSize; ++row) {
156 int outCol = rowStart + row;
159 if (outCol < output.getSize(0)) {
161 NormSquared ? rowNorm[row] :
163 sqrtf(ConvertTo<float>::to(rowNorm[row])));
167 NormSquared ? rowNorm[row] :
169 sqrtf(ConvertTo<float>::to(rowNorm[row])));
176 template <
typename T,
typename TVec>
177 void runL2Norm(Tensor<T, 2, true>& input,
178 Tensor<T, 1, true>& output,
180 cudaStream_t stream) {
181 FAISS_ASSERT(input.getSize(0) == output.getSize(0));
183 int maxThreads = getMaxThreadsCurrentDevice();
184 constexpr
int rowTileSize = 8;
186 #define RUN_L2(TYPE_T, TYPE_TVEC, INPUT) \
190 l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, true, true> \
191 <<<grid, block, smem, stream>>>(INPUT, output); \
193 l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, true, false> \
194 <<<grid, block, smem, stream>>>(INPUT, output); \
198 l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, false, true> \
199 <<<grid, block, smem, stream>>>(INPUT, output); \
201 l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, false, false> \
202 <<<grid, block, smem, stream>>>(INPUT, output); \
207 if (input.template canCastResize<TVec>()) {
209 auto inputV = input.template castResize<TVec>();
211 int dim = inputV.getSize(1);
212 bool normLoop = dim > maxThreads;
213 int numThreads = min(dim, maxThreads);
215 auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize));
216 auto block = dim3(numThreads);
218 auto smem =
sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
220 RUN_L2(T, TVec, inputV);
224 int dim = input.getSize(1);
225 bool normLoop = dim > maxThreads;
226 int numThreads = min(dim, maxThreads);
228 auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize));
229 auto block = dim3(numThreads);
231 auto smem =
sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
238 CUDA_VERIFY(cudaGetLastError());
241 void runL2Norm(Tensor<float, 2, true>& input,
242 Tensor<float, 1, true>& output,
244 cudaStream_t stream) {
245 runL2Norm<float, float4>(input, output, normSquared, stream);
248 #ifdef FAISS_USE_FLOAT16
249 void runL2Norm(Tensor<half, 2, true>& input,
250 Tensor<half, 1, true>& output,
252 cudaStream_t stream) {
253 runL2Norm<half, half2>(input, output, normSquared, stream);
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)