-
Notifications
You must be signed in to change notification settings - Fork 0
/
pybfjit.py
251 lines (226 loc) · 7.91 KB
/
pybfjit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
#!/usr/bin/env python3
import sys
import ctypes
if sys.platform == 'win32':
LPVOID = ctypes.c_void_p
HANDLE = LPVOID
SIZE_T = ctypes.c_size_t
DWORD = ctypes.c_uint32
LPDWORD = ctypes.POINTER(DWORD)
PDWORD = LPDWORD
def error_if_zero(result, func, args):
if not result:
raise ctypes.WinError()
return result
PAGE_NOACCESS = 0x01
PAGE_READONLY = 0x02
PAGE_READWRITE = 0x04
PAGE_WRITECOPY = 0x08
PAGE_EXECUTE = 0x10
PAGE_EXECUTE_READ = 0x20
PAGE_EXECUTE_READWRITE = 0x40
PAGE_EXECUTE_WRITECOPY = 0x80
PAGE_GUARD = 0x100
PAGE_NOCACHE = 0x200
PAGE_WRITECOMBINE = 0x400
_VirtualProtect = ctypes.windll.kernel32.VirtualProtect
_VirtualProtect.argtypes = [LPVOID, SIZE_T, DWORD, PDWORD]
_VirtualProtect.restype = bool
_VirtualProtect.errcheck = error_if_zero
flOldProtect = DWORD(0)
def VirtualProtect(lpAddress, dwSize, flNewProtect):
_VirtualProtect(lpAddress, dwSize, flNewProtect, ctypes.byref(flOldProtect))
return flOldProtect.value
def make_memory_executable(buffer):
VirtualProtect(buffer, ctypes.sizeof(buffer), PAGE_EXECUTE_READWRITE)
else:
libc = ctypes.CDLL('libc.{}'.format('dylib' if sys.platform == 'darwin' else 'so.6'))
PAGESIZE = libc.getpagesize()
PROT_NONE = 0x0
PROT_READ = 0x1
PROT_WRITE = 0x2
PROT_EXEC = 0x4
mprotect = libc.mprotect
mprotect.restype = ctypes.c_int
mprotect.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int]
def make_memory_executable(buffer):
orig = ctypes.addressof(buffer)
addr = orig & (~(PAGESIZE-1)) # assumes the page size is a pow of 2
size = ((orig+len(buffer)+PAGESIZE-1) & (~(PAGESIZE-1))) - addr
ret = mprotect(addr, size, PROT_READ|PROT_WRITE|PROT_EXEC)
if ret == -1:
raise Exception('An error occured during mprotect')
def create_executable_buffer(init, size=None):
if isinstance(init, bytes) and size is None:
size = len(init)
buffer = ctypes.create_string_buffer(init, size)
make_memory_executable(buffer)
return buffer
def i8_to_bytes(n):
return bytes([n&0xFF])
def i16_to_bytes(n):
return bytes([n&0xFF, (n>>8)&0xFF])
def i32_to_bytes(n):
return bytes([n&0xFF, (n>>8)&0xFF, (n>>16)&0xFF, (n>>24)&0xFF])
def i64_to_bytes(n):
return bytes([
n &0xFF, (n>> 8)&0xFF, (n>>16)&0xFF, (n>>24)&0xFF,
(n>>32)&0xFF, (n>>40)&0xFF, (n>>48)&0xFF, (n>>56)&0xFF
])
class Assembler():
def __init__(self):
self.chunks = []
self.labels = []
self.label_refs = []
self.pos = 0
self.poses = []
def emit(self, raw):
assert isinstance(raw, bytes)
self.chunks.append(raw)
self.poses.append(self.pos)
self.pos += len(raw)
def emit_thunk(self, label, nbytes, thunk):
self.chunks.append(thunk)
self.poses.append(self.pos)
self.pos += nbytes
self.label_refs[label].append(len(self.chunks)-1)
if self.labels[label] is not None:
self.resolve(label)
def resolve(self, label):
pos = self.labels[label]
assert pos is not None
for r in self.label_refs[label]:
self.chunks[r] = self.chunks[r](self.poses[r], pos)
self.label_refs[label] = []
def label(self):
self.labels.append(None)
self.label_refs.append([])
return len(self.labels)-1
def put_label(self, label):
assert self.labels[label] is None
self.labels[label] = self.pos
self.resolve(label)
def assemble(self):
return b''.join(self.chunks)
def parse(string):
if len(string) == 0: return ([], '')
c, rest = string[0], string[1:]
if c == '[':
subexpr, rem = parse(rest)
assert rem[0] == ']'
expr, rem = parse(rem[1:])
return ([subexpr]+expr, rem)
elif c == ']':
return ([], string)
elif c in '+-><.,':
expr, rem = parse(rest)
return ([c]+expr, rem)
else:
return parse(rest)
def inc_while(pred, lis, pos):
while pos < len(lis) and pred(lis[pos]):
pos += 1
return pos
def optimize(ast):
pos = 0
while pos < len(ast):
cur = ast[pos]
if cur in ('+', '-'):
newpos = inc_while(lambda i: i in ('+', '-'), ast, pos)
seq = ast[pos:newpos]
delta = seq.count('+')-seq.count('-')
if 0 != delta:
yield ('+', delta)
elif cur in ('>', '<'):
newpos = inc_while(lambda i: i in ('>', '<'), ast, pos)
seq = ast[pos:newpos]
delta = seq.count('>')-seq.count('<')
if 0 != delta:
yield ('>', delta)
elif isinstance(cur, list):
yield list(optimize(cur))
newpos = pos+1
else:
yield cur
newpos = pos+1
pos = newpos
def invoke(asm, func):
addr = ctypes.addressof(func)
asm.emit(b'\x48\xb8') # mov rax, ...
asm.emit(i64_to_bytes(addr))
asm.emit(b'\xff\x10') # call qword [rax]
if sys.platform == 'win32':
libc = ctypes.cdll.msvcrt
def invoke_putchar(asm):
asm.emit(b'\x0f\xbe\x0b') # movsx ecx, byte [rbx]
# shadow space required in windows calling convention
asm.emit(b'\x48\x83\xec\x20') # sub rsp, 32
invoke(asm, libc.putchar)
asm.emit(b'\x48\x83\xc4\x20') # add rsp, 32
def invoke_getchar(asm):
asm.emit(b'\x0f\xbe\x0b') # movsx ecx, byte [rbx]
# shadow space required in windows calling convention
asm.emit(b'\x48\x83\xec\x20') # sub rsp, 32
invoke(asm, libc.getchar)
asm.emit(b'\x48\x83\xc4\x20') # add rsp, 32
asm.emit(b'\x88\x03') # mov byte [rbx], al
else:
def invoke_putchar(asm):
asm.emit(b'\x0f\xbe\x3b') # movsx edi, byte [rbx]
invoke(asm, libc.putchar)
def invoke_getchar(asm):
invoke(asm, libc.getchar)
asm.emit(b'\x88\x03') # mov byte [rbx], al
def compile_chunk(ast):
asm = Assembler()
for i in ast:
if isinstance(i, list):
begin = asm.label()
asm.put_label(begin)
asm.emit(b'\x80\x3b\x00') # cmp byte [rbx], 0
end = asm.label()
asm.emit_thunk(end, 6,
lambda cur, target: b'\x0f\x84'+i32_to_bytes(target-cur-6))
asm.emit(compile_chunk(i))
asm.emit_thunk(begin, 5,
lambda cur, target: b'\xe9'+i32_to_bytes(target-cur-5))
asm.put_label(end)
elif isinstance(i, tuple) and i[0] == '+':
asm.emit(b'\x80\x03') # add byte [rbx], ...
asm.emit(i8_to_bytes(i[1]))
elif isinstance(i, tuple) and i[0] == '>':
asm.emit(b'\x48\x81\xc3') # add rbx, ...
asm.emit(i32_to_bytes(i[1]))
elif i == '.':
invoke_putchar(asm)
elif i == ',':
invoke_getchar(asm)
else:
raise Exception('unhandled node: {}'.format(repr(i)))
return asm.assemble()
def compile_bf(ast, bufaddr):
asm = Assembler()
asm.emit(b'\x53') # push rbx
asm.emit(b'\x48\xbb') # mov rbx, ...
asm.emit(i64_to_bytes(bufaddr))
asm.emit(compile_chunk(ast))
asm.emit(b'\x5b') # pop rbx
asm.emit(b'\xc3') # ret
return asm.assemble()
def main(argv):
if len(argv) == 0:
bf = sys.stdin.buffer.read()
else:
bf = open(argv[0], 'rb').read()
bf = bf.decode('ascii')
ast, rem = parse(bf)
assert len(rem) == 0
mem = ctypes.create_string_buffer(30000)
code = compile_bf(optimize(ast), ctypes.addressof(mem))
with open('obj', 'wb') as fp:
fp.write(code)
buf = create_executable_buffer(code)
ctypes.CFUNCTYPE(None)(ctypes.addressof(buf))()
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))