Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Transpose.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #pragma once
10 
11 #include "../../FaissAssert.h"
12 #include "Tensor.cuh"
13 #include "DeviceUtils.h"
14 #include "StaticUtils.h"
15 #include <cuda.h>
16 
17 namespace faiss { namespace gpu {
18 
19 template <typename T, typename IndexT>
20 struct TensorInfo {
21  static constexpr int kMaxDims = 8;
22 
23  T* data;
24  IndexT sizes[kMaxDims];
25  IndexT strides[kMaxDims];
26  int dims;
27 };
28 
29 template <typename T, typename IndexT, int Dim>
31  __device__ inline static unsigned int get(const TensorInfo<T, IndexT>& info,
32  IndexT linearId) {
33  IndexT offset = 0;
34 
35 #pragma unroll
36  for (int i = Dim - 1; i >= 0; --i) {
37  IndexT curDimIndex = linearId % info.sizes[i];
38  IndexT curDimOffset = curDimIndex * info.strides[i];
39 
40  offset += curDimOffset;
41 
42  if (i > 0) {
43  linearId /= info.sizes[i];
44  }
45  }
46 
47  return offset;
48  }
49 };
50 
51 template <typename T, typename IndexT>
52 struct TensorInfoOffset<T, IndexT, -1> {
53  __device__ inline static unsigned int get(const TensorInfo<T, IndexT>& info,
54  IndexT linearId) {
55  return linearId;
56  }
57 };
58 
59 template <typename T, typename IndexT, int Dim>
60 TensorInfo<T, IndexT> getTensorInfo(const Tensor<T, Dim, true>& t) {
62 
63  for (int i = 0; i < Dim; ++i) {
64  info.sizes[i] = (IndexT) t.getSize(i);
65  info.strides[i] = (IndexT) t.getStride(i);
66  }
67 
68  info.data = t.data();
69  info.dims = Dim;
70 
71  return info;
72 }
73 
74 template <typename T, typename IndexT, int DimInput, int DimOutput>
75 __global__ void transposeAny(TensorInfo<T, IndexT> input,
76  TensorInfo<T, IndexT> output,
77  IndexT totalSize) {
78  for (IndexT i = blockIdx.x * blockDim.x + threadIdx.x;
79  i < totalSize;
80  i += gridDim.x + blockDim.x) {
81  auto inputOffset = TensorInfoOffset<T, IndexT, DimInput>::get(input, i);
82  auto outputOffset = TensorInfoOffset<T, IndexT, DimOutput>::get(output, i);
83 
84 #if __CUDA_ARCH__ >= 350
85  output.data[outputOffset] = __ldg(&input.data[inputOffset]);
86 #else
87  output.data[outputOffset] = input.data[inputOffset];
88 #endif
89  }
90 }
91 
92 /// Performs an out-of-place transposition between any two dimensions.
93 /// Best performance is if the transposed dimensions are not
94 /// innermost, since the reads and writes will be coalesced.
95 /// Could include a shared memory transposition if the dimensions
96 /// being transposed are innermost, but would require support for
97 /// arbitrary rectangular matrices.
98 /// This linearized implementation seems to perform well enough,
99 /// especially for cases that we care about (outer dimension
100 /// transpositions).
101 template <typename T, int Dim>
102 void runTransposeAny(Tensor<T, Dim, true>& in,
103  int dim1, int dim2,
104  Tensor<T, Dim, true>& out,
105  cudaStream_t stream) {
106  static_assert(Dim <= TensorInfo<T, unsigned int>::kMaxDims,
107  "too many dimensions");
108 
109  FAISS_ASSERT(dim1 != dim2);
110  FAISS_ASSERT(dim1 < Dim && dim2 < Dim);
111 
112  int outSize[Dim];
113 
114  for (int i = 0; i < Dim; ++i) {
115  outSize[i] = in.getSize(i);
116  }
117 
118  std::swap(outSize[dim1], outSize[dim2]);
119 
120  for (int i = 0; i < Dim; ++i) {
121  FAISS_ASSERT(out.getSize(i) == outSize[i]);
122  }
123 
124  size_t totalSize = in.numElements();
125  size_t block = std::min((size_t) getMaxThreadsCurrentDevice(), totalSize);
126 
127  if (totalSize <= (size_t) std::numeric_limits<int>::max()) {
128  // div/mod seems faster with unsigned types
129  auto inInfo = getTensorInfo<T, unsigned int, Dim>(in);
130  auto outInfo = getTensorInfo<T, unsigned int, Dim>(out);
131 
132  std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]);
133  std::swap(inInfo.strides[dim1], inInfo.strides[dim2]);
134 
135  auto grid = std::min(utils::divUp(totalSize, block), (size_t) 4096);
136 
137  transposeAny<T, unsigned int, Dim, -1>
138  <<<grid, block, 0, stream>>>(inInfo, outInfo, totalSize);
139  } else {
140  auto inInfo = getTensorInfo<T, unsigned long, Dim>(in);
141  auto outInfo = getTensorInfo<T, unsigned long, Dim>(out);
142 
143  std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]);
144  std::swap(inInfo.strides[dim1], inInfo.strides[dim2]);
145 
146  auto grid = std::min(utils::divUp(totalSize, block), (size_t) 4096);
147 
148  transposeAny<T, unsigned long, Dim, -1>
149  <<<grid, block, 0, stream>>>(inInfo, outInfo, totalSize);
150  }
151  CUDA_TEST_ERROR();
152 }
153 
154 } } // namespace
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:222
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:174
Our tensor type.
Definition: Tensor.cuh:28
__host__ __device__ IndexT getStride(int i) const
Definition: Tensor.cuh:228