forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFunctionalStorageImpl.cpp
136 lines (120 loc) · 4.92 KB
/
FunctionalStorageImpl.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <c10/core/CPUAllocator.h>
#include <c10/util/Exception.h>
#include <vector>
namespace at {
namespace functionalization {
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
if (out_idx == this->out_index) return *this;
return ViewMeta(forward_fn, reverse_fn, out_idx);
}
Alias::Alias(const at::Tensor& base) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base));
base_ = base;
}
const at::Tensor& Alias::base() const {
return base_;
}
void Alias::add_update(const at::Tensor& updated_val, const std::vector<ViewMeta>& metas) {
updates_.push_back({updated_val, metas});
generation_++;
}
// Note [Functionalization: Alias Removal Part 2]
// See Note [Functionalization: Alias Removal] for more details.
// This function applies a single update from one of the views to the Alias object.
// We start out with <original_base> and <mutated_view>, and our goal is to end up with <mutated_base>.
// Consider this program:
//
// base = ...
// a = base.view1()
// b = a.view2()
// c = b.view3()
// c.add_(3)
//
// Then the functionalization pass will queue an update as follows:
//
// update.new_val = c # the updated value of c
// update.view_metas = [view1_meta, view2_meta, view3_meta]
//
// Syncing any of a, b or c will eventually call apply_update() on the alias, and the following will run:
//
// tmp_values = [base, a, b] # NB: c is not necessary
// t = update.new_val
// t = view3_inverse(b, t, 0) # 0 is output index, these are all single output views so it's 0
// t = view2_inverse(a, t, 0)
// t = view1_inverse(base, t, 0) # t now represents the updated alias.
// alias.base_ = t
const Tensor apply_update(const Alias::Update& update, const Tensor& base) {
at::Tensor t = update.new_val;
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
if (update.view_metas.size() == 0) return t;
std::vector<at::Tensor> tmp_values({base});
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
// All of these ops require additional information to recover the sizes of the original tensor.
// If need to, we could probably apply this optimization and only bother computing tmp_values
// for those necessary view ops.
tmp_values.push_back(std::move(next_view));
}
for(int i = update.view_metas.size()-1; i >= 0; --i) {
int64_t out_idx = update.view_metas[i].out_index;
// Each view inverse is implemented in ViewInverses.cpp.
t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
}
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
return t;
}
bool Alias::apply_updates() {
// N.B:none of the tensors used in this function should be FunctionalTensorWrappers at this point.
// The only reason we currently need the TLS exclude guard here is because of functorch's DynamicLayer stack.
// It adds the Functionalize key into TLS before redispatching to the functionalization kernels,
// which means that we need to explicitly exclude it here before doing any other work underneath the pass.
at::AutoDispatchSkipFunctionalize guard;
bool any_updates = updates_.size() > 0;
for (auto& update_data: updates_) {
base_ = apply_update(update_data, base_);
}
updates_.clear();
return any_updates;
}
c10::SymInt get_nbytes(const Tensor& value) {
if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
// Today, the two implementations of SymInt are in Python (proxy tensor),
// and lazy tensor (LTC/XLA).
// LTC hasn't implemented SymInt support yet though (torch::lazy::SymIntNodeImpl).
// Once it does, we should remove this check.
if (value.key_set().has(c10::DispatchKey::Python)) {
return value.storage().sym_nbytes();
}
}
// XLA storage objects also do not properly track nbytes.
return at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset());
}
FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& value)
: c10::StorageImpl(
c10::StorageImpl::use_byte_size_t(),
get_nbytes(value),
DataPtr{nullptr, value.device()},
GetAllocator(kMeta),
/*resizeable=*/true
),
alias_(Alias(value))
{}
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& view_metas) {
alias_.add_update(updated_val, view_metas);
}
bool FunctionalStorageImpl::apply_updates() {
return alias_.apply_updates();
}
const Tensor& FunctionalStorageImpl::base() {
return alias_.base();
}
size_t FunctionalStorageImpl::generation() const {
return alias_.generation();
}
} // namespace functionalization
} // namespace at