diff --git a/tests/conftest.py b/tests/conftest.py index 4b3d90f65a..31c72246bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -196,7 +196,7 @@ def env(gas_limit, evm_version, evm_backend, tracing, account_keys) -> BaseEnv: ) -@pytest.fixture +@pytest.fixture(scope="module") def get_contract_from_ir(env, optimize): def ir_compiler(ir, *args, **kwargs): ir = IRnode.from_list(ir) diff --git a/tests/evm_backends/base_env.py b/tests/evm_backends/base_env.py index a8ab4d2367..1ea3dba328 100644 --- a/tests/evm_backends/base_env.py +++ b/tests/evm_backends/base_env.py @@ -30,7 +30,7 @@ class ExecutionResult: gas_used: int -class EvmError(RuntimeError): +class EvmError(Exception): """Exception raised when a call fails.""" @@ -205,6 +205,16 @@ def out_of_gas_error(self) -> str: """Expected error message when user runs out of gas""" raise NotImplementedError # must be implemented by subclasses + @property + def contract_size_limit_error(self) -> str: + """Expected error message when contract is over codesize limit""" + raise NotImplementedError # must be implemented by subclasses + + @property + def initcode_size_limit_error(self) -> str: + """Expected error message when contract is over codesize limit""" + raise NotImplementedError # must be implemented by subclasses + def _compile( source_code: str, diff --git a/tests/evm_backends/revm_env.py b/tests/evm_backends/revm_env.py index 5c8b8aba08..d5a7570f96 100644 --- a/tests/evm_backends/revm_env.py +++ b/tests/evm_backends/revm_env.py @@ -11,6 +11,8 @@ class RevmEnv(BaseEnv): invalid_opcode_error = "InvalidFEOpcode" out_of_gas_error = "OutOfGas" + contract_size_limit_error = "CreateContractSizeLimit" + initcode_size_limit_error = "CreateInitCodeSizeLimit" def __init__( self, diff --git a/tests/functional/builtins/codegen/abi_decode.py b/tests/functional/builtins/codegen/abi_decode.py new file mode 100644 index 0000000000..9e10b862d5 --- /dev/null +++ b/tests/functional/builtins/codegen/abi_decode.py @@ -0,0 +1,148 @@ +from typing import TYPE_CHECKING, Iterable + +from eth_utils import to_checksum_address + +from vyper.abi_types import ( + ABI_Address, + ABI_Bool, + ABI_Bytes, + ABI_BytesM, + ABI_DynamicArray, + ABI_GIntM, + ABI_StaticArray, + ABI_String, + ABI_Tuple, + ABIType, +) +from vyper.utils import int_bounds, unsigned_to_signed + +if TYPE_CHECKING: + from vyper.semantics.types import VyperType + + +class DecodeError(Exception): + pass + + +def _strict_slice(payload, start, length): + if start < 0: + raise DecodeError(f"OOB {start}") + + end = start + length + if end > len(payload): + raise DecodeError(f"OOB {start} + {length} (=={end}) > {len(payload)}") + return payload[start:end] + + +def _read_int(payload, ofst): + return int.from_bytes(_strict_slice(payload, ofst, 32)) + + +# vyper abi_decode spec implementation +def spec_decode(typ: "VyperType", payload: bytes): + abi_t = typ.abi_type + + lo, hi = abi_t.static_size(), abi_t.size_bound() + if not (lo <= len(payload) <= hi): + raise DecodeError(f"bad payload size {lo}, {len(payload)}, {hi}") + + return _decode_r(abi_t, 0, payload) + + +def _decode_r(abi_t: ABIType, current_offset: int, payload: bytes): + if isinstance(abi_t, ABI_Tuple): + return tuple(_decode_multi_r(abi_t.subtyps, current_offset, payload)) + + if isinstance(abi_t, ABI_StaticArray): + n = abi_t.m_elems + subtypes = [abi_t.subtyp] * n + return _decode_multi_r(subtypes, current_offset, payload) + + if isinstance(abi_t, ABI_DynamicArray): + bound = abi_t.elems_bound + + n = _read_int(payload, current_offset) + if n > bound: + raise DecodeError("Dynarray too large") + + # offsets in dynarray start from after the length word + current_offset += 32 + subtypes = [abi_t.subtyp] * n + return _decode_multi_r(subtypes, current_offset, payload) + + # sanity check + assert not abi_t.is_complex_type() + + if isinstance(abi_t, ABI_Bytes): + bound = abi_t.bytes_bound + length = _read_int(payload, current_offset) + if length > bound: + raise DecodeError("bytes too large") + + current_offset += 32 # size of length word + ret = _strict_slice(payload, current_offset, length) + + # abi string doesn't actually define string decoder, so we + # just bytecast the output + if isinstance(abi_t, ABI_String): + # match eth-stdlib, since that's what we check against + ret = ret.decode(errors="surrogateescape") + + return ret + + # sanity check + assert not abi_t.is_dynamic() + + if isinstance(abi_t, ABI_GIntM): + ret = _read_int(payload, current_offset) + + # handle signedness + if abi_t.signed: + ret = unsigned_to_signed(ret, 256, strict=True) + + # bounds check + lo, hi = int_bounds(signed=abi_t.signed, bits=abi_t.m_bits) + if not (lo <= ret <= hi): + u = "" if abi_t.signed else "u" + raise DecodeError(f"invalid {u}int{abi_t.m_bits}") + + if isinstance(abi_t, ABI_Address): + return to_checksum_address(ret.to_bytes(20, "big")) + + if isinstance(abi_t, ABI_Bool): + if ret not in (0, 1): + raise DecodeError("invalid bool") + return ret + + return ret + + if isinstance(abi_t, ABI_BytesM): + ret = _strict_slice(payload, current_offset, 32) + m = abi_t.m_bytes + assert 1 <= m <= 32 # internal sanity check + # BytesM is right-padded with zeroes + if ret[m:] != b"\x00" * (32 - m): + raise DecodeError(f"invalid bytes{m}") + return ret[:m] + + raise RuntimeError("unreachable") + + +def _decode_multi_r(types: Iterable[ABIType], outer_offset: int, payload: bytes) -> list: + ret = [] + static_ofst = outer_offset + + for sub_t in types: + if sub_t.is_dynamic(): + # "head" terminology from abi spec + head = _read_int(payload, static_ofst) + ofst = outer_offset + head + else: + ofst = static_ofst + + item = _decode_r(sub_t, ofst, payload) + + ret.append(item) + static_ofst += sub_t.embedded_static_size() + + return ret diff --git a/tests/functional/builtins/codegen/test_abi_decode_fuzz.py b/tests/functional/builtins/codegen/test_abi_decode_fuzz.py new file mode 100644 index 0000000000..d12b2cde7e --- /dev/null +++ b/tests/functional/builtins/codegen/test_abi_decode_fuzz.py @@ -0,0 +1,416 @@ +from dataclasses import dataclass + +import hypothesis as hp +import hypothesis.strategies as st +import pytest +from eth.codecs import abi + +from tests.evm_backends.base_env import EvmError +from vyper.codegen.core import calculate_type_for_external_return, needs_external_call_wrap +from vyper.semantics.types import ( + AddressT, + BoolT, + BytesM_T, + BytesT, + DArrayT, + DecimalT, + HashMapT, + IntegerT, + SArrayT, + StringT, + TupleT, + VyperType, + _get_primitive_types, + _get_sequence_types, +) +from vyper.semantics.types.shortcuts import UINT256_T + +from .abi_decode import DecodeError, spec_decode + +pytestmark = pytest.mark.fuzzing + +type_ctors = [] +for t in _get_primitive_types().values(): + if t == HashMapT or t == DecimalT(): + continue + if isinstance(t, VyperType): + t = t.__class__ + if t in type_ctors: + continue + type_ctors.append(t) + +complex_static_ctors = [SArrayT, TupleT] +complex_dynamic_ctors = [DArrayT] +leaf_ctors = [t for t in type_ctors if t not in _get_sequence_types().values()] +static_leaf_ctors = [t for t in leaf_ctors if t._is_prim_word] +dynamic_leaf_ctors = [BytesT, StringT] + +MAX_MUTATIONS = 33 + + +@st.composite +# max type nesting +def vyper_type(draw, nesting=3, skip=None): + assert nesting >= 0 + + skip = skip or [] + + st_leaves = st.one_of(st.sampled_from(dynamic_leaf_ctors), st.sampled_from(static_leaf_ctors)) + st_complex = st.one_of( + st.sampled_from(complex_dynamic_ctors), st.sampled_from(complex_static_ctors) + ) + + if nesting == 0: + st_type = st_leaves + else: + st_type = st.one_of(st_complex, st_leaves) + + # filter here is a bit of a kludge, would be better to improve sampling + t = draw(st_type.filter(lambda t: t not in skip)) + + # note: maybe st.deferred is good here, we could define it with + # mutual recursion + def _go(skip=skip): + return draw(vyper_type(nesting=nesting - 1, skip=skip)) + + if t in (BytesT, StringT): + # arbitrary max_value + bound = draw(st.integers(min_value=1, max_value=1024)) + return t(bound) + + if t == SArrayT: + subtype = _go(skip=[TupleT, BytesT, StringT]) + bound = draw(st.integers(min_value=1, max_value=6)) + return t(subtype, bound) + if t == DArrayT: + subtype = _go(skip=[TupleT]) + bound = draw(st.integers(min_value=1, max_value=16)) + return t(subtype, bound) + + if t == TupleT: + # zero-length tuples are not allowed in vyper + n = draw(st.integers(min_value=1, max_value=6)) + subtypes = [_go() for _ in range(n)] + return TupleT(subtypes) + + if t in (BoolT, AddressT): + return t() + + if t == IntegerT: + signed = draw(st.booleans()) + bits = 8 * draw(st.integers(min_value=1, max_value=32)) + return t(signed, bits) + + if t == BytesM_T: + m = draw(st.integers(min_value=1, max_value=32)) + return t(m) + + raise RuntimeError("unreachable") + + +@st.composite +def data_for_type(draw, typ): + def _go(t): + return draw(data_for_type(t)) + + if isinstance(typ, TupleT): + return tuple(_go(item_t) for item_t in typ.member_types) + + if isinstance(typ, SArrayT): + return [_go(typ.value_type) for _ in range(typ.length)] + + if isinstance(typ, DArrayT): + n = draw(st.integers(min_value=0, max_value=typ.length)) + return [_go(typ.value_type) for _ in range(n)] + + if isinstance(typ, StringT): + # technically the ABI spec doesn't say string has to be valid utf-8, + # but eth-stdlib won't encode invalid utf-8 + return draw(st.text(max_size=typ.length)) + + if isinstance(typ, BytesT): + return draw(st.binary(max_size=typ.length)) + + if isinstance(typ, IntegerT): + lo, hi = typ.ast_bounds + return draw(st.integers(min_value=lo, max_value=hi)) + + if isinstance(typ, BytesM_T): + return draw(st.binary(min_size=typ.length, max_size=typ.length)) + + if isinstance(typ, BoolT): + return draw(st.booleans()) + + if isinstance(typ, AddressT): + ret = draw(st.binary(min_size=20, max_size=20)) + return "0x" + ret.hex() + + raise RuntimeError("unreachable") + + +def _sort2(x, y): + if x > y: + return y, x + return x, y + + +@st.composite +def _mutate(draw, payload, max_mutations=MAX_MUTATIONS): + # do point+bulk mutations, + # add/edit/delete/splice/flip up to max_mutations. + if len(payload) == 0: + return + + ret = bytearray(payload) + + # for add/edit, the new byte is any character, but we bias it towards + # bytes already in the payload. + st_any_byte = st.integers(min_value=0, max_value=255) + payload_nonzeroes = list(x for x in payload if x != 0) + if len(payload_nonzeroes) > 0: + st_existing_byte = st.sampled_from(payload) + st_byte = st.one_of(st_existing_byte, st_any_byte) + else: + st_byte = st_any_byte + + # add, edit, delete, word, splice, flip + possible_actions = "adwww" + actions = draw(st.lists(st.sampled_from(possible_actions), max_size=MAX_MUTATIONS)) + + for action in actions: + if len(ret) == 0: + # bail out. could we maybe be smarter, like only add here? + break + + # for the mutation position, we can use any index in the payload, + # but we bias it towards indices of nonzero bytes. + st_any_ix = st.integers(min_value=0, max_value=len(ret) - 1) + nonzero_indexes = [i for i, s in enumerate(ret) if s != 0] + if len(nonzero_indexes) > 0: + st_nonzero_ix = st.sampled_from(nonzero_indexes) + st_ix = st.one_of(st_any_ix, st_nonzero_ix) + else: + st_ix = st_any_ix + + ix = draw(st_ix) + + if action == "a": + ret.insert(ix, draw(st_byte)) + elif action == "e": + ret[ix] = draw(st_byte) + elif action == "d": + ret.pop(ix) + elif action == "w": + # splice word + st_uint256 = st.integers(min_value=0, max_value=2**256 - 1) + + # valid pointers, but maybe *just* out of bounds + st_poison = st.integers(min_value=-2 * len(ret), max_value=2 * len(ret)).map( + lambda x: x % (2**256) + ) + word = draw(st.one_of(st_poison, st_uint256)) + ret[ix - 31 : ix + 1] = word.to_bytes(32) + elif action == "s": + ix2 = draw(st_ix) + ix, ix2 = _sort2(ix, ix2) + ix2 += 1 + # max splice is 64 bytes, due to MAX_BUFFER_SIZE limitation in st.binary + ix2 = ix + (ix2 % 64) + length = ix2 - ix + substr = draw(st.binary(min_size=length, max_size=length)) + ret[ix:ix2] = substr + elif action == "f": + ix2 = draw(st_ix) + ix, ix2 = _sort2(ix, ix2) + ix2 += 1 + for i in range(ix, ix2): + # flip the bits in the byte + ret[i] = 255 ^ ret[i] + else: + raise RuntimeError("unreachable") + + return bytes(ret) + + +@st.composite +def payload_from(draw, typ): + data = draw(data_for_type(typ)) + schema = typ.abi_type.selector_name() + payload = abi.encode(schema, data) + + return draw(_mutate(payload)) + + +_settings = dict( + report_multiple_bugs=False, + # verbosity=hp.Verbosity.verbose, + suppress_health_check=( + hp.HealthCheck.data_too_large, + hp.HealthCheck.too_slow, + hp.HealthCheck.large_base_example, + ), + phases=( + hp.Phase.explicit, + hp.Phase.reuse, + hp.Phase.generate, + hp.Phase.target, + # Phase.shrink, # can force long waiting for examples + # Phase.explain, # not helpful here + ), +) + + +@dataclass(frozen=True) +class _TypeStats: + nesting: int = 0 + num_dynamic_types: int = 0 # number of dynamic types in the type + breadth: int = 0 # e.g. int16[50] has higher breadth than int16[1] + width: int = 0 # size of type + + +def _type_stats(typ: VyperType) -> _TypeStats: + def _finalize(): # little trick to save re-typing the arguments + width = typ.memory_bytes_required + return _TypeStats( + nesting=nesting, num_dynamic_types=num_dynamic_types, breadth=breadth, width=width + ) + + if typ._is_prim_word: + nesting = 0 + breadth = 1 + num_dynamic_types = 0 + return _finalize() + + if isinstance(typ, (BytesT, StringT)): + nesting = 0 + breadth = 1 # idk + num_dynamic_types = 1 + return _finalize() + + if isinstance(typ, TupleT): + substats = [_type_stats(t) for t in typ.member_types] + nesting = 1 + max(s.nesting for s in substats) + breadth = max(typ.length, *[s.breadth for s in substats]) + num_dynamic_types = sum(s.num_dynamic_types for s in substats) + return _finalize() + + if isinstance(typ, DArrayT): + substat = _type_stats(typ.value_type) + nesting = 1 + substat.nesting + breadth = max(typ.count, substat.breadth) + num_dynamic_types = 1 + substat.num_dynamic_types + return _finalize() + + if isinstance(typ, SArrayT): + substat = _type_stats(typ.value_type) + nesting = 1 + substat.nesting + breadth = max(typ.count, substat.breadth) + num_dynamic_types = substat.num_dynamic_types + return _finalize() + + raise RuntimeError("unreachable") + + +@pytest.fixture(scope="module") +def payload_copier(get_contract_from_ir): + # some contract which will return the buffer passed to it + # note: hardcode the location of the bytestring + ir = [ + "with", + "length", + ["calldataload", 36], + ["seq", ["calldatacopy", 0, 68, "length"], ["return", 0, "length"]], + ] + return get_contract_from_ir(["deploy", 0, ir, 0]) + + +PARALLELISM = 1 # increase on fuzzer box + + +# NOTE: this is a heavy test. 100 types * 100 payloads per type can take +# 3-4minutes on a regular CPU core. +@pytest.mark.parametrize("_n", list(range(PARALLELISM))) +@hp.given(typ=vyper_type()) +@hp.settings(max_examples=100, **_settings) +@hp.example(typ=DArrayT(DArrayT(UINT256_T, 2), 2)) +def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier): + # import time + # t0 = time.time() + # print("ENTER", typ) + + wrapped_type = calculate_type_for_external_return(typ) + + stats = _type_stats(typ) + # for k, v in asdict(stats).items(): + # event(k, v) + hp.target(stats.num_dynamic_types) + # hp.target(typ.abi_type.is_dynamic() + typ.abi_type.is_complex_type())) + + # add max_mutations bytes worth of padding so we don't just get caught + # by bytes length check at function entry + type_bound = wrapped_type.abi_type.size_bound() + buffer_bound = type_bound + MAX_MUTATIONS + type_str = repr(typ) # annotation in vyper code + # TODO: intrinsic decode from staticcall/extcall + # TODO: _abi_decode from other sources (staticcall/extcall?) + # TODO: dirty the buffer + # TODO: check unwrap_tuple=False + code = f""" +@external +def run(xs: Bytes[{buffer_bound}]) -> {type_str}: + ret: {type_str} = abi_decode(xs, {type_str}) + return ret + +interface Foo: + def foo(xs: Bytes[{buffer_bound}]) -> {type_str}: view # STATICCALL + def bar(xs: Bytes[{buffer_bound}]) -> {type_str}: nonpayable # CALL + +@external +def run2(xs: Bytes[{buffer_bound}], copier: Foo) -> {type_str}: + assert len(xs) <= {type_bound} + return staticcall copier.foo(xs) + +@external +def run3(xs: Bytes[{buffer_bound}], copier: Foo) -> {type_str}: + assert len(xs) <= {type_bound} + return (extcall copier.bar(xs)) + """ + c = get_contract(code) + + @hp.given(data=payload_from(wrapped_type)) + @hp.settings(max_examples=100, **_settings) + def _fuzz(data): + hp.note(f"type: {typ}") + hp.note(f"abi_t: {wrapped_type.abi_type.selector_name()}") + hp.note(code) + hp.note(data.hex()) + + try: + expected = spec_decode(wrapped_type, data) + + # unwrap if necessary + if needs_external_call_wrap(typ): + assert isinstance(expected, tuple) + (expected,) = expected + + hp.note(f"expected {expected}") + assert expected == c.run(data) + assert expected == c.run2(data, payload_copier.address) + assert expected == c.run3(data, payload_copier.address) + + except DecodeError: + # note EvmError includes reverts *and* exceptional halts. + # we can get OOG during abi decoding due to how + # `_abi_payload_size()` works + hp.note("expect failure") + with tx_failed(EvmError): + c.run(data) + with tx_failed(EvmError): + c.run2(data, payload_copier.address) + with tx_failed(EvmError): + c.run3(data, payload_copier.address) + + _fuzz() + + # t1 = time.time() + # print(f"elapsed {t1 - t0}s") diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index ff0f801d74..9a0a08097c 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -1169,8 +1169,12 @@ def clamp_bytestring(ir_node, hi=None): if hi is not None: assert t.maxlen < 2**64 # sanity check - # note: this add does not risk arithmetic overflow because + # NOTE: this add does not risk arithmetic overflow because # length is bounded by maxlen. + # however(!) _abi_payload_size can OOG, since it loads the word + # at `ir_node` to find the length of the bytearray, which could + # be out-of-bounds. + # if we didn't get OOG, we could overflow in `add`. item_end = add_ofst(ir_node, _abi_payload_size(ir_node)) len_check = ["seq", ["assert", ["le", item_end, hi]], len_check] @@ -1189,8 +1193,12 @@ def clamp_dyn_array(ir_node, hi=None): if hi is not None: assert t.count < 2**64 # sanity check - # note: this add does not risk arithmetic overflow because + # NOTE: this add does not risk arithmetic overflow because # length is bounded by count * elemsize. + # however(!) _abi_payload_size can OOG, since it loads the word + # at `ir_node` to find the length of the bytearray, which could + # be out-of-bounds. + # if we didn't get OOG, we could overflow in `add`. item_end = add_ofst(ir_node, _abi_payload_size(ir_node)) # if the subtype is dynamic, the length check is performed in diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index c392ff48b1..4068d815d2 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -334,7 +334,10 @@ def __init__(self, member_types: Tuple[VyperType, ...]) -> None: self.key_type = UINT256_T # API Compatibility def __repr__(self): - return "(" + ", ".join(repr(t) for t in self.member_types) + ")" + if len(self.member_types) == 1: + (t,) = self.member_types + return f"({t},)" + return "(" + ", ".join(f"{t}" for t in self.member_types) + ")" @property def length(self):