forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
IndexingUtils.h
133 lines (117 loc) · 4.39 KB
/
IndexingUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#pragma once
#include <ATen/ExpandUtils.h>
#include <ATen/native/TensorIterator.h>
#include <limits>
namespace at { namespace native {
TORCH_API bool canUse32BitIndexMath(const at::Tensor &t, int64_t max_elem=std::numeric_limits<int32_t>::max());
[[noreturn]]
static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
" does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
}
static std::vector<Tensor> expandTensors(const Tensor & self, TensorList indices) {
// If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
std::vector<Tensor> result;
for (const auto & index : indices) {
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
if (index.scalar_type() == kByte) {
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
" please use a dtype torch.bool instead.");
}
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
// corresponding dimensions in self
for (int64_t j = 0; j < index.dim(); j++) {
int64_t srcIdx = result.size() + j;
if (index.size(j) != self.size(srcIdx)) {
invalid_mask(self, srcIdx, index, j);
}
}
// Replace with nonzeros
auto nonzero = index.nonzero();
for (int64_t j = 0; j < index.dim(); j++) {
result.emplace_back(nonzero.select(1, j));
}
} else {
result.emplace_back(index);
}
}
return result;
}
static void checkIndexTensorTypes(TensorList indices) {
for (auto& tensor : indices) {
if (tensor.defined()) {
auto scalarType = tensor.scalar_type();
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
}
}
}
}
static bool hasContiguousSubspace(TensorList tl) {
// true if all the non-null tensors are adjacent
auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
auto start = std::find_if(tl.begin(), tl.end(), isDefined);
auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
auto it = std::find_if(start, stop.base(), isNull);
return it == stop.base();
}
// Transposes the tensor and indices together so that all the non-null indices
// index the first k dimensions of the tensor. Returns the transposed tensor
// and the reordered indices. For example:
// transposeToFront(tensor, {nullptr, a, nullptr, b})
// returns
// tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
static std::tuple<Tensor, std::vector<Tensor>>
transposeToFront(Tensor self, TensorList indices) {
std::vector<int64_t> dims;
std::vector<Tensor> transposedIndices;
dims.reserve(self.dim());
for (auto i = decltype(self.dim()){0}; i < self.dim(); i++) {
if (indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back(indices[i]);
}
}
for (auto i = decltype(self.dim()){0}; i < self.dim(); i++) {
if (!indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back();
}
}
return std::make_tuple(self.permute(dims), std::move(transposedIndices));
}
inline std::tuple<Tensor, std::vector<Tensor>, std::vector<int64_t>>
transposeToFrontAndInvPerm(Tensor self, TensorList indices) {
std::vector<int64_t> dims;
std::vector<int64_t> invPerm;
std::vector<Tensor> transposedIndices;
dims.reserve(self.dim());
invPerm.resize(self.dim());
for (auto i = decltype(self.dim()){0}; i < self.dim(); i++) {
if (indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back(indices[i]);
}
}
for (auto i = decltype(self.dim()){0}; i < self.dim(); i++) {
if (!indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back();
}
}
for (auto i = decltype(self.dim()){0}; i < self.dim(); i++) {
invPerm[dims[i]] = i;
}
return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
}
struct AdvancedIndex {
AdvancedIndex(const Tensor& src, TensorList indices);
Tensor src;
std::vector<Tensor> indices;
DimVector indexed_sizes;
DimVector indexed_strides;
int64_t dims_before;
int64_t dims_after;
};
}}