Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Tensor.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #pragma once
12 
13 #include <assert.h>
14 #include <cuda.h>
15 #include <cuda_runtime.h>
16 #include <initializer_list>
17 
18 /// Multi-dimensional array class for CUDA device and host usage.
19 /// Originally from Facebook's fbcunn, since added to the Torch GPU
20 /// library cutorch as well.
21 
22 namespace faiss { namespace gpu {
23 
24 /// Our tensor type
25 template <typename T,
26  int Dim,
27  bool Contig,
28  typename IndexT,
29  template <typename U> class PtrTraits>
30 class Tensor;
31 
32 /// Type of a subspace of a tensor
33 namespace detail {
34 template <typename TensorType,
35  int SubDim,
36  template <typename U> class PtrTraits>
37 class SubTensor;
38 }
39 
40 namespace traits {
41 
42 template <typename T>
44  typedef T* __restrict__ PtrType;
45 };
46 
47 template <typename T>
49  typedef T* PtrType;
50 };
51 
52 }
53 
54 /**
55  Templated multi-dimensional array that supports strided access of
56  elements. Main access is through `operator[]`; e.g.,
57  `tensor[x][y][z]`.
58 
59  - `T` is the contained type (e.g., `float`)
60  - `Dim` is the tensor rank
61  - If `Contig` is true, then the tensor is assumed to be
62  - contiguous, and only operations that make sense on contiguous
63  - arrays are allowed (e.g., no transpose). Strides are still
64  - calculated, but innermost stride is assumed to be 1.
65  - `IndexT` is the integer type used for size/stride arrays, and for
66  - all indexing math. Default is `int`, but for large tensors, `long`
67  - can be used instead.
68  - `PtrTraits` are traits applied to our data pointer (T*). By default,
69  - this is just T*, but RestrictPtrTraits can be used to apply T*
70  - __restrict__ for alias-free analysis.
71 */
72 template <typename T,
73  int Dim,
74  bool Contig = false,
75  typename IndexT = int,
76  template <typename U> class PtrTraits = traits::DefaultPtrTraits>
77 class Tensor {
78  public:
79  enum { NumDim = Dim };
80  typedef T DataType;
81  typedef IndexT IndexType;
82  enum { IsContig = Contig };
83  typedef typename PtrTraits<T>::PtrType DataPtrType;
84  typedef Tensor<T, Dim, Contig, IndexT, PtrTraits> TensorType;
85 
86  /// Default constructor
87  __host__ __device__ Tensor();
88 
89  /// Copy constructor
90  __host__ __device__ Tensor(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t)
91  = default;
92 
93  /// Move constructor
94  __host__ __device__ Tensor(Tensor<T, Dim, Contig, IndexT, PtrTraits>&& t)
95  = default;
96 
97  /// Assignment
98  __host__ __device__ Tensor<T, Dim, Contig, IndexT, PtrTraits>&
99  operator=(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t) = default;
100 
101  /// Move assignment
102  __host__ __device__ Tensor<T, Dim, Contig, IndexT, PtrTraits>&
103  operator=(Tensor<T, Dim, Contig, IndexT, PtrTraits>&& t);
104 
105  /// Constructor that calculates strides with no padding
106  __host__ __device__ Tensor(DataPtrType data,
107  const IndexT sizes[Dim]);
108  __host__ __device__ Tensor(DataPtrType data,
109  std::initializer_list<IndexT> sizes);
110 
111  /// Constructor that takes arbitrary size/stride arrays.
112  /// Errors if you attempt to pass non-contiguous strides to a
113  /// contiguous tensor.
114  __host__ __device__ Tensor(DataPtrType data,
115  const IndexT sizes[Dim],
116  const IndexT strides[Dim]);
117 
118  /// Copies a tensor into ourselves; sizes must match
119  __host__ void copyFrom(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t,
120  cudaStream_t stream);
121 
122  /// Copies ourselves into a tensor; sizes must match
123  __host__ void copyTo(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t,
124  cudaStream_t stream);
125 
126  /// Returns true if the two tensors are of the same dimensionality,
127  /// size and stride.
128  template <int OtherDim>
129  __host__ __device__ bool
130  isSame(const Tensor<T, OtherDim, Contig, IndexT, PtrTraits>& rhs) const;
131 
132  /// Cast to a tensor of a different type of the same size and
133  /// stride. U and our type T must be of the same size
134  template <typename U>
135  __host__ __device__ Tensor<U, Dim, Contig, IndexT, PtrTraits> cast();
136 
137  /// Const version of `cast`
138  template <typename U>
139  __host__ __device__
140  const Tensor<U, Dim, Contig, IndexT, PtrTraits> cast() const;
141 
142  /// Cast to a tensor of a different type which is potentially a
143  /// different size than our type T. Tensor must be aligned and the
144  /// innermost dimension must be a size that is a multiple of
145  /// sizeof(U) / sizeof(T), and the stride of the innermost dimension
146  /// must be contiguous. The stride of all outer dimensions must be a
147  /// multiple of sizeof(U) / sizeof(T) as well.
148  template <typename U>
149  __host__ __device__ Tensor<U, Dim, Contig, IndexT, PtrTraits> castResize();
150 
151  /// Const version of `castResize`
152  template <typename U>
153  __host__ __device__ const Tensor<U, Dim, Contig, IndexT, PtrTraits>
154  castResize() const;
155 
156  /// Returns true if we can castResize() this tensor to the new type
157  template <typename U>
158  __host__ __device__ bool canCastResize() const;
159 
160  /// Attempts to cast this tensor to a tensor of a different IndexT.
161  /// Fails if size or stride entries are not representable in the new
162  /// IndexT.
163  template <typename NewIndexT>
164  __host__ Tensor<T, Dim, Contig, NewIndexT, PtrTraits>
165  castIndexType() const;
166 
167  /// Returns true if we can castIndexType() this tensor to the new
168  /// index type
169  template <typename NewIndexT>
170  __host__ bool canCastIndexType() const;
171 
172  /// Returns a raw pointer to the start of our data.
173  __host__ __device__ inline DataPtrType data() {
174  return data_;
175  }
176 
177  /// Returns a raw pointer to the end of our data, assuming
178  /// continuity
179  __host__ __device__ inline DataPtrType end() {
180  return data() + numElements();
181  }
182 
183  /// Returns a raw pointer to the start of our data (const).
184  __host__ __device__ inline
185  const DataPtrType data() const {
186  return data_;
187  }
188 
189  /// Returns a raw pointer to the end of our data, assuming
190  /// continuity (const)
191  __host__ __device__ inline DataPtrType end() const {
192  return data() + numElements();
193  }
194 
195  /// Cast to a different datatype
196  template <typename U>
197  __host__ __device__ inline
198  typename PtrTraits<U>::PtrType dataAs() {
199  return reinterpret_cast<typename PtrTraits<U>::PtrType>(data_);
200  }
201 
202  /// Cast to a different datatype
203  template <typename U>
204  __host__ __device__ inline
205  const typename PtrTraits<const U>::PtrType dataAs() const {
206  return reinterpret_cast<typename PtrTraits<const U>::PtrType>(data_);
207  }
208 
209  /// Returns a read/write view of a portion of our tensor.
210  __host__ __device__ inline
211  detail::SubTensor<TensorType, Dim - 1, PtrTraits>
212  operator[](IndexT);
213 
214  /// Returns a read/write view of a portion of our tensor (const).
215  __host__ __device__ inline
216  const detail::SubTensor<TensorType, Dim - 1, PtrTraits>
217  operator[](IndexT) const;
218 
219  /// Returns the size of a given dimension, `[0, Dim - 1]`. No bounds
220  /// checking.
221  __host__ __device__ inline IndexT getSize(int i) const {
222  return size_[i];
223  }
224 
225  /// Returns the stride of a given dimension, `[0, Dim - 1]`. No bounds
226  /// checking.
227  __host__ __device__ inline IndexT getStride(int i) const {
228  return stride_[i];
229  }
230 
231  /// Returns the total number of elements contained within our data
232  /// (product of `getSize(i)`)
233  __host__ __device__ IndexT numElements() const;
234 
235  /// If we are contiguous, returns the total size in bytes of our
236  /// data
237  __host__ __device__ size_t getSizeInBytes() const {
238  return (size_t) numElements() * sizeof(T);
239  }
240 
241  /// Returns the size array.
242  __host__ __device__ inline const IndexT* sizes() const {
243  return size_;
244  }
245 
246  /// Returns the stride array.
247  __host__ __device__ inline const IndexT* strides() const {
248  return stride_;
249  }
250 
251  /// Returns true if there is no padding within the tensor and no
252  /// re-ordering of the dimensions.
253  /// ~~~
254  /// (stride(i) == size(i + 1) * stride(i + 1)) && stride(dim - 1) == 0
255  /// ~~~
256  __host__ __device__ bool isContiguous() const;
257 
258  /// Returns whether a given dimension has only increasing stride
259  /// from the previous dimension. A tensor that was permuted by
260  /// exchanging size and stride only will fail this check.
261  /// If `i == 0` just check `size > 0`. Returns `false` if `stride` is `<= 0`.
262  __host__ __device__ bool isConsistentlySized(int i) const;
263 
264  // Returns whether at each dimension `stride <= size`.
265  // If this is not the case then iterating once over the size space will
266  // touch the same memory locations multiple times.
267  __host__ __device__ bool isConsistentlySized() const;
268 
269  /// Returns true if the given dimension index has no padding
270  __host__ __device__ bool isContiguousDim(int i) const;
271 
272  /// Returns a tensor of the same dimension after transposing the two
273  /// dimensions given. Does not actually move elements; transposition
274  /// is made by permuting the size/stride arrays.
275  /// If the dimensions are not valid, asserts.
277  transpose(int dim1, int dim2) const;
278 
279  /// Upcast a tensor of dimension `D` to some tensor of dimension
280  /// D' > D by padding the leading dimensions by 1
281  /// e.g., upcasting a 2-d tensor `[2][3]` to a 4-d tensor `[1][1][2][3]`
282  template <int NewDim>
284  upcastOuter();
285 
286  /// Upcast a tensor of dimension `D` to some tensor of dimension
287  /// D' > D by padding the lowest/most varying dimensions by 1
288  /// e.g., upcasting a 2-d tensor `[2][3]` to a 4-d tensor `[2][3][1][1]`
289  template <int NewDim>
291  upcastInner();
292 
293  /// Downcast a tensor of dimension `D` to some tensor of dimension
294  /// D' < D by collapsing the leading dimensions. asserts if there is
295  /// padding on the leading dimensions.
296  template <int NewDim>
297  __host__ __device__
299 
300  /// Downcast a tensor of dimension `D` to some tensor of dimension
301  /// D' < D by collapsing the leading dimensions. asserts if there is
302  /// padding on the leading dimensions.
303  template <int NewDim>
304  __host__ __device__
306 
307  /// Returns a tensor that is a view of the `SubDim`-dimensional slice
308  /// of this tensor, starting at `at`.
309  template <int SubDim>
311  view(DataPtrType at);
312 
313  /// Returns a tensor that is a view of the `SubDim`-dimensional slice
314  /// of this tensor, starting where our data begins
315  template <int SubDim>
317  view();
318 
319  /// Returns a tensor of the same dimension that is a view of the
320  /// original tensor with the specified dimension restricted to the
321  /// elements in the range [start, start + size)
323  narrowOutermost(IndexT start, IndexT size);
324 
325  /// Returns a tensor of the same dimension that is a view of the
326  /// original tensor with the specified dimension restricted to the
327  /// elements in the range [start, start + size).
328  /// Can occur in an arbitrary dimension, and is possibly
329  /// non-contiguous
330  __host__ __device__ Tensor<T, Dim, false, IndexT, PtrTraits>
331  narrow(int dim, IndexT start, IndexT size);
332 
333  /// Returns a view of the given tensor expressed as a tensor of a
334  /// different number of dimensions.
335  /// Only works if we are contiguous.
336  template <int NewDim>
338  view(std::initializer_list<IndexT> sizes);
339 
340  protected:
341  /// Raw pointer to where the tensor data begins
342  DataPtrType data_;
343 
344  /// Array of strides (in sizeof(T) terms) per each dimension
345  IndexT stride_[Dim];
346 
347  /// Size per each dimension
348  IndexT size_[Dim];
349 };
350 
351 // Utilities for checking a collection of tensors
352 namespace detail {
353 
354 template <typename IndexType>
355 bool canCastIndexType() {
356  return true;
357 }
358 
359 template <typename IndexType, typename T, typename... U>
360 bool canCastIndexType(const T& arg, const U&... args) {
361  return arg.canCastIndexType<IndexType>() &&
362  canCastIndexType(args...);
363 }
364 
365 } // namespace detail
366 
367 template <typename IndexType, typename... T>
368 bool canCastIndexType(const T&... args) {
369  return detail::canCastIndexType(args...);
370 }
371 
372 namespace detail {
373 
374 /// Specialization for a view of a single value (0-dimensional)
375 template <typename TensorType, template <typename U> class PtrTraits>
376 class SubTensor<TensorType, 0, PtrTraits> {
377  public:
378  __host__ __device__ SubTensor<TensorType, 0, PtrTraits>
379  operator=(typename TensorType::DataType val) {
380  *data_ = val;
381  return *this;
382  }
383 
384  // operator T&
385  __host__ __device__ operator typename TensorType::DataType&() {
386  return *data_;
387  }
388 
389  // const operator T& returning const T&
390  __host__ __device__ operator const typename TensorType::DataType&() const {
391  return *data_;
392  }
393 
394  // operator& returning T*
395  __host__ __device__ typename TensorType::DataType* operator&() {
396  return data_;
397  }
398 
399  // const operator& returning const T*
400  __host__ __device__ const typename TensorType::DataType* operator&() const {
401  return data_;
402  }
403 
404  /// Returns a raw accessor to our slice.
405  __host__ __device__ inline typename TensorType::DataPtrType data() {
406  return data_;
407  }
408 
409  /// Returns a raw accessor to our slice (const).
410  __host__ __device__ inline
411  const typename TensorType::DataPtrType data() const {
412  return data_;
413  }
414 
415  /// Cast to a different datatype.
416  template <typename T>
417  __host__ __device__ T& as() {
418  return *dataAs<T>();
419  }
420 
421  /// Cast to a different datatype (const).
422  template <typename T>
423  __host__ __device__ const T& as() const {
424  return *dataAs<T>();
425  }
426 
427  /// Cast to a different datatype
428  template <typename T>
429  __host__ __device__ inline
430  typename PtrTraits<T>::PtrType dataAs() {
431  return reinterpret_cast<typename PtrTraits<T>::PtrType>(data_);
432  }
433 
434  /// Cast to a different datatype (const)
435  template <typename T>
436  __host__ __device__ inline
437  typename PtrTraits<const T>::PtrType dataAs() const {
438  return reinterpret_cast<typename PtrTraits<const T>::PtrType>(data_);
439  }
440 
441  /// Use the texture cache for reads
442  __device__ inline typename TensorType::DataType ldg() const {
443 #if __CUDA_ARCH__ >= 350
444  return __ldg(data_);
445 #else
446  return *data_;
447 #endif
448  }
449 
450  /// Use the texture cache for reads; cast as a particular type
451  template <typename T>
452  __device__ inline T ldgAs() const {
453 #if __CUDA_ARCH__ >= 350
454  return __ldg(dataAs<T>());
455 #else
456  return as<T>();
457 #endif
458  }
459 
460  protected:
461  /// One dimension greater can create us
462  friend class SubTensor<TensorType, 1, PtrTraits>;
463 
464  /// Our parent tensor can create us
465  friend class Tensor<typename TensorType::DataType,
466  1,
467  TensorType::IsContig,
468  typename TensorType::IndexType,
469  PtrTraits>;
470 
471  __host__ __device__ inline SubTensor(
472  TensorType& t,
473  typename TensorType::DataPtrType data)
474  : tensor_(t),
475  data_(data) {
476  }
477 
478  /// The tensor we're referencing
479  TensorType& tensor_;
480 
481  /// Where our value is located
482  typename TensorType::DataPtrType const data_;
483 };
484 
485 /// A `SubDim`-rank slice of a parent Tensor
486 template <typename TensorType,
487  int SubDim,
488  template <typename U> class PtrTraits>
489 class SubTensor {
490  public:
491  /// Returns a view of the data located at our offset (the dimension
492  /// `SubDim` - 1 tensor).
493  __host__ __device__ inline
494  SubTensor<TensorType, SubDim - 1, PtrTraits>
495  operator[](typename TensorType::IndexType index) {
496  if (TensorType::IsContig && SubDim == 1) {
497  // Innermost dimension is stride 1 for contiguous arrays
498  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
499  tensor_, data_ + index);
500  } else {
501  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
502  tensor_,
503  data_ + index * tensor_.getStride(TensorType::NumDim - SubDim));
504  }
505  }
506 
507  /// Returns a view of the data located at our offset (the dimension
508  /// `SubDim` - 1 tensor) (const).
509  __host__ __device__ inline
510  const SubTensor<TensorType, SubDim - 1, PtrTraits>
511  operator[](typename TensorType::IndexType index) const {
512  if (TensorType::IsContig && SubDim == 1) {
513  // Innermost dimension is stride 1 for contiguous arrays
514  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
515  tensor_, data_ + index);
516  } else {
517  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
518  tensor_,
519  data_ + index * tensor_.getStride(TensorType::NumDim - SubDim));
520  }
521  }
522 
523  // operator& returning T*
524  __host__ __device__ typename TensorType::DataType* operator&() {
525  return data_;
526  }
527 
528  // const operator& returning const T*
529  __host__ __device__ const typename TensorType::DataType* operator&() const {
530  return data_;
531  }
532 
533  /// Returns a raw accessor to our slice.
534  __host__ __device__ inline typename TensorType::DataPtrType data() {
535  return data_;
536  }
537 
538  /// Returns a raw accessor to our slice (const).
539  __host__ __device__ inline
540  const typename TensorType::DataPtrType data() const {
541  return data_;
542  }
543 
544  /// Cast to a different datatype.
545  template <typename T>
546  __host__ __device__ T& as() {
547  return *dataAs<T>();
548  }
549 
550  /// Cast to a different datatype (const).
551  template <typename T>
552  __host__ __device__ const T& as() const {
553  return *dataAs<T>();
554  }
555 
556  /// Cast to a different datatype
557  template <typename T>
558  __host__ __device__ inline
559  typename PtrTraits<T>::PtrType dataAs() {
560  return reinterpret_cast<typename PtrTraits<T>::PtrType>(data_);
561  }
562 
563  /// Cast to a different datatype (const)
564  template <typename T>
565  __host__ __device__ inline
566  typename PtrTraits<const T>::PtrType dataAs() const {
567  return reinterpret_cast<typename PtrTraits<const T>::PtrType>(data_);
568  }
569 
570  /// Use the texture cache for reads
571  __device__ inline typename TensorType::DataType ldg() const {
572 #if __CUDA_ARCH__ >= 350
573  return __ldg(data_);
574 #else
575  return *data_;
576 #endif
577  }
578 
579  /// Use the texture cache for reads; cast as a particular type
580  template <typename T>
581  __device__ inline T ldgAs() const {
582 #if __CUDA_ARCH__ >= 350
583  return __ldg(dataAs<T>());
584 #else
585  return as<T>();
586 #endif
587  }
588 
589  /// Returns a tensor that is a view of the SubDim-dimensional slice
590  /// of this tensor, starting where our data begins
591  Tensor<typename TensorType::DataType,
592  SubDim,
593  TensorType::IsContig,
594  typename TensorType::IndexType,
595  PtrTraits> view() {
596  return tensor_.template view<SubDim>(data_);
597  }
598 
599  protected:
600  /// One dimension greater can create us
601  friend class SubTensor<TensorType, SubDim + 1, PtrTraits>;
602 
603  /// Our parent tensor can create us
604  friend class
605  Tensor<typename TensorType::DataType,
606  TensorType::NumDim,
607  TensorType::IsContig,
608  typename TensorType::IndexType,
609  PtrTraits>;
610 
611  __host__ __device__ inline SubTensor(
612  TensorType& t,
613  typename TensorType::DataPtrType data)
614  : tensor_(t),
615  data_(data) {
616  }
617 
618  /// The tensor we're referencing
619  TensorType& tensor_;
620 
621  /// The start of our sub-region
622  typename TensorType::DataPtrType const data_;
623 };
624 
625 } // namespace detail
626 
627 template <typename T, int Dim, bool Contig,
628  typename IndexT, template <typename U> class PtrTraits>
629 __host__ __device__ inline
631  Dim - 1, PtrTraits>
633  return detail::SubTensor<TensorType, Dim - 1, PtrTraits>(
635  *this, data_)[index]);
636 }
637 
638 template <typename T, int Dim, bool Contig,
639  typename IndexT, template <typename U> class PtrTraits>
640 __host__ __device__ inline
642  Dim - 1, PtrTraits>
644  return detail::SubTensor<TensorType, Dim - 1, PtrTraits>(
646  const_cast<TensorType&>(*this), data_)[index]);
647 }
648 
649 } } // namespace
650 
651 #include "Tensor-inl.cuh"
__host__ __device__ Tensor()
Default constructor.
Definition: Tensor-inl.cuh:20
__host__ __device__ const PtrTraits< const U >::PtrType dataAs() const
Cast to a different datatype.
Definition: Tensor.cuh:205
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > downcastInner()
Definition: Tensor-inl.cuh:535
__host__ __device__ PtrTraits< const T >::PtrType dataAs() const
Cast to a different datatype (const)
Definition: Tensor.cuh:566
__host__ __device__ bool isContiguousDim(int i) const
Returns true if the given dimension index has no padding.
Definition: Tensor-inl.cuh:396
__device__ T ldgAs() const
Use the texture cache for reads; cast as a particular type.
Definition: Tensor.cuh:581
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > transpose(int dim1, int dim2) const
Definition: Tensor-inl.cuh:405
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:173
__host__ __device__ PtrTraits< T >::PtrType dataAs()
Cast to a different datatype.
Definition: Tensor.cuh:559
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > narrowOutermost(IndexT start, IndexT size)
Definition: Tensor-inl.cuh:605
__host__ __device__ PtrTraits< const T >::PtrType dataAs() const
Cast to a different datatype (const)
Definition: Tensor.cuh:437
DataPtrType data_
Raw pointer to where the tensor data begins.
Definition: Tensor.cuh:342
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > & operator=(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t)=default
Assignment.
__host__ __device__ const TensorType::DataPtrType data() const
Returns a raw accessor to our slice (const).
Definition: Tensor.cuh:540
__device__ TensorType::DataType ldg() const
Use the texture cache for reads.
Definition: Tensor.cuh:442
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > upcastOuter()
Definition: Tensor-inl.cuh:434
__host__ bool canCastIndexType() const
Definition: Tensor-inl.cuh:306
__host__ __device__ T & as()
Cast to a different datatype.
Definition: Tensor.cuh:417
__host__ __device__ T & as()
Cast to a different datatype.
Definition: Tensor.cuh:546
TensorType & tensor_
The tensor we&#39;re referencing.
Definition: Tensor.cuh:479
__host__ __device__ Tensor< T, Dim, false, IndexT, PtrTraits > narrow(int dim, IndexT start, IndexT size)
Definition: Tensor-inl.cuh:629
__device__ TensorType::DataType ldg() const
Use the texture cache for reads.
Definition: Tensor.cuh:571
__host__ __device__ Tensor< U, Dim, Contig, IndexT, PtrTraits > cast()
Definition: Tensor-inl.cuh:203
__host__ void copyTo(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t, cudaStream_t stream)
Copies ourselves into a tensor; sizes must match.
Definition: Tensor-inl.cuh:139
TensorType::DataPtrType const data_
The start of our sub-region.
Definition: Tensor.cuh:621
__host__ __device__ bool isSame(const Tensor< T, OtherDim, Contig, IndexT, PtrTraits > &rhs) const
Definition: Tensor-inl.cuh:178
__host__ Tensor< T, Dim, Contig, NewIndexT, PtrTraits > castIndexType() const
Definition: Tensor-inl.cuh:286
__host__ __device__ IndexT numElements() const
Definition: Tensor-inl.cuh:337
__device__ T ldgAs() const
Use the texture cache for reads; cast as a particular type.
Definition: Tensor.cuh:452
__host__ __device__ const SubTensor< TensorType, SubDim-1, PtrTraits > operator[](typename TensorType::IndexType index) const
Definition: Tensor.cuh:511
Tensor< typename TensorType::DataType, SubDim, TensorType::IsContig, typename TensorType::IndexType, PtrTraits > view()
Definition: Tensor.cuh:595
TensorType::DataPtrType const data_
Where our value is located.
Definition: Tensor.cuh:482
__host__ __device__ const IndexT * strides() const
Returns the stride array.
Definition: Tensor.cuh:247
__host__ __device__ IndexT getStride(int i) const
Definition: Tensor.cuh:227
Our tensor type.
Definition: Tensor.cuh:30
__host__ __device__ const IndexT * sizes() const
Returns the size array.
Definition: Tensor.cuh:242
__host__ __device__ PtrTraits< U >::PtrType dataAs()
Cast to a different datatype.
Definition: Tensor.cuh:198
__host__ __device__ size_t getSizeInBytes() const
Definition: Tensor.cuh:237
__host__ __device__ DataPtrType end()
Definition: Tensor.cuh:179
Specialization for a view of a single value (0-dimensional)
Definition: Tensor.cuh:376
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > upcastInner()
Definition: Tensor-inl.cuh:463
__host__ __device__ Tensor< T, SubDim, Contig, IndexT, PtrTraits > view()
Definition: Tensor-inl.cuh:598
TensorType & tensor_
The tensor we&#39;re referencing.
Definition: Tensor.cuh:618
__host__ __device__ DataPtrType end() const
Definition: Tensor.cuh:191
__host__ __device__ const TensorType::DataPtrType data() const
Returns a raw accessor to our slice (const).
Definition: Tensor.cuh:411
__host__ __device__ SubTensor< TensorType, SubDim-1, PtrTraits > operator[](typename TensorType::IndexType index)
Definition: Tensor.cuh:495
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:221
__host__ void copyFrom(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t, cudaStream_t stream)
Copies a tensor into ourselves; sizes must match.
Definition: Tensor-inl.cuh:101
__host__ __device__ const T & as() const
Cast to a different datatype (const).
Definition: Tensor.cuh:552
A SubDim-rank slice of a parent Tensor.
Definition: Tensor.cuh:37
__host__ __device__ PtrTraits< T >::PtrType dataAs()
Cast to a different datatype.
Definition: Tensor.cuh:430
__host__ __device__ TensorType::DataPtrType data()
Returns a raw accessor to our slice.
Definition: Tensor.cuh:534
IndexT stride_[Dim]
Array of strides (in sizeof(T) terms) per each dimension.
Definition: Tensor.cuh:345
__host__ __device__ bool isContiguous() const
Definition: Tensor-inl.cuh:350
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > downcastOuter()
Definition: Tensor-inl.cuh:490
IndexT size_[Dim]
Size per each dimension.
Definition: Tensor.cuh:348
__host__ __device__ TensorType::DataPtrType data()
Returns a raw accessor to our slice.
Definition: Tensor.cuh:405
__host__ __device__ detail::SubTensor< TensorType, Dim-1, PtrTraits > operator[](IndexT)
Returns a read/write view of a portion of our tensor.
Definition: Tensor.cuh:632
__host__ __device__ bool canCastResize() const
Returns true if we can castResize() this tensor to the new type.
Definition: Tensor-inl.cuh:259
__host__ __device__ const T & as() const
Cast to a different datatype (const).
Definition: Tensor.cuh:423
__host__ __device__ const DataPtrType data() const
Returns a raw pointer to the start of our data (const).
Definition: Tensor.cuh:185
__host__ __device__ Tensor< U, Dim, Contig, IndexT, PtrTraits > castResize()
Definition: Tensor-inl.cuh:225