forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Batching.cpp
131 lines (122 loc) · 5.37 KB
/
Batching.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <ATen/BatchedTensorImpl.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/VmapTransforms.h>
namespace at { namespace native {
// Adds a batch dimension to the tensor `self` out-of-place
Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) {
return addBatchDim(self, level, batch_dim);
}
static bool has_level(const Tensor& self, int64_t level) {
const auto* batched = maybeGetBatchedImpl(self);
if (!batched) {
return false;
}
auto bdims = batched->bdims();
auto* it = std::find_if(bdims.begin(), bdims.end(), [&](const BatchDim& bdim) {
return bdim.level() == level;
});
return it != bdims.end();
}
// Returns a Tensor with batch dim with level `level` turned into a regular dimension,
// as well as a logical dim index of where said dimension is in the returned tensor.
// A call to this function is always followed by a call to `movedim`.
//
// Preconditions: A BatchDim with level `level` must exist inside `batched`.
//
// The reason why we want to return the index of where said dimension is in the returned
// tensor is because we want to keep track of which dimension used to be the batch
// dimension so that we can move it to the correct logical dimension specified by
// `out_dims` in vmap. For example, if we had
// >>> x = torch.randn(2, 3, 5)
// >>> vmap(lambda x: x, in_dims=0, out_dims=1)(x)
// then right when we are about to exit the vmap block, x is a BatchedTensor with a
// batch dimension at (physical) index 0. Note that the batch dimension doesn't
// always have to exist at (physical) index 0. When we undo the batch dimension,
// we want to move it to dimension 1 (as specified by out_dims). So we return the
// index at which the batch dim appears so that we can move it to the correct place.
// later down the line via a call to `movedim`.
static std::pair<Tensor,int64_t> remove_existing_batch_dim(
const BatchedTensorImpl* batched, int64_t level) {
auto bdims = batched->bdims();
if (bdims.size() == 1) {
TORCH_INTERNAL_ASSERT(bdims[0].level() == level);
return std::make_pair(batched->value(), bdims[0].dim());
}
BatchDims new_bdims;
int64_t newly_exposed_physical_dim = -1;
new_bdims.reserve(bdims.size() - 1);
for (const auto& bdim : bdims) {
if (bdim.level() == level) {
newly_exposed_physical_dim = bdim.dim();
} else {
new_bdims.push_back(bdim);
}
}
// Because a BatchDim with level `level` must exist inside `batched,
// we should have found a `newly_exposed_logical_dim`.
TORCH_INTERNAL_ASSERT(newly_exposed_physical_dim != -1);
int64_t num_batch_dims_before_newly_exposed_physical_dim = std::count_if(
new_bdims.begin(), new_bdims.end(),
[&](const BatchDim& bdim) {
return bdim.dim() < newly_exposed_physical_dim;
});
int64_t newly_exposed_logical_dim =
newly_exposed_physical_dim - num_batch_dims_before_newly_exposed_physical_dim;
auto result_tensor = makeBatched(batched->value(), std::move(new_bdims));
return std::make_pair(std::move(result_tensor), newly_exposed_logical_dim);
}
// Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src`
// while preserving the order of other existing dimensions.
// We should probably add np.moveaxis (it is more general) to PyTorch. (#36048)
// When we do, replace the following with it.
static Tensor movedim(const Tensor& self, int64_t src, int64_t dst) {
auto logical_dim = self.dim();
src = maybe_wrap_dim(src, logical_dim);
dst = maybe_wrap_dim(dst, logical_dim);
if (src == dst) {
return self;
}
VmapDimVector permutation;
permutation.reserve(logical_dim);
for (int64_t dim = 0; dim < logical_dim; dim++) {
if (dim == src) {
continue;
}
permutation.push_back(dim);
}
permutation.insert(permutation.begin() + dst, src);
return self.permute(permutation);
}
// Removes the batch dim with level `level` from `self`. If this causes the
// last batch dim to be removed from a BatchedTensor, then this returns a
// regular Tensor.
//
// If the `level` of the batch dim to remove does not exist in `self`, then we
// add the batch dim in. This can happen if `self` didn't interact with a tensor
// inside the vmap level, for example,
// self = torch.randn(3)
// y = torch.randn(5)
// out = vmap(lambda x: vmap(lambda y: x)(y))(self)
// assert out.shape == (3, 5)
// Inside the inner vmap, `x` is a BatchedTensor with a single batch dimension
// corresponding to the *outer* vmap level and it doesn't have any dimensions that
// correspond to the inner vmap level so we need to create one for the user.
//
// `out_dim` controls where we should put the batch dimension in the output tensor.
Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, int64_t out_dim) {
if (!has_level(self, level)) {
auto self_sizes = self.sizes();
VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end());
expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size);
return self.expand(expanded_sizes);
}
// Must be batched if has_level(self, /*any_level*/)
const auto* batched = maybeGetBatchedImpl(self);
TORCH_INTERNAL_ASSERT(batched != nullptr);
Tensor self_without_bdim;
int64_t newly_exposed_logical_dim;
std::tie(self_without_bdim, newly_exposed_logical_dim) = remove_existing_batch_dim(batched, level);
return movedim(self_without_bdim, newly_exposed_logical_dim, out_dim);
}
} // namespace native
} // namespace at