Skip to content

Commit

Permalink
Merge pull request #6 from antazoey/fix/leadingzeroes
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Dec 15, 2023
2 parents a3b9c90 + 2011eea commit 1a66905
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
25 changes: 19 additions & 6 deletions eth_pydantic_types/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,33 @@ def validate_size(value: __SIZED_T, size: int, coerce: Optional[Callable] = None


def validate_bytes_size(value: bytes, size: int) -> bytes:
return validate_size(value, size, coerce=lambda v: _left_pad_bytes(v, size))
return validate_size(value, size, coerce=lambda v: _coerce_hexbytes_size(v, size))


def validate_address_size(value: str) -> str:
return validate_str_size(value, 40)


def validate_str_size(value: str, size: int) -> str:
return validate_size(value, size, coerce=lambda v: _left_pad_str(v, size))
return validate_size(value, size, coerce=lambda v: _coerce_hexstr_size(v, size))


def _left_pad_str(val: str, length: int) -> str:
return "0" * (length - len(val)) + val if len(val) < length else val
def _coerce_hexstr_size(val: str, length: int) -> str:
val = val.replace("0x", "") if val.startswith("0x") else val
if len(val) == length:
return val

val_stripped = val.lstrip("0")
num_zeroes = max(0, length - len(val_stripped))
zeroes = "0" * num_zeroes
return f"{zeroes}{val_stripped}"

def _left_pad_bytes(val: bytes, num_bytes: int) -> bytes:
return b"\x00" * (num_bytes - len(val)) + val if len(val) < num_bytes else val

def _coerce_hexbytes_size(val: bytes, num_bytes: int) -> bytes:
if len(val) == num_bytes:
return val

val_stripped = val.lstrip(b"\x00")
num_zeroes = max(0, num_bytes - len(val_stripped))
zeroes = b"\x00" * num_zeroes
return zeroes + val_stripped
18 changes: 18 additions & 0 deletions tests/test_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from eth_pydantic_types.hash import (
HashBytes8,
HashBytes16,
HashBytes20,
HashBytes32,
HashBytes64,
HashStr8,
Expand All @@ -17,6 +18,7 @@
class Model(BaseModel):
valuebytes8: HashBytes8
valuebytes16: HashBytes16
valuebytes20: HashBytes20
valuebytes32: HashBytes32
valuebytes64: HashBytes64
valuestr8: HashStr8
Expand All @@ -29,6 +31,7 @@ def from_single(cls, value):
return cls(
valuebytes8=value,
valuebytes16=value,
valuebytes20=value,
valuebytes32=value,
valuebytes64=value,
valuestr8=value,
Expand All @@ -54,6 +57,7 @@ def test_hash(value):
model = Model.from_single(value)
assert len(model.valuebytes8) == 8
assert len(model.valuebytes16) == 16
assert len(model.valuebytes20) == 20
assert len(model.valuebytes32) == 32
assert len(model.valuebytes64) == 64
assert len(model.valuestr8) == 18
Expand All @@ -76,6 +80,19 @@ def test_invalid_hash(value):
Model.from_single(value)


def test_hash_removes_leading_zeroes_if_needed():
address = "0x000000000000000000000000cafac3dd18ac6c6e92c921884f9e4176737c052c"

class MyModel(BaseModel):
my_address: HashBytes20

# Test both str and bytes for input.
for addr in (address, HexBytes(address)):
model = MyModel(my_address=addr)
assert len(model.my_address) == 20
assert model.my_address == HexBytes("0xcafac3dd18ac6c6e92c921884f9e4176737c052c")


def test_schema():
actual = Model.model_json_schema()
for name, prop in actual["properties"].items():
Expand Down Expand Up @@ -108,6 +125,7 @@ def test_model_dump(bytes32str):
"valuebytes8": "0x0000000000000005",
"valuestr8": "0x0000000000000005",
"valuebytes16": "0x00000000000000000000000000000005",
"valuebytes20": "0x0000000000000000000000000000000000000005",
"valuestr16": "0x00000000000000000000000000000005",
"valuebytes32": "0x0000000000000000000000000000000000000000000000000000000000000005",
"valuestr32": "0x0000000000000000000000000000000000000000000000000000000000000005",
Expand Down
14 changes: 14 additions & 0 deletions tests/test_hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from hexbytes import HexBytes as BaseHexBytes
from pydantic import BaseModel, ValidationError

from eth_pydantic_types import HashStr20
from eth_pydantic_types.hex import HexBytes, HexStr


Expand Down Expand Up @@ -115,3 +116,16 @@ def test_from_bytes():
value = b"\xb7\xfc\xef\x7f\xe7E\xf2\xa9U`\xff_U\x0e;\x8f"
actual = HexStr.from_bytes(value)
assert actual.startswith("0x")


def test_hex_removes_leading_zeroes_if_needed():
address = "0x000000000000000000000000cafac3dd18ac6c6e92c921884f9e4176737c052c"

class MyModel(BaseModel):
my_address: HashStr20

# Test both str and bytes for input.
for addr in (address, HexBytes(address)):
model = MyModel(my_address=addr)
assert len(model.my_address) == 42
assert model.my_address == "0xcafac3dd18ac6c6e92c921884f9e4176737c052c"

0 comments on commit 1a66905

Please sign in to comment.