forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNestedIntSymNodeImpl.h
187 lines (159 loc) · 5.83 KB
/
NestedIntSymNodeImpl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#pragma once
#include <c10/core/ConstantSymNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/intrusive_ptr.h>
#include <cstdint>
#include <string>
namespace c10 {
// The motivating usecase for this is to represent the ragged size structure
// of a jagged tensor [B, [s_0, s_1, s_2], D] as a single integer j0. This
// allows us to simply return [B, j0, D] if someone queries for the size of our
// tensor.
//
// Morally we define comparison between two nested ints to return true if
// that comparison holds for all corresponding elements of the arrays they
// represent. Comparison between a nested int and a plain int is defined
// similarly.
//
// To simulate this desired behavior but also avoid the O(N) cost of checking,
// we associate each raggedness pattern with an integer "id" that can be used as
// a proxy to evaluate equality. We also constrain the range of values for this
// as to enable inequality checks.
//
// We also support a positive integer scalar "coeff" that is used for computing
// strides. For example given, a [B, j0, D] tensor, it can be strided in two
// different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to
// differentiate the two cases.
//
// During tracing the strides of the outputs need to be a function of the size
// and strides of the inputs so it is important that NestedIntSymNode itself is
// able to express this.
class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
public:
// CAUTION: you should probably not be constructing these directly; please
// the higher-level API in python instead (TODO: actually introduce that).
explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff)
: val_(val), coeff_(coeff) {}
bool bool_() override {
return false;
}
bool is_int() override {
return true;
}
bool is_float() override {
return false;
}
bool is_bool() override {
return false;
}
bool is_nested_int() const override {
return true;
}
bool has_hint() override {
return true;
}
c10::SymNode wrap_int(int64_t num) override {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(num));
};
int64_t guard_int(const char* file, int64_t line) override {
TORCH_CHECK(false);
}
double guard_float(const char* file, int64_t line) override {
TORCH_CHECK(false, "not a float");
}
bool guard_bool(const char* file, int64_t line) override {
TORCH_CHECK(false, "not a bool");
}
int64_t int_() override {
TORCH_CHECK(false);
}
std::string str() override {
if (coeff_ == 1) {
return "j" + std::to_string(val_);
}
return std::to_string(coeff_) + "*j" + std::to_string(val_);
}
// NOTE [ Inequalities with nested int ]
//
// The semantics of nested int when it comes to relations is that it is
// treated as integer known to be within a certain range,
//
// j0 \in [2, int64_t::max]
//
// allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False).
// This is a useful default range for the raggedness pattern of a jagged
// tensor (1) since sizes are non-negative, and (2) we need to get past 0/1
// specialization checks.
//
// [ Indeterminate inequalities error out ]
//
// Given the semantic defined above, certain relations like j0 < 3 are thus
// indeterminable. In our impl today, evaluating such relations error
//
// It may seem convenient to just define indeterminate relations to return
// False, but the implementation we maintain in parallel using sympy does not
// allow this.
//
// Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are,
// by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This
// would mean that means that if we define the indeterminate j0 >= 3 to be
// False, the also indeterminate j0 < 3 will be evaluated to be True!
//
// [ Coefficient are assumed positive ]
//
// For the purpose of computing inequalities, we consider the coefficient of
// the nested int to be a positive integer.
//
// Thus, no modifications are needed to the logic since
// j0 >= k implies coeff * j0 >= k
//
c10::SymNode eq(const c10::SymNode& other) override;
c10::SymNode ne(const c10::SymNode& other) override;
c10::SymNode ge(const c10::SymNode& other) override;
c10::SymNode gt(const c10::SymNode& other) override;
c10::SymNode lt(const c10::SymNode& other) override;
c10::SymNode le(const c10::SymNode& other) override;
c10::SymNode mul(const c10::SymNode& other) override;
std::optional<int64_t> nested_int() override {
return val_;
}
std::optional<int64_t> nested_int_coeff() override {
return coeff_;
}
bool is_symbolic() override {
return false;
}
c10::SymNode clone() override;
#define DEFINE_BINARY_NOT_SUPPORTED(name) \
c10::SymNode name(const c10::SymNode& other) override { \
TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \
}
DEFINE_BINARY_NOT_SUPPORTED(add)
DEFINE_BINARY_NOT_SUPPORTED(sub)
DEFINE_BINARY_NOT_SUPPORTED(truediv)
DEFINE_BINARY_NOT_SUPPORTED(pow)
DEFINE_BINARY_NOT_SUPPORTED(floordiv)
DEFINE_BINARY_NOT_SUPPORTED(mod)
DEFINE_BINARY_NOT_SUPPORTED(sym_min)
DEFINE_BINARY_NOT_SUPPORTED(sym_max)
DEFINE_BINARY_NOT_SUPPORTED(sym_and)
DEFINE_BINARY_NOT_SUPPORTED(sym_or)
#undef DEFINE_BINARY_NOT_SUPPORTED
#define DEFINE_NOT_SUPPORTED(name) \
c10::SymNode name() override { \
TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \
}
DEFINE_NOT_SUPPORTED(sym_not)
DEFINE_NOT_SUPPORTED(ceil)
DEFINE_NOT_SUPPORTED(floor)
DEFINE_NOT_SUPPORTED(neg)
DEFINE_NOT_SUPPORTED(sym_float)
#undef DEFINE_NOT_SUPPORTED
private:
int64_t val_;
int64_t coeff_;
};
} // namespace c10