12 #include "../../FaissAssert.h"
13 #include "DeviceUtils.h"
15 namespace faiss {
namespace gpu {
17 template <
typename T,
int Dim,
bool Contig,
18 typename IndexT,
template <
typename U>
class PtrTraits>
22 static_assert(Dim > 0,
"must have > 0 dimensions");
24 for (
int i = 0; i < Dim; ++i) {
30 template <
typename T,
int Dim,
bool Contig,
31 typename IndexT,
template <
typename U>
class PtrTraits>
36 data_ = t.data_; t.data_ =
nullptr;
37 for (
int i = 0; i < Dim; ++i) {
38 stride_[i] = t.stride_[i]; t.stride_[i] = 0;
39 size_[i] = t.size_[i]; t.size_[i] = 0;
45 template <
typename T,
int Dim,
bool Contig,
46 typename IndexT,
template <
typename U>
class PtrTraits>
49 Tensor(DataPtrType data,
const IndexT sizes[Dim])
51 static_assert(Dim > 0,
"must have > 0 dimensions");
53 for (
int i = 0; i < Dim; ++i) {
58 for (
int i = Dim - 2; i >= 0; --i) {
63 template <
typename T,
int Dim,
bool Contig,
64 typename IndexT,
template <
typename U>
class PtrTraits>
67 Tensor(DataPtrType data, std::initializer_list<IndexT> sizes)
69 assert(sizes.size() == Dim);
70 static_assert(Dim > 0,
"must have > 0 dimensions");
73 for (
auto s : sizes) {
78 for (
int j = Dim - 2; j >= 0; --j) {
84 template <
typename T,
int Dim,
bool Contig,
85 typename IndexT,
template <
typename U>
class PtrTraits>
88 DataPtrType data,
const IndexT sizes[Dim],
const IndexT strides[Dim])
90 static_assert(Dim > 0,
"must have > 0 dimensions");
92 for (
int i = 0; i < Dim; ++i) {
98 template <
typename T,
int Dim,
bool Contig,
99 typename IndexT,
template <
typename U>
class PtrTraits>
103 cudaStream_t stream) {
104 static_assert(Contig,
"only contiguous tensors handled");
109 FAISS_ASSERT(this->numElements() == t.
numElements());
112 FAISS_ASSERT(this->data_);
113 FAISS_ASSERT(t.
data());
115 int ourDev = getDeviceForAddress(this->data_);
116 int tDev = getDeviceForAddress(t.
data());
119 CUDA_VERIFY(cudaMemcpyAsync(this->data_,
121 this->getSizeInBytes(),
122 ourDev == -1 ? cudaMemcpyHostToHost :
123 cudaMemcpyHostToDevice,
126 CUDA_VERIFY(cudaMemcpyAsync(this->data_,
128 this->getSizeInBytes(),
129 ourDev == -1 ? cudaMemcpyDeviceToHost :
130 cudaMemcpyDeviceToDevice,
136 template <
typename T,
int Dim,
bool Contig,
137 typename IndexT,
template <
typename U>
class PtrTraits>
141 cudaStream_t stream) {
142 static_assert(Contig,
"only contiguous tensors handled");
147 FAISS_ASSERT(this->numElements() == t.
numElements());
150 FAISS_ASSERT(this->data_);
151 FAISS_ASSERT(t.
data());
153 int ourDev = getDeviceForAddress(this->data_);
154 int tDev = getDeviceForAddress(t.
data());
157 CUDA_VERIFY(cudaMemcpyAsync(t.
data(),
159 this->getSizeInBytes(),
160 ourDev == -1 ? cudaMemcpyHostToHost :
161 cudaMemcpyDeviceToHost,
164 CUDA_VERIFY(cudaMemcpyAsync(t.
data(),
166 this->getSizeInBytes(),
167 ourDev == -1 ? cudaMemcpyHostToDevice :
168 cudaMemcpyDeviceToDevice,
174 template <
typename T,
int Dim,
bool Contig,
175 typename IndexT,
template <
typename U>
class PtrTraits>
176 template <
int OtherDim>
177 __host__ __device__
bool
180 if (Dim != OtherDim) {
184 for (
int i = 0; i < Dim; ++i) {
185 if (size_[i] != rhs.
size_[i]) {
190 if (stride_[i] != rhs.
stride_[i]) {
199 template <
typename T,
int Dim,
bool Contig,
200 typename IndexT,
template <
typename U>
class PtrTraits>
201 template <
typename U>
204 static_assert(
sizeof(U) ==
sizeof(T),
"cast must be to same size object");
207 reinterpret_cast<U*
>(data_), size_, stride_);
210 template <
typename T,
int Dim,
bool Contig,
211 typename IndexT,
template <
typename U>
class PtrTraits>
212 template <
typename U>
215 static_assert(
sizeof(U) ==
sizeof(T),
"cast must be to same size object");
218 reinterpret_cast<U*
>(data_), size_, stride_);
221 template <
typename T,
int Dim,
bool Contig,
222 typename IndexT,
template <
typename U>
class PtrTraits>
223 template <
typename U>
226 static_assert(
sizeof(U) >=
sizeof(T),
"only handles greater sizes");
227 constexpr
int kMultiple =
sizeof(U) /
sizeof(T);
229 assert(canCastResize<U>());
232 IndexT newStride[Dim];
234 for (
int i = 0; i < Dim - 1; ++i) {
235 newSize[i] = size_[i];
236 newStride[i] = stride_[i] / kMultiple;
239 newStride[Dim - 1] = 1;
240 newSize[Dim - 1] = size_[Dim - 1] / kMultiple;
243 reinterpret_cast<U*
>(data_), newSize, newStride);
246 template <
typename T,
int Dim,
bool Contig,
247 typename IndexT,
template <
typename U>
class PtrTraits>
248 template <
typename U>
255 template <
typename T,
int Dim,
bool Contig,
256 typename IndexT,
template <
typename U>
class PtrTraits>
257 template <
typename U>
258 __host__ __device__
bool
260 static_assert(
sizeof(U) >=
sizeof(T),
"only handles greater sizes");
261 constexpr
int kMultiple =
sizeof(U) /
sizeof(T);
264 for (
int i = 0; i < Dim - 1; ++i) {
265 if (stride_[i] % kMultiple != 0) {
271 if (size_[Dim - 1] % kMultiple != 0) {
275 if (stride_[Dim - 1] != 1) {
282 template <
typename T,
int Dim,
bool Contig,
283 typename IndexT,
template <
typename U>
class PtrTraits>
284 __host__ __device__ IndexT
286 long size = getSize(0);
288 for (
int i = 1; i < Dim; ++i) {
295 template <
typename T,
int Dim,
bool Contig,
296 typename IndexT,
template <
typename U>
class PtrTraits>
297 __host__ __device__
bool
301 for (
int i = Dim - 1; i >= 0; --i) {
302 if (getSize(i) != (IndexT) 1) {
303 if (getStride(i) == prevSize) {
304 prevSize *= getSize(i);
314 template <
typename T,
int Dim,
bool Contig,
315 typename IndexT,
template <
typename U>
class PtrTraits>
316 __host__ __device__
bool
318 if (i == 0 && getStride(i) > 0 && getSize(i) > 0) {
320 }
else if ((i > 0) && (i < Dim) && (getStride(i) > 0) &&
321 ((getStride(i - 1) / getStride(i)) >= getSize(i))) {
328 template <
typename T,
int Dim,
bool Contig,
329 typename IndexT,
template <
typename U>
class PtrTraits>
330 __host__ __device__
bool
332 for (
int i = 0; i < Dim; ++i) {
333 if (!isConsistentlySized(i)) {
341 template <
typename T,
int Dim,
bool Contig,
342 typename IndexT,
template <
typename U>
class PtrTraits>
343 __host__ __device__
bool
345 return (i == Dim - 1) ||
347 ((getStride(i) / getStride(i + 1)) == getSize(i + 1)));
350 template <
typename T,
int Dim,
bool Contig,
351 typename IndexT,
template <
typename U>
class PtrTraits>
355 assert(dim1 >= 0 && dim1 < Dim);
356 assert(dim1 >= 0 && dim2 < Dim);
357 static_assert(!Contig,
"cannot transpose contiguous arrays");
360 IndexT newStride[Dim];
362 for (
int i = 0; i < Dim; ++i) {
363 newSize[i] = size_[i];
364 newStride[i] = stride_[i];
367 IndexT tmp = newSize[dim1];
368 newSize[dim1] = newSize[dim2];
371 tmp = newStride[dim1];
372 newStride[dim1] = newStride[dim2];
373 newStride[dim2] = tmp;
378 template <
typename T,
int Dim,
bool Contig,
379 typename IndexT,
template <
typename U>
class PtrTraits>
380 template <
int NewDim>
384 static_assert(NewDim > Dim,
"Can only upcast to greater dim");
386 IndexT newSize[NewDim];
387 IndexT newStride[NewDim];
389 int shift = NewDim - Dim;
391 for (
int i = 0; i < NewDim; ++i) {
394 newSize[i] = (IndexT) 1;
395 newStride[i] = size_[0] * stride_[0];
398 newSize[i] = size_[i - shift];
399 newStride[i] = stride_[i - shift];
404 data_, newSize, newStride);
407 template <
typename T,
int Dim,
bool Contig,
408 typename IndexT,
template <
typename U>
class PtrTraits>
409 template <
int NewDim>
413 static_assert(NewDim > Dim,
"Can only upcast to greater dim");
415 IndexT newSize[NewDim];
416 IndexT newStride[NewDim];
418 for (
int i = 0; i < NewDim; ++i) {
421 newSize[i] = size_[i];
422 newStride[i] = stride_[i];
425 newSize[i] = (IndexT) 1;
426 newStride[i] = (IndexT) 1;
431 data_, newSize, newStride);
434 template <
typename T,
int Dim,
bool Contig,
435 typename IndexT,
template <
typename U>
class PtrTraits>
436 template <
int NewDim>
440 static_assert(NewDim < Dim,
"Can only downcast to lesser dim");
446 for (
int i = 0; i < Dim - NewDim; ++i) {
447 bool cont = isContiguousDim(i);
451 IndexT newSize[NewDim];
452 IndexT newStride[NewDim];
454 int ignoredDims = Dim - NewDim;
455 IndexT collapsedSize = 1;
457 for (
int i = 0; i < Dim; ++i) {
458 if (i < ignoredDims) {
460 collapsedSize *= getSize(i);
463 if (i == ignoredDims) {
465 newSize[i - ignoredDims] = collapsedSize * getSize(i);
468 newSize[i - ignoredDims] = getSize(i);
471 newStride[i - ignoredDims] = getStride(i);
476 data_, newSize, newStride);
479 template <
typename T,
int Dim,
bool Contig,
480 typename IndexT,
template <
typename U>
class PtrTraits>
481 template <
int NewDim>
485 static_assert(NewDim < Dim,
"Can only downcast to lesser dim");
491 for (
int i = NewDim; i < Dim; ++i) {
492 assert(isContiguousDim(i));
495 IndexT newSize[NewDim];
496 IndexT newStride[NewDim];
498 IndexT collapsedSize = 1;
500 for (
int i = Dim - 1; i >= 0; --i) {
503 collapsedSize *= getSize(i);
506 if (i == NewDim - 1) {
508 newSize[i] = collapsedSize * getSize(i);
509 newStride[i] = getStride(Dim - 1);
512 newSize[i] = getSize(i);
513 newStride[i] = getStride(i);
519 data_, newSize, newStride);
522 template <
typename T,
int Dim,
bool Contig,
523 typename IndexT,
template <
typename U>
class PtrTraits>
524 template <
int SubDim>
527 static_assert(SubDim >= 1 && SubDim < Dim,
528 "can only create view of lesser dim");
530 IndexT viewSizes[SubDim];
531 IndexT viewStrides[SubDim];
533 for (
int i = 0; i < SubDim; ++i) {
534 viewSizes[i] = size_[Dim - SubDim + i];
535 viewStrides[i] = stride_[Dim - SubDim + i];
539 at, viewSizes, viewStrides);
542 template <
typename T,
int Dim,
bool Contig,
543 typename IndexT,
template <
typename U>
class PtrTraits>
544 template <
int SubDim>
547 return view<SubDim>(data_);
550 template <
typename T,
int Dim,
bool Contig,
551 typename IndexT,
template <
typename U>
class PtrTraits>
555 DataPtrType newData = data_;
558 newData += start * stride_[0];
562 for (
int i = 0; i < Dim; ++i) {
564 assert(start + size <= size_[0]);
567 newSize[i] = size_[i];
574 template <
typename T,
int Dim,
bool Contig,
575 typename IndexT,
template <
typename U>
class PtrTraits>
580 DataPtrType newData = data_;
583 newData += start * stride_[dim];
587 for (
int i = 0; i < Dim; ++i) {
589 assert(start + size <= size_[dim]);
592 newSize[i] = size_[i];
600 template <
typename T,
int Dim,
bool Contig,
601 typename IndexT,
template <
typename U>
class PtrTraits>
602 template <
int NewDim>
605 std::initializer_list<IndexT> sizes) {
606 static_assert(Contig,
"on contiguous tensors only");
608 assert(sizes.size() == NewDim);
612 size_t curSize = numElements();
616 for (
auto s : sizes) {
620 assert(curSize == newSize);
__host__ __device__ Tensor()
Default constructor.
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > downcastInner()
__host__ __device__ bool isContiguousDim(int i) const
Returns true if the given dimension index has no padding.
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > transpose(int dim1, int dim2) const
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > narrowOutermost(IndexT start, IndexT size)
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > & operator=(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t)=default
Assignment.
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > upcastOuter()
__host__ __device__ Tensor< T, Dim, false, IndexT, PtrTraits > narrow(int dim, IndexT start, IndexT size)
__host__ __device__ Tensor< U, Dim, Contig, IndexT, PtrTraits > cast()
__host__ void copyTo(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t, cudaStream_t stream)
Copies ourselves into a tensor; sizes must match.
__host__ __device__ bool isSame(const Tensor< T, OtherDim, Contig, IndexT, PtrTraits > &rhs) const
__host__ __device__ IndexT numElements() const
__host__ __device__ const IndexT * strides() const
Returns the stride array.
__host__ __device__ const IndexT * sizes() const
Returns the size array.
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > upcastInner()
__host__ __device__ Tensor< T, SubDim, Contig, IndexT, PtrTraits > view()
__host__ void copyFrom(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t, cudaStream_t stream)
Copies a tensor into ourselves; sizes must match.
IndexT stride_[Dim]
Array of strides (in sizeof(T) terms) per each dimension.
__host__ __device__ bool isContiguous() const
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > downcastOuter()
IndexT size_[Dim]
Size per each dimension.
__host__ __device__ bool canCastResize() const
Returns true if we can castResize() this tensor to the new type.
__host__ __device__ Tensor< U, Dim, Contig, IndexT, PtrTraits > castResize()