-
Notifications
You must be signed in to change notification settings - Fork 0
/
ppl_graph_codegen.py
executable file
·321 lines (274 loc) · 13.5 KB
/
ppl_graph_codegen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
#
# This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python.
#
# License: MIT (see LICENSE.txt)
#
# 12. Mar 2018, Tobias Kohn
# 07. May 2018, Tobias Kohn
#
import datetime
import importlib
from ..graphs import *
from ..ppl_ast import *
class GraphCodeGenerator(object):
"""
In contrast to the more general code generator `CodeGenerator`, this class creates the code for a graph-based
model. The output of the method `generate_model_code()` is therefore the code of a class `Model` with functions
such as `gen_log_prob()` or `gen_prior_samples()`, including all necessary imports.
You want to change this class if you need additional (or adapted) methods in your model-class.
Usage:
```
graph = ... # <- actually generated by the graph-factory/generator
graph_code_gen = GraphCodeGenerator(graph.nodes, state_object='state', imports='import distributions as dist')
code = graph_code_gen.generate_model_code()
my_globals = {}
exec(code, my_globals)
Model = my_globals['Model']
model = Model(graph.vertices, graph.arcs, graph.data, graph.conditionals)
```
The state-object specifies the actual name of the dictionary/map that holds the state, i. e. all the variables.
When any state-object is given, the generated code reads, say, `state['x']` instead of purely `x`.
Hacking:
The `generate_model_code`-method uses three fixed methods to generate the code for `__init__`, `__repr__` as
well as the doc-string: `_generate_doc_string`, `_generate_init_method`, `_generate_repr_method`. After that,
it scans the object instance of `GraphCodeGenerator` for public methods, and assumes that each method returns
the code for the respective method.
Say, for instance, you wanted your Model-class to have a method `get_all_nodes` with the following code:
```
def get_all_nodes(self):
return set.union(self.vertices, self.conditionals)
```
You then add the following method to the `GraphCodeGenerator` and be done.
```
def get_all_nodes(self):
return "return set.union(self.vertices, self.conditionals)"
```
If, on the other hand, you need some additional parameters/arguments for your method, then you should return
a tuple with the first element being the parameters as a string, and the second element being the code as before.
```
def get_all_nodes(self):
return "param1, param2", "return set.union(self.vertices, self.conditionals)"
```
Of course, you do not need to actually change this class, but you can derive a new class from it, if you wish.
"""
def __init__(self, nodes: list, state_object: Optional[str]=None, imports: Optional[str]=None):
self.nodes = nodes
self.state_object = state_object
self.imports = imports
self.bit_vector_name = None
self.logpdf_suffix = None
def _complete_imports(self, imports: str):
if imports != '':
has_dist = False
uses_numpy = False
uses_torch = False
uses_pyfo = False
for s in imports.split('\n'):
s = s.strip()
if s.endswith(' dist') or s.endswith('.dist'):
has_dist = True
if s.startswith('from'):
s = s[5:]
elif s.startswith('import'):
s = s[7:]
i = 0
while i < len(s) and 'A' <= s[i].upper() <= 'Z':
i += 1
m = s[:i]
uses_numpy = uses_numpy or m == 'numpy'
uses_torch = uses_torch or m == 'torch'
uses_pyfo = uses_pyfo or m == 'pyfo'
if uses_torch or uses_numpy:
self.logpdf_suffix = ''
if not has_dist:
if uses_torch and uses_pyfo:
return 'import pyfo.distributions as dist\n'
else:
return 'import torch.distributions as dist\n'
return ''
def generate_model_code(self, *,
class_name: str='Model',
base_class: str='',
imports: str='') -> str:
if self.imports is not None:
imports = self.imports + "\n" + imports
if base_class is None:
base_class = ''
if '.' in base_class:
idx = base_class.rindex('.')
base_module = base_class[:idx]
try:
importlib.import_module(base_module)
base_class = base_class[idx+1:]
imports = "from {} import {}\n".format(base_module, base_class) + imports
except:
pass
# try:
# graph_module = 'pyppl.aux.graph_plots'
# m = importlib.import_module(graph_module)
# names = [n for n in dir(m) if not n.startswith('_')]
# if len(names) > 1:
# names = [n for n in names if n[0].isupper()]
# if len(names) == 1:
# if base_class != '':
# base_class += ', '
# base_class += '_' + names[0]
# imports = "from {} import {} as _{}\n".format(graph_module, names[0], names[0]) + imports
# except ModuleNotFoundError:
# pass
imports = self._complete_imports(imports) + imports
result = ["# {}".format(datetime.datetime.now()),
imports,
"class {}({}):".format(class_name, base_class)]
doc_str = self._generate_doc_string()
if doc_str is not None and doc_str != '':
result.append('\t"""\n\t{}\n\t"""'.format(doc_str.replace('\n', '\n\t')))
result.append('')
init_method = self._generate_init_method()
if init_method is not None:
result.append('\t' + init_method.replace('\n', '\n\t'))
repr_method = self._generate_repr_method()
if repr_method is not None:
result.append('\t' + repr_method.replace('\n', '\n\t'))
methods = [x for x in dir(self) if not x.startswith('_') and x != 'generate_model_code']
for method_name in methods:
method = getattr(self, method_name)
if callable(method):
code = method()
if type(code) is tuple and len(code) == 2:
args, code = code
args = 'self, ' + args
else:
args = 'self'
code = code.replace('\n', '\n\t\t')
result.append("\tdef {}({}):\n\t\t{}\n".format(method_name, args, code))
return '\n'.join(result)
def _generate_doc_string(self):
return ''
def _generate_init_method(self):
return "def __init__(self, vertices: set, arcs: set, data: set, conditionals: set):\n" \
"\tsuper().__init__()\n" \
"\tself.vertices = vertices\n" \
"\tself.arcs = arcs\n" \
"\tself.data = data\n" \
"\tself.conditionals = conditionals\n"
def _generate_repr_method(self):
s = "def __repr__(self):\n" \
"\tV = '\\n'.join(sorted([repr(v) for v in self.vertices]))\n" \
"\tA = ', '.join(['({}, {})'.format(u.name, v.name) for (u, v) in self.arcs]) if len(self.arcs) > 0 else ' -'\n" \
"\tC = '\\n'.join(sorted([repr(v) for v in self.conditionals])) if len(self.conditionals) > 0 else ' -'\n" \
"\tD = '\\n'.join([repr(u) for u in self.data]) if len(self.data) > 0 else ' -'\n" \
"\tgraph = 'Vertices V:\\n{V}\\nArcs A:\\n {A}\\n\\nConditions C:\\n{C}\\n\\nData D:\\n{D}\\n'.format(V=V, A=A, C=C, D=D)\n" \
"\tgraph = '#Vertices: {}, #Arcs: {}\\n'.format(len(self.vertices), len(self.arcs)) + graph\n" \
"\treturn graph\n"
return s
def get_vertices(self):
return "return self.vertices"
def get_vertices_names(self):
return "return [v.name for v in self.vertices]"
def get_arcs(self):
return "return self.arcs"
def get_arcs_names(self):
return "return [(u.name, v.name) for (u, v) in self.arcs]"
def get_conditions(self):
return "return self.conditionals"
def gen_cond_vars(self):
return "return [c.name for c in self.conditionals]"
def gen_if_vars(self):
return "return [v.name for v in self.vertices if v.is_conditional and v.is_sampled and v.is_continuous]"
def gen_cont_vars(self):
return "return [v.name for v in self.vertices if v.is_continuous and not v.is_conditional and v.is_sampled]"
def gen_disc_vars(self):
return "return [v.name for v in self.vertices if v.is_discrete and v.is_sampled]"
def get_vars(self):
return "return [v.name for v in self.vertices if v.is_sampled]"
def is_torch_imported(self):
return "import sys \nprint('torch' in sys.modules) \nprint(torch.__version__) \nprint(type(torch.tensor)) \nimport inspect \nprint(inspect.getfile(torch))"
def _gen_code(self, buffer: list, code_for_vertex, *, want_data_node: bool=True, flags=None):
distribution = None
state = self.state_object
if self.bit_vector_name is not None:
if state is not None:
buffer.append("{}['{}'] = 0".format(state, self.bit_vector_name))
else:
buffer.append("{} = 0".format(self.bit_vector_name))
for node in self.nodes:
name = node.name
if state is not None:
name = "{}['{}']".format(state, name)
if isinstance(node, Vertex):
if flags is not None:
code = "dst_ = {}".format(node.get_code(**flags))
else:
code = "dst_ = {}".format(node.get_code())
if code != distribution:
buffer.append(code)
distribution = code
code = code_for_vertex(name, node)
if type(code) is str:
buffer.append(code)
elif type(code) is list:
buffer += code
elif isinstance(node, ConditionNode) and self.bit_vector_name is not None:
bit_vector = "{}['{}']".format(state, self.bit_vector_name) if state is not None else self.bit_vector_name
code = "_c = {}\n{} = _c".format(node.get_code(), name)
buffer.append(code)
buffer.append("{} |= {} if _c else 0".format(bit_vector, node.bit_index))
elif want_data_node or not isinstance(node, DataNode):
code = "{} = {}".format(name, node.get_code())
buffer.append(code)
def gen_log_prob(self):
def code_for_vertex(name: str, node: Vertex):
cond_code = node.get_cond_code(state_object=self.state_object)
if cond_code is not None:
result = cond_code + "\tlog_prob = log_prob + dst_.log_prob({})".format(name)
else:
result = "log_prob = log_prob + dst_.log_prob({})".format(name)
if self.logpdf_suffix is not None:
result = result + self.logpdf_suffix
return result
logpdf_code = ["log_prob = 0"]
self._gen_code(logpdf_code, code_for_vertex=code_for_vertex, want_data_node=False)
logpdf_code.append("return log_prob")
logpdf_code.insert(0, "try:")
# return 'state', '\n'.join(logpdf_code)
code = ['\n\t'.join(logpdf_code), "\nexcept(ValueError, RuntimeError) as e:\n\tprint('****Warning: Target density is ill-defined****')"]
return 'state', ''.join(code)
# def gen_log_prob_transformed(self):
# def code_for_vertex(name: str, node: Vertex):
# cond_code = node.get_cond_code(state_object=self.state_object)
# if cond_code is not None:
# result = cond_code + "log_prob = log_prob + dst_.log_prob({})".format(name)
# else:
# result = "log_prob = log_prob + dst_.log_prob({})".format(name)
# if self.logpdf_suffix is not None:
# result += self.logpdf_suffix
# return result
# # Note to self : To change suffix for torch or numpy look at line 87-88 in compiled imports (above)
# logpdf_code = ["log_prob = 0"]
# self._gen_code(logpdf_code, code_for_vertex=code_for_vertex, want_data_node=False, flags={'transformed': True})
# logpdf_code.append("return log_prob.sum()")
# return 'state', '\n'.join(logpdf_code)
def gen_prior_samples(self):
def code_for_vertex(name: str, node: Vertex):
if node.has_observation:
return "{} = {}".format(name, node.observation)
sample_size = node.sample_size
if sample_size is not None and sample_size > 1:
return "{} = dst_.sample(sample_size={})".format(name, sample_size)
else:
return "{} = dst_.sample()".format(name)
state = self.state_object
sample_code = []
if state is not None:
sample_code.append(state + " = {}")
self._gen_code(sample_code, code_for_vertex=code_for_vertex, want_data_node=True)
if state is not None:
sample_code.append("return " + state)
return '\n'.join(sample_code)
def gen_cond_bit_vector(self):
code = "result = 0\n" \
"for cond in self.conditionals:\n" \
"\tresult = cond.update_bit_vector(state, result)\n" \
"return result"
return 'state', code