Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
VectorResidual.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 #include "VectorResidual.cuh"
11 #include "../../FaissAssert.h"
12 #include "../utils/ConversionOperators.cuh"
13 #include "../utils/DeviceUtils.h"
14 #include "../utils/Tensor.cuh"
15 #include "../utils/StaticUtils.h"
16 #include <math_constants.h> // in CUDA SDK, for CUDART_NAN_F
17 
18 namespace faiss { namespace gpu {
19 
20 template <typename CentroidT, bool LargeDim>
21 __global__ void calcResidual(Tensor<float, 2, true> vecs,
22  Tensor<CentroidT, 2, true> centroids,
23  Tensor<int, 1, true> vecToCentroid,
24  Tensor<float, 2, true> residuals) {
25  auto vec = vecs[blockIdx.x];
26  auto residual = residuals[blockIdx.x];
27 
28  int centroidId = vecToCentroid[blockIdx.x];
29  // Vector could be invalid (containing NaNs), so -1 was the
30  // classified centroid
31  if (centroidId == -1) {
32  if (LargeDim) {
33  for (int i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) {
34  residual[i] = CUDART_NAN_F;
35  }
36  } else {
37  residual[threadIdx.x] = CUDART_NAN_F;
38  }
39 
40  return;
41  }
42 
43  auto centroid = centroids[centroidId];
44 
45  if (LargeDim) {
46  for (int i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) {
47  residual[i] = vec[i] - ConvertTo<float>::to(centroid[i]);
48  }
49  } else {
50  residual[threadIdx.x] = vec[threadIdx.x] -
51  ConvertTo<float>::to(centroid[threadIdx.x]);
52  }
53 }
54 
55 template <typename CentroidT>
56 void calcResidual(Tensor<float, 2, true>& vecs,
57  Tensor<CentroidT, 2, true>& centroids,
58  Tensor<int, 1, true>& vecToCentroid,
59  Tensor<float, 2, true>& residuals,
60  cudaStream_t stream) {
61  FAISS_ASSERT(vecs.getSize(1) == centroids.getSize(1));
62  FAISS_ASSERT(vecs.getSize(1) == residuals.getSize(1));
63  FAISS_ASSERT(vecs.getSize(0) == vecToCentroid.getSize(0));
64  FAISS_ASSERT(vecs.getSize(0) == residuals.getSize(0));
65 
66  dim3 grid(vecs.getSize(0));
67 
68  int maxThreads = getMaxThreadsCurrentDevice();
69  bool largeDim = vecs.getSize(1) > maxThreads;
70  dim3 block(std::min(vecs.getSize(1), maxThreads));
71 
72  if (largeDim) {
73  calcResidual<CentroidT, true><<<grid, block, 0, stream>>>(
74  vecs, centroids, vecToCentroid, residuals);
75  } else {
76  calcResidual<CentroidT, false><<<grid, block, 0, stream>>>(
77  vecs, centroids, vecToCentroid, residuals);
78  }
79 
80  CUDA_TEST_ERROR();
81 }
82 
83 void runCalcResidual(Tensor<float, 2, true>& vecs,
84  Tensor<float, 2, true>& centroids,
85  Tensor<int, 1, true>& vecToCentroid,
86  Tensor<float, 2, true>& residuals,
87  cudaStream_t stream) {
88  calcResidual<float>(vecs, centroids, vecToCentroid, residuals, stream);
89 }
90 
91 #ifdef FAISS_USE_FLOAT16
92 void runCalcResidual(Tensor<float, 2, true>& vecs,
93  Tensor<half, 2, true>& centroids,
94  Tensor<int, 1, true>& vecToCentroid,
95  Tensor<float, 2, true>& residuals,
96  cudaStream_t stream) {
97  calcResidual<half>(vecs, centroids, vecToCentroid, residuals, stream);
98 }
99 #endif
100 
101 } } // namespace