Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Tensor-inl.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "../../FaissAssert.h"
13 #include "DeviceUtils.h"
14 
15 namespace faiss { namespace gpu {
16 
17 template <typename T, int Dim, bool Contig,
18  typename IndexT, template <typename U> class PtrTraits>
19 __host__ __device__
21  : data_(nullptr) {
22  static_assert(Dim > 0, "must have > 0 dimensions");
23 
24  for (int i = 0; i < Dim; ++i) {
25  size_[i] = 0;
26  stride_[i] = (IndexT) 1;
27  }
28 }
29 
30 template <typename T, int Dim, bool Contig,
31  typename IndexT, template <typename U> class PtrTraits>
32 __host__ __device__
36  data_ = t.data_; t.data_ = nullptr;
37  for (int i = 0; i < Dim; ++i) {
38  stride_[i] = t.stride_[i]; t.stride_[i] = 0;
39  size_[i] = t.size_[i]; t.size_[i] = 0;
40  }
41 
42  return *this;
43 }
44 
45 template <typename T, int Dim, bool Contig,
46  typename IndexT, template <typename U> class PtrTraits>
47 __host__ __device__
49 Tensor(DataPtrType data, const IndexT sizes[Dim])
50  : data_(data) {
51  static_assert(Dim > 0, "must have > 0 dimensions");
52 
53  for (int i = 0; i < Dim; ++i) {
54  size_[i] = sizes[i];
55  }
56 
57  stride_[Dim - 1] = (IndexT) 1;
58  for (int i = Dim - 2; i >= 0; --i) {
59  stride_[i] = stride_[i + 1] * sizes[i + 1];
60  }
61 }
62 
63 template <typename T, int Dim, bool Contig,
64  typename IndexT, template <typename U> class PtrTraits>
65 __host__ __device__
67 Tensor(DataPtrType data, std::initializer_list<IndexT> sizes)
68  : data_(data) {
69  assert(sizes.size() == Dim);
70  static_assert(Dim > 0, "must have > 0 dimensions");
71 
72  int i = 0;
73  for (auto s : sizes) {
74  size_[i++] = s;
75  }
76 
77  stride_[Dim - 1] = (IndexT) 1;
78  for (int j = Dim - 2; j >= 0; --j) {
79  stride_[j] = stride_[j + 1] * size_[j + 1];
80  }
81 }
82 
83 
84 template <typename T, int Dim, bool Contig,
85  typename IndexT, template <typename U> class PtrTraits>
86 __host__ __device__
88  DataPtrType data, const IndexT sizes[Dim], const IndexT strides[Dim])
89  : data_(data) {
90  static_assert(Dim > 0, "must have > 0 dimensions");
91 
92  for (int i = 0; i < Dim; ++i) {
93  size_[i] = sizes[i];
94  stride_[i] = strides[i];
95  }
96 }
97 
98 template <typename T, int Dim, bool Contig,
99  typename IndexT, template <typename U> class PtrTraits>
100 __host__ void
103  cudaStream_t stream) {
104  static_assert(Contig, "only contiguous tensors handled");
105 
106  // Size must be the same (since dimensions are checked and
107  // continuity is assumed, we need only check total number of
108  // elements
109  FAISS_ASSERT(this->numElements() == t.numElements());
110 
111  if (t.numElements() > 0) {
112  FAISS_ASSERT(this->data_);
113  FAISS_ASSERT(t.data());
114 
115  int ourDev = getDeviceForAddress(this->data_);
116  int tDev = getDeviceForAddress(t.data());
117 
118  if (tDev == -1) {
119  CUDA_VERIFY(cudaMemcpyAsync(this->data_,
120  t.data(),
121  this->getSizeInBytes(),
122  ourDev == -1 ? cudaMemcpyHostToHost :
123  cudaMemcpyHostToDevice,
124  stream));
125  } else {
126  CUDA_VERIFY(cudaMemcpyAsync(this->data_,
127  t.data(),
128  this->getSizeInBytes(),
129  ourDev == -1 ? cudaMemcpyDeviceToHost :
130  cudaMemcpyDeviceToDevice,
131  stream));
132  }
133  }
134 }
135 
136 template <typename T, int Dim, bool Contig,
137  typename IndexT, template <typename U> class PtrTraits>
138 __host__ void
141  cudaStream_t stream) {
142  static_assert(Contig, "only contiguous tensors handled");
143 
144  // Size must be the same (since dimensions are checked and
145  // continuity is assumed, we need only check total number of
146  // elements
147  FAISS_ASSERT(this->numElements() == t.numElements());
148 
149  if (t.numElements() > 0) {
150  FAISS_ASSERT(this->data_);
151  FAISS_ASSERT(t.data());
152 
153  int ourDev = getDeviceForAddress(this->data_);
154  int tDev = getDeviceForAddress(t.data());
155 
156  if (tDev == -1) {
157  CUDA_VERIFY(cudaMemcpyAsync(t.data(),
158  this->data_,
159  this->getSizeInBytes(),
160  ourDev == -1 ? cudaMemcpyHostToHost :
161  cudaMemcpyDeviceToHost,
162  stream));
163  } else {
164  CUDA_VERIFY(cudaMemcpyAsync(t.data(),
165  this->data_,
166  this->getSizeInBytes(),
167  ourDev == -1 ? cudaMemcpyHostToDevice :
168  cudaMemcpyDeviceToDevice,
169  stream));
170  }
171  }
172 }
173 
174 template <typename T, int Dim, bool Contig,
175  typename IndexT, template <typename U> class PtrTraits>
176 template <int OtherDim>
177 __host__ __device__ bool
180  if (Dim != OtherDim) {
181  return false;
182  }
183 
184  for (int i = 0; i < Dim; ++i) {
185  if (size_[i] != rhs.size_[i]) {
186  return false;
187  }
188 
189  if (!Contig) {
190  if (stride_[i] != rhs.stride_[i]) {
191  return false;
192  }
193  }
194  }
195 
196  return true;
197 }
198 
199 template <typename T, int Dim, bool Contig,
200  typename IndexT, template <typename U> class PtrTraits>
201 template <typename U>
204  static_assert(sizeof(U) == sizeof(T), "cast must be to same size object");
205 
207  reinterpret_cast<U*>(data_), size_, stride_);
208 }
209 
210 template <typename T, int Dim, bool Contig,
211  typename IndexT, template <typename U> class PtrTraits>
212 template <typename U>
213 __host__ __device__ const Tensor<U, Dim, Contig, IndexT, PtrTraits>
215  static_assert(sizeof(U) == sizeof(T), "cast must be to same size object");
216 
218  reinterpret_cast<U*>(data_), size_, stride_);
219 }
220 
221 template <typename T, int Dim, bool Contig,
222  typename IndexT, template <typename U> class PtrTraits>
223 template <typename U>
226  static_assert(sizeof(U) >= sizeof(T), "only handles greater sizes");
227  constexpr int kMultiple = sizeof(U) / sizeof(T);
228 
229  assert(canCastResize<U>());
230 
231  IndexT newSize[Dim];
232  IndexT newStride[Dim];
233 
234  for (int i = 0; i < Dim - 1; ++i) {
235  newSize[i] = size_[i];
236  newStride[i] = stride_[i] / kMultiple;
237  }
238 
239  newStride[Dim - 1] = 1; // this is the same as the old stride
240  newSize[Dim - 1] = size_[Dim - 1] / kMultiple;
241 
243  reinterpret_cast<U*>(data_), newSize, newStride);
244 }
245 
246 template <typename T, int Dim, bool Contig,
247  typename IndexT, template <typename U> class PtrTraits>
248 template <typename U>
249 __host__ __device__ const Tensor<U, Dim, Contig, IndexT, PtrTraits>
251  return const_cast<Tensor<T, Dim, Contig, IndexT, PtrTraits>*>(this)->
252  castResize<U>();
253 }
254 
255 template <typename T, int Dim, bool Contig,
256  typename IndexT, template <typename U> class PtrTraits>
257 template <typename U>
258 __host__ __device__ bool
260  static_assert(sizeof(U) >= sizeof(T), "only handles greater sizes");
261  constexpr int kMultiple = sizeof(U) / sizeof(T);
262 
263  // Check all outer strides
264  for (int i = 0; i < Dim - 1; ++i) {
265  if (stride_[i] % kMultiple != 0) {
266  return false;
267  }
268  }
269 
270  // Check inner size
271  if (size_[Dim - 1] % kMultiple != 0) {
272  return false;
273  }
274 
275  if (stride_[Dim - 1] != 1) {
276  return false;
277  }
278 
279  return true;
280 }
281 
282 template <typename T, int Dim, bool Contig,
283  typename IndexT, template <typename U> class PtrTraits>
284 __host__ __device__ IndexT
286  long size = getSize(0);
287 
288  for (int i = 1; i < Dim; ++i) {
289  size *= getSize(i);
290  }
291 
292  return size;
293 }
294 
295 template <typename T, int Dim, bool Contig,
296  typename IndexT, template <typename U> class PtrTraits>
297 __host__ __device__ bool
299  long prevSize = 1;
300 
301  for (int i = Dim - 1; i >= 0; --i) {
302  if (getSize(i) != (IndexT) 1) {
303  if (getStride(i) == prevSize) {
304  prevSize *= getSize(i);
305  } else {
306  return false;
307  }
308  }
309  }
310 
311  return true;
312 }
313 
314 template <typename T, int Dim, bool Contig,
315  typename IndexT, template <typename U> class PtrTraits>
316 __host__ __device__ bool
318  if (i == 0 && getStride(i) > 0 && getSize(i) > 0) {
319  return true;
320  } else if ((i > 0) && (i < Dim) && (getStride(i) > 0) &&
321  ((getStride(i - 1) / getStride(i)) >= getSize(i))) {
322  return true;
323  }
324 
325  return false;
326 }
327 
328 template <typename T, int Dim, bool Contig,
329  typename IndexT, template <typename U> class PtrTraits>
330 __host__ __device__ bool
332  for (int i = 0; i < Dim; ++i) {
333  if (!isConsistentlySized(i)) {
334  return false;
335  }
336  }
337 
338  return true;
339 }
340 
341 template <typename T, int Dim, bool Contig,
342  typename IndexT, template <typename U> class PtrTraits>
343 __host__ __device__ bool
345  return (i == Dim - 1) || // just in case
346  ((i < Dim - 1) &&
347  ((getStride(i) / getStride(i + 1)) == getSize(i + 1)));
348 }
349 
350 template <typename T, int Dim, bool Contig,
351  typename IndexT, template <typename U> class PtrTraits>
354  int dim2) const {
355  assert(dim1 >= 0 && dim1 < Dim);
356  assert(dim1 >= 0 && dim2 < Dim);
357  static_assert(!Contig, "cannot transpose contiguous arrays");
358 
359  IndexT newSize[Dim];
360  IndexT newStride[Dim];
361 
362  for (int i = 0; i < Dim; ++i) {
363  newSize[i] = size_[i];
364  newStride[i] = stride_[i];
365  }
366 
367  IndexT tmp = newSize[dim1];
368  newSize[dim1] = newSize[dim2];
369  newSize[dim2] = tmp;
370 
371  tmp = newStride[dim1];
372  newStride[dim1] = newStride[dim2];
373  newStride[dim2] = tmp;
374 
375  return Tensor<T, Dim, Contig, IndexT, PtrTraits>(data_, newSize, newStride);
376 }
377 
378 template <typename T, int Dim, bool Contig,
379  typename IndexT, template <typename U> class PtrTraits>
380 template <int NewDim>
383  // Can only create tensors of greater dimension
384  static_assert(NewDim > Dim, "Can only upcast to greater dim");
385 
386  IndexT newSize[NewDim];
387  IndexT newStride[NewDim];
388 
389  int shift = NewDim - Dim;
390 
391  for (int i = 0; i < NewDim; ++i) {
392  if (i < shift) {
393  // These are the extended dimensions
394  newSize[i] = (IndexT) 1;
395  newStride[i] = size_[0] * stride_[0];
396  } else {
397  // Shift the remaining dimensions
398  newSize[i] = size_[i - shift];
399  newStride[i] = stride_[i - shift];
400  }
401  }
402 
404  data_, newSize, newStride);
405 }
406 
407 template <typename T, int Dim, bool Contig,
408  typename IndexT, template <typename U> class PtrTraits>
409 template <int NewDim>
412  // Can only create tensors of greater dimension
413  static_assert(NewDim > Dim, "Can only upcast to greater dim");
414 
415  IndexT newSize[NewDim];
416  IndexT newStride[NewDim];
417 
418  for (int i = 0; i < NewDim; ++i) {
419  if (i < Dim) {
420  // Existing dimensions get copied over
421  newSize[i] = size_[i];
422  newStride[i] = stride_[i];
423  } else {
424  // Extended dimensions
425  newSize[i] = (IndexT) 1;
426  newStride[i] = (IndexT) 1;
427  }
428  }
429 
431  data_, newSize, newStride);
432 }
433 
434 template <typename T, int Dim, bool Contig,
435  typename IndexT, template <typename U> class PtrTraits>
436 template <int NewDim>
439  // Can only create tensors of lesser dimension
440  static_assert(NewDim < Dim, "Can only downcast to lesser dim");
441 
442  // We can't downcast non-contiguous tensors, since it leaves
443  // garbage data in the tensor. The tensor needs to be contiguous
444  // in all of the dimensions we are collapsing (no padding in
445  // them).
446  for (int i = 0; i < Dim - NewDim; ++i) {
447  bool cont = isContiguousDim(i);
448  assert(cont);
449  }
450 
451  IndexT newSize[NewDim];
452  IndexT newStride[NewDim];
453 
454  int ignoredDims = Dim - NewDim;
455  IndexT collapsedSize = 1;
456 
457  for (int i = 0; i < Dim; ++i) {
458  if (i < ignoredDims) {
459  // Collapse these dimensions
460  collapsedSize *= getSize(i);
461  } else {
462  // Non-collapsed dimensions
463  if (i == ignoredDims) {
464  // This is the first non-collapsed dimension
465  newSize[i - ignoredDims] = collapsedSize * getSize(i);
466  } else {
467  // Subsequent non-collapsed dimensions
468  newSize[i - ignoredDims] = getSize(i);
469  }
470 
471  newStride[i - ignoredDims] = getStride(i);
472  }
473  }
474 
476  data_, newSize, newStride);
477 }
478 
479 template <typename T, int Dim, bool Contig,
480  typename IndexT, template <typename U> class PtrTraits>
481 template <int NewDim>
484  // Can only create tensors of lesser dimension
485  static_assert(NewDim < Dim, "Can only downcast to lesser dim");
486 
487  // We can't downcast non-contiguous tensors, since it leaves
488  // garbage data in the tensor. The tensor needs to be contiguous
489  // in all of the dimensions we are collapsing (no padding in
490  // them).
491  for (int i = NewDim; i < Dim; ++i) {
492  assert(isContiguousDim(i));
493  }
494 
495  IndexT newSize[NewDim];
496  IndexT newStride[NewDim];
497 
498  IndexT collapsedSize = 1;
499 
500  for (int i = Dim - 1; i >= 0; --i) {
501  if (i >= NewDim) {
502  // Collapse these dimensions
503  collapsedSize *= getSize(i);
504  } else {
505  // Non-collapsed dimensions
506  if (i == NewDim - 1) {
507  // This is the first non-collapsed dimension
508  newSize[i] = collapsedSize * getSize(i);
509  newStride[i] = getStride(Dim - 1);
510  } else {
511  // Subsequent non-collapsed dimensions
512  newSize[i] = getSize(i);
513  newStride[i] = getStride(i);
514  }
515  }
516  }
517 
519  data_, newSize, newStride);
520 }
521 
522 template <typename T, int Dim, bool Contig,
523  typename IndexT, template <typename U> class PtrTraits>
524 template <int SubDim>
527  static_assert(SubDim >= 1 && SubDim < Dim,
528  "can only create view of lesser dim");
529 
530  IndexT viewSizes[SubDim];
531  IndexT viewStrides[SubDim];
532 
533  for (int i = 0; i < SubDim; ++i) {
534  viewSizes[i] = size_[Dim - SubDim + i];
535  viewStrides[i] = stride_[Dim - SubDim + i];
536  }
537 
539  at, viewSizes, viewStrides);
540 }
541 
542 template <typename T, int Dim, bool Contig,
543  typename IndexT, template <typename U> class PtrTraits>
544 template <int SubDim>
547  return view<SubDim>(data_);
548 }
549 
550 template <typename T, int Dim, bool Contig,
551  typename IndexT, template <typename U> class PtrTraits>
554  IndexT size) {
555  DataPtrType newData = data_;
556 
557  if (start > 0) {
558  newData += start * stride_[0];
559  }
560 
561  IndexT newSize[Dim];
562  for (int i = 0; i < Dim; ++i) {
563  if (i == 0) {
564  assert(start + size <= size_[0]);
565  newSize[i] = size;
566  } else {
567  newSize[i] = size_[i];
568  }
569  }
570 
571  return Tensor<T, Dim, Contig, IndexT, PtrTraits>(newData, newSize, stride_);
572 }
573 
574 template <typename T, int Dim, bool Contig,
575  typename IndexT, template <typename U> class PtrTraits>
578  IndexT start,
579  IndexT size) {
580  DataPtrType newData = data_;
581 
582  if (start > 0) {
583  newData += start * stride_[dim];
584  }
585 
586  IndexT newSize[Dim];
587  for (int i = 0; i < Dim; ++i) {
588  if (i == dim) {
589  assert(start + size <= size_[dim]);
590  newSize[i] = size;
591  } else {
592  newSize[i] = size_[i];
593  }
594  }
595 
596  // The narrowed tensor is not necessarily contiguous
597  return Tensor<T, Dim, false, IndexT, PtrTraits>(newData, newSize, stride_);
598 }
599 
600 template <typename T, int Dim, bool Contig,
601  typename IndexT, template <typename U> class PtrTraits>
602 template <int NewDim>
605  std::initializer_list<IndexT> sizes) {
606  static_assert(Contig, "on contiguous tensors only");
607 
608  assert(sizes.size() == NewDim);
609 
610  // The total size of the new view must be the same as the total size
611  // of the old view
612  size_t curSize = numElements();
613 
614  size_t newSize = 1;
615 
616  for (auto s : sizes) {
617  newSize *= s;
618  }
619 
620  assert(curSize == newSize);
621  return Tensor<T, NewDim, true, IndexT, PtrTraits>(data(), sizes);
622 }
623 
624 } } // namespace
__host__ __device__ Tensor()
Default constructor.
Definition: Tensor-inl.cuh:20
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > downcastInner()
Definition: Tensor-inl.cuh:483
__host__ __device__ bool isContiguousDim(int i) const
Returns true if the given dimension index has no padding.
Definition: Tensor-inl.cuh:344
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > transpose(int dim1, int dim2) const
Definition: Tensor-inl.cuh:353
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:162
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > narrowOutermost(IndexT start, IndexT size)
Definition: Tensor-inl.cuh:553
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > & operator=(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t)=default
Assignment.
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > upcastOuter()
Definition: Tensor-inl.cuh:382
__host__ __device__ Tensor< T, Dim, false, IndexT, PtrTraits > narrow(int dim, IndexT start, IndexT size)
Definition: Tensor-inl.cuh:577
__host__ __device__ Tensor< U, Dim, Contig, IndexT, PtrTraits > cast()
Definition: Tensor-inl.cuh:203
__host__ void copyTo(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t, cudaStream_t stream)
Copies ourselves into a tensor; sizes must match.
Definition: Tensor-inl.cuh:139
__host__ __device__ bool isSame(const Tensor< T, OtherDim, Contig, IndexT, PtrTraits > &rhs) const
Definition: Tensor-inl.cuh:178
__host__ __device__ IndexT numElements() const
Definition: Tensor-inl.cuh:285
__host__ __device__ const IndexT * strides() const
Returns the stride array.
Definition: Tensor.cuh:236
Our tensor type.
Definition: Tensor.cuh:31
__host__ __device__ const IndexT * sizes() const
Returns the size array.
Definition: Tensor.cuh:231
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > upcastInner()
Definition: Tensor-inl.cuh:411
__host__ __device__ Tensor< T, SubDim, Contig, IndexT, PtrTraits > view()
Definition: Tensor-inl.cuh:546
__host__ void copyFrom(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t, cudaStream_t stream)
Copies a tensor into ourselves; sizes must match.
Definition: Tensor-inl.cuh:101
IndexT stride_[Dim]
Array of strides (in sizeof(T) terms) per each dimension.
Definition: Tensor.cuh:334
__host__ __device__ bool isContiguous() const
Definition: Tensor-inl.cuh:298
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > downcastOuter()
Definition: Tensor-inl.cuh:438
IndexT size_[Dim]
Size per each dimension.
Definition: Tensor.cuh:337
__host__ __device__ bool canCastResize() const
Returns true if we can castResize() this tensor to the new type.
Definition: Tensor-inl.cuh:259
__host__ __device__ Tensor< U, Dim, Contig, IndexT, PtrTraits > castResize()
Definition: Tensor-inl.cuh:225