Skip to content

Commit

Permalink
relocate data sections to end
Browse files Browse the repository at this point in the history
this is important because in EVM, data immediately before regular
(valid) code can mangle the valid code.
  • Loading branch information
charles-cooper committed Jul 13, 2023
1 parent 92ad539 commit 4be2f3d
Showing 1 changed file with 39 additions and 29 deletions.
68 changes: 39 additions & 29 deletions vyper/ir/compile_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,14 +513,10 @@ def _height_of(witharg):
o.extend(["_mem_deploy_start"]) # stack: len mem_ofst
o.extend(["RETURN"])

# add a symbol before the subcode so that _sym_runtime_begin will
# resolve during assembly
o.append(runtime_begin)

# since the asm data structures are very primitive, to make sure
# assembly_to_evm is able to calculate data offsets correctly,
# we pass the memsize via magic opcodes to the subcode
subcode = [_RuntimeHeader(memsize)] + subcode
subcode = [_RuntimeHeader(runtime_begin, memsize)] + subcode

# append the runtime code after the ctor code
# `append(...)` call here is intentional.
Expand Down Expand Up @@ -664,7 +660,7 @@ def _height_of(witharg):
)

elif code.value == "data":
data_node = [_DATA]
data_node = [_DataHeader("_sym_" + code.args[0].value)]

for c in code.args[1:]:
if isinstance(c.value, int):
Expand All @@ -678,10 +674,8 @@ def _height_of(witharg):
else:
raise ValueError(f"Invalid data: {type(c)} {c}")

o = ["_sym_" + code.args[0].value]
# intentionally return a sublist.
o.append(data_node)
return o
return [data_node]

# jump to a symbol, and push variable # of arguments onto stack
elif code.value == "goto":
Expand Down Expand Up @@ -796,7 +790,7 @@ def _prune_unreachable_code(assembly):
instr = assembly[i][-1]

if assembly[i] in _TERMINAL_OPS and not (
is_symbol(assembly[i + 1]) and is_symbol_map_indicator(assembly[i + 2])
is_symbol(assembly[i + 1]) or isinstance(assembly[i+1], list)
):
changed = True
del assembly[i + 1]
Expand Down Expand Up @@ -913,7 +907,7 @@ def _merge_iszero(assembly):
# this helper function tells us if we want to add the previous instruction
# to the symbol map.
def is_symbol_map_indicator(asm_node):
return asm_node == "JUMPDEST" or isinstance(asm_node, list)
return asm_node == "JUMPDEST"


def _prune_unused_jumpdests(assembly):
Expand All @@ -926,7 +920,7 @@ def _prune_unused_jumpdests(assembly):
if is_symbol(assembly[i]) and not is_symbol_map_indicator(assembly[i + 1]):
used_jumpdests.add(assembly[i])

if isinstance(assembly[i], list) and assembly[i][0] == _DATA:
if isinstance(assembly[i], list) and isinstance(assembly[i][0], _DataHeader):
# add symbols used in data sections as they are likely
# used for a jumptable.
for t in assembly[i]:
Expand Down Expand Up @@ -1009,7 +1003,7 @@ def adjust_pc_maps(pc_maps, ofst):

def _data_to_evm(assembly, symbol_map):
ret = bytearray()
assert assembly[0] == _DATA
assert isinstance(assembly[0], _DataHeader)
for item in assembly[1:]:
if is_symbol(item):
symbol = symbol_map[item].to_bytes(SYMBOL_SIZE, "big")
Expand All @@ -1020,38 +1014,41 @@ def _data_to_evm(assembly, symbol_map):
ret.extend(item)
else:
raise ValueError(f"invalid data {type(item)} {item}")

return ret


# predict what length of an assembly [data] node will be in bytecode
def _length_of_data(assembly):
ret = 0
for i in assembly:
if is_symbol(i):
assert isinstance(assembly[0], _DataHeader)
for item in assembly[1:]:
if is_symbol(item):
ret += SYMBOL_SIZE
elif isinstance(i, int):
assert 0 <= i < 256, f"invalid data byte {i}"
elif isinstance(item, int):
assert 0 <= item < 256, f"invalid data byte {i}"
ret += 1
elif isinstance(i, bytes):
ret += len(i)
elif isinstance(item, bytes):
ret += len(item)
else:
raise ValueError(f"invalid data {type(i)} {i}")
return ret
raise ValueError(f"invalid data {type(item)} {item}")

return ret


class _RuntimeHeader:
def __init__(self, ctor_mem_size):
def __init__(self, label, ctor_mem_size):
self.label = label
self.ctor_mem_size = ctor_mem_size

def __repr__(self):
return f"<RUNTIME mem @{self.ctor_mem_size}>"
return f"<RUNTIME {self.label} mem @{self.ctor_mem_size}>"

class _DataHeader:
def __init__(self, label):
self.label = label
def __repr__(self):
return "DATA"

_DATA = _DataHeader()
return f"DATA {self.label}"


def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False):
Expand Down Expand Up @@ -1110,6 +1107,17 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False):
if runtime_code_end is not None:
mem_ofst_size = calc_mem_ofst_size(runtime_code_end + max_mem_ofst)

# relocate all data segments to the end, otherwise data could be
# interpreted as PUSH instructions and mangle otherwies valid jumpdests
data_segments = []
non_data_segments = []
for t in assembly:
if isinstance(t, list) and isinstance(t[0], _DataHeader):
data_segments.append(t)
else:
non_data_segments.append(t)
assembly = non_data_segments + data_segments

# go through the code, resolving symbolic locations
# (i.e. JUMPDEST locations) to actual code locations
for i, item in enumerate(assembly):
Expand Down Expand Up @@ -1153,13 +1161,15 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False):
# [_OFST, _mem_foo, bar] -> PUSHN (foo+bar)
pc -= 1
elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader):
symbol_map[item[0].label] = pc
# add source map for all items in the runtime map
t = adjust_pc_maps(runtime_map, pc)
for key in line_number_map:
line_number_map[key].update(t[key])
pc += len(runtime_code)
elif isinstance(item, list) and item[0] == _DATA:
pc += _length_of_data(item[1:])
elif isinstance(item, list) and isinstance(item[0], _DataHeader):
symbol_map[item[0].label] = pc
pc += _length_of_data(item)
else:
pc += 1

Expand Down Expand Up @@ -1216,7 +1226,7 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False):
ret.append(SWAP_OFFSET + int(item[4:]))
elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader):
ret.extend(runtime_code)
elif isinstance(item, list) and item[0] == _DATA:
elif isinstance(item, list) and isinstance(item[0], _DataHeader):
ret.extend(_data_to_evm(item, symbol_map))
else: # pragma: no cover
# unreachable
Expand Down

0 comments on commit 4be2f3d

Please sign in to comment.