Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
VectorResidual.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include "VectorResidual.cuh"
9 #include "../../FaissAssert.h"
10 #include "../utils/ConversionOperators.cuh"
11 #include "../utils/DeviceUtils.h"
12 #include "../utils/Tensor.cuh"
13 #include "../utils/StaticUtils.h"
14 #include <math_constants.h> // in CUDA SDK, for CUDART_NAN_F
15 
16 namespace faiss { namespace gpu {
17 
18 template <typename CentroidT, bool LargeDim>
19 __global__ void calcResidual(Tensor<float, 2, true> vecs,
20  Tensor<CentroidT, 2, true> centroids,
21  Tensor<int, 1, true> vecToCentroid,
22  Tensor<float, 2, true> residuals) {
23  auto vec = vecs[blockIdx.x];
24  auto residual = residuals[blockIdx.x];
25 
26  int centroidId = vecToCentroid[blockIdx.x];
27  // Vector could be invalid (containing NaNs), so -1 was the
28  // classified centroid
29  if (centroidId == -1) {
30  if (LargeDim) {
31  for (int i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) {
32  residual[i] = CUDART_NAN_F;
33  }
34  } else {
35  residual[threadIdx.x] = CUDART_NAN_F;
36  }
37 
38  return;
39  }
40 
41  auto centroid = centroids[centroidId];
42 
43  if (LargeDim) {
44  for (int i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) {
45  residual[i] = vec[i] - ConvertTo<float>::to(centroid[i]);
46  }
47  } else {
48  residual[threadIdx.x] = vec[threadIdx.x] -
49  ConvertTo<float>::to(centroid[threadIdx.x]);
50  }
51 }
52 
53 template <typename CentroidT>
54 void calcResidual(Tensor<float, 2, true>& vecs,
55  Tensor<CentroidT, 2, true>& centroids,
56  Tensor<int, 1, true>& vecToCentroid,
57  Tensor<float, 2, true>& residuals,
58  cudaStream_t stream) {
59  FAISS_ASSERT(vecs.getSize(1) == centroids.getSize(1));
60  FAISS_ASSERT(vecs.getSize(1) == residuals.getSize(1));
61  FAISS_ASSERT(vecs.getSize(0) == vecToCentroid.getSize(0));
62  FAISS_ASSERT(vecs.getSize(0) == residuals.getSize(0));
63 
64  dim3 grid(vecs.getSize(0));
65 
66  int maxThreads = getMaxThreadsCurrentDevice();
67  bool largeDim = vecs.getSize(1) > maxThreads;
68  dim3 block(std::min(vecs.getSize(1), maxThreads));
69 
70  if (largeDim) {
71  calcResidual<CentroidT, true><<<grid, block, 0, stream>>>(
72  vecs, centroids, vecToCentroid, residuals);
73  } else {
74  calcResidual<CentroidT, false><<<grid, block, 0, stream>>>(
75  vecs, centroids, vecToCentroid, residuals);
76  }
77 
78  CUDA_TEST_ERROR();
79 }
80 
81 void runCalcResidual(Tensor<float, 2, true>& vecs,
82  Tensor<float, 2, true>& centroids,
83  Tensor<int, 1, true>& vecToCentroid,
84  Tensor<float, 2, true>& residuals,
85  cudaStream_t stream) {
86  calcResidual<float>(vecs, centroids, vecToCentroid, residuals, stream);
87 }
88 
89 #ifdef FAISS_USE_FLOAT16
90 void runCalcResidual(Tensor<float, 2, true>& vecs,
91  Tensor<half, 2, true>& centroids,
92  Tensor<int, 1, true>& vecToCentroid,
93  Tensor<float, 2, true>& residuals,
94  cudaStream_t stream) {
95  calcResidual<half>(vecs, centroids, vecToCentroid, residuals, stream);
96 }
97 #endif
98 
99 } } // namespace