Skip to content

Commit

Permalink
feat: to/from json for extension/package (#1575)
Browse files Browse the repository at this point in the history
Closes #1523
  • Loading branch information
ss2165 authored Oct 11, 2024
1 parent 051de71 commit f8bf61a
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
12 changes: 11 additions & 1 deletion hugr-py/src/hugr/_serialization/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pydantic as pd
from pydantic_extra_types.semantic_version import SemanticVersion # noqa: TCH002

from hugr.hugr.base import Hugr
from hugr.utils import deser_it

from .ops import Value
Expand Down Expand Up @@ -156,5 +157,14 @@ class Package(ConfiguredBaseModel):
def get_version(cls) -> str:
return serialization_version()

def deserialize(self) -> package.Package:
return package.Package(
modules=[Hugr._from_serial(m) for m in self.modules],
extensions=[e.deserialize() for e in self.extensions],
)


from hugr import ext # noqa: E402
from hugr import ( # noqa: E402
ext,
package,
)
16 changes: 16 additions & 0 deletions hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,22 @@ def _to_serial(self) -> ext_s.Extension:
operations={k: v._to_serial() for k, v in self.operations.items()},
)

def to_json(self) -> str:
"""Serialize the extension to a JSON string."""
return self._to_serial().model_dump_json()

@classmethod
def from_json(cls, json_str: str) -> Extension:
"""Deserialize a JSON string to a Extension object.
Args:
json_str: The JSON string representing a Extension.
Returns:
The deserialized Extension object.
"""
return ext_s.Extension.model_validate_json(json_str).deserialize()

def add_op_def(self, op_def: OpDef) -> OpDef:
"""Add an operation definition to the extension.
Expand Down
12 changes: 12 additions & 0 deletions hugr-py/src/hugr/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ def _to_serial(self) -> ext_s.Package:
def to_json(self) -> str:
return self._to_serial().model_dump_json()

@classmethod
def from_json(cls, json_str: str) -> Package:
"""Deserialize a JSON string to a Package object.
Args:
json_str: The JSON string representing a Package.
Returns:
The deserialized Package object.
"""
return ext_s.Package.model_validate_json(json_str).deserialize()


@dataclass(frozen=True)
class PackagePointer:
Expand Down
10 changes: 9 additions & 1 deletion hugr-py/tests/serialization/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TypeDef,
TypeDefBound,
)
from hugr._serialization.ops import Module, OpType
from hugr._serialization.serial_hugr import SerialHugr, serialization_version
from hugr._serialization.tys import (
FunctionType,
Expand Down Expand Up @@ -109,6 +110,8 @@ def test_extension():
dumped_json = ext.model_dump_json()

assert Extension.model_validate_json(dumped_json) == ext
hugr_ext = ext.deserialize()
assert hugr_ext.from_json(hugr_ext.to_json()) == hugr_ext


def test_package():
Expand All @@ -123,9 +126,14 @@ def test_package():
operations={},
)
ext_load = Extension.model_validate_json(EXAMPLE)

package = Package(
extensions=[ext, ext_load], modules=[SerialHugr(nodes=[], edges=[])]
extensions=[ext, ext_load],
modules=[SerialHugr(nodes=[OpType(root=Module(parent=0))], edges=[])],
)

package_load = Package.model_validate_json(package.model_dump_json())
assert package == package_load

hugr_package = package.deserialize()
assert hugr_package.from_json(hugr_package.to_json()) == hugr_package

0 comments on commit f8bf61a

Please sign in to comment.