Skip to content

Commit

Permalink
get dense jumptable working (at least compiles)
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Jul 13, 2023
1 parent f99cd2d commit 134107c
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 31 deletions.
6 changes: 3 additions & 3 deletions vyper/codegen/function_definitions/external_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def handler_for(calldata_kwargs, default_kwargs):

ret.append(["goto", func_t._ir_info.external_function_base_entry_label])

method_id = util.method_id_int(abi_sig)
label = f"{func_t._ir_info.ir_identifier}{method_id}"
ret = ["label", label, ["var_list"], ret]
#method_id = util.method_id_int(abi_sig)
#label = f"{func_t._ir_info.ir_identifier}{method_id}"
#ret = ["label", label, ["var_list"], ret]

# return something we can turn into ExternalFuncIR
return abi_sig, calldata_min_size, ret
Expand Down
7 changes: 7 additions & 0 deletions vyper/codegen/ir_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ def _check(condition, err):

self.valency = 1
self._gas = 5
elif isinstance(self.value, bytes):
# a literal bytes value, probably inside a "data" node.
_check(len(self.args) == 0, "bytes can't have arguments")

self.valency = 0
self._gas = 0

elif isinstance(self.value, str):
# Opcodes and pseudo-opcodes (e.g. clamp)
if self.value.upper() in get_ir_opcodes():
Expand Down
9 changes: 6 additions & 3 deletions vyper/codegen/jumptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class Signature:
class Bucket:
bucket_id: int
magic: int
signatures: list[int]
method_ids: list[int]

@property
def image(self):
return _image_of([s for s in self.signatures], self.magic)
return _image_of([s for s in self.method_ids], self.magic)

@property
def bucket_size(self):
return len(self.signatures)
return len(self.method_ids)


_PRIMES = []
Expand Down Expand Up @@ -114,6 +114,9 @@ def _dense_jumptable_info(method_ids, n_buckets):
START_BUCKET_SIZE = 5


# this is expensive! for 80 methods, costs about 350ms and probably
# linear in # of methods.
# see _bench_perfect()
def generate_dense_jumptable_info(signatures):
method_ids = [method_id_int(sig) for sig in signatures]
n = len(signatures)
Expand Down
64 changes: 45 additions & 19 deletions vyper/codegen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def _annotated_method_id(abi_sig):
return IRnode(method_id, annotation=annotation)


#def label_for_entry_point(abi_sig, entry_point):
# method_id = method_id_int(abi_sig)
# return f"{entry_point.func_t._ir_info.ir_identifier}{method_id}"
def label_for_entry_point(abi_sig, entry_point):
method_id = method_id_int(abi_sig)
return f"{entry_point.func_t._ir_info.ir_identifier}{method_id}"


# TODO: probably dead code
Expand Down Expand Up @@ -101,22 +101,28 @@ def _ir_for_internal_function(func_ast, *args, **kwargs):
def _selector_section_dense(external_functions, global_ctx):
function_irs = []
entry_points = {} # map from ABI sigs to ir code
sig_of = {} # reverse map from method ids to abi sig

for code in external_functions:
func_ir = generate_ir_for_function(code, global_ctx, skip_nonpayable_check=True)
for abi_sig, entry_point in func_ir.entry_points.items():
assert abi_sig not in entry_points
entry_points[abi_sig] = entry_point
sig_of[method_id_int(abi_sig)] = abi_sig
# stick function common body into final entry point to save a jump
entry_point.ir_node.append(func_ir.common_ir)
ir_node = IRnode.from_list(["seq", entry_point.ir_node, func_ir.common_ir])
entry_point.ir_node = ir_node

for entry_point in entry_points.values():
function_irs.append(IRnode.from_list(entry_point.ir_node))
for abi_sig, entry_point in entry_points.items():
label = label_for_entry_point(abi_sig, entry_point)
ir_node = ["label", label, ["var_list"], entry_point.ir_node]
function_irs.append(IRnode.from_list(ir_node))

jumptable_info = jumptable.generate_dense_jumptable_info(entry_points.keys())
n_buckets = len(jumptable_info)

# 2 bytes for bucket magic, 2 bytes for bucket location
# TODO: can make it smaller if the largest bucket magic <= 255
SZ_BUCKET_HEADER = 4

selector_section = ["seq"]
Expand All @@ -132,15 +138,15 @@ def _selector_section_dense(external_functions, global_ctx):
assert dst >= 0

# memory is PROBABLY 0, but just be paranoid.
selector_section.append(["mstore", 0, 0])
selector_section.append(["assert", ["eq", "msize", 0]])
selector_section.append(["codecopy", dst, bucket_hdr_location, SZ_BUCKET_HEADER])

# figure out the minimum number of bytes we can use to encode
# min_calldatasize in function info
largest_mincalldatasize = max(f.min_calldatasize for f in entry_points.values())
variable_bytes_needed = (largest_mincalldatasize.bit_length() + 7) // 8
FN_METADATA_BYTES = (largest_mincalldatasize.bit_length() + 7) // 8

func_info_size = 4 + 2 + variable_bytes_needed
func_info_size = 4 + 2 + FN_METADATA_BYTES
# grab function info. 4 bytes for method id, 2 bytes for label,
# 1-3 bytes (packed) for: expected calldatasize, is payable bit
# NOTE: might be able to improve codesize if we use variable # of bytes
Expand All @@ -163,8 +169,8 @@ def _selector_section_dense(external_functions, global_ctx):
selector_section.append(b1.resolve(["codecopy", dst, func_info_location, func_info_size]))

func_info = IRnode.from_list(["mload", 0])
variable_bytes_mask = 2 ** (variable_bytes_needed * 8) - 1
calldatasize_mask = variable_bytes_mask - 1 # ex. 0xFFFE
fn_metadata_mask = 2 ** (FN_METADATA_BYTES * 8) - 1
calldatasize_mask = fn_metadata_mask - 1 # ex. 0xFFFE
with func_info.cache_when_complex("func_info") as (b1, func_info):
x = ["seq"]

Expand All @@ -176,9 +182,9 @@ def _selector_section_dense(external_functions, global_ctx):

# method id <4 bytes> | label <2 bytes> | func info <1-3 bytes>

label_bits_ofst = variable_bytes_needed * 8
label_bits_ofst = FN_METADATA_BYTES * 8
function_label = ["and", 0xFFFF, shr(label_bits_ofst, func_info)]
method_id_bits_ofst = (variable_bytes_needed + 3) * 8
method_id_bits_ofst = (FN_METADATA_BYTES + 3) * 8
function_method_id = shr(method_id_bits_ofst, func_info)

# check method id is right, if not then fallback.
Expand All @@ -193,9 +199,32 @@ def _selector_section_dense(external_functions, global_ctx):
bad_calldatasize = ["lt", "calldatasize", expected_calldatasize]
failed_entry_conditions = ["or", bad_callvalue, bad_calldatasize]
x.append(["assert", ["iszero", failed_entry_conditions]])
x.append(["goto", function_label])
x.append(["jump", function_label])
selector_section.append(b1.resolve(x))

bucket_headers = ["data", "BUCKET_HEADERS"]

for bucket_id, bucket in jumptable_info.items():
bucket_headers.append(["symbol", f"bucket_{bucket_id}"])
bucket_headers.append(bucket.magic.to_bytes(2, "big"))

selector_section.append(bucket_headers)

for bucket_id, bucket in jumptable_info.items():
function_infos = ["data", f"bucket_{bucket_id}"]
for method_id in bucket.method_ids:
abi_sig = sig_of[method_id]
entry_point = entry_points[abi_sig]

method_id_bytes = method_id.to_bytes(4, "big")
symbol = ["symbol", label_for_entry_point(abi_sig, entry_point)]
func_metadata_int = entry_point.min_calldatasize | int(not entry_point.func_t.is_payable)
func_metadata = func_metadata_int.to_bytes(FN_METADATA_BYTES, "big")

function_infos.extend([method_id_bytes, symbol, func_metadata])

selector_section.append(function_infos)

runtime = [
"seq",
["with", "_calldata_method_id", shr(224, ["calldataload", 0]), selector_section],
Expand All @@ -218,7 +247,6 @@ def _selector_section_sparse(external_functions, global_ctx):
# payable functions, nonpayable functions, fallback function, internal_functions
default_function = next((f for f in external_functions if _is_fallback(f)), None)

function_irs = []
entry_points = {} # map from ABI sigs to ir code
sig_of = {} # map from method ids back to signatures

Expand Down Expand Up @@ -254,7 +282,7 @@ def _selector_section_sparse(external_functions, global_ctx):
assert dst >= 0

# memory is PROBABLY 0, but just be paranoid.
selector_section.append(["mstore", 0, 0])
selector_section.append(["assert", ["eq", "msize", 0]])
selector_section.append(["codecopy", dst, bucket_hdr_location, SZ_BUCKET_HEADER])

jumpdest = IRnode.from_list(["mload", 0])
Expand Down Expand Up @@ -322,8 +350,6 @@ def _selector_section_sparse(external_functions, global_ctx):
["with", "_calldata_method_id", shr(224, ["calldataload", 0]), selector_section],
]

ret.extend(function_irs)

return ret


Expand Down Expand Up @@ -424,7 +450,7 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]:

# XXX: AWAITING MCOPY PR
# dense vs sparse global overhead is amortized after about 4 methods
dense = False # if core._opt_codesize() and len(external_functions) > 4:
dense = False # if core._opt_codesize() and len(external_functions) > 4:
if dense:
selector_section = _selector_section_dense(external_functions, global_ctx)
else:
Expand Down
16 changes: 10 additions & 6 deletions vyper/ir/compile_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,12 +667,11 @@ def _height_of(witharg):
data_node = [_DATA]

for c in code.args[1:]:
if isinstance(c, int):
if isinstance(c.value, int):
assert 0 <= c < 256, f"invalid data byte {c}"
data_node.append(c)
elif isinstance(c, bytes):
as_ints = list(c) # list(b"1234") -> [49, 50, 51, 52]
data_node.extend(as_ints)
data_node.append(c.value)
elif isinstance(c.value, bytes):
data_node.append(c.value)
elif isinstance(c, IRnode):
assert c.value == "symbol"
data_node.extend(_compile_to_assembly(c, withargs, existing_labels, break_dest, height))
Expand Down Expand Up @@ -1017,6 +1016,9 @@ def _data_to_evm(assembly, symbol_map):
ret.extend(symbol)
elif isinstance(item, int):
ret.append(item)
elif isinstance(item, bytes):
as_ints = list(item)
ret.extend(as_ints)
else:
raise ValueError(f"invalid data {type(item)} {item}")
return ret
Expand All @@ -1031,6 +1033,8 @@ def _length_of_data(assembly):
elif isinstance(i, int):
assert 0 <= i < 256, f"invalid data byte {i}"
ret += 1
elif isinstance(i, bytes):
ret += len(i)
else:
raise ValueError(f"invalid data {type(i)} {i}")
return ret
Expand Down Expand Up @@ -1150,7 +1154,7 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False):
# [_OFST, _sym_foo, bar] -> PUSH2 (foo+bar)
# [_OFST, _mem_foo, bar] -> PUSHN (foo+bar)
pc -= 1
elif isinstance(item, list) and isinstance(item[0], RuntimeHeader):
elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader):
# add source map for all items in the runtime map
t = adjust_pc_maps(runtime_map, pc)
for key in line_number_map:
Expand Down

0 comments on commit 134107c

Please sign in to comment.