From 2d82a74937edeed5e9d4c0c8cecd78a0d70530fa Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 17 Jun 2024 04:10:01 -0700 Subject: [PATCH] feat[test]: add more coverage to `abi_decode` fuzzer tests (#4153) fuzz with `unwrap_tuple=False` add fuzzing for structs follow up to 69e5c0541a9b23 --- .../builtins/codegen/test_abi_decode_fuzz.py | 124 +++++++++++++++--- vyper/semantics/types/user.py | 11 +- 2 files changed, 115 insertions(+), 20 deletions(-) diff --git a/tests/functional/builtins/codegen/test_abi_decode_fuzz.py b/tests/functional/builtins/codegen/test_abi_decode_fuzz.py index d12b2cde7e..e215002446 100644 --- a/tests/functional/builtins/codegen/test_abi_decode_fuzz.py +++ b/tests/functional/builtins/codegen/test_abi_decode_fuzz.py @@ -18,12 +18,12 @@ IntegerT, SArrayT, StringT, + StructT, TupleT, VyperType, _get_primitive_types, _get_sequence_types, ) -from vyper.semantics.types.shortcuts import UINT256_T from .abi_decode import DecodeError, spec_decode @@ -39,7 +39,7 @@ continue type_ctors.append(t) -complex_static_ctors = [SArrayT, TupleT] +complex_static_ctors = [SArrayT, TupleT, StructT] complex_dynamic_ctors = [DArrayT] leaf_ctors = [t for t in type_ctors if t not in _get_sequence_types().values()] static_leaf_ctors = [t for t in leaf_ctors if t._is_prim_word] @@ -50,10 +50,12 @@ @st.composite # max type nesting -def vyper_type(draw, nesting=3, skip=None): +def vyper_type(draw, nesting=3, skip=None, source_fragments=None): assert nesting >= 0 skip = skip or [] + if source_fragments is None: + source_fragments = [] st_leaves = st.one_of(st.sampled_from(dynamic_leaf_ctors), st.sampled_from(static_leaf_ctors)) st_complex = st.one_of( @@ -71,39 +73,52 @@ def vyper_type(draw, nesting=3, skip=None): # note: maybe st.deferred is good here, we could define it with # mutual recursion def _go(skip=skip): - return draw(vyper_type(nesting=nesting - 1, skip=skip)) + _, typ = draw(vyper_type(nesting=nesting - 1, skip=skip, source_fragments=source_fragments)) + return typ + + def finalize(typ): + return source_fragments, typ if t in (BytesT, StringT): # arbitrary max_value bound = draw(st.integers(min_value=1, max_value=1024)) - return t(bound) + return finalize(t(bound)) if t == SArrayT: subtype = _go(skip=[TupleT, BytesT, StringT]) bound = draw(st.integers(min_value=1, max_value=6)) - return t(subtype, bound) + return finalize(t(subtype, bound)) if t == DArrayT: subtype = _go(skip=[TupleT]) bound = draw(st.integers(min_value=1, max_value=16)) - return t(subtype, bound) + return finalize(t(subtype, bound)) if t == TupleT: # zero-length tuples are not allowed in vyper n = draw(st.integers(min_value=1, max_value=6)) subtypes = [_go() for _ in range(n)] - return TupleT(subtypes) + return finalize(TupleT(subtypes)) + + if t == StructT: + n = draw(st.integers(min_value=1, max_value=6)) + subtypes = {f"x{i}": _go() for i in range(n)} + _id = len(source_fragments) # poor man's unique id + name = f"MyStruct{_id}" + typ = StructT(name, subtypes) + source_fragments.append(typ.def_source_str()) + return finalize(StructT(name, subtypes)) if t in (BoolT, AddressT): - return t() + return finalize(t()) if t == IntegerT: signed = draw(st.booleans()) bits = 8 * draw(st.integers(min_value=1, max_value=32)) - return t(signed, bits) + return finalize(t(signed, bits)) if t == BytesM_T: m = draw(st.integers(min_value=1, max_value=32)) - return t(m) + return finalize(t(m)) raise RuntimeError("unreachable") @@ -116,6 +131,9 @@ def _go(t): if isinstance(typ, TupleT): return tuple(_go(item_t) for item_t in typ.member_types) + if isinstance(typ, StructT): + return tuple(_go(item_t) for item_t in typ.tuple_members()) + if isinstance(typ, SArrayT): return [_go(typ.value_type) for _ in range(typ.length)] @@ -294,6 +312,13 @@ def _finalize(): # little trick to save re-typing the arguments num_dynamic_types = sum(s.num_dynamic_types for s in substats) return _finalize() + if isinstance(typ, StructT): + substats = [_type_stats(t) for t in typ.tuple_members()] + nesting = 1 + max(s.nesting for s in substats) + breadth = max(len(typ.member_types), *[s.breadth for s in substats]) + num_dynamic_types = sum(s.num_dynamic_types for s in substats) + return _finalize() + if isinstance(typ, DArrayT): substat = _type_stats(typ.value_type) nesting = 1 + substat.nesting @@ -332,8 +357,8 @@ def payload_copier(get_contract_from_ir): @pytest.mark.parametrize("_n", list(range(PARALLELISM))) @hp.given(typ=vyper_type()) @hp.settings(max_examples=100, **_settings) -@hp.example(typ=DArrayT(DArrayT(UINT256_T, 2), 2)) -def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier): +def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier, env): + source_fragments, typ = typ # import time # t0 = time.time() # print("ENTER", typ) @@ -350,12 +375,13 @@ def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier): # by bytes length check at function entry type_bound = wrapped_type.abi_type.size_bound() buffer_bound = type_bound + MAX_MUTATIONS - type_str = repr(typ) # annotation in vyper code - # TODO: intrinsic decode from staticcall/extcall - # TODO: _abi_decode from other sources (staticcall/extcall?) - # TODO: dirty the buffer - # TODO: check unwrap_tuple=False + + preamble = "\n\n".join(source_fragments) + type_str = str(typ) # annotation in vyper code + code = f""" +{preamble} + @external def run(xs: Bytes[{buffer_bound}]) -> {type_str}: ret: {type_str} = abi_decode(xs, {type_str}) @@ -375,6 +401,13 @@ def run3(xs: Bytes[{buffer_bound}], copier: Foo) -> {type_str}: assert len(xs) <= {type_bound} return (extcall copier.bar(xs)) """ + try: + c = get_contract(code) + except EvmError as e: + if env.contract_size_limit_error in str(e): + hp.assume(False) + # print(code) + hp.note(code) c = get_contract(code) @hp.given(data=payload_from(wrapped_type)) @@ -382,7 +415,6 @@ def run3(xs: Bytes[{buffer_bound}], copier: Foo) -> {type_str}: def _fuzz(data): hp.note(f"type: {typ}") hp.note(f"abi_t: {wrapped_type.abi_type.selector_name()}") - hp.note(code) hp.note(data.hex()) try: @@ -414,3 +446,57 @@ def _fuzz(data): # t1 = time.time() # print(f"elapsed {t1 - t0}s") + + +@pytest.mark.parametrize("_n", list(range(PARALLELISM))) +@hp.given(typ=vyper_type()) +@hp.settings(max_examples=100, **_settings) +def test_abi_decode_no_wrap_fuzz(_n, typ, get_contract, tx_failed, env): + source_fragments, typ = typ + # import time + # t0 = time.time() + # print("ENTER", typ) + + stats = _type_stats(typ) + hp.target(stats.num_dynamic_types) + + # add max_mutations bytes worth of padding so we don't just get caught + # by bytes length check at function entry + type_bound = typ.abi_type.size_bound() + buffer_bound = type_bound + MAX_MUTATIONS + + type_str = str(typ) # annotation in vyper code + preamble = "\n\n".join(source_fragments) + + code = f""" +{preamble} + +@external +def run(xs: Bytes[{buffer_bound}]) -> {type_str}: + ret: {type_str} = abi_decode(xs, {type_str}, unwrap_tuple=False) + return ret + """ + try: + c = get_contract(code) + except EvmError as e: + if env.contract_size_limit_error in str(e): + hp.assume(False) + + @hp.given(data=payload_from(typ)) + @hp.settings(max_examples=100, **_settings) + def _fuzz(data): + hp.note(code) + hp.note(data.hex()) + try: + expected = spec_decode(typ, data) + hp.note(f"expected {expected}") + assert expected == c.run(data) + except DecodeError: + hp.note("expect failure") + with tx_failed(EvmError): + c.run(data) + + _fuzz() + + # t1 = time.time() + # print(f"elapsed {t1 - t0}s") diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index a6ee646e62..ca8e99bc92 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -371,8 +371,11 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": return cls(struct_name, members, ast_def=base_node) + def __str__(self): + return f"{self._id}" + def __repr__(self): - return f"{self._id} declaration object" + return f"{self._id} {self.members}" def _try_fold(self, node): if len(node.args) != 1: @@ -384,6 +387,12 @@ def _try_fold(self, node): # it can't be reduced, but this lets upstream code know it's constant return node + def def_source_str(self): + ret = f"struct {self._id}:\n" + for k, v in self.member_types.items(): + ret += f" {k}: {v}\n" + return ret + @property def size_in_bytes(self): return sum(i.size_in_bytes for i in self.member_types.values())