Skip to content

Commit

Permalink
perf: flake8 type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Oct 29, 2024
1 parent b064700 commit 28861cc
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 40 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ repos:
rev: 7.1.1
hooks:
- id: flake8
additional_dependencies: [flake8-breakpoint, flake8-print, flake8-pydantic, flake8-type-checking]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
Expand Down
16 changes: 9 additions & 7 deletions ape_safe/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import os
from collections.abc import Iterable, Iterator, Mapping
from pathlib import Path
from typing import Any, Dict, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast

from ape.api import AccountAPI, AccountContainerAPI, ReceiptAPI, TransactionAPI
from ape.api.address import BaseAddress
from ape.api.networks import ForkedNetworkAPI
from ape.cli import select_account
from ape.contracts import ContractInstance
from ape.exceptions import ContractNotFoundError, ProviderNotConnectedError
from ape.logging import logger
from ape.managers.accounts import AccountManager, TestAccountManager
Expand Down Expand Up @@ -37,6 +35,10 @@
)
from ape_safe.utils import get_safe_tx_hash, order_by_signer

if TYPE_CHECKING:
from ape.api.address import BaseAddress
from ape.contracts import ContractInstance


class SafeContainer(AccountContainerAPI):
_accounts: Dict[str, "SafeAccount"] = {}
Expand Down Expand Up @@ -215,7 +217,7 @@ def address(self) -> AddressType:
return ecosystem.decode_address(self.account_file["address"])

@cached_property
def contract(self) -> ContractInstance:
def contract(self) -> "ContractInstance":
safe_contract = self.chain_manager.contracts.instance_at(self.address)
if self.fallback_handler:
contract_signatures = {x.signature for x in safe_contract.contract_type.abi}
Expand All @@ -236,7 +238,7 @@ def contract(self) -> ContractInstance:
return safe_contract

@cached_property
def fallback_handler(self) -> Optional[ContractInstance]:
def fallback_handler(self) -> Optional["ContractInstance"]:
slot = keccak(text="fallback_manager.handler.address")
value = self.provider.get_storage(self.address, slot)
address = self.network_manager.ecosystem.decode_address(value[-20:])
Expand Down Expand Up @@ -408,7 +410,7 @@ def create_execute_transaction(
*exec_args, encoded_signatures, **txn_options
)

def compute_prev_signer(self, signer: Union[str, AddressType, BaseAddress]) -> AddressType:
def compute_prev_signer(self, signer: Union[str, AddressType, "BaseAddress"]) -> AddressType:
"""
Sometimes it's handy to have "previous owner" for ownership change operations,
this function makes it easy to calculate.
Expand Down Expand Up @@ -465,7 +467,7 @@ def estimate_gas_cost(self, **kwargs) -> int:
)

def _preapproved_signature(
self, signer: Union[AddressType, BaseAddress, str]
self, signer: Union[AddressType, "BaseAddress", str]
) -> MessageSignature:
# Get the Safe-style "preapproval" signature type, which is a sentinel value used to denote
# when a signer approved via some other method, such as `approveHash` or being `msg.sender`
Expand Down
22 changes: 12 additions & 10 deletions ape_safe/client/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from abc import ABC, abstractmethod
from collections.abc import Iterator
from functools import cached_property
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

import certifi
import requests
import urllib3
from ape.types import AddressType, MessageSignature
from requests import Response
from requests.adapters import HTTPAdapter

from ape_safe.client.types import (
Expand All @@ -21,6 +19,10 @@
)
from ape_safe.exceptions import ClientResponseError

if TYPE_CHECKING:
from ape.types import AddressType, MessageSignature
from requests import Response

DEFAULT_HEADERS = {
"Accept": "application/json",
"Content-Type": "application/json",
Expand Down Expand Up @@ -48,19 +50,19 @@ def get_confirmations(self, safe_tx_hash: SafeTxID) -> Iterator[SafeTxConfirmati

@abstractmethod
def post_transaction(
self, safe_tx: SafeTx, signatures: dict[AddressType, MessageSignature], **kwargs
self, safe_tx: SafeTx, signatures: dict["AddressType", "MessageSignature"], **kwargs
): ...

@abstractmethod
def post_signatures(
self,
safe_tx_or_hash: Union[SafeTx, SafeTxID],
signatures: dict[AddressType, MessageSignature],
signatures: dict["AddressType", "MessageSignature"],
): ...

@abstractmethod
def estimate_gas_cost(
self, receiver: AddressType, value: int, data: bytes, operation: int = 0
self, receiver: "AddressType", value: int, data: bytes, operation: int = 0
) -> Optional[int]: ...

"""Shared methods"""
Expand All @@ -71,7 +73,7 @@ def get_transactions(
starting_nonce: int = 0,
ending_nonce: Optional[int] = None,
filter_by_ids: Optional[set[SafeTxID]] = None,
filter_by_missing_signers: Optional[set[AddressType]] = None,
filter_by_missing_signers: Optional[set["AddressType"]] = None,
) -> Iterator[SafeApiTxData]:
"""
confirmed: Confirmed if True, not confirmed if False, both if None
Expand Down Expand Up @@ -126,17 +128,17 @@ def session(self) -> requests.Session:
session.mount("https://", adapter)
return session

def _get(self, url: str) -> Response:
def _get(self, url: str) -> "Response":
return self._request("GET", url)

def _post(self, url: str, json: Optional[dict] = None, **kwargs) -> Response:
def _post(self, url: str, json: Optional[dict] = None, **kwargs) -> "Response":
return self._request("POST", url, json=json, **kwargs)

@cached_property
def _http(self):
return urllib3.PoolManager(ca_certs=certifi.where())

def _request(self, method: str, url: str, json: Optional[dict] = None, **kwargs) -> Response:
def _request(self, method: str, url: str, json: Optional[dict] = None, **kwargs) -> "Response":
# NOTE: paged requests include full url already
if url.startswith(f"{self.transaction_service_url}/api/v1/"):
api_url = url
Expand Down
20 changes: 11 additions & 9 deletions ape_safe/client/mock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from collections.abc import Iterator
from datetime import datetime, timezone
from typing import Optional, Union, cast
from typing import TYPE_CHECKING, Optional, Union, cast

from ape.contracts import ContractInstance
from ape.types import AddressType, MessageSignature
from ape.utils import ZERO_ADDRESS, ManagerAccessMixin
from eth_utils import keccak
from hexbytes import HexBytes
Expand All @@ -20,9 +18,13 @@
)
from ape_safe.utils import get_safe_tx_hash

if TYPE_CHECKING:
from ape.contracts import ContractInstance
from ape.types import AddressType, MessageSignature


class MockSafeClient(BaseSafeClient, ManagerAccessMixin):
def __init__(self, contract: ContractInstance):
def __init__(self, contract: "ContractInstance"):
self.contract = contract
self.transactions: dict[SafeTxID, SafeApiTxData] = {}
self.transactions_by_nonce: dict[int, list[SafeTxID]] = {}
Expand All @@ -47,13 +49,13 @@ def safe_details(self) -> SafeDetails:
)

@property
def guard(self) -> AddressType:
def guard(self) -> "AddressType":
return (
self.contract.getGuard() if "getGuard" in self.contract._view_methods_ else ZERO_ADDRESS
)

@property
def modules(self) -> list[AddressType]:
def modules(self) -> list["AddressType"]:
return self.contract.getModules() if "getModules" in self.contract._view_methods_ else []

def get_next_nonce(self) -> int:
Expand All @@ -73,7 +75,7 @@ def get_confirmations(self, safe_tx_hash: SafeTxID) -> Iterator[SafeTxConfirmati
yield from safe_tx_data.confirmations

def post_transaction(
self, safe_tx: SafeTx, signatures: dict[AddressType, MessageSignature], **kwargs
self, safe_tx: SafeTx, signatures: dict["AddressType", "MessageSignature"], **kwargs
):
safe_tx_data = UnexecutedTxData.from_safe_tx(safe_tx, self.safe_details.threshold)
safe_tx_data.confirmations.extend(
Expand All @@ -95,7 +97,7 @@ def post_transaction(
def post_signatures(
self,
safe_tx_or_hash: Union[SafeTx, SafeTxID],
signatures: dict[AddressType, MessageSignature],
signatures: dict["AddressType", "MessageSignature"],
):
for signer, signature in signatures.items():
safe_tx_id = (
Expand All @@ -114,6 +116,6 @@ def post_signatures(
)

def estimate_gas_cost(
self, receiver: AddressType, value: int, data: bytes, operation: int = 0
self, receiver: "AddressType", value: int, data: bytes, operation: int = 0
) -> Optional[int]:
return None # Estimate gas normally
14 changes: 8 additions & 6 deletions ape_safe/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from contextlib import ContextDecorator
from typing import Optional
from typing import TYPE_CHECKING, Optional

from ape.exceptions import AccountsError, ApeException, ContractLogicError, SignatureError
from ape.types import AddressType
from requests import Response

if TYPE_CHECKING:
from ape.types import AddressType
from requests import Response


class ApeSafeException(ApeException):
Expand All @@ -17,7 +19,7 @@ class ApeSafeError(ApeSafeException, AccountsError):


class NotASigner(ApeSafeException):
def __init__(self, signer: AddressType):
def __init__(self, signer: "AddressType"):
super().__init__(f"{signer} is not a valid signer.")


Expand Down Expand Up @@ -122,14 +124,14 @@ def __init__(self, message: str):


class ClientResponseError(SafeClientException):
def __init__(self, endpoint_url: str, response: Response, message: Optional[str] = None):
def __init__(self, endpoint_url: str, response: "Response", message: Optional[str] = None):
self.endpoint_url = endpoint_url
self.response = response
message = message or f"Exception when calling '{endpoint_url}':\n{response.text}"
super().__init__(message)


class MultisigTransactionNotFoundError(ClientResponseError):
def __init__(self, tx_hash: str, endpoint_url: str, response: Response):
def __init__(self, tx_hash: str, endpoint_url: str, response: "Response"):
message = f"Multisig transaction '{tx_hash}' not found."
super().__init__(endpoint_url, response, message=message)
15 changes: 9 additions & 6 deletions ape_safe/multisend.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from importlib.resources import files
from io import BytesIO
from typing import TYPE_CHECKING

from ape import convert
from ape.api import ReceiptAPI, TransactionAPI
from ape.contracts.base import ContractInstance, ContractTransactionHandler
from ape.types import AddressType, HexBytes
from ape.utils import ManagerAccessMixin, cached_property
from eth_abi.packed import encode_packed
from ethpm_types import PackageManifest

from ape_safe.exceptions import UnsupportedChainError, ValueRequired

if TYPE_CHECKING:
from ape.api import ReceiptAPI, TransactionAPI
from ape.contracts.base import ContractInstance, ContractTransactionHandler

MULTISEND_CALL_ONLY_ADDRESSES = (
"0x40A2aCCbd92BCA938b02010E17A5b8929b49130D", # MultiSend Call Only v1.3.0
"0xA1dabEF33b3B82c7814B6D82A79e50F4AC44102B", # MultiSend Call Only v1.3.0 (EIP-155)
Expand Down Expand Up @@ -80,7 +83,7 @@ def multisend():
)

@cached_property
def contract(self) -> ContractInstance:
def contract(self) -> "ContractInstance":
for address in MULTISEND_CALL_ONLY_ADDRESSES:
if self.provider.get_code(address) == MULTISEND_CALL_ONLY.get_runtime_bytecode():
return self.chain_manager.contracts.instance_at(
Expand All @@ -90,7 +93,7 @@ def contract(self) -> ContractInstance:
raise UnsupportedChainError()

@property
def handler(self) -> ContractTransactionHandler:
def handler(self) -> "ContractTransactionHandler":
return self.contract.multiSend

def add(
Expand Down Expand Up @@ -149,7 +152,7 @@ def encoded_calls(self):
for call in self.calls
]

def __call__(self, **txn_kwargs) -> ReceiptAPI:
def __call__(self, **txn_kwargs) -> "ReceiptAPI":
"""
Execute the MultiSend transaction. The transaction will broadcast again every time
the ``Transaction`` object is called.
Expand All @@ -169,7 +172,7 @@ def __call__(self, **txn_kwargs) -> ReceiptAPI:
txn_kwargs["operation"] = 1
return self.handler(b"".join(self.encoded_calls), **txn_kwargs)

def as_transaction(self, **txn_kwargs) -> TransactionAPI:
def as_transaction(self, **txn_kwargs) -> "TransactionAPI":
"""
Encode the MultiSend transaction as a ``TransactionAPI`` object, but do not execute it.
Expand Down
7 changes: 5 additions & 2 deletions ape_safe/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, cast

from ape.types import AddressType, MessageSignature
from eip712.messages import calculate_hash
from eth_utils import to_int

if TYPE_CHECKING:
from ape.types import AddressType, MessageSignature

from ape_safe.client.types import SafeTxID


def order_by_signer(signatures: Mapping[AddressType, MessageSignature]) -> list[MessageSignature]:
def order_by_signer(
signatures: Mapping["AddressType", "MessageSignature"]
) -> list["MessageSignature"]:
# NOTE: Must order signatures in ascending order of signer address (converted to int)
return list(signatures[signer] for signer in sorted(signatures, key=lambda a: to_int(hexstr=a)))

Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
[flake8]
max-line-length = 100
ignore = E704,W503,PYD002,TC003,TC006
exclude =
venv*
.eggs
docs
build
type-checking-pydantic-enabled = True
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
"types-requests", # Needed for mypy type shed
"types-setuptools", # Needed for mypy type shed
"flake8>=7.1.1,<8", # Style linter
"flake8-breakpoint>=1.1.0,<2", # Detect breakpoints left in code
"flake8-print>=4.0.1,<5", # Detect print statements left in code
"flake8-pydantic", # For detecting issues with Pydantic models
"flake8-type-checking", # Detect imports to move in/out of type-checking blocks
"isort>=5.13.2,<6", # Import sorting linter
"mdformat>=0.7.18,<0.8", # Docs formatter and linter
"mdformat-pyproject>=0.0.1", # Allows configuring in pyproject.toml
Expand Down

0 comments on commit 28861cc

Please sign in to comment.