Skip to content

Commit

Permalink
Merge pull request #2 from antazoey/refactor/address-base
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Oct 26, 2023
2 parents a95501f + 9a4f7be commit f6cda09
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 47 deletions.
3 changes: 2 additions & 1 deletion eth_pydantic_types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .address import Address
from .address import Address, AddressType
from .bip122 import Bip122Uri
from .hash import (
HashBytes4,
Expand All @@ -18,6 +18,7 @@

__all__ = [
"Address",
"AddressType",
"Bip122Uri",
"HashBytes4",
"HashBytes8",
Expand Down
49 changes: 24 additions & 25 deletions eth_pydantic_types/address.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from typing import Any, ClassVar, Optional, Tuple
from typing import Any, ClassVar, Optional, Tuple, cast

from eth_typing import ChecksumAddress
from eth_utils import is_checksum_address, to_checksum_address
from pydantic_core import CoreSchema
from pydantic_core.core_schema import (
ValidationInfo,
str_schema,
with_info_before_validator_function,
)
from pydantic_core.core_schema import ValidationInfo, str_schema
from typing_extensions import Annotated

from eth_pydantic_types.hex import BaseHexStr, HexBytes
from eth_pydantic_types.validators import validate_address_size
from eth_pydantic_types.hash import HashStr20

ADDRESS_PATTERN = "^0x[a-fA-F0-9]{40}$"

Expand All @@ -18,7 +14,7 @@ def address_schema():
return str_schema(min_length=42, max_length=42, pattern=ADDRESS_PATTERN)


class Address(BaseHexStr):
class Address(HashStr20):
"""
Use for address-types. Validates as a checksummed address. Left-pads zeroes
if necessary.
Expand All @@ -32,22 +28,25 @@ class Address(BaseHexStr):
"0x1e59ce931B4CFea3fe4B875411e280e173cB7A9C",
)

def __get_pydantic_core_schema__(self, *args, **kwargs) -> CoreSchema:
schema = with_info_before_validator_function(
self._validate_address,
address_schema(),
)
return schema
@classmethod
def __eth_pydantic_validate__(cls, value: Any, info: Optional[ValidationInfo] = None) -> str:
value = super().__eth_pydantic_validate__(value)
return cls.to_checksum_address(value)

@classmethod
def _validate_address(cls, value: Any, info: Optional[ValidationInfo] = None) -> str:
if isinstance(value, str) and is_checksum_address(value):
return value
def to_checksum_address(cls, value: str) -> ChecksumAddress:
return (
cast(ChecksumAddress, value)
if is_checksum_address(value)
else to_checksum_address(value)
)


elif not isinstance(value, str):
value = HexBytes(value).hex()
"""
A type that can be used in place of ``eth_typing.ChecksumAddress``.
number = value[2:] if value.startswith("0x") else value
number_padded = validate_address_size(number)
value = f"0x{number_padded}"
return to_checksum_address(value)
**NOTE**: We are unable to subclass ``eth_typing.ChecksumAddress``
in :class:`~eth_pydantic_types.address.Address` because it is
a NewType; that is why we offer this annotated approach.
"""
AddressType = Annotated[ChecksumAddress, Address]
7 changes: 4 additions & 3 deletions eth_pydantic_types/bip122.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ def __get_pydantic_json_schema__(cls, core_schema, handler):
json_schema.update(examples=[example], pattern=pattern)
return json_schema

def __get_pydantic_core_schema__(self, *args, **kwargs) -> CoreSchema:
@classmethod
def __get_pydantic_core_schema__(cls, value, handler=None) -> CoreSchema:
return with_info_before_validator_function(
self._validate,
value.__eth_pydantic_validate__,
str_schema(),
)

@classmethod
def _validate(cls, value: Any, info: Optional[ValidationInfo] = None) -> str:
def __eth_pydantic_validate__(cls, value: Any, info: Optional[ValidationInfo] = None) -> str:
if not value.startswith(cls.prefix):
raise Bip122UriFormatError(value)

Expand Down
19 changes: 12 additions & 7 deletions eth_pydantic_types/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,19 @@ class HashBytes(HexBytes):
schema_pattern: ClassVar[str] = _get_hash_pattern(1)
schema_examples: ClassVar[Tuple[str, ...]] = _get_hash_examples(1)

def __get_pydantic_core_schema__(self, *args, **kwargs) -> CoreSchema:
@classmethod
def __get_pydantic_core_schema__(cls, value, handler=None) -> CoreSchema:
schema = with_info_before_validator_function(
self._validate_hash, bytes_schema(max_length=self.size, min_length=self.size)
cls.__eth_pydantic_validate__,
bytes_schema(max_length=cls.size, min_length=cls.size),
)
schema["serialization"] = hex_serializer
return schema

@classmethod
def _validate_hash(cls, value: Any, info: Optional[ValidationInfo] = None) -> bytes:
def __eth_pydantic_validate__(
cls, value: Any, info: Optional[ValidationInfo] = None
) -> HexBytes:
return cls(cls.validate_size(HexBytes(value)))

@classmethod
Expand All @@ -64,14 +68,15 @@ class HashStr(BaseHexStr):
schema_pattern: ClassVar[str] = _get_hash_pattern(1)
schema_examples: ClassVar[Tuple[str, ...]] = _get_hash_examples(1)

def __get_pydantic_core_schema__(self, *args, **kwargs) -> CoreSchema:
str_size = self.size * 2 + 2
@classmethod
def __get_pydantic_core_schema__(cls, value, handler=None) -> CoreSchema:
str_size = cls.size * 2 + 2
return with_info_before_validator_function(
self._validate_hash, str_schema(max_length=str_size, min_length=str_size)
cls.__eth_pydantic_validate__, str_schema(max_length=str_size, min_length=str_size)
)

@classmethod
def _validate_hash(cls, value: Any, info: Optional[ValidationInfo] = None) -> str:
def __eth_pydantic_validate__(cls, value: Any, info: Optional[ValidationInfo] = None) -> str:
hex_str = cls.validate_hex(value)
hex_value = hex_str[2:] if hex_str.startswith("0x") else hex_str
sized_value = cls.validate_size(hex_value)
Expand Down
22 changes: 17 additions & 5 deletions eth_pydantic_types/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class HexBytes(BaseHexBytes, BaseHex):
a pydantic validator and serializer.
"""

def __get_pydantic_core_schema__(self, *args, **kwargs) -> CoreSchema:
schema = with_info_before_validator_function(self._validate_hexbytes, bytes_schema())
@classmethod
def __get_pydantic_core_schema__(cls, value, handle=None) -> CoreSchema:
schema = with_info_before_validator_function(cls.__eth_pydantic_validate__, bytes_schema())
schema["serialization"] = hex_serializer
return schema

Expand All @@ -55,11 +56,21 @@ def fromhex(cls, hex_str: str) -> "HexBytes":
return super().fromhex(value)

@classmethod
def _validate_hexbytes(cls, value: Any, info: Optional[ValidationInfo] = None) -> BaseHexBytes:
def __eth_pydantic_validate__(
cls, value: Any, info: Optional[ValidationInfo] = None
) -> BaseHexBytes:
return BaseHexBytes(value)


class BaseHexStr(str, BaseHex):
@classmethod
def __get_pydantic_core_schema__(cls, value, handler=None):
return no_info_before_validator_function(cls.__eth_pydantic_validate__, str_schema())

@classmethod
def __eth_pydantic_validate__(cls, value):
return value # Override.

@classmethod
def from_bytes(cls, data: bytes) -> "BaseHexStr":
hex_str = data.hex()
Expand Down Expand Up @@ -88,8 +99,9 @@ def __bytes__(self) -> bytes:
class HexStr(BaseHexStr):
"""A hex string value, typically from a hash."""

def __get_pydantic_core_schema__(cls, *args, **kwargs):
return no_info_before_validator_function(cls.validate_hex, str_schema())
@classmethod
def __eth_pydantic_validate__(cls, value):
return cls.validate_hex(value)

@classmethod
def from_bytes(cls, data: bytes) -> "HexStr":
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@
include_package_data=True,
install_requires=[
"hexbytes>=0.3.0,<1",
"eth-utils>=2.2.0,<3",
"eth-hash[pycryptodome]>=0.5.2,<1",
"eth-utils>=2.2.0,<3",
"eth-typing>=3.5.0,<4",
"pydantic>=2.4.2,<3",
"typing_extensions>=4.8.0,<5",
],
python_requires=">=3.8,<4",
extras_require=extras_require,
Expand Down
12 changes: 7 additions & 5 deletions tests/test_address.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from pydantic import BaseModel, ValidationError

from eth_pydantic_types.address import Address
from eth_pydantic_types.address import Address, AddressType
from eth_pydantic_types.hex import HexBytes

# NOTE: This address purposely is the wrong length (missing left zero),
Expand All @@ -12,6 +12,7 @@

class Model(BaseModel):
address: Address
address_type: AddressType


@pytest.fixture
Expand All @@ -31,14 +32,15 @@ def checksum_address():
),
)
def test_address(address, checksum_address):
actual = Model(address=address)
actual = Model(address=address, address_type=address)
assert actual.address == checksum_address
assert actual.address_type == checksum_address


@pytest.mark.parametrize("address", ("foo", -35, "0x" + ("F" * 100)))
def test_invalid_address(address):
with pytest.raises(ValidationError):
Model(address=address)
Model(address=address, address_type=address)


def test_schema():
Expand All @@ -57,7 +59,7 @@ def test_schema():


def test_model_dump():
model = Model(address=ADDRESS)
model = Model(address=ADDRESS, address_type=ADDRESS)
actual = model.model_dump()
expected = {"address": CHECKSUM_ADDRESS}
expected = {"address": CHECKSUM_ADDRESS, "address_type": CHECKSUM_ADDRESS}
assert actual == expected

0 comments on commit f6cda09

Please sign in to comment.