Skip to content

Commit

Permalink
feat: hexstr
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Oct 11, 2023
1 parent c4087f9 commit 6c9b758
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 101 deletions.
14 changes: 12 additions & 2 deletions eth_pydantic_types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from .address import Address
from .hash import Hash4, Hash8, Hash16, Hash20, Hash32, Hash64
from .hexbytes import HexBytes
from .hex import HexBytes, HexStr

__all__ = ["Address", "Hash4", "Hash8", "Hash16", "Hash20", "Hash32", "Hash64", "HexBytes"]
__all__ = [
"Address",
"Hash4",
"Hash8",
"Hash16",
"Hash20",
"Hash32",
"Hash64",
"HexBytes",
"HexStr",
]
18 changes: 18 additions & 0 deletions eth_pydantic_types/_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any, Callable

from pydantic_core import PydanticCustomError

# NOTE: We use the factory approach because PydanticCustomError is a final class.
# That is also why this module is internal.


def EthPydanticTypesException(fn: Callable, invalid_tag: str, **kwargs):
return PydanticCustomError(fn.__name__, f"Invalid {invalid_tag}", kwargs)


def HexValueError(value: Any):
return EthPydanticTypesException(HexValueError, "hex value", value=value)


def SizeError(size: Any, value: Any):
return EthPydanticTypesException(SizeError, "size of value", value=value)
26 changes: 16 additions & 10 deletions eth_pydantic_types/address.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any, ClassVar, Optional, Tuple

from eth_utils import is_checksum_address, to_checksum_address
from pydantic_core import CoreSchema
Expand All @@ -8,20 +8,32 @@
with_info_before_validator_function,
)

from eth_pydantic_types.hexbytes import HexBytes
from eth_pydantic_types.hex import BaseHexStr, HexBytes
from eth_pydantic_types.validators import validate_address_size

ADDRESS_PATTERN = "^0x[a-fA-F0-9]{40}$"

class Address(str):

def address_schema():
return str_schema(min_length=42, max_length=42, pattern=ADDRESS_PATTERN)


class Address(BaseHexStr):
"""
Use for address-types. Validates as a checksummed address. Left-pads zeroes
if necessary.
"""

_SCHEMA_PATTERN: ClassVar[str] = ADDRESS_PATTERN
_SCHEMA_EXAMPLES: ClassVar[Tuple[str, ...]] = (
"0x0000000000000000000000000000000000000000", # empty address
"0x1e59ce931B4CFea3fe4B875411e280e173cB7A9C",
)

def __get_pydantic_core_schema__(self, *args, **kwargs) -> CoreSchema:
schema = with_info_before_validator_function(
self._validate_address,
str_schema(min_length=42, max_length=42, pattern="^0x[a-fA-F0-9]{40}$"),
address_schema(),
)
return schema

Expand All @@ -37,9 +49,3 @@ def _validate_address(cls, value: Any, info: Optional[ValidationInfo] = None) ->
number_padded = validate_address_size(number, 40)
value = f"0x{number_padded}"
return to_checksum_address(value)

def __int__(self) -> int:
return int(self, 16)

def __bytes__(self) -> HexBytes:
return HexBytes(self)
2 changes: 1 addition & 1 deletion eth_pydantic_types/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with_info_before_validator_function,
)

from eth_pydantic_types.hexbytes import HexBytes
from eth_pydantic_types.hex import HexBytes
from eth_pydantic_types.serializers import hex_serializer
from eth_pydantic_types.validators import validate_bytes_size

Expand Down
99 changes: 99 additions & 0 deletions eth_pydantic_types/hex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Any, ClassVar, Optional, Tuple, Union

from hexbytes import HexBytes as BaseHexBytes
from pydantic_core import CoreSchema
from pydantic_core.core_schema import (
ValidationInfo,
bytes_schema,
no_info_before_validator_function,
str_schema,
with_info_before_validator_function,
)

from eth_pydantic_types._error import HexValueError
from eth_pydantic_types.serializers import hex_serializer


class HexBytes(BaseHexBytes):
"""
Use when receiving ``hexbytes.HexBytes`` values. Includes
a pydantic validator and serializer.
"""

def __get_pydantic_core_schema__(self, *args, **kwargs) -> CoreSchema:
schema = with_info_before_validator_function(self._validate_hexbytes, bytes_schema())
schema["serialization"] = hex_serializer
return schema

@classmethod
def fromhex(cls, hex_str: str) -> "HexBytes":
value = hex_str[2:] if hex_str.startswith("0x") else hex_str
return super().fromhex(value)

@classmethod
def _validate_hexbytes(cls, value: Any, info: Optional[ValidationInfo] = None) -> BaseHexBytes:
return BaseHexBytes(value)


class BaseHexStr(str):
_SCHEMA_PATTERN: ClassVar[str] = "^0x([0-9a-f][0-9a-f])*$"
_SCHEMA_EXAMPLES: ClassVar[Tuple[str, ...]] = (
"0x", # empty bytes
"0xd4",
"0xd4e5",
"0xd4e56740",
"0xd4e56740f876aef8",
"0xd4e56740f876aef8c010b86a40d5f567",
"0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3",
)

@classmethod
def __get_pydantic_json_schema__(cls, core_schema, handler):
json_schema = handler(core_schema)
json_schema.update(
format="binary", pattern=cls._SCHEMA_PATTERN, examples=list(cls._SCHEMA_EXAMPLES)
)
return json_schema

@classmethod
def from_bytes(cls, data: bytes) -> "HexStr":
hex_str = data.hex()
return cls(hex_str if hex_str.startswith("0x") else hex_str)

def __int__(self) -> int:
return int(self, 16)

def __bytes__(self) -> bytes:
return bytes.fromhex(self[2:])


class HexStr(BaseHexStr):
"""A hex string value, typically from a hash."""

def __get_pydantic_core_schema__(cls, *args, **kwargs):
return no_info_before_validator_function(cls.validate_hex, str_schema())

@classmethod
def validate_hex(cls, data: Union[bytes, str, int]):
if isinstance(data, bytes):
return cls.from_bytes(data)

elif isinstance(data, str):
return cls._validate_hex_str(data)

elif isinstance(data, int):
return BaseHexBytes(data).hex()

raise HexValueError(data)

@classmethod
def _validate_hex_str(cls, data: str) -> str:
hex_value = (data[2:] if data.startswith("0x") else data).lower()
if set(hex_value) - set("1234567890abcdef"):
raise HexValueError(data)

# Missing zero padding.
if len(hex_value) % 2 != 0:
hex_value = f"0{hex_value}"

return f"0x{hex_value}"
32 changes: 0 additions & 32 deletions eth_pydantic_types/hexbytes.py

This file was deleted.

5 changes: 3 additions & 2 deletions eth_pydantic_types/validators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Callable, Dict, Optional, Sized, TypeVar, cast

from pydantic import WithJsonSchema
from pydantic_core import PydanticCustomError
from pydantic_core.core_schema import bytes_schema

from eth_pydantic_types._error import SizeError

__SIZED_T = TypeVar("__SIZED_T", bound=Sized)


Expand All @@ -21,7 +22,7 @@ def validate_size(value: __SIZED_T, size: int, coerce: Optional[Callable] = None
elif coerce:
return validate_size(coerce(value), size)

raise PydanticCustomError("value_size", "Invalid size of value", {"size": size, "value": value})
raise SizeError(size, value)


def validate_bytes_size(value: bytes, size: int) -> bytes:
Expand Down
16 changes: 10 additions & 6 deletions tests/test_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pydantic import BaseModel, ValidationError

from eth_pydantic_types.address import Address
from eth_pydantic_types.hexbytes import HexBytes
from eth_pydantic_types.hex import HexBytes

# NOTE: This address purposely is the wrong length (missing left zero),
# not checksummed, and not 0x prefixed.
Expand Down Expand Up @@ -43,11 +43,15 @@ def test_invalid_address(address):

def test_schema():
actual = Model.model_json_schema()
for name, prop in actual["properties"].items():
assert prop["maxLength"] == 42
assert prop["minLength"] == 42
assert prop["type"] == "string"
assert prop["pattern"] == "^0x[a-fA-F0-9]{40}$"
prop = actual["properties"]["address"]
assert prop["maxLength"] == 42
assert prop["minLength"] == 42
assert prop["type"] == "string"
assert prop["pattern"] == "^0x[a-fA-F0-9]{40}$"
assert prop["examples"] == [
"0x0000000000000000000000000000000000000000",
"0x1e59ce931B4CFea3fe4B875411e280e173cB7A9C",
]


def test_model_dump():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pydantic import BaseModel, ValidationError

from eth_pydantic_types.hash import Hash8, Hash16, Hash32, Hash64
from eth_pydantic_types.hexbytes import HexBytes
from eth_pydantic_types.hex import HexBytes


class Model(BaseModel):
Expand Down
95 changes: 95 additions & 0 deletions tests/test_hex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest
from hexbytes import HexBytes as BaseHexBytes
from pydantic import BaseModel, ValidationError

from eth_pydantic_types.hex import HexBytes, HexStr


class BytesModel(BaseModel):
value: HexBytes


class StrModel(BaseModel):
value: HexStr


@pytest.mark.parametrize("value", ("0xa", 10, b"\n"))
def test_hexbytes(value):
actual = BytesModel(value=value)

# The end result, the value is a hexbytes.HexBytes
assert actual.value == BaseHexBytes(value)
assert actual.value.hex() == "0x0a"
assert isinstance(actual.value, bytes)
assert isinstance(actual.value, BaseHexBytes)


def test_invalid_hexbytes():
with pytest.raises(ValidationError):
BytesModel(value="foo")


def test_hexbytes_fromhex(bytes32str):
actual_with_0x = HexBytes.fromhex(bytes32str)
actual_without_0x = HexBytes.fromhex(bytes32str[2:])
expected = HexBytes(bytes32str)
assert actual_with_0x == actual_without_0x == expected


def test_hexbytes_schema():
actual = BytesModel.model_json_schema()
for name, prop in actual["properties"].items():
assert prop["type"] == "string"
assert prop["format"] == "binary"


def test_hexbytes_model_dump(bytes32str):
model = BytesModel(value=bytes32str)
actual = model.model_dump()
expected = {"value": "0x9b70bd98ccb5b6434c2ead14d68d15f392435a06ff469f8d1f8cf38b2ae0b0e2"}
assert actual == expected


@pytest.mark.parametrize("value", ("0xa", 10, HexBytes(10)))
def test_hexstr(value):
actual = StrModel(value=value)

# The end result, the value is a str
assert actual.value == "0x0a"
assert isinstance(actual.value, str)


def test_invalid_hexstr():
with pytest.raises(ValidationError):
StrModel(value="foo")


def test_hexstr_conversions():
model = StrModel(value="0x123")
assert int(model.value, 16) == 291
assert bytes.fromhex(model.value[2:]) == b"\x01#"


def test_hexstr_schema():
actual = StrModel.model_json_schema()
properties = actual["properties"]
assert len(properties) == 1
prop = properties["value"]
assert prop["type"] == "string"
assert prop["format"] == "binary"
assert prop["examples"] == [
"0x",
"0xd4",
"0xd4e5",
"0xd4e56740",
"0xd4e56740f876aef8",
"0xd4e56740f876aef8c010b86a40d5f567",
"0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3",
]


def test_hexstr_model_dump(bytes32str):
model = StrModel(value=bytes32str)
actual = model.model_dump()
expected = {"value": "0x9b70bd98ccb5b6434c2ead14d68d15f392435a06ff469f8d1f8cf38b2ae0b0e2"}
assert actual == expected
Loading

0 comments on commit 6c9b758

Please sign in to comment.