Skip to content

Commit

Permalink
test!: test roundtrip serialisation against strict + lax schema (#982)
Browse files Browse the repository at this point in the history
We also include some ops in testing, and apply required fixes.

Note that we intend to replace this `rstest::case` style testing with
`proptest`.

BREAKING CHANGE: serialisation schema
  • Loading branch information
doug-q authored May 2, 2024
1 parent 81e9602 commit 954b2cb
Show file tree
Hide file tree
Showing 12 changed files with 5,404 additions and 193 deletions.
41 changes: 27 additions & 14 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from typing import Any, Literal, cast

from pydantic import BaseModel, Field, RootModel
from pydantic import Field, RootModel

from . import tys
from .tys import (
Expand All @@ -15,12 +15,15 @@
TypeRow,
SumType,
TypeBound,
ConfiguredBaseModel,
classes as tys_classes,
model_rebuild as tys_model_rebuild,
)

NodeID = int


class BaseOp(ABC, BaseModel):
class BaseOp(ABC, ConfiguredBaseModel):
"""Base class for ops that store their node's input/output types"""

# Parent node index of node the op belongs to, used only at serialization time
Expand Down Expand Up @@ -84,7 +87,7 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
CustomConst = Any # TODO


class ExtensionValue(BaseModel):
class ExtensionValue(ConfiguredBaseModel):
"""An extension constant value, that can check it is of a given [CustomType]."""

c: Literal["Extension"] = Field("Extension", title="ValueTag")
Expand All @@ -96,7 +99,7 @@ class Config:
}


class FunctionValue(BaseModel):
class FunctionValue(ConfiguredBaseModel):
"""A higher-order function value."""

c: Literal["Function"] = Field("Function", title="ValueTag")
Expand All @@ -108,7 +111,7 @@ class Config:
}


class TupleValue(BaseModel):
class TupleValue(ConfiguredBaseModel):
"""A constant tuple value."""

c: Literal["Tuple"] = Field("Tuple", title="ValueTag")
Expand All @@ -120,7 +123,7 @@ class Config:
}


class SumValue(BaseModel):
class SumValue(ConfiguredBaseModel):
"""A Sum variant
For any Sum type where this value meets the type of the variant indicated by the tag
Expand Down Expand Up @@ -263,10 +266,11 @@ class Call(DataflowOp):
"""

op: Literal["Call"] = "Call"
signature: FunctionType = Field(default_factory=FunctionType.empty)
func_sig: PolyFuncType = Field(default_factory=FunctionType.empty)
type_args: list[tys.TypeArg] = Field(default_factory=list)
instantiation: FunctionType = Field(default_factory=FunctionType.empty)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
# The constE edge comes after the value inputs
fun_ty = in_types[-1]
assert isinstance(fun_ty, PolyFuncType)
poly_func = cast(PolyFuncType, fun_ty)
Expand Down Expand Up @@ -495,6 +499,12 @@ class AliasDecl(BaseOp):
bound: TypeBound


class AliasDefn(BaseOp):
op: Literal["AliasDefn"] = "AliasDefn"
name: str
definition: Type


class OpType(RootModel):
"""A constant operation."""

Expand Down Expand Up @@ -523,6 +533,7 @@ class OpType(RootModel):
| Lift
| DFG
| AliasDecl
| AliasDefn
) = Field(discriminator="op")


Expand All @@ -547,10 +558,12 @@ class OpDef(BaseOp, populate_by_name=True):

# Now that all classes are defined, we need to update the ForwardRefs in all type
# annotations. We use some inspect magic to find all classes defined in this file.
classes = inspect.getmembers(
sys.modules[__name__],
lambda member: inspect.isclass(member) and member.__module__ == __name__,
classes = (
inspect.getmembers(
sys.modules[__name__],
lambda member: inspect.isclass(member) and member.__module__ == __name__,
)
+ tys_classes
)
for _, c in classes:
if issubclass(c, BaseModel):
c.model_rebuild()

tys_model_rebuild(dict(classes))
11 changes: 9 additions & 2 deletions hugr-py/src/hugr/serialization/serial_hugr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict

from .ops import NodeID, OpType
from .ops import NodeID, OpType, classes as ops_classes
from .tys import model_rebuild
import hugr

Port = tuple[NodeID, int | None] # (node, offset)
Expand Down Expand Up @@ -34,6 +35,12 @@ def get_version(cls) -> str:
"""Return the version of the schema."""
return cls(nodes=[], edges=[]).version

@classmethod
def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs):
my_classes = dict(ops_classes)
my_classes[cls.__name__] = cls
model_rebuild(my_classes, config=config, **kwargs)

class Config:
title = "Hugr"
json_schema_extra = {
Expand Down
29 changes: 18 additions & 11 deletions hugr-py/src/hugr/serialization/testing_hugr.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
from typing import Literal, Optional
from pydantic import BaseModel
from .tys import Type, SumType, PolyFuncType
from .ops import Value
from pydantic import ConfigDict
from typing import Literal
from .tys import Type, SumType, PolyFuncType, ConfiguredBaseModel, model_rebuild
from .ops import Value, OpType, classes as ops_classes


class TestingHugr(BaseModel):
"""A serializable representation of a Hugr Type, SumType, PolyFuncType, or
Value. Intended for testing only."""
class TestingHugr(ConfiguredBaseModel):
"""A serializable representation of a Hugr Type, SumType, PolyFuncType,
Value, OpType. Intended for testing only."""

version: Literal["v1"] = "v1"
typ: Optional[Type] = None
sum_type: Optional[SumType] = None
poly_func_type: Optional[PolyFuncType] = None
value: Optional[Value] = None
typ: Type | None = None
sum_type: SumType | None = None
poly_func_type: PolyFuncType | None = None
value: Value | None = None
optype: OpType | None = None

@classmethod
def get_version(cls) -> str:
"""Return the version of the schema."""
return cls().version

@classmethod
def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs):
my_classes = dict(ops_classes)
my_classes[cls.__name__] = cls
model_rebuild(my_classes, config=config, **kwargs)

class Config:
title = "HugrTesting"
Loading

0 comments on commit 954b2cb

Please sign in to comment.