-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsymengine-0.10.1.diff
204 lines (191 loc) · 7.42 KB
/
symengine-0.10.1.diff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
Only in workdir/symengine-0.10.1: build
Only in workdir/symengine-0.10.1: cmake-build-debug
Only in workdir/symengine-0.10.1: .idea
diff -ur workdir/symengine-0.10.1-patched/symengine/basic.h workdir/symengine-0.10.1/symengine/basic.h
--- workdir/symengine-0.10.1-patched/symengine/basic.h 2023-03-23 17:41:51.000000000 +0100
+++ workdir/symengine-0.10.1/symengine/basic.h 2024-04-28 15:01:30.750175803 +0200
@@ -282,7 +282,7 @@
// Common subexpression elimination of symbolic expressions
// Return a vector of replacement pairs and a vector of reduced exprs
void cse(vec_pair &replacements, vec_basic &reduced_exprs,
- const vec_basic &exprs);
+ const vec_basic &exprs, const std::function<RCP<const Basic>()> &next_symbol_gen= nullptr);
/*! This `<<` overloaded function simply calls `p.__str__`, so it allows any
Basic
diff -ur workdir/symengine-0.10.1-patched/symengine/cse.cpp workdir/symengine-0.10.1/symengine/cse.cpp
--- workdir/symengine-0.10.1-patched/symengine/cse.cpp 2023-03-23 17:41:51.000000000 +0100
+++ workdir/symengine-0.10.1/symengine/cse.cpp 2024-04-28 15:01:30.750175803 +0200
@@ -5,12 +5,14 @@
#include <symengine/visitor.h>
#include <queue>
+#include <functional>
namespace SymEngine
{
umap_basic_basic opt_cse(const vec_basic &exprs);
void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs,
- const vec_basic &exprs, umap_basic_basic &opt_subs);
+ const vec_basic &exprs, umap_basic_basic &opt_subs,
+ const std::function<RCP<const Basic>()> &next_symbol_gen);
class FuncArgTracker
{
@@ -426,15 +428,18 @@
set_basic &excluded_symbols;
vec_pair &replacements;
unsigned next_symbol_index = 0;
+ const std::function<RCP<const Basic>()> next_symbol_gen;
public:
using TransformVisitor::bvisit;
using TransformVisitor::result_;
RebuildVisitor(umap_basic_basic &subs_, umap_basic_basic &opt_subs_,
set_basic &to_eliminate_, set_basic &excluded_symbols_,
- vec_pair &replacements_)
+ vec_pair &replacements_,
+ const std::function<RCP<const Basic>()> &next_symbol_gen)
: subs(subs_), opt_subs(opt_subs_), to_eliminate(to_eliminate_),
- excluded_symbols(excluded_symbols_), replacements(replacements_)
+ excluded_symbols(excluded_symbols_), replacements(replacements_),
+ next_symbol_gen(next_symbol_gen)
{
}
RCP<const Basic> apply(const RCP<const Basic> &orig_expr) override
@@ -465,12 +470,17 @@
}
RCP<const Basic> next_symbol()
{
- RCP<const Basic> sym = symbol("x" + to_string(next_symbol_index));
- next_symbol_index++;
- if (excluded_symbols.find(sym) == excluded_symbols.end()) {
- return sym;
+ if (next_symbol_gen== nullptr) {
+ RCP<const Symbol> sym = symbol("x" + to_string(next_symbol_index));
+ sym->set_link_flag(true);
+ next_symbol_index++;
+ if (excluded_symbols.find(sym) == excluded_symbols.end()) {
+ return sym;
+ } else {
+ return next_symbol();
+ }
} else {
- return next_symbol();
+ return next_symbol_gen();
}
};
void bvisit(const FunctionSymbol &x)
@@ -493,7 +503,8 @@
};
void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs,
- const vec_basic &exprs, umap_basic_basic &opt_subs)
+ const vec_basic &exprs, umap_basic_basic &opt_subs,
+ const std::function<RCP<const Basic>()> &next_symbol_gen)
{
set_basic to_eliminate;
set_basic seen_subexp;
@@ -536,7 +547,8 @@
umap_basic_basic subs;
RebuildVisitor rebuild_visitor(subs, opt_subs, to_eliminate,
- excluded_symbols, replacements);
+ excluded_symbols, replacements,
+ next_symbol_gen);
for (auto &e : exprs) {
auto reduced_e = rebuild_visitor.apply(e);
@@ -545,12 +557,14 @@
}
void cse(vec_pair &replacements, vec_basic &reduced_exprs,
- const vec_basic &exprs)
+ const vec_basic &exprs,
+ const std::function<RCP<const Basic>()> &next_symbol_gen)
{
// Find other optimization opportunities.
umap_basic_basic opt_subs = opt_cse(exprs);
// Main CSE algorithm.
- tree_cse(replacements, reduced_exprs, exprs, opt_subs);
+ tree_cse(replacements, reduced_exprs, exprs, opt_subs, next_symbol_gen);
}
+
} // namespace SymEngine
diff -ur workdir/symengine-0.10.1-patched/symengine/serialize-cereal.h workdir/symengine-0.10.1/symengine/serialize-cereal.h
--- workdir/symengine-0.10.1-patched/symengine/serialize-cereal.h 2023-03-23 17:41:51.000000000 +0100
+++ workdir/symengine-0.10.1/symengine/serialize-cereal.h 2024-04-28 15:09:45.473526228 +0200
@@ -41,7 +41,7 @@
template <class Archive>
inline void save_basic(Archive &ar, const Symbol &b)
{
- ar(b.__str__());
+ ar(b.__str__(), b.is_link(), b.get_local_symbol_index(), b.get_target_tensor_id());
}
template <class Archive>
inline void save_basic(Archive &ar, const Mul &b)
@@ -333,8 +333,14 @@
RCP<const Basic> load_basic(Archive &ar, RCP<const Symbol> &)
{
std::string name;
- ar(name);
- return symbol(name);
+ bool isLink;
+ size_t localSymbolIndex, targetTensorId;
+ ar(name, isLink, localSymbolIndex, targetTensorId);
+ auto s = symbol(name);
+ s->set_link_flag(isLink);
+ s->set_local_symbol_index(localSymbolIndex);
+ s->set_target_tensor_id(targetTensorId);
+ return s;
}
template <class Archive>
RCP<const Basic> load_basic(Archive &ar, RCP<const Mul> &)
diff -ur workdir/symengine-0.10.1-patched/symengine/symbol.cpp workdir/symengine-0.10.1/symengine/symbol.cpp
--- workdir/symengine-0.10.1-patched/symengine/symbol.cpp 2023-03-23 17:41:51.000000000 +0100
+++ workdir/symengine-0.10.1/symengine/symbol.cpp 2024-04-28 15:06:38.310186413 +0200
@@ -3,6 +3,30 @@
namespace SymEngine
{
+//! DeepFusion Patch:
+bool Symbol::is_link() const{
+ return isLink;
+}
+
+void Symbol::set_link_flag(bool flag) const{
+ isLink = flag;
+}
+
+size_t Symbol::get_local_symbol_index() const{
+ return localSymbolIndex;
+}
+
+void Symbol::set_local_symbol_index(size_t index) const {
+ localSymbolIndex = index;
+}
+
+size_t Symbol::get_target_tensor_id() const{
+ return targetTensorId;
+}
+
+void Symbol::set_target_tensor_id(size_t id) const{
+ targetTensorId = id;
+}
Symbol::Symbol(const std::string &name)
: name_{name} {SYMENGINE_ASSIGN_TYPEID()}
diff -ur workdir/symengine-0.10.1-patched/symengine/symbol.h workdir/symengine-0.10.1/symengine/symbol.h
--- workdir/symengine-0.10.1-patched/symengine/symbol.h 2023-03-23 17:41:51.000000000 +0100
+++ workdir/symengine-0.10.1/symengine/symbol.h 2024-04-28 15:06:38.300186412 +0200
@@ -17,8 +17,22 @@
//! name of Symbol
std::string name_;
+ //! DeepFusion Patch:
+ mutable bool isLink=false;
+ mutable size_t localSymbolIndex=0;
+ mutable size_t targetTensorId=0;
+
public:
IMPLEMENT_TYPEID(SYMENGINE_SYMBOL)
+
+ //! DeepFusion Patch:
+ bool is_link() const;
+ void set_link_flag(bool flag) const;
+ size_t get_local_symbol_index() const;
+ void set_local_symbol_index(size_t index) const;
+ size_t get_target_tensor_id() const;
+ void set_target_tensor_id(size_t id) const;
+
//! Symbol Constructor
explicit Symbol(const std::string &name);
//! \return Size of the hash