diff --git a/ape_alchemy/provider.py b/ape_alchemy/provider.py index a1a78df..4149d3c 100644 --- a/ape_alchemy/provider.py +++ b/ape_alchemy/provider.py @@ -18,8 +18,11 @@ from requests import HTTPError from web3 import HTTPProvider, Web3 from web3.exceptions import ContractLogicError as Web3ContractLogicError +from web3.exceptions import ExtraDataLengthError from web3.gas_strategies.rpc import rpc_gas_price_strategy from web3.middleware import geth_poa_middleware +from web3.middleware import geth_poa_middleware as ExtraDataToPOAMiddleware +from web3.middleware.validation import MAX_EXTRADATA_LENGTH from web3.types import RPCEndpoint from .exceptions import AlchemyFeatureNotAvailable, AlchemyProviderError, MissingProjectKeyError @@ -119,19 +122,44 @@ def connection_str(self) -> str: def connect(self): self._web3 = Web3(HTTPProvider(self.uri)) + is_poa = None try: # Any chain that *began* as PoA needs the middleware for pre-merge blocks base = 8453 optimism = 10 polygon = 137 + polygon_amoy = 80002 - if self._web3.eth.chain_id in (base, optimism, polygon): + if self._web3.eth.chain_id in (base, optimism, polygon, polygon_amoy): self._web3.middleware_onion.inject(geth_poa_middleware, layer=0) + is_poa = True self._web3.eth.set_gas_price_strategy(rpc_gas_price_strategy) except Exception as err: raise ProviderError(f"Failed to connect to Alchemy.\n{repr(err)}") from err + if is_poa is None: + # Check if is PoA but just wasn't as such yet. + # NOTE: We have to check both earliest and latest + # because if the chain was _ever_ PoA, we need + # this middleware. + for option in ("earliest", "latest"): + try: + block = self.web3.eth.get_block(option) # type: ignore[arg-type] + except ExtraDataLengthError: + is_poa = True + break + else: + is_poa = ( + "proofOfAuthorityData" in block + or len(block.get("extraData", "")) > MAX_EXTRADATA_LENGTH + ) + if is_poa: + break + + if is_poa and ExtraDataToPOAMiddleware not in self.web3.middleware_onion: + self.web3.middleware_onion.inject(ExtraDataToPOAMiddleware, layer=0) + def disconnect(self): self._web3 = None diff --git a/tests/test_integration.py b/tests/test_integration.py index b5a4797..5ac0f3b 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -21,6 +21,7 @@ def test_http(provider): assert provider.http_uri.startswith("https") assert provider.get_balance(ZERO_ADDRESS) > 0 assert provider.get_block(0) + assert provider.get_block("latest") def test_ws(provider):