diff --git a/src/erc7730/common/json.py b/src/erc7730/common/json.py new file mode 100644 index 0000000..24a893d --- /dev/null +++ b/src/erc7730/common/json.py @@ -0,0 +1,66 @@ +import json +from pathlib import Path +from typing import Any + + +def read_jsons_with_includes(paths: list[Path]) -> Any: + """ + Read a list JSON files, recursively inlining any included files. + + Keys from the calling file override those of the included file (the included file defines defaults). + Keys from the later files in the list override those of the first files. + + Note: + - circular includes are not detected and will result in a stack overflow. + - "includes" key can only be used at root level of an object. + """ + result: dict[str, Any] = {} + for path in paths: + # noinspection PyTypeChecker + result = _merge_dicts(result, read_json_with_includes(path)) + return result + + +def read_json_with_includes(path: Path) -> Any: + """ + Read a JSON file, recursively inlining any included files. + + Keys from the calling file override those of the included file (the included file defines defaults). + + If include is a list, files are included in other they are defined, with later files overriding previous files. + + Note: + - circular includes are not detected and will result in a stack overflow. + - "includes" key can only be used at root level of an object. + """ + result: Any + with open(path) as f: + result = json.load(f) + if isinstance(result, dict) and (includes := result.pop("includes", None)) is not None: + if isinstance(includes, list): + parent = read_jsons_with_includes(paths=[path.parent / p for p in includes]) + else: + # noinspection PyTypeChecker + parent = read_json_with_includes(path.parent / includes) + result = _merge_dicts(parent, result) + return result + + +def _merge_dicts(d1: dict[str, Any], d2: dict[str, Any]) -> dict[str, Any]: + """ + Merge d1 and d2, with priority to d2. + + Recursively called when dicts are encountered. + + This function assumes that if the same field is in both dicts, then the types must be the same. + """ + merged = {} + for key, val1 in d1.items(): + if (val2 := d2.get(key)) is not None: + if isinstance(val1, dict): + merged[key] = _merge_dicts(d1=val1, d2=val2) + else: + merged[key] = val2 + else: + merged[key] = val1 + return {**d2, **merged} diff --git a/src/erc7730/common/pydantic.py b/src/erc7730/common/pydantic.py index b792dca..cc6d6ab 100644 --- a/src/erc7730/common/pydantic.py +++ b/src/erc7730/common/pydantic.py @@ -4,22 +4,23 @@ from pydantic import BaseModel +from erc7730.common.json import read_json_with_includes + M = TypeVar("M", bound=BaseModel) -def model_from_json_file(path: Path, model: type[M]) -> M: +def model_from_json_file_with_includes(path: Path, model: type[M]) -> M: """ - Load a Pydantic model from a JSON file. + Load a Pydantic model from a JSON file., including includes references """ - with open(path) as f: - return model.model_validate_json(f.read(), strict=True) + return model.model_validate(read_json_with_includes(path), strict=False) -def model_from_json_file_or_none(path: Path, model: type[M]) -> M | None: +def model_from_json_file_with_includes_or_none(path: Path, model: type[M]) -> M | None: """ Load a Pydantic model from a JSON file, or None if file does not exist. """ - return model_from_json_file(path, model) if os.path.isfile(path) else None + return model_from_json_file_with_includes(path, model) if os.path.isfile(path) else None def json_file_from_model(model: type[M], obj: M) -> str: diff --git a/src/erc7730/linter/lint.py b/src/erc7730/linter/lint.py index 2851a47..ac737a7 100644 --- a/src/erc7730/linter/lint.py +++ b/src/erc7730/linter/lint.py @@ -1,7 +1,7 @@ from pathlib import Path from erc7730 import ERC_7730_REGISTRY_CALLDATA_PREFIX, ERC_7730_REGISTRY_EIP712_PREFIX -from erc7730.common.pydantic import model_from_json_file +from erc7730.common.pydantic import model_from_json_file_with_includes from erc7730.linter import ERC7730Linter from erc7730.linter.linter_base import MultiLinter from erc7730.linter.linter_transaction_type_classifier_ai import ClassifyTransactionTypeLinter @@ -46,7 +46,7 @@ def adder(output: ERC7730Linter.Output) -> None: out(output.model_copy(update={"file": path})) try: - descriptor = model_from_json_file(path, ERC7730Descriptor) + descriptor = model_from_json_file_with_includes(path, ERC7730Descriptor) descriptor = resolve_external_references(descriptor) linter.lint(descriptor, adder) except Exception as e: diff --git a/src/erc7730/mapper/mapper.py b/src/erc7730/mapper/mapper.py index 245764d..192ab47 100644 --- a/src/erc7730/mapper/mapper.py +++ b/src/erc7730/mapper/mapper.py @@ -1,15 +1,17 @@ -"""from pydantic import AnyUrl +from pydantic import AnyUrl from erc7730.common.pydantic import model_from_json_bytes -from erc7730.model.context import EIP712JsonSchema, EIP712Context, EIP712, Domain, EIP712Domain -from erc7730.model.erc7730_descriptor import ERC7730Descriptor, EIP712Context +from erc7730.model.context import EIP712JsonSchema, EIP712Context, EIP712, Domain, NameType +from erc7730.model.erc7730_descriptor import ERC7730Descriptor from erc7730.model.display import ( + CallDataParameters, Display, + FieldDescription, + NestedFields, + NftNameParameters, Reference, FieldFormat, Field, Format, - DateParameters, - DateEncoding, TokenAmountParameters, ) from eip712 import ( @@ -30,93 +32,108 @@ def to_eip712_mapper(erc7730: ERC7730Descriptor) -> EIP712DAppDescriptor | list[ context = erc7730.context if context is not None and isinstance(context, EIP712Context): domain = context.eip712.domain - if domain is None: - exceptions.append(Exception(f"no domain defined for {context.eip712}")) # type: ignore - else: - if domain.chainId is None: - exceptions.append(Exception(f"chain id is None for {domain}")) # type: ignore - elif domain.verifyingContract is None: - exceptions.append(Exception(f"verifying contract is None for {domain}")) # type: ignore + chain_id = None + contract_address = None + name = "" + if domain is not None and domain.name is not None: + name = domain.name + if domain is not None and domain.chainId is not None: + chain_id = domain.chainId + contract_address = domain.verifyingContract + if chain_id is None: + if context.eip712.deployments is not None and context.eip712.deployments.root.__len__() > 0: + chain_id = context.eip712.deployments.root[0].chainId else: - schema = dict[str, str]() - schs = context.eip712.schemas - if schs is not None: - for item in schs: - sch = None - if isinstance(item, EIP712JsonSchema): - sch = item - else: - try: - response = requests.get(item.__str__()) - sch = model_from_json_bytes(response.content, model=EIP712JsonSchema) - except Exception as e: - exceptions.append(e) # type: ignore - if sch is not None: - for key in sch.types: - for d in sch.types[key]: - schema[key + "." + d.name] = d.type - chain_id = domain.chainId - contract_address = domain.verifyingContract - name = "" - if domain.name is not None: - name = domain.name - display = erc7730.display - contracts = list[EIP712ContractDescriptor]() - if display is not None: - for primaryType in display.formats: - format = display.formats[primaryType] - messages = list[EIP712MessageDescriptor]() - if format.fields is not None: - eip712Fields = parseFields(display, primaryType, list[EIP712Field](), format.fields) - messages.append( - EIP712MessageDescriptor( - schema=schema, mapper=EIP712Mapper(label=primaryType, fields=eip712Fields) - ) - ) - contracts.append( - EIP712ContractDescriptor(address=contract_address, contractName=name, messages=messages) - ) - return EIP712DAppDescriptor(blockchainName="ethereum", chainId=chain_id, name=name, contracts=contracts) + exceptions.append(Exception(f"chain id is None for {domain}")) + if contract_address is None: + if domain is None or domain.verifyingContract is None: + if context.eip712.deployments is not None and context.eip712.deployments.root.__len__() > 0: + contract_address = context.eip712.deployments.root[0].address + else: + exceptions.append(Exception(f"verifying contract is None for {domain}")) + schema = dict[str, str]() + schs = context.eip712.schemas + if schs is not None: + for item in schs: + sch = None + if isinstance(item, EIP712JsonSchema): + sch = item + else: + try: + response = requests.get(item.__str__()) + sch = model_from_json_bytes(response.content, model=EIP712JsonSchema) + except Exception as e: + exceptions.append(e) + if sch is not None: + for key in sch.types: + for d in sch.types[key]: + schema[sch.primaryType + "." + key] = d.type + display = erc7730.display + contracts = list[EIP712ContractDescriptor]() + if display is not None: + for key in display.formats: + format = display.formats[key] + messages = list[EIP712MessageDescriptor]() + eip712Fields = list[EIP712Field]() + if format.fields is not None: + for field in format.fields: + eip712Fields.extend(parseField(display, field)) + messages.append( + EIP712MessageDescriptor(schema=schema, mapper=EIP712Mapper(label=key, fields=eip712Fields)) + ) + if contract_address is not None: + contracts.append( + EIP712ContractDescriptor(address=contract_address, contractName=name, messages=messages) + ) + if chain_id is not None: + return EIP712DAppDescriptor(blockchainName="ethereum", chainId=chain_id, name=name, contracts=contracts) + else: + exceptions.append(Exception(f"no chain id for {erc7730}")) else: - exceptions.append(Exception(f"context for {erc7730} is None or is not EIP712")) # type: ignore + exceptions.append(Exception(f"context for {erc7730} is None or is not EIP712")) return exceptions -def parseFields(display: Display, primaryType: str, output: list[EIP712Field], fields: Field) -> list[EIP712Field]: - for _, field in fields: - if isinstance(field, Reference): - # get field from definition section - if display.definitions is not None: - for _, f in display.definitions[field.ref]: - parseField(primaryType, output, f) - elif isinstance(field, StructFormats): - parseFields(display, primaryType, output, field.fields) - elif isinstance(field, Field): - parseField(primaryType, output, field) +def parseField(display: Display, field: Field) -> list[EIP712Field]: + output = list[EIP712Field]() + fieldRoot = field.root + if isinstance(fieldRoot, Reference): + # get field from definition section + if display.definitions is not None: + f = display.definitions[fieldRoot.ref] + output.append(parseFieldDescription(f)) + elif isinstance(fieldRoot, NestedFields): + for f in fieldRoot.fields: # type: ignore + output.extend(parseField(display, field=f)) # type: ignore + else: + output.append(parseFieldDescription(fieldRoot)) return output -def parseField(primaryType: str, output: list[EIP712Field], field: Field) -> list[EIP712Field]: +def parseFieldDescription(field: FieldDescription) -> EIP712Field: name = field.label assetPath = None fieldFormat = None match field.format: case FieldFormat.NFT_NAME: - assetPath = field.collectionPath - case FieldFormat.TOKEN_NAME: - if field.tokenAmountParameters is not None: - assetPath = field.tokenAmountParameters.tokenPath + if field.params is not None and isinstance(field.params, NftNameParameters): + assetPath = field.params.collectionPath + case FieldFormat.TOKEN_AMOUNT: + if field.params is not None and isinstance(field.params, TokenAmountParameters): + assetPath = field.params.tokenPath fieldFormat = EIP712Format.AMOUNT - case FieldFormat.ALLOWANCE_AMOUNT: - if field.allowanceAmountParameters is not None: - assetPath = field.allowanceAmountParameters.tokenPath + case FieldFormat.CALL_DATA: + if field.params is not None and isinstance(field.params, CallDataParameters): + assetPath = field.params.calleePath + case FieldFormat.AMOUNT: fieldFormat = EIP712Format.AMOUNT case FieldFormat.DATE: fieldFormat = EIP712Format.DATETIME + case FieldFormat.RAW: + fieldFormat = EIP712Format.RAW case _: pass - output.append(EIP712Field(path=primaryType, label=name, assetPath=assetPath, format=fieldFormat, coinRef=None)) - return output + return EIP712Field(path=field.path, label=name, assetPath=assetPath, format=fieldFormat, coinRef=None) def to_erc7730_mapper(eip712DappDescriptor: EIP712DAppDescriptor) -> ERC7730Descriptor: @@ -132,42 +149,43 @@ def to_erc7730_mapper(eip712DappDescriptor: EIP712DAppDescriptor) -> ERC7730Desc formats = dict[str, Format]() schemas = list[EIP712JsonSchema | AnyUrl]() for contract in eip712DappDescriptor.contracts: - types = dict[str, list[EIP712Domain]]() + types = dict[str, list[NameType]]() for message in contract.messages: mapper = message.mapper - fields = dict[str, Reference | Field | StructFormats]() - eip712Domains = list[EIP712Domain]() + namesTypes = list[NameType]() + fields = list[Field]() + for key in message.schema_: + namesTypes.append(NameType(name=key, type=message.schema_[key])) for item in mapper.fields: - dateParameters = None - tokenAmountParameters = None - if item.format is not None: - match item.format: - case EIP712Format.AMOUNT: - if item.assetPath is not None: - tokenAmountParameters = TokenAmountParameters(tokenPath=item.assetPath) - case EIP712Format.DATETIME: - dateParameters = DateParameters(encoding=DateEncoding.TIMESTAMP) - case _: - pass - fields[item.label] = Field( - sources=None, - collectionPath=None, - tokenAmountParameters=tokenAmountParameters, - allowanceAmountParameters=None, - percentageParameters=None, - dateParameters=dateParameters, - enumParameters=None, - ) - eip712Domains.append(EIP712Domain(name=item.label, type=item.format.name)) - types[mapper.label] = eip712Domains + label = item.label + path = item.path + match item.format: + case EIP712Format.AMOUNT: + if item.assetPath is not None: + params = TokenAmountParameters(tokenPath=item.assetPath) + field = FieldDescription( + label=label, format=FieldFormat.TOKEN_AMOUNT, params=params, path=path + ) + else: + field = FieldDescription(label=label, format=FieldFormat.AMOUNT, params=None, path=path) + + case EIP712Format.DATETIME: + field = FieldDescription(label=label, format=FieldFormat.DATE, params=None, path=path) + case _: + field = FieldDescription(label=label, format=FieldFormat.RAW, params=None, path=path) + fields.append(Field(root=field)) formats[mapper.label] = Format( - id=None, intent=None, fields=Field(root=fields), required=None, screens=None + id=None, + intent=None, + fields=fields, + required=None, + screens=None, # type: ignore ) + types[mapper.label] = namesTypes schemas.append(EIP712JsonSchema(primaryType=mapper.label, types=types)) eip712 = EIP712(domain=domain, schemas=schemas) context = EIP712Context(eip712=eip712) display = Display(definitions=None, formats=formats) metadata = Metadata(owner=None, info=None, token=None, constants=None, enums=None) - return ERC7730Descriptor(context=context, includes=None, metadata=metadata, display=display) -""" + return ERC7730Descriptor(context=context, metadata=metadata, display=display) diff --git a/src/erc7730/model/abi.py b/src/erc7730/model/abi.py index 19c0502..e579523 100644 --- a/src/erc7730/model/abi.py +++ b/src/erc7730/model/abi.py @@ -35,8 +35,6 @@ class Function(BaseLibraryModel): inputs: list[InputOutput] | None = None outputs: list[InputOutput] | None = None stateMutability: StateMutability | None = None - constant: bool | None = None - payable: bool | None = None gas: int | None = None signature: str | None = None @@ -47,8 +45,6 @@ class Constructor(BaseLibraryModel): inputs: list[InputOutput] | None = None outputs: list[InputOutput] | None = None stateMutability: StateMutability | None = None - constant: bool | None = None - payable: bool | None = None gas: int | None = None signature: str | None = None @@ -59,8 +55,6 @@ class Receive(BaseLibraryModel): inputs: list[InputOutput] | None = None outputs: list[InputOutput] | None = None stateMutability: StateMutability | None = None - constant: bool | None = None - payable: bool | None = None gas: int | None = None signature: str | None = None @@ -71,8 +65,6 @@ class Fallback(BaseLibraryModel): inputs: list[InputOutput] | None = None outputs: list[InputOutput] | None = None stateMutability: StateMutability | None = None - constant: bool | None = None - payable: bool | None = None gas: int | None = None signature: str | None = None diff --git a/src/erc7730/model/context.py b/src/erc7730/model/context.py index 7d2e422..deab452 100644 --- a/src/erc7730/model/context.py +++ b/src/erc7730/model/context.py @@ -1,6 +1,6 @@ from enum import Enum from typing import ForwardRef, Union, Optional -from pydantic import AnyUrl, RootModel, field_validator +from pydantic import AnyUrl, RootModel, field_validator, Field from erc7730.model.base import BaseLibraryModel from erc7730.model.types import ContractAddress, Id @@ -105,8 +105,8 @@ class ContractBinding(BaseLibraryModel): class ContractContext(ContractBinding): - id: Optional[Id] = None + id: Optional[Id] = Field(None, alias="$id") class EIP712Context(EIP712DomainBinding): - id: Optional[Id] = None + id: Optional[Id] = Field(None, alias="$id") diff --git a/src/erc7730/model/display.py b/src/erc7730/model/display.py index ebc4d5e..e84fd08 100644 --- a/src/erc7730/model/display.py +++ b/src/erc7730/model/display.py @@ -1,11 +1,11 @@ from erc7730.model.base import BaseLibraryModel from erc7730.model.types import Id from typing import Annotated, Any, Dict, ForwardRef, Optional, Union -from enum import StrEnum +from enum import Enum from pydantic import Discriminator, RootModel, Field as PydanticField, Tag -class Source(StrEnum): +class Source(str, Enum): WALLET = "wallet" ENS = "ens" CONTRACT = "contract" @@ -13,7 +13,7 @@ class Source(StrEnum): COLLECTION = "collection" -class FieldFormat(StrEnum): +class FieldFormat(str, Enum): RAW = "raw" ADDRESS_NAME = "addressName" CALL_DATA = "calldata" @@ -42,7 +42,7 @@ class TokenAmountParameters(BaseLibraryModel): message: Optional[str] = None -class DateEncoding(StrEnum): +class DateEncoding(str, Enum): BLOCKHEIGHT = "blockheight" TIMESTAMP = "timestamp" @@ -51,7 +51,7 @@ class DateParameters(BaseLibraryModel): encoding: DateEncoding -class AddressNameType(StrEnum): +class AddressNameType(str, Enum): WALLET = "wallet" EOA = "eoa" CONTRACT = "contract" @@ -59,7 +59,7 @@ class AddressNameType(StrEnum): NFT = "nft" -class AddressNameSources(StrEnum): +class AddressNameSources(str, Enum): LOCAL = "local" ENS = "ens" @@ -79,7 +79,7 @@ class NftNameParameters(BaseLibraryModel): class UnitParameters(BaseLibraryModel): - base: int + base: str decimals: Optional[int] = None prefix: Optional[bool] = None @@ -88,19 +88,57 @@ class EnumParameters(BaseLibraryModel): field_ref: str = PydanticField(alias="$ref") +def get_param_discriminator(v: Any) -> str | None: + if isinstance(v, dict): + if v.get("tokenPath") is not None: + return "token_amount" + if v.get("collectionPath") is not None: + return "nft_name" + if v.get("encoding") is not None: + return "date" + if v.get("base") is not None: + return "unit" + if v.get("$ref") is not None: + return "enum" + if v.get("type") is not None or v.get("sources") is not None: + return "address_name" + if v.get("selector") is not None or v.get("calleePath") is not None: + return "call_data" + return None + if getattr(v, "tokenPath", None) is not None: + return "token_amount" + if getattr(v, "encoding", None) is not None: + return "date" + if getattr(v, "collectionPath", None) is not None: + return "nft_name" + if getattr(v, "base", None) is not None: + return "unit" + if getattr(v, "$ref", None) is not None: + return "enum" + if getattr(v, "type", None) is not None: + return "address_name" + if getattr(v, "selector", None) is not None: + return "call_data" + return None + + class FieldDescription(BaseLibraryModel): + path: Optional[str] = None field_id: Optional[Id] = PydanticField(None, alias="$id") label: str format: FieldFormat params: Optional[ - Union[ - AddressNameParameters, - CallDataParameters, - TokenAmountParameters, - NftNameParameters, - DateParameters, - UnitParameters, - EnumParameters, + Annotated[ + Union[ + Annotated[AddressNameParameters, Tag("address_name")], + Annotated[CallDataParameters, Tag("call_data")], + Annotated[TokenAmountParameters, Tag("token_amount")], + Annotated[NftNameParameters, Tag("nft_name")], + Annotated[DateParameters, Tag("date")], + Annotated[UnitParameters, Tag("unit")], + Annotated[EnumParameters, Tag("enum")], + ], + Discriminator(get_param_discriminator), ] ] = None @@ -109,15 +147,22 @@ class NestedFields(FieldsParent): fields: Optional[list[ForwardRef("Field")]] = None # type: ignore -def get_discriminator_value(v: Any) -> str: +def get_discriminator_value(v: Any) -> str | None: if isinstance(v, dict): - if v.get("$ref") is not None: - return "reference" if v.get("label") is not None and v.get("format") is not None: return "field_description" if v.get("fields") is not None: return "nested_fields" - return "" + if v.get("$ref") is not None: + return "reference" + return None + if getattr(v, "label", None) is not None and getattr(v, "format") is not None: + return "field_description" + if getattr(v, "fields", None) is not None: + return "nested_fields" + if getattr(v, "ref", None) is not None: + return "reference" + return None class Field( diff --git a/src/erc7730/model/erc7730_descriptor.py b/src/erc7730/model/erc7730_descriptor.py index 5687be0..3e65208 100644 --- a/src/erc7730/model/erc7730_descriptor.py +++ b/src/erc7730/model/erc7730_descriptor.py @@ -21,6 +21,5 @@ class ERC7730Descriptor(BaseLibraryModel): ) field_schema: Optional[str] = Field(None, alias="$schema") context: Optional[Union[ContractContext, EIP712Context]] = None - includes: Optional[str] = None metadata: Optional[Metadata] = None display: Optional[Display] = None diff --git a/src/erc7730/model/metadata.py b/src/erc7730/model/metadata.py index 0c90c3d..756eccb 100644 --- a/src/erc7730/model/metadata.py +++ b/src/erc7730/model/metadata.py @@ -1,13 +1,12 @@ from datetime import datetime -from pydantic import AnyUrl from erc7730.model.base import BaseLibraryModel from typing import Union, Optional, Dict class OwnerInfo(BaseLibraryModel): - legalName: Optional[str] = None + legalName: str lastUpdate: Optional[datetime] - url: AnyUrl + url: str class TokenInfo(BaseLibraryModel): diff --git a/src/erc7730/model/utils.py b/src/erc7730/model/utils.py index 8530af6..8f15432 100644 --- a/src/erc7730/model/utils.py +++ b/src/erc7730/model/utils.py @@ -19,8 +19,6 @@ def resolve_external_references(descriptor: ERC7730Descriptor) -> ERC7730Descrip def _resolve_external_references_eip712(descriptor: ERC7730Descriptor) -> ERC7730Descriptor: schemas: list[EIP712JsonSchema | AnyUrl] = descriptor.context.eip712.schemas # type:ignore schemas_resolved = [] - if schemas is None: - raise ValueError("Missing EIP-712 message schemas") for schema in schemas: if isinstance(schemas, AnyUrl): resp = requests.get(_adapt_uri(schema)) # type:ignore @@ -42,8 +40,6 @@ def _resolve_external_references_eip712(descriptor: ERC7730Descriptor) -> ERC773 def _resolve_external_references_contract(descriptor: ERC7730Descriptor) -> ERC7730Descriptor: abis: AnyUrl | list[ABI] = descriptor.context.contract.abi # type:ignore - if abis is None: - raise ValueError("Missing contract ABI") if isinstance(abis, AnyUrl): resp = requests.get(_adapt_uri(abis)) # type:ignore resp.raise_for_status() diff --git a/tests/src/erc7730/common/client/__init__.py b/tests/src/erc7730/client/__init__.py similarity index 100% rename from tests/src/erc7730/common/client/__init__.py rename to tests/src/erc7730/client/__init__.py diff --git a/tests/src/erc7730/common/client/test_etherscan_client.py b/tests/src/erc7730/client/test_etherscan_client.py similarity index 100% rename from tests/src/erc7730/common/client/test_etherscan_client.py rename to tests/src/erc7730/client/test_etherscan_client.py diff --git a/tests/src/erc7730/test_datamodel.py b/tests/src/erc7730/common/test_datamodel.py similarity index 67% rename from tests/src/erc7730/test_datamodel.py rename to tests/src/erc7730/common/test_datamodel.py index 9aef8aa..ac7b324 100644 --- a/tests/src/erc7730/test_datamodel.py +++ b/tests/src/erc7730/common/test_datamodel.py @@ -1,5 +1,6 @@ from pathlib import Path -from erc7730.common.pydantic import model_from_json_file_or_none, json_file_from_model +from erc7730.common.json import read_json_with_includes +from erc7730.common.pydantic import model_from_json_file_with_includes_or_none, json_file_from_model from erc7730.model.erc7730_descriptor import ERC7730Descriptor import pytest import glob @@ -13,7 +14,8 @@ @pytest.mark.parametrize("file", files) def test_from_erc7730(file: str) -> None: - model_erc7730 = model_from_json_file_or_none(Path(file), ERC7730Descriptor) + original_dict_with_includes = read_json_with_includes(Path(file)) + model_erc7730 = model_from_json_file_with_includes_or_none(Path(file), ERC7730Descriptor) assert model_erc7730 is not None json_str_from_model = json_file_from_model(ERC7730Descriptor, model_erc7730) json_from_model = json.loads(json_str_from_model) @@ -21,3 +23,4 @@ def test_from_erc7730(file: str) -> None: validate(instance=json_from_model, schema=schema) except exceptions.ValidationError as ex: pytest.fail(f"Invalid schema for serialized data from {file}: {ex}") + assert json_from_model == original_dict_with_includes diff --git a/tests/src/erc7730/mapper/__init__.py b/tests/src/erc7730/mapper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/src/erc7730/mapper/test_mapper.py b/tests/src/erc7730/mapper/test_mapper.py new file mode 100644 index 0000000..b07fbb2 --- /dev/null +++ b/tests/src/erc7730/mapper/test_mapper.py @@ -0,0 +1,23 @@ +from pathlib import Path +from erc7730.common.pydantic import model_from_json_file_with_includes_or_none +from erc7730.model.erc7730_descriptor import ERC7730Descriptor +from erc7730.mapper.mapper import to_eip712_mapper, to_erc7730_mapper +from eip712 import EIP712DAppDescriptor +import pytest +import glob + +inputs = glob.glob("clear-signing-erc7730-registry/registry/*/eip712*.json") + + +@pytest.mark.parametrize("input", inputs) +def test_roundtrip(input: str) -> None: + erc7730Descriptor = model_from_json_file_with_includes_or_none(Path(input), ERC7730Descriptor) + assert erc7730Descriptor is not None + assert isinstance(erc7730Descriptor, ERC7730Descriptor) + eip712DappDescriptor = to_eip712_mapper(erc7730Descriptor) + assert eip712DappDescriptor is not None + assert isinstance(eip712DappDescriptor, EIP712DAppDescriptor) + newErc7730Descriptor = to_erc7730_mapper(eip712DappDescriptor) + assert newErc7730Descriptor is not None + if erc7730Descriptor.context is not None and erc7730Descriptor.context.eip712.domain is not None: # type: ignore + assert newErc7730Descriptor.context.eip712.domain.name == erc7730Descriptor.context.eip712.domain.name # type: ignore diff --git a/tests/src/erc7730/test_mapper.py b/tests/src/erc7730/test_mapper.py deleted file mode 100644 index 9d37554..0000000 --- a/tests/src/erc7730/test_mapper.py +++ /dev/null @@ -1,28 +0,0 @@ -# from pathlib import Path -# from erc7730.common.pydantic import model_from_json_file_or_none -# from erc7730.model.erc7730_descriptor import ERC7730Descriptor -# from erc7730.mapper.mapper import to_eip712_mapper, to_erc7730_mapper -# mport pytest -# import glob - -# files = glob.glob('clear-signing-erc7730-registry/registry/*/eip712*.json') - -"""@pytest.mark.parametrize("file", files) -def roundtrip(file: str) -> None: - erc7730Descriptor = model_from_json_file_or_none(Path(file), ERC7730Descriptor) - assert erc7730Descriptor is not None - eip712DappDescriptor = to_eip712_mapper(erc7730Descriptor) - assert eip712DappDescriptor is not None - newErc7730Descriptor = to_erc7730_mapper(eip712DappDescriptor) - assert newErc7730Descriptor is not None - assert newErc7730Descriptor == erc7730Descriptor - -def test_to_eip712_mapper() -> None: - uniswap_eip712_cs_descriptor = model_from_json_file_or_none(Path("clear-signing-erc7730-registry/registry/uniswap/eip712-permit2.json"), ERC7730Descriptor) - assert uniswap_eip712_cs_descriptor is not None - eip712DappDescriptor = to_eip712_mapper(uniswap_eip712_cs_descriptor) - assert eip712DappDescriptor is not None - assert eip712DappDescriptor.chain_id == 1 - assert eip712DappDescriptor.name == "Permit2" - assert eip712DappDescriptor.contracts.__len__() == 2 - assert eip712DappDescriptor.contracts[0].messages.__len__() == 1"""