12 #include "../../FaissAssert.h"
13 #include "../utils/ConversionOperators.cuh"
14 #include "../utils/DeviceDefs.cuh"
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/Float16.cuh"
17 #include "../utils/MathOperators.cuh"
18 #include "../utils/PtxUtils.cuh"
19 #include "../utils/StaticUtils.h"
20 #include "../utils/Reductions.cuh"
22 namespace faiss {
namespace gpu {
34 template <
typename T,
typename TVec,
35 int RowTileSize,
bool NormLoop,
bool NormSquared>
36 __global__
void l2Norm(Tensor<TVec, 2, true> input,
37 Tensor<T, 1, true> output) {
38 extern __shared__
char smemByte[];
39 T* smem = (T*) smemByte;
41 int numWarps = utils::divUp(blockDim.x, kWarpSize);
42 int laneId = getLaneId();
43 int warpId = threadIdx.x / kWarpSize;
45 bool lastRowTile = (blockIdx.x == (gridDim.x - 1));
46 int rowStart = RowTileSize * blockIdx.x;
47 T rowNorm[RowTileSize];
51 for (
int row = 0; row < input.getSize(0) - rowStart; ++row) {
53 rowNorm[0] = Math<T>::zero();
55 for (
int col = threadIdx.x; col < input.getSize(1); col += blockDim.x) {
56 TVec val = input[rowStart + row][col];
57 val = Math<TVec>::mul(val, val);
58 rowNorm[0] = Math<T>::add(rowNorm[0], Math<TVec>::reduceAdd(val));
61 TVec val = input[rowStart + row][threadIdx.x];
62 val = Math<TVec>::mul(val, val);
66 rowNorm[0] = warpReduceAllSum(rowNorm[0]);
68 smem[row * numWarps + warpId] = rowNorm[0];
78 TVec tmp[RowTileSize];
81 for (
int row = 0; row < RowTileSize; ++row) {
82 rowNorm[row] = Math<T>::zero();
85 for (
int col = threadIdx.x; col < input.getSize(1); col += blockDim.x) {
87 for (
int row = 0; row < RowTileSize; ++row) {
88 tmp[row] = input[rowStart + row][col];
92 for (
int row = 0; row < RowTileSize; ++row) {
93 tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
97 for (
int row = 0; row < RowTileSize; ++row) {
98 rowNorm[row] = Math<T>::add(rowNorm[row],
99 Math<TVec>::reduceAdd(tmp[row]));
103 TVec tmp[RowTileSize];
107 for (
int row = 0; row < RowTileSize; ++row) {
108 tmp[row] = input[rowStart + row][threadIdx.x];
112 for (
int row = 0; row < RowTileSize; ++row) {
113 tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
117 for (
int row = 0; row < RowTileSize; ++row) {
124 for (
int row = 0; row < RowTileSize; ++row) {
125 rowNorm[row] = warpReduceAllSum(rowNorm[row]);
130 for (
int row = 0; row < RowTileSize; ++row) {
131 smem[row * numWarps + warpId] = rowNorm[row];
141 for (
int row = 0; row < RowTileSize; ++row) {
142 rowNorm[row] = laneId < numWarps ?
143 smem[row * numWarps + laneId] : Math<T>::zero();
147 for (
int row = 0; row < RowTileSize; ++row) {
148 rowNorm[row] = warpReduceAllSum(rowNorm[row]);
154 for (
int row = 0; row < RowTileSize; ++row) {
155 int outCol = rowStart + row;
158 if (outCol < output.getSize(0)) {
160 NormSquared ? rowNorm[row] :
162 sqrtf(ConvertTo<float>::to(rowNorm[row])));
166 NormSquared ? rowNorm[row] :
168 sqrtf(ConvertTo<float>::to(rowNorm[row])));
175 template <
typename T,
typename TVec>
176 void runL2Norm(Tensor<T, 2, true>& input,
177 Tensor<T, 1, true>& output,
179 cudaStream_t stream) {
180 FAISS_ASSERT(input.getSize(0) == output.getSize(0));
182 int maxThreads = getMaxThreadsCurrentDevice();
183 constexpr
int rowTileSize = 8;
185 #define RUN_L2(TYPE_T, TYPE_TVEC, INPUT) \
189 l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, true, true> \
190 <<<grid, block, smem, stream>>>(INPUT, output); \
192 l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, true, false> \
193 <<<grid, block, smem, stream>>>(INPUT, output); \
197 l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, false, true> \
198 <<<grid, block, smem, stream>>>(INPUT, output); \
200 l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, false, false> \
201 <<<grid, block, smem, stream>>>(INPUT, output); \
206 if (input.template canCastResize<TVec>()) {
208 auto inputV = input.template castResize<TVec>();
210 int dim = inputV.getSize(1);
211 bool normLoop = dim > maxThreads;
212 int numThreads = min(dim, maxThreads);
214 auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize));
215 auto block = dim3(numThreads);
217 auto smem =
sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
219 RUN_L2(T, TVec, inputV);
223 int dim = input.getSize(1);
224 bool normLoop = dim > maxThreads;
225 int numThreads = min(dim, maxThreads);
227 auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize));
228 auto block = dim3(numThreads);
230 auto smem =
sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
240 void runL2Norm(Tensor<float, 2, true>& input,
241 Tensor<float, 1, true>& output,
243 cudaStream_t stream) {
244 runL2Norm<float, float4>(input, output, normSquared, stream);
247 #ifdef FAISS_USE_FLOAT16
248 void runL2Norm(Tensor<half, 2, true>& input,
249 Tensor<half, 1, true>& output,
251 cudaStream_t stream) {
252 runL2Norm<half, half2>(input, output, normSquared, stream);
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)