forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MatrixRef.h
109 lines (92 loc) · 2.83 KB
/
MatrixRef.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#pragma once
#include <ATen/Utils.h>
#include <c10/util/ArrayRef.h>
#include <vector>
namespace at {
/// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
/// we can easily view it as a multidimensional array.
///
/// Like ArrayRef, this class does not own the underlying data, it is expected
/// to be used in situations where the data resides in some other buffer.
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
///
/// For now, 2D only (so the copies are actually cheap, without having
/// to write a SmallVector class) and contiguous only (so we can
/// return non-strided ArrayRef on index).
///
/// P.S. dimension 0 indexes rows, dimension 1 indexes columns
template <typename T>
class MatrixRef {
public:
typedef size_t size_type;
private:
/// Underlying ArrayRef
ArrayRef<T> arr;
/// Stride of dim 0 (outer dimension)
size_type stride0;
// Stride of dim 1 is assumed to be 1
public:
/// Construct an empty Matrixref.
/*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
/// Construct an MatrixRef from an ArrayRef and outer stride.
/*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
: arr(arr), stride0(stride0) {
TORCH_CHECK(
arr.size() % stride0 == 0,
"MatrixRef: ArrayRef size ",
arr.size(),
" not divisible by stride ",
stride0)
}
/// @}
/// @name Simple Operations
/// @{
/// empty - Check if the matrix is empty.
bool empty() const {
return arr.empty();
}
const T* data() const {
return arr.data();
}
/// size - Get size a dimension
size_t size(size_t dim) const {
if (dim == 0) {
return arr.size() / stride0;
} else if (dim == 1) {
return stride0;
} else {
TORCH_CHECK(
0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
}
}
size_t numel() const {
return arr.size();
}
/// equals - Check for element-wise equality.
bool equals(MatrixRef RHS) const {
return stride0 == RHS.stride0 && arr.equals(RHS.arr);
}
/// @}
/// @name Operator Overloads
/// @{
ArrayRef<T> operator[](size_t Index) const {
return arr.slice(Index * stride0, stride0);
}
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
U&& Temporary) = delete;
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
std::initializer_list<U>) = delete;
};
} // end namespace at