Skip to content

Commit

Permalink
fix: signature
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Dec 4, 2023
1 parent cc0aeb6 commit 7575821
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 17 deletions.
25 changes: 22 additions & 3 deletions ape_safe/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
SafeClientException,
handle_safe_logic_error,
)
from ape_safe.utils import get_safe_tx_hash, order_by_signer
from ape_safe.utils import get_safe_tx_hash, hash_message, order_by_signer


class SafeContainer(AccountContainerAPI):
Expand Down Expand Up @@ -678,9 +678,11 @@ def add_signatures(
safe_tx_hash = _get_safe_tx_id(safe_tx, confirmations)

for signer in signers:
signature = signer.sign_message(safe_tx.signable_message)
data_hash = hash_message(safe_tx_hash)
signature = signer.sign_message(data_hash) # type: ignore
if signature:
signatures[signer.address] = signature
signature_adjusted = adjust_v_in_signature(signature)
signatures[signer.address] = signature_adjusted

if signatures:
self.client.post_signatures(safe_tx_hash, signatures)
Expand All @@ -696,3 +698,20 @@ def _get_safe_tx_id(safe_tx: SafeTx, confirmations: List[SafeTxConfirmation]) ->
return value

raise ApeSafeError("Failed to get transaction hash.")


def adjust_v_in_signature(signature: MessageSignature) -> MessageSignature:
MIN_VALID_V_VALUE_FOR_SAFE_ECDSA = 27
v = signature.v

if v < MIN_VALID_V_VALUE_FOR_SAFE_ECDSA:
v += MIN_VALID_V_VALUE_FOR_SAFE_ECDSA

# Add 4 because we signed with the prefix.
v += 4

return MessageSignature(
v=v,
r=signature.r,
s=signature.s,
)
9 changes: 3 additions & 6 deletions ape_safe/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from datetime import datetime
from functools import reduce
from typing import Dict, Iterator, Optional, Union, cast
Expand Down Expand Up @@ -142,13 +143,9 @@ def post_signatures(
safe_tx_hash = cast(SafeTxID, HexBytes(safe_tx_hash).hex())
url = f"multisig-transactions/{safe_tx_hash}/confirmations"
signature = HexBytes(b"".join([x.encode_rsv() for x in order_by_signer(signatures)])).hex()

# from gnosis.safe.safe_signature import SafeSignature
# parsed_signatures = SafeSignature.parse_signature(signature, safe_tx_hash)
# breakpoint()

data_str = json.dumps({"signature": signature})
try:
self._post(url, json={"signature": signature})
self._post(url, data=data_str)
except ClientResponseError as err:
if "The requested resource was not found on this server" in err.response.text:
raise MultisigTransactionNotFoundError(safe_tx_hash, url, err.response) from err
Expand Down
21 changes: 14 additions & 7 deletions ape_safe/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import requests
from ape.types import AddressType, MessageSignature
from requests import Response
from requests.adapters import HTTPAdapter

from ape_safe.client.types import (
ExecutedTxData,
Expand All @@ -17,6 +18,8 @@
)
from ape_safe.exceptions import ActionNotPerformedError, ClientResponseError

# from requests import Response


class BaseSafeClient(ABC):
def __init__(self, transaction_service_url: str):
Expand Down Expand Up @@ -100,7 +103,7 @@ def get_transactions(
@cached_property
def session(self) -> requests.Session:
session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
adapter = HTTPAdapter(
pool_connections=10, # Doing all the connections to the same url
pool_maxsize=100, # Number of concurrent connections
pool_block=False,
Expand All @@ -113,23 +116,27 @@ def _get(self, url: str) -> Response:
return self._request("GET", url)

def _post(self, url: str, json: Optional[Dict] = None, **kwargs) -> Response:
json = json or {}
if "origin" not in json and isinstance(json, dict):
if json is not None and "origin" not in json and isinstance(json, dict):
json["origin"] = "ApeWorX/ape-safe"

if "headers" not in kwargs:
kwargs["headers"] = {"Content-type": "application/json"}

return self._request("POST", url, json=json, **kwargs)

def _request(self, method: str, url: str, json: Optional[Dict] = None, **kwargs) -> Response:
api_url = f"{self.transaction_service_url}/api/v1/{url}"
do_fail = not kwargs.pop("allow_failure", False)

if "timeout" not in kwargs:
kwargs["timeout"] = 10

headers = kwargs.get("headers", {})

# Add default headers
default_headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
kwargs["headers"] = {**default_headers, **headers}
response = self.session.request(method, api_url, json=json, **kwargs)
do_fail = not kwargs.get("allow_failure", False)

if method != response.request.method and do_fail:
# Handle weird Safe API behavior where it doesn't do the right thing.
Expand Down
65 changes: 64 additions & 1 deletion ape_safe/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import TYPE_CHECKING, List, Mapping, cast

from ape.types import AddressType, MessageSignature
from eth_utils import keccak, to_int
from eth_typing import HexStr
from eth_utils import add_0x_prefix, keccak, to_int
from hexbytes import HexBytes

if TYPE_CHECKING:
Expand All @@ -20,3 +21,65 @@ def get_safe_tx_hash(safe_tx) -> "SafeTxID":
return cast(
"SafeTxID", HexBytes(keccak(b"".join([bytes.fromhex("19"), *safe_tx.signable_message])))
)


def to_int_array(value) -> List[int]:
value_hex = HexBytes(value).hex()
value_int = int(value_hex, 16)

result: List[int] = []
while value_int:
result.insert(0, value_int & 0xFF)
value_int = value_int // 256

if len(result) == 0:
result.append(0)

return result


def to_utf8_bytes(value: str) -> List[int]:
result = []
i = 0
while i < len(value):
c = ord(value[i])

if c < 0x80:
result.append(c)

elif c < 0x800:
result.append((c >> 6) | 0xC0)
result.append((c & 0x3F) | 0x80)

elif 0xD800 <= c <= 0xDBFF:
i += 1
c2 = ord(value[i])

if i >= len(value) or not (0xDC00 <= c2 <= 0xDFFF):
raise ValueError("Invalid UTF-8 string")

# Surrogate Pair
pair = 0x10000 + ((c & 0x03FF) << 10) + (c2 & 0x03FF)
result.append((pair >> 18) | 0xF0)
result.append(((pair >> 12) & 0x3F) | 0x80)
result.append(((pair >> 6) & 0x3F) | 0x80)
result.append((pair & 0x3F) | 0x80)

else:
result.append((c >> 12) | 0xE0)
result.append(((c >> 6) & 0x3F) | 0x80)
result.append((c & 0x3F) | 0x80)

i += 1

return result


def hash_message(message: str) -> str:
message_array = to_int_array(message)
message_prefix = "\x19Ethereum Signed Message:\n"
prefix_bytes = to_utf8_bytes(message_prefix)
length_bytes = to_utf8_bytes(f"{len(message_array)}")
full_array = prefix_bytes + length_bytes + message_array
result = keccak(bytearray(full_array)).hex()
return add_0x_prefix(cast(HexStr, result))

0 comments on commit 7575821

Please sign in to comment.