Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load includes into json + rountrip test with content check + fix mappers #15

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/erc7730/common/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import json
from pathlib import Path
from typing import Any


def read_jsons_with_includes(paths: list[Path]) -> Any:
"""
Read a list JSON files, recursively inlining any included files.

Keys from the calling file override those of the included file (the included file defines defaults).
Keys from the later files in the list override those of the first files.

Note:
- circular includes are not detected and will result in a stack overflow.
- "includes" key can only be used at root level of an object.
"""
result: dict[str, Any] = {}
for path in paths:
# noinspection PyTypeChecker
result = _merge_dicts(result, read_json_with_includes(path))
return result


def read_json_with_includes(path: Path) -> Any:
"""
Read a JSON file, recursively inlining any included files.

Keys from the calling file override those of the included file (the included file defines defaults).

If include is a list, files are included in other they are defined, with later files overriding previous files.

Note:
- circular includes are not detected and will result in a stack overflow.
- "includes" key can only be used at root level of an object.
"""
result: Any
with open(path) as f:
result = json.load(f)
if isinstance(result, dict) and (includes := result.pop("includes", None)) is not None:
if isinstance(includes, list):
parent = read_jsons_with_includes(paths=[path.parent / p for p in includes])
else:
# noinspection PyTypeChecker
parent = read_json_with_includes(path.parent / includes)
result = _merge_dicts(parent, result)
return result


def _merge_dicts(d1: dict[str, Any], d2: dict[str, Any]) -> dict[str, Any]:
"""
Merge d1 and d2, with priority to d2.

Recursively called when dicts are encountered.

This function assumes that if the same field is in both dicts, then the types must be the same.
"""
merged = {}
for key, val1 in d1.items():
if (val2 := d2.get(key)) is not None:
if isinstance(val1, dict):
merged[key] = _merge_dicts(d1=val1, d2=val2)
else:
merged[key] = val2
else:
merged[key] = val1
return {**d2, **merged}
13 changes: 7 additions & 6 deletions src/erc7730/common/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@

from pydantic import BaseModel

from erc7730.common.json import read_json_with_includes

M = TypeVar("M", bound=BaseModel)


def model_from_json_file(path: Path, model: type[M]) -> M:
def model_from_json_file_with_includes(path: Path, model: type[M]) -> M:
"""
Load a Pydantic model from a JSON file.
Load a Pydantic model from a JSON file., including includes references
"""
with open(path) as f:
return model.model_validate_json(f.read(), strict=True)
return model.model_validate(read_json_with_includes(path), strict=False)


def model_from_json_file_or_none(path: Path, model: type[M]) -> M | None:
def model_from_json_file_with_includes_or_none(path: Path, model: type[M]) -> M | None:
"""
Load a Pydantic model from a JSON file, or None if file does not exist.
"""
return model_from_json_file(path, model) if os.path.isfile(path) else None
return model_from_json_file_with_includes(path, model) if os.path.isfile(path) else None


def json_file_from_model(model: type[M], obj: M) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/erc7730/linter/lint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path

from erc7730 import ERC_7730_REGISTRY_CALLDATA_PREFIX, ERC_7730_REGISTRY_EIP712_PREFIX
from erc7730.common.pydantic import model_from_json_file
from erc7730.common.pydantic import model_from_json_file_with_includes
from erc7730.linter import ERC7730Linter
from erc7730.linter.linter_base import MultiLinter
from erc7730.linter.linter_transaction_type_classifier_ai import ClassifyTransactionTypeLinter
Expand Down Expand Up @@ -46,7 +46,7 @@ def adder(output: ERC7730Linter.Output) -> None:
out(output.model_copy(update={"file": path}))

try:
descriptor = model_from_json_file(path, ERC7730Descriptor)
descriptor = model_from_json_file_with_includes(path, ERC7730Descriptor)
descriptor = resolve_external_references(descriptor)
linter.lint(descriptor, adder)
except Exception as e:
Expand Down
220 changes: 119 additions & 101 deletions src/erc7730/mapper/mapper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""from pydantic import AnyUrl
from pydantic import AnyUrl
from erc7730.common.pydantic import model_from_json_bytes
from erc7730.model.context import EIP712JsonSchema, EIP712Context, EIP712, Domain, EIP712Domain
from erc7730.model.erc7730_descriptor import ERC7730Descriptor, EIP712Context
from erc7730.model.context import EIP712JsonSchema, EIP712Context, EIP712, Domain, NameType
from erc7730.model.erc7730_descriptor import ERC7730Descriptor
from erc7730.model.display import (
CallDataParameters,
Display,
FieldDescription,
NestedFields,
NftNameParameters,
Reference,
FieldFormat,
Field,
Format,
DateParameters,
DateEncoding,
TokenAmountParameters,
)
from eip712 import (
Expand All @@ -30,93 +32,108 @@ def to_eip712_mapper(erc7730: ERC7730Descriptor) -> EIP712DAppDescriptor | list[
context = erc7730.context
if context is not None and isinstance(context, EIP712Context):
domain = context.eip712.domain
if domain is None:
exceptions.append(Exception(f"no domain defined for {context.eip712}")) # type: ignore
else:
if domain.chainId is None:
exceptions.append(Exception(f"chain id is None for {domain}")) # type: ignore
elif domain.verifyingContract is None:
exceptions.append(Exception(f"verifying contract is None for {domain}")) # type: ignore
chain_id = None
contract_address = None
name = ""
if domain is not None and domain.name is not None:
name = domain.name
if domain is not None and domain.chainId is not None:
chain_id = domain.chainId
contract_address = domain.verifyingContract
if chain_id is None:
if context.eip712.deployments is not None and context.eip712.deployments.root.__len__() > 0:
chain_id = context.eip712.deployments.root[0].chainId
else:
schema = dict[str, str]()
schs = context.eip712.schemas
if schs is not None:
for item in schs:
sch = None
if isinstance(item, EIP712JsonSchema):
sch = item
else:
try:
response = requests.get(item.__str__())
sch = model_from_json_bytes(response.content, model=EIP712JsonSchema)
except Exception as e:
exceptions.append(e) # type: ignore
if sch is not None:
for key in sch.types:
for d in sch.types[key]:
schema[key + "." + d.name] = d.type
chain_id = domain.chainId
contract_address = domain.verifyingContract
name = ""
if domain.name is not None:
name = domain.name
display = erc7730.display
contracts = list[EIP712ContractDescriptor]()
if display is not None:
for primaryType in display.formats:
format = display.formats[primaryType]
messages = list[EIP712MessageDescriptor]()
if format.fields is not None:
eip712Fields = parseFields(display, primaryType, list[EIP712Field](), format.fields)
messages.append(
EIP712MessageDescriptor(
schema=schema, mapper=EIP712Mapper(label=primaryType, fields=eip712Fields)
)
)
contracts.append(
EIP712ContractDescriptor(address=contract_address, contractName=name, messages=messages)
)
return EIP712DAppDescriptor(blockchainName="ethereum", chainId=chain_id, name=name, contracts=contracts)
exceptions.append(Exception(f"chain id is None for {domain}"))
if contract_address is None:
if domain is None or domain.verifyingContract is None:
if context.eip712.deployments is not None and context.eip712.deployments.root.__len__() > 0:
contract_address = context.eip712.deployments.root[0].address
else:
exceptions.append(Exception(f"verifying contract is None for {domain}"))
schema = dict[str, str]()
schs = context.eip712.schemas
if schs is not None:
for item in schs:
sch = None
if isinstance(item, EIP712JsonSchema):
sch = item
else:
try:
response = requests.get(item.__str__())
sch = model_from_json_bytes(response.content, model=EIP712JsonSchema)
except Exception as e:
exceptions.append(e)
if sch is not None:
for key in sch.types:
for d in sch.types[key]:
schema[sch.primaryType + "." + key] = d.type
display = erc7730.display
contracts = list[EIP712ContractDescriptor]()
if display is not None:
for key in display.formats:
format = display.formats[key]
messages = list[EIP712MessageDescriptor]()
eip712Fields = list[EIP712Field]()
if format.fields is not None:
for field in format.fields:
eip712Fields.extend(parseField(display, field))
messages.append(
EIP712MessageDescriptor(schema=schema, mapper=EIP712Mapper(label=key, fields=eip712Fields))
)
if contract_address is not None:
contracts.append(
EIP712ContractDescriptor(address=contract_address, contractName=name, messages=messages)
)
if chain_id is not None:
return EIP712DAppDescriptor(blockchainName="ethereum", chainId=chain_id, name=name, contracts=contracts)
else:
exceptions.append(Exception(f"no chain id for {erc7730}"))
else:
exceptions.append(Exception(f"context for {erc7730} is None or is not EIP712")) # type: ignore
exceptions.append(Exception(f"context for {erc7730} is None or is not EIP712"))
return exceptions


def parseFields(display: Display, primaryType: str, output: list[EIP712Field], fields: Field) -> list[EIP712Field]:
for _, field in fields:
if isinstance(field, Reference):
# get field from definition section
if display.definitions is not None:
for _, f in display.definitions[field.ref]:
parseField(primaryType, output, f)
elif isinstance(field, StructFormats):
parseFields(display, primaryType, output, field.fields)
elif isinstance(field, Field):
parseField(primaryType, output, field)
def parseField(display: Display, field: Field) -> list[EIP712Field]:
output = list[EIP712Field]()
fieldRoot = field.root
if isinstance(fieldRoot, Reference):
# get field from definition section
if display.definitions is not None:
f = display.definitions[fieldRoot.ref]
output.append(parseFieldDescription(f))
elif isinstance(fieldRoot, NestedFields):
for f in fieldRoot.fields: # type: ignore
output.extend(parseField(display, field=f)) # type: ignore
else:
output.append(parseFieldDescription(fieldRoot))
return output


def parseField(primaryType: str, output: list[EIP712Field], field: Field) -> list[EIP712Field]:
def parseFieldDescription(field: FieldDescription) -> EIP712Field:
name = field.label
assetPath = None
fieldFormat = None
match field.format:
case FieldFormat.NFT_NAME:
assetPath = field.collectionPath
case FieldFormat.TOKEN_NAME:
if field.tokenAmountParameters is not None:
assetPath = field.tokenAmountParameters.tokenPath
if field.params is not None and isinstance(field.params, NftNameParameters):
assetPath = field.params.collectionPath
case FieldFormat.TOKEN_AMOUNT:
if field.params is not None and isinstance(field.params, TokenAmountParameters):
assetPath = field.params.tokenPath
fieldFormat = EIP712Format.AMOUNT
case FieldFormat.ALLOWANCE_AMOUNT:
if field.allowanceAmountParameters is not None:
assetPath = field.allowanceAmountParameters.tokenPath
case FieldFormat.CALL_DATA:
if field.params is not None and isinstance(field.params, CallDataParameters):
assetPath = field.params.calleePath
case FieldFormat.AMOUNT:
fieldFormat = EIP712Format.AMOUNT
case FieldFormat.DATE:
fieldFormat = EIP712Format.DATETIME
case FieldFormat.RAW:
fieldFormat = EIP712Format.RAW
case _:
pass
output.append(EIP712Field(path=primaryType, label=name, assetPath=assetPath, format=fieldFormat, coinRef=None))
return output
return EIP712Field(path=field.path, label=name, assetPath=assetPath, format=fieldFormat, coinRef=None)


def to_erc7730_mapper(eip712DappDescriptor: EIP712DAppDescriptor) -> ERC7730Descriptor:
Expand All @@ -132,42 +149,43 @@ def to_erc7730_mapper(eip712DappDescriptor: EIP712DAppDescriptor) -> ERC7730Desc
formats = dict[str, Format]()
schemas = list[EIP712JsonSchema | AnyUrl]()
for contract in eip712DappDescriptor.contracts:
types = dict[str, list[EIP712Domain]]()
types = dict[str, list[NameType]]()
for message in contract.messages:
mapper = message.mapper
fields = dict[str, Reference | Field | StructFormats]()
eip712Domains = list[EIP712Domain]()
namesTypes = list[NameType]()
fields = list[Field]()
for key in message.schema_:
namesTypes.append(NameType(name=key, type=message.schema_[key]))
for item in mapper.fields:
dateParameters = None
tokenAmountParameters = None
if item.format is not None:
match item.format:
case EIP712Format.AMOUNT:
if item.assetPath is not None:
tokenAmountParameters = TokenAmountParameters(tokenPath=item.assetPath)
case EIP712Format.DATETIME:
dateParameters = DateParameters(encoding=DateEncoding.TIMESTAMP)
case _:
pass
fields[item.label] = Field(
sources=None,
collectionPath=None,
tokenAmountParameters=tokenAmountParameters,
allowanceAmountParameters=None,
percentageParameters=None,
dateParameters=dateParameters,
enumParameters=None,
)
eip712Domains.append(EIP712Domain(name=item.label, type=item.format.name))
types[mapper.label] = eip712Domains
label = item.label
path = item.path
match item.format:
case EIP712Format.AMOUNT:
if item.assetPath is not None:
params = TokenAmountParameters(tokenPath=item.assetPath)
field = FieldDescription(
label=label, format=FieldFormat.TOKEN_AMOUNT, params=params, path=path
)
else:
field = FieldDescription(label=label, format=FieldFormat.AMOUNT, params=None, path=path)

case EIP712Format.DATETIME:
field = FieldDescription(label=label, format=FieldFormat.DATE, params=None, path=path)
case _:
field = FieldDescription(label=label, format=FieldFormat.RAW, params=None, path=path)
fields.append(Field(root=field))
formats[mapper.label] = Format(
id=None, intent=None, fields=Field(root=fields), required=None, screens=None
id=None,
intent=None,
fields=fields,
required=None,
screens=None, # type: ignore
)
types[mapper.label] = namesTypes
schemas.append(EIP712JsonSchema(primaryType=mapper.label, types=types))

eip712 = EIP712(domain=domain, schemas=schemas)
context = EIP712Context(eip712=eip712)
display = Display(definitions=None, formats=formats)
metadata = Metadata(owner=None, info=None, token=None, constants=None, enums=None)
return ERC7730Descriptor(context=context, includes=None, metadata=metadata, display=display)
"""
return ERC7730Descriptor(context=context, metadata=metadata, display=display)
Loading
Loading