From 0c3fd1e3bf541e336529526aa022858d03b11b9c Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 11 Oct 2024 14:27:04 +0100 Subject: [PATCH 1/3] feat!: define wrappers around package that point to internals Closes #1561 So far just top level things, can add inner ones (cfg, blocks, etc.) later. BREAKING CHANGE: `Package` moved to new `hugr.package` module --- hugr-py/src/hugr/ext.py | 26 +----- hugr-py/src/hugr/package.py | 130 ++++++++++++++++++++++++++++++ hugr-py/tests/conftest.py | 3 +- hugr-py/tests/test_cond_loop.py | 2 +- hugr-py/tests/test_custom.py | 3 +- hugr-py/tests/test_package.py | 45 +++++++++++ hugr-py/tests/test_tracked_dfg.py | 2 +- 7 files changed, 182 insertions(+), 29 deletions(-) create mode 100644 hugr-py/src/hugr/package.py create mode 100644 hugr-py/tests/test_package.py diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index a2205c974..6bed102d6 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -1,4 +1,4 @@ -"""HUGR extensions and packages.""" +"""HUGR extensions.""" from __future__ import annotations @@ -20,7 +20,6 @@ "OpDef", "ExtensionValue", "Extension", - "Package", "Version", ] @@ -456,26 +455,3 @@ def get_extension(self, name: ExtensionId) -> Extension: return self.extensions[name] except KeyError as e: raise self.ExtensionNotFound(name) from e - - -@dataclass -class Package: - """A package of HUGR modules and extensions. - - - The HUGRs may refer to the included extensions or those not included. - """ - - #: HUGR modules in the package. - modules: list[Hugr] - #: Extensions included in the package. - extensions: list[Extension] = field(default_factory=list) - - def _to_serial(self) -> ext_s.Package: - return ext_s.Package( - modules=[m._to_serial() for m in self.modules], - extensions=[e._to_serial() for e in self.extensions], - ) - - def to_json(self) -> str: - return self._to_serial().model_dump_json() diff --git a/hugr-py/src/hugr/package.py b/hugr-py/src/hugr/package.py new file mode 100644 index 000000000..f9667f7ab --- /dev/null +++ b/hugr-py/src/hugr/package.py @@ -0,0 +1,130 @@ +"""HUGR package and pointed package interfaces.""" + +from dataclasses import dataclass, field +from typing import Generic, TypeVar, cast + +import hugr._serialization.extension as ext_s +from hugr.ext import Extension +from hugr.hugr.base import Hugr +from hugr.hugr.node_port import Node +from hugr.ops import FuncDecl, FuncDefn, Op + +__all__ = [ + "Package", + "PackagePointer", + "ModulePointer", + "ExtensionPointer", + "NodePointer", + "FuncDeclPointer", + "FuncDefnPointer", +] + + +@dataclass +class Package: + """A package of HUGR modules and extensions. + + + The HUGRs may refer to the included extensions or those not included. + """ + + #: HUGR modules in the package. + modules: list[Hugr] + #: Extensions included in the package. + extensions: list[Extension] = field(default_factory=list) + + def _to_serial(self) -> ext_s.Package: + return ext_s.Package( + modules=[m._to_serial() for m in self.modules], + extensions=[e._to_serial() for e in self.extensions], + ) + + def to_json(self) -> str: + return self._to_serial().model_dump_json() + + +@dataclass +class PackagePointer: + """Classes that point to packages and their inner contents.""" + + package: Package + + def get_package(self) -> Package: + """Get the package pointed to.""" + return self.package + + +@dataclass +class ModulePointer(PackagePointer): + """Pointer to a module in a package.""" + + module_index: int + + def module(self) -> Hugr: + """Hugr definition of the module.""" + return self.package.modules[self.module_index] + + def to_executable_package(self) -> "ExecutablePackage": + """Create an executable package from a module containing a main function. + + Raises: + StopIteration: If the module does not contain a main function. + """ + module = self.module() + main_node = next( + n + for n in module.children() + if isinstance((f_def := module[n].op), FuncDefn) and f_def.f_name == "main" + ) + + return ExecutablePackage(self.package, self.module_index, main_node) + + +@dataclass +class ExtensionPointer(PackagePointer): + """Pointer to an extension in a package.""" + + extension_index: int + + def extension(self) -> Extension: + """Extension definition.""" + return self.package.extensions[self.extension_index] + + +OpType = TypeVar("OpType", bound=Op) + + +@dataclass +class NodePointer(Generic[OpType], ModulePointer): + """Pointer to a node in a module.""" + + node: Node + + def node_op(self) -> OpType: + """Get the operation of the node.""" + return cast(OpType, self.module()[self.node].op) + + +@dataclass +class FuncDeclPointer(NodePointer[FuncDecl]): + """Pointer to a function declaration in a module.""" + + def func_decl(self) -> FuncDecl: + """Function declaration.""" + return self.node_op() + + +@dataclass +class FuncDefnPointer(NodePointer[FuncDefn]): + """Pointer to a function definition in a module.""" + + def func_defn(self) -> FuncDefn: + """Function definition.""" + return self.node_op() + + +@dataclass +class ExecutablePackage(FuncDefnPointer): + def entry_point_node(self) -> Node: + """Get the entry point node of the package.""" + return self.node diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 914546dce..c2496fea9 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -20,6 +20,7 @@ from syrupy.assertion import SnapshotAssertion from hugr.ops import ComWire + from hugr.package import Package QUANTUM_EXT = ext.Extension("pytest.quantum,", ext.Version(0, 1, 0)) QUANTUM_EXT.add_op_def( @@ -133,7 +134,7 @@ def mermaid(h: Hugr): def validate( - h: Hugr | ext.Package, + h: Hugr | Package, *, roundtrip: bool = True, snap: SnapshotAssertion | None = None, diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index b49ea69c6..73fd55381 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -3,7 +3,7 @@ from hugr import ops, tys, val from hugr.build.cond_loop import Conditional, ConditionalError, TailLoop from hugr.build.dfg import Dfg -from hugr.ext import Package +from hugr.package import Package from hugr.std.int import INT_T, IntVal from .conftest import QUANTUM_EXT, H, Measure, validate diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index dcaaee8bc..48f57de7a 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -6,6 +6,7 @@ from hugr.build.dfg import Dfg from hugr.hugr import Hugr, Node from hugr.ops import AsExtOp, Custom, ExtOp +from hugr.package import Package from hugr.std.float import FLOAT_T from hugr.std.float import FLOAT_TYPES_EXTENSION as FLOAT_EXT from hugr.std.int import INT_OPS_EXTENSION, INT_TYPES_EXTENSION, DivMod, int_t @@ -56,7 +57,7 @@ def test_stringly_typed(): n = dfg.add(StringlyOp("world")()) dfg.set_outputs() assert dfg.hugr[n].op == StringlyOp("world") - validate(ext.Package([dfg.hugr], [STRINGLY_EXT])) + validate(Package([dfg.hugr], [STRINGLY_EXT])) new_h = Hugr._from_serial(dfg.hugr._to_serial()) diff --git a/hugr-py/tests/test_package.py b/hugr-py/tests/test_package.py new file mode 100644 index 000000000..0ce9717aa --- /dev/null +++ b/hugr-py/tests/test_package.py @@ -0,0 +1,45 @@ +from hugr import tys +from hugr.build.function import Module +from hugr.package import ( + FuncDeclPointer, + FuncDefnPointer, + ModulePointer, + Package, + PackagePointer, +) + +from .conftest import validate + + +def test_package(): + mod = Module() + f_id = mod.define_function("id", [tys.Qubit]) + f_id.set_outputs(f_id.input_node[0]) + + mod2 = Module() + f_id_decl = mod2.declare_function( + "id", tys.PolyFuncType([], tys.FunctionType([tys.Qubit], [tys.Qubit])) + ) + f_main = mod2.define_main([tys.Qubit]) + q = f_main.input_node[0] + call = f_main.call(f_id_decl, q) + f_main.set_outputs(call) + + package = Package([mod.hugr, mod2.hugr]) + validate(package) + + p = PackagePointer(package) + assert p.get_package() == package + + m = ModulePointer(package, 1) + assert m.module() == mod2.hugr + + f = FuncDeclPointer(package, 1, f_id_decl) + assert f.func_decl() == mod2.hugr[f_id_decl].op + + f = FuncDefnPointer(package, 0, f_id.to_node()) + + assert f.func_defn() == mod.hugr[f_id.to_node()].op + + main = m.to_executable_package() + assert main.entry_point_node() == f_main.to_node() diff --git a/hugr-py/tests/test_tracked_dfg.py b/hugr-py/tests/test_tracked_dfg.py index cac9ec0ee..5157e527c 100644 --- a/hugr-py/tests/test_tracked_dfg.py +++ b/hugr-py/tests/test_tracked_dfg.py @@ -2,7 +2,7 @@ from hugr import tys from hugr.build.tracked_dfg import TrackedDfg -from hugr.ext import Package +from hugr.package import Package from hugr.std.float import FLOAT_T, FloatVal from hugr.std.logic import Not From cd49613b63485617ce66fe2517b7e38bec8e8761 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 11 Oct 2024 15:12:16 +0100 Subject: [PATCH 2/3] use properties --- hugr-py/src/hugr/package.py | 18 ++++++++++-------- hugr-py/tests/test_package.py | 10 +++++----- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/hugr-py/src/hugr/package.py b/hugr-py/src/hugr/package.py index f9667f7ab..bea5d8a69 100644 --- a/hugr-py/src/hugr/package.py +++ b/hugr-py/src/hugr/package.py @@ -49,10 +49,6 @@ class PackagePointer: package: Package - def get_package(self) -> Package: - """Get the package pointed to.""" - return self.package - @dataclass class ModulePointer(PackagePointer): @@ -60,6 +56,7 @@ class ModulePointer(PackagePointer): module_index: int + @property def module(self) -> Hugr: """Hugr definition of the module.""" return self.package.modules[self.module_index] @@ -70,7 +67,7 @@ def to_executable_package(self) -> "ExecutablePackage": Raises: StopIteration: If the module does not contain a main function. """ - module = self.module() + module = self.module main_node = next( n for n in module.children() @@ -86,6 +83,7 @@ class ExtensionPointer(PackagePointer): extension_index: int + @property def extension(self) -> Extension: """Extension definition.""" return self.package.extensions[self.extension_index] @@ -100,31 +98,35 @@ class NodePointer(Generic[OpType], ModulePointer): node: Node + @property def node_op(self) -> OpType: """Get the operation of the node.""" - return cast(OpType, self.module()[self.node].op) + return cast(OpType, self.module[self.node].op) @dataclass class FuncDeclPointer(NodePointer[FuncDecl]): """Pointer to a function declaration in a module.""" + @property def func_decl(self) -> FuncDecl: """Function declaration.""" - return self.node_op() + return self.node_op @dataclass class FuncDefnPointer(NodePointer[FuncDefn]): """Pointer to a function definition in a module.""" + @property def func_defn(self) -> FuncDefn: """Function definition.""" - return self.node_op() + return self.node_op @dataclass class ExecutablePackage(FuncDefnPointer): + @property def entry_point_node(self) -> Node: """Get the entry point node of the package.""" return self.node diff --git a/hugr-py/tests/test_package.py b/hugr-py/tests/test_package.py index 0ce9717aa..aa038eaf1 100644 --- a/hugr-py/tests/test_package.py +++ b/hugr-py/tests/test_package.py @@ -29,17 +29,17 @@ def test_package(): validate(package) p = PackagePointer(package) - assert p.get_package() == package + assert p.package == package m = ModulePointer(package, 1) - assert m.module() == mod2.hugr + assert m.module == mod2.hugr f = FuncDeclPointer(package, 1, f_id_decl) - assert f.func_decl() == mod2.hugr[f_id_decl].op + assert f.func_decl == mod2.hugr[f_id_decl].op f = FuncDefnPointer(package, 0, f_id.to_node()) - assert f.func_defn() == mod.hugr[f_id.to_node()].op + assert f.func_defn == mod.hugr[f_id.to_node()].op main = m.to_executable_package() - assert main.entry_point_node() == f_main.to_node() + assert main.entry_point_node == f_main.to_node() From 477a365e6ace81f53a8db75fccd85fd89ba6959e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 11 Oct 2024 15:22:17 +0100 Subject: [PATCH 3/3] review changes --- hugr-py/src/hugr/package.py | 98 +++++++++++++++++++++++++++---------- 1 file changed, 73 insertions(+), 25 deletions(-) diff --git a/hugr-py/src/hugr/package.py b/hugr-py/src/hugr/package.py index bea5d8a69..f7c5a9fc7 100644 --- a/hugr-py/src/hugr/package.py +++ b/hugr-py/src/hugr/package.py @@ -1,14 +1,18 @@ """HUGR package and pointed package interfaces.""" +from __future__ import annotations + from dataclasses import dataclass, field -from typing import Generic, TypeVar, cast +from typing import TYPE_CHECKING, Generic, TypeVar, cast import hugr._serialization.extension as ext_s -from hugr.ext import Extension -from hugr.hugr.base import Hugr -from hugr.hugr.node_port import Node from hugr.ops import FuncDecl, FuncDefn, Op +if TYPE_CHECKING: + from hugr.ext import Extension + from hugr.hugr.base import Hugr + from hugr.hugr.node_port import Node + __all__ = [ "Package", "PackagePointer", @@ -20,7 +24,7 @@ ] -@dataclass +@dataclass(frozen=True) class Package: """A package of HUGR modules and extensions. @@ -43,17 +47,24 @@ def to_json(self) -> str: return self._to_serial().model_dump_json() -@dataclass +@dataclass(frozen=True) class PackagePointer: """Classes that point to packages and their inner contents.""" + #: Package pointed to. package: Package -@dataclass +@dataclass(frozen=True) class ModulePointer(PackagePointer): - """Pointer to a module in a package.""" + """Pointer to a module in a package. + Args: + package: Package pointed to. + module_index: Index of the module in the package. + """ + + #: Index of the module in the package. module_index: int @property @@ -61,26 +72,36 @@ def module(self) -> Hugr: """Hugr definition of the module.""" return self.package.modules[self.module_index] - def to_executable_package(self) -> "ExecutablePackage": + def to_executable_package(self) -> ExecutablePackage: """Create an executable package from a module containing a main function. Raises: - StopIteration: If the module does not contain a main function. + ValueError: If the module does not contain a main function. """ module = self.module - main_node = next( - n - for n in module.children() - if isinstance((f_def := module[n].op), FuncDefn) and f_def.f_name == "main" - ) - + try: + main_node = next( + n + for n in module.children() + if isinstance((f_def := module[n].op), FuncDefn) + and f_def.f_name == "main" + ) + except StopIteration as e: + msg = "Module does not contain a main function" + raise ValueError(msg) from e return ExecutablePackage(self.package, self.module_index, main_node) -@dataclass +@dataclass(frozen=True) class ExtensionPointer(PackagePointer): - """Pointer to an extension in a package.""" + """Pointer to an extension in a package. + + Args: + package: Package pointed to. + extension_index: Index of the extension in the package. + """ + #: Index of the extension in the package. extension_index: int @property @@ -92,10 +113,17 @@ def extension(self) -> Extension: OpType = TypeVar("OpType", bound=Op) -@dataclass +@dataclass(frozen=True) class NodePointer(Generic[OpType], ModulePointer): - """Pointer to a node in a module.""" + """Pointer to a node in a module. + Args: + package: Package pointed to. + module_index: Index of the module in the package. + node: Node pointed to + """ + + #: Node pointed to. node: Node @property @@ -104,9 +132,15 @@ def node_op(self) -> OpType: return cast(OpType, self.module[self.node].op) -@dataclass +@dataclass(frozen=True) class FuncDeclPointer(NodePointer[FuncDecl]): - """Pointer to a function declaration in a module.""" + """Pointer to a function declaration in a module. + + Args: + package: Package pointed to. + module_index: Index of the module in the package. + node: Node containing the function declaration. + """ @property def func_decl(self) -> FuncDecl: @@ -114,9 +148,15 @@ def func_decl(self) -> FuncDecl: return self.node_op -@dataclass +@dataclass(frozen=True) class FuncDefnPointer(NodePointer[FuncDefn]): - """Pointer to a function definition in a module.""" + """Pointer to a function definition in a module. + + Args: + package: Package pointed to. + module_index: Index of the module in the package. + node: Node containing the function definition + """ @property def func_defn(self) -> FuncDefn: @@ -124,8 +164,16 @@ def func_defn(self) -> FuncDefn: return self.node_op -@dataclass +@dataclass(frozen=True) class ExecutablePackage(FuncDefnPointer): + """PackagePointer with a defined entrypoint node. + + Args: + package: Package pointed to. + module_index: Index of the module in the package. + node: Node containing the entry point function definition. + """ + @property def entry_point_node(self) -> Node: """Get the entry point node of the package."""