-
Notifications
You must be signed in to change notification settings - Fork 0
/
bf2onnx.py
257 lines (240 loc) · 10.2 KB
/
bf2onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# Brainfuck to ONNX Compiler
import onnx
import onnx.helper as oh
import typing as T
from collections import defaultdict
class IRBuilder:
def __init__(self):
self.cur_nodes = []
self.nodes_stack = []
self.used_names = set()
self.name_count = defaultdict(int)
def emit(self, op: str, *, inputs, outputs, **kwargs):
node = oh.make_node(op, inputs=inputs, outputs=outputs, **kwargs)
self.cur_nodes.append(node)
return node
def pop_graph(self, name: str, *, inputs, outputs, **kwargs):
nodes = self.cur_nodes
if self.nodes_stack:
self.cur_nodes = self.nodes_stack.pop()
else:
self.cur_nodes = []
return oh.make_graph(nodes, name, inputs=inputs, outputs=outputs, **kwargs)
def enter_graph(self):
self.nodes_stack.append(self.cur_nodes)
self.cur_nodes = []
def generate_name(self, name):
attempt = name
while attempt in self.used_names:
self.name_count[name] += 1
attempt = f'{name}_{self.name_count[name]}'
self.used_names.add(attempt)
return attempt
def parse(bf: str) -> list:
def _parse(s: str) -> T.Optional[tuple[str, str]]:
cur = s
ret = []
while cur:
if cur[0] in '+-<>.,':
ret.append(cur[0])
cur = cur[1:]
continue
if cur[0] == ']':
break
if cur[0] != '[':
cur = cur[1:]
continue
maybe_parsed = _parse(cur[1:])
if maybe_parsed is None:
return None
subtree, new = maybe_parsed
if not new or new[0] != ']':
return None
ret.append(subtree)
cur = new[1:]
return ret, cur
maybe_parsed = _parse(bf)
if maybe_parsed is None:
raise ValueError('parse error')
tree, rem = maybe_parsed
if rem:
raise ValueError('input is not exhausted after parsing')
return tree
# Maybe the names should be managed by the builder
def compile_single_bf_inst(builder, inst, mem_name: str, pointer_name: str, output_name: str) -> tuple[str, str, str]:
if isinstance(inst, list):
cond = builder.generate_name('cond')
iszero = builder.generate_name('iszero')
isnonzero = builder.generate_name('isnonzero')
value = builder.generate_name('value')
succ_ptr = builder.generate_name('succ_ptr')
builder.emit('Add', inputs=[pointer_name, 'one_i32'], outputs=[succ_ptr])
builder.emit('Slice', inputs=[mem_name, pointer_name, succ_ptr], outputs=[value])
builder.emit('Equal', inputs=[value, 'zero_i32'], outputs=[iszero])
builder.emit('Not', inputs=[iszero], outputs=[isnonzero])
builder.enter_graph()
mem_inner = mem = builder.generate_name('mem')
ptr_inner = ptr = builder.generate_name('pointer')
out_inner = out = builder.generate_name('output')
for i in inst:
mem, ptr, out = compile_single_bf_inst(builder, i, mem, ptr, out)
iszero_inner = builder.generate_name('iszero_inner')
isnonzero_inner = builder.generate_name('isnonzero_inner')
value_inner = builder.generate_name('value_inner')
succ_ptr_inner = builder.generate_name('succ_ptr_inner')
builder.emit('Add', inputs=[ptr, 'one_i32'], outputs=[succ_ptr_inner])
builder.emit('Slice', inputs=[mem, ptr, succ_ptr], outputs=[value_inner])
builder.emit('Equal', inputs=[value_inner, 'zero_i32'], outputs=[iszero_inner])
builder.emit('Not', inputs=[iszero_inner], outputs=[isnonzero_inner])
loop_body = builder.pop_graph(
builder.generate_name('loop_body'),
inputs=[
oh.make_tensor_value_info("iter_count", oh.TensorProto.INT64, []),
oh.make_tensor_value_info("cond_in", oh.TensorProto.BOOL, [1]),
oh.make_tensor_value_info(mem_inner, oh.TensorProto.INT32, [None]),
oh.make_tensor_value_info(ptr_inner, oh.TensorProto.INT32, [1]),
oh.make_tensor_value_info(out_inner, oh.TensorProto.INT32, [None]),
],
outputs=[
oh.make_tensor_value_info(isnonzero_inner, oh.TensorProto.BOOL, [None]), ## ?
oh.make_tensor_value_info(mem, oh.TensorProto.INT32, [None]),
oh.make_tensor_value_info(ptr, oh.TensorProto.INT32, [1]),
oh.make_tensor_value_info(out, oh.TensorProto.INT32, [None]),
]
)
mem_after_loop = builder.generate_name('mem_after_loop')
ptr_after_loop = builder.generate_name('ptr_after_loop')
out_after_loop = builder.generate_name('out_after_loop')
loop_body = builder.emit(
'Loop',
inputs=['', isnonzero, mem_name, pointer_name, output_name],
outputs=[mem_after_loop, ptr_after_loop, out_after_loop],
body=loop_body
)
return mem_after_loop, ptr_after_loop, out_after_loop
if inst in '><':
ptr = builder.generate_name('pointer')
if inst == '>':
builder.emit('Add', inputs=[pointer_name, 'one_i32'], outputs=[ptr])
else:
builder.emit('Sub', inputs=[pointer_name, 'one_i32'], outputs=[ptr])
return mem_name, ptr, output_name
elif inst in '+-':
value = builder.generate_name('value')
succ_ptr = builder.generate_name('succ_ptr')
builder.emit('Add', inputs=[pointer_name, 'one_i32'], outputs=[succ_ptr])
builder.emit('Slice', inputs=[mem_name, pointer_name, succ_ptr], outputs=[value])
mod_value = builder.generate_name('succ_value')
if inst == '+':
builder.emit('Add', inputs=[value, 'one_i32'], outputs=[mod_value])
else:
builder.emit('Sub', inputs=[value, 'one_i32'], outputs=[mod_value])
left = builder.generate_name('left')
right = builder.generate_name('right')
builder.emit('Slice', inputs=[mem_name, 'zero_i32', pointer_name], outputs=[left])
builder.emit('Slice', inputs=[mem_name, succ_ptr, 'mem_end'], outputs=[right])
mem = builder.generate_name('mem')
builder.emit('Concat', inputs=[left, mod_value, right], outputs=[mem], axis=0)
return mem, pointer_name, output_name
elif inst == '.':
char = builder.generate_name('char')
succ_ptr = builder.generate_name('succ_ptr')
builder.emit('Add', inputs=[pointer_name, 'one_i32'], outputs=[succ_ptr])
builder.emit('Slice', inputs=[mem_name, pointer_name, succ_ptr], outputs=[char])
out = builder.generate_name('output')
builder.emit('Concat', inputs=[output_name, char], outputs=[out], axis=0)
return mem_name, pointer_name, out
raise NotImplementedError(f'inst {inst} not implemented')
def compile_bf_to_onnx(tree):
builder = IRBuilder()
MEMORY_SIZE = 30000
mem = builder.generate_name('mem')
mem_end = 'mem_end' # builder.generate_name('mem_end')
final_output = builder.generate_name('output')
output = builder.generate_name('output')
pointer = builder.generate_name('pointer')
builder.emit('Constant', inputs=[], outputs=[mem], value=oh.make_tensor(
name="mem_init",
data_type=oh.TensorProto.INT32,
dims=(MEMORY_SIZE,),
vals=[0]*MEMORY_SIZE,
))
builder.emit('Constant', inputs=[], outputs=[mem_end], value=oh.make_tensor(
name="mem_end_const",
data_type=oh.TensorProto.INT32,
dims=(1,),
vals=[MEMORY_SIZE],
))
builder.emit('Constant', inputs=[], outputs=[output], value=oh.make_tensor(
name="output_init",
data_type=oh.TensorProto.INT32,
dims=(0,),
vals=[],
))
builder.emit('Constant', inputs=[], outputs=[pointer], value=oh.make_tensor(
name="pointer_init",
data_type=oh.TensorProto.INT32,
dims=(1,),
vals=[1],
))
builder.emit('Constant', inputs=[], outputs=['zero_i32'], value=oh.make_tensor(
name="zero_i32_const",
data_type=oh.TensorProto.INT32,
dims=(1,),
vals=[0],
))
builder.emit('Constant', inputs=[], outputs=['one_i32'], value=oh.make_tensor(
name="one_i32_const",
data_type=oh.TensorProto.INT32,
dims=(1,),
vals=[1],
))
mem_name = mem
pointer_name = pointer
output_name = output
for inst in tree:
mem_name, pointer_name, output_name = compile_single_bf_inst(builder, inst, mem_name, pointer_name, output_name)
builder.emit('Identity', inputs=[output_name], outputs=[final_output])
graph = builder.pop_graph(
'prog',
inputs=[],
outputs=[oh.make_tensor_value_info(final_output, oh.TensorProto.INT32, [None])]
)
model = oh.make_model(graph=graph, opset_imports=[oh.make_opsetid("", 19)])
return model
def run(model):
import onnxruntime
session = onnxruntime.InferenceSession(model.SerializeToString())
result = session.run(['output'], {})
return bytes(result[0].tolist())
def main():
import sys
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--out', '-o', help='ONNX output path. If not given the textual form of the model is printed to stdout')
parser.add_argument('--run', action='store_true', help='Run the compiled model rather than writing it')
parser.add_argument('--bf', help='Brainfuck program')
parser.add_argument('--file', '-f', help='Brainfuck program file')
args = parser.parse_args()
if (not args.bf) == (not args.file):
print(parser.format_help(), file=sys.stderr)
print('Exactly one of --bf and --file must be given', file=sys.stderr)
sys.exit(1)
if args.bf:
bf = args.bf
elif args.file:
with open(args.file, 'rb') as fp:
bf = fp.read().decode('utf-8')
model = compile_bf_to_onnx(parse(bf))
onnx.checker.check_model(model)
if args.run:
result = run(model)
sys.stdout.buffer.write(result)
else:
if args.out:
with open(args.out, 'wb') as fp:
fp.write(model.SerializeToString())
else:
print(onnx.printer.to_text(model))
if __name__ == '__main__':
main()