diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index a2205c974..6bed102d6 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -1,4 +1,4 @@ -"""HUGR extensions and packages.""" +"""HUGR extensions.""" from __future__ import annotations @@ -20,7 +20,6 @@ "OpDef", "ExtensionValue", "Extension", - "Package", "Version", ] @@ -456,26 +455,3 @@ def get_extension(self, name: ExtensionId) -> Extension: return self.extensions[name] except KeyError as e: raise self.ExtensionNotFound(name) from e - - -@dataclass -class Package: - """A package of HUGR modules and extensions. - - - The HUGRs may refer to the included extensions or those not included. - """ - - #: HUGR modules in the package. - modules: list[Hugr] - #: Extensions included in the package. - extensions: list[Extension] = field(default_factory=list) - - def _to_serial(self) -> ext_s.Package: - return ext_s.Package( - modules=[m._to_serial() for m in self.modules], - extensions=[e._to_serial() for e in self.extensions], - ) - - def to_json(self) -> str: - return self._to_serial().model_dump_json() diff --git a/hugr-py/src/hugr/package.py b/hugr-py/src/hugr/package.py new file mode 100644 index 000000000..f7c5a9fc7 --- /dev/null +++ b/hugr-py/src/hugr/package.py @@ -0,0 +1,180 @@ +"""HUGR package and pointed package interfaces.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Generic, TypeVar, cast + +import hugr._serialization.extension as ext_s +from hugr.ops import FuncDecl, FuncDefn, Op + +if TYPE_CHECKING: + from hugr.ext import Extension + from hugr.hugr.base import Hugr + from hugr.hugr.node_port import Node + +__all__ = [ + "Package", + "PackagePointer", + "ModulePointer", + "ExtensionPointer", + "NodePointer", + "FuncDeclPointer", + "FuncDefnPointer", +] + + +@dataclass(frozen=True) +class Package: + """A package of HUGR modules and extensions. + + + The HUGRs may refer to the included extensions or those not included. + """ + + #: HUGR modules in the package. + modules: list[Hugr] + #: Extensions included in the package. + extensions: list[Extension] = field(default_factory=list) + + def _to_serial(self) -> ext_s.Package: + return ext_s.Package( + modules=[m._to_serial() for m in self.modules], + extensions=[e._to_serial() for e in self.extensions], + ) + + def to_json(self) -> str: + return self._to_serial().model_dump_json() + + +@dataclass(frozen=True) +class PackagePointer: + """Classes that point to packages and their inner contents.""" + + #: Package pointed to. + package: Package + + +@dataclass(frozen=True) +class ModulePointer(PackagePointer): + """Pointer to a module in a package. + + Args: + package: Package pointed to. + module_index: Index of the module in the package. + """ + + #: Index of the module in the package. + module_index: int + + @property + def module(self) -> Hugr: + """Hugr definition of the module.""" + return self.package.modules[self.module_index] + + def to_executable_package(self) -> ExecutablePackage: + """Create an executable package from a module containing a main function. + + Raises: + ValueError: If the module does not contain a main function. + """ + module = self.module + try: + main_node = next( + n + for n in module.children() + if isinstance((f_def := module[n].op), FuncDefn) + and f_def.f_name == "main" + ) + except StopIteration as e: + msg = "Module does not contain a main function" + raise ValueError(msg) from e + return ExecutablePackage(self.package, self.module_index, main_node) + + +@dataclass(frozen=True) +class ExtensionPointer(PackagePointer): + """Pointer to an extension in a package. + + Args: + package: Package pointed to. + extension_index: Index of the extension in the package. + """ + + #: Index of the extension in the package. + extension_index: int + + @property + def extension(self) -> Extension: + """Extension definition.""" + return self.package.extensions[self.extension_index] + + +OpType = TypeVar("OpType", bound=Op) + + +@dataclass(frozen=True) +class NodePointer(Generic[OpType], ModulePointer): + """Pointer to a node in a module. + + Args: + package: Package pointed to. + module_index: Index of the module in the package. + node: Node pointed to + """ + + #: Node pointed to. + node: Node + + @property + def node_op(self) -> OpType: + """Get the operation of the node.""" + return cast(OpType, self.module[self.node].op) + + +@dataclass(frozen=True) +class FuncDeclPointer(NodePointer[FuncDecl]): + """Pointer to a function declaration in a module. + + Args: + package: Package pointed to. + module_index: Index of the module in the package. + node: Node containing the function declaration. + """ + + @property + def func_decl(self) -> FuncDecl: + """Function declaration.""" + return self.node_op + + +@dataclass(frozen=True) +class FuncDefnPointer(NodePointer[FuncDefn]): + """Pointer to a function definition in a module. + + Args: + package: Package pointed to. + module_index: Index of the module in the package. + node: Node containing the function definition + """ + + @property + def func_defn(self) -> FuncDefn: + """Function definition.""" + return self.node_op + + +@dataclass(frozen=True) +class ExecutablePackage(FuncDefnPointer): + """PackagePointer with a defined entrypoint node. + + Args: + package: Package pointed to. + module_index: Index of the module in the package. + node: Node containing the entry point function definition. + """ + + @property + def entry_point_node(self) -> Node: + """Get the entry point node of the package.""" + return self.node diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 914546dce..c2496fea9 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -20,6 +20,7 @@ from syrupy.assertion import SnapshotAssertion from hugr.ops import ComWire + from hugr.package import Package QUANTUM_EXT = ext.Extension("pytest.quantum,", ext.Version(0, 1, 0)) QUANTUM_EXT.add_op_def( @@ -133,7 +134,7 @@ def mermaid(h: Hugr): def validate( - h: Hugr | ext.Package, + h: Hugr | Package, *, roundtrip: bool = True, snap: SnapshotAssertion | None = None, diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index b49ea69c6..73fd55381 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -3,7 +3,7 @@ from hugr import ops, tys, val from hugr.build.cond_loop import Conditional, ConditionalError, TailLoop from hugr.build.dfg import Dfg -from hugr.ext import Package +from hugr.package import Package from hugr.std.int import INT_T, IntVal from .conftest import QUANTUM_EXT, H, Measure, validate diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index dcaaee8bc..48f57de7a 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -6,6 +6,7 @@ from hugr.build.dfg import Dfg from hugr.hugr import Hugr, Node from hugr.ops import AsExtOp, Custom, ExtOp +from hugr.package import Package from hugr.std.float import FLOAT_T from hugr.std.float import FLOAT_TYPES_EXTENSION as FLOAT_EXT from hugr.std.int import INT_OPS_EXTENSION, INT_TYPES_EXTENSION, DivMod, int_t @@ -56,7 +57,7 @@ def test_stringly_typed(): n = dfg.add(StringlyOp("world")()) dfg.set_outputs() assert dfg.hugr[n].op == StringlyOp("world") - validate(ext.Package([dfg.hugr], [STRINGLY_EXT])) + validate(Package([dfg.hugr], [STRINGLY_EXT])) new_h = Hugr._from_serial(dfg.hugr._to_serial()) diff --git a/hugr-py/tests/test_package.py b/hugr-py/tests/test_package.py new file mode 100644 index 000000000..aa038eaf1 --- /dev/null +++ b/hugr-py/tests/test_package.py @@ -0,0 +1,45 @@ +from hugr import tys +from hugr.build.function import Module +from hugr.package import ( + FuncDeclPointer, + FuncDefnPointer, + ModulePointer, + Package, + PackagePointer, +) + +from .conftest import validate + + +def test_package(): + mod = Module() + f_id = mod.define_function("id", [tys.Qubit]) + f_id.set_outputs(f_id.input_node[0]) + + mod2 = Module() + f_id_decl = mod2.declare_function( + "id", tys.PolyFuncType([], tys.FunctionType([tys.Qubit], [tys.Qubit])) + ) + f_main = mod2.define_main([tys.Qubit]) + q = f_main.input_node[0] + call = f_main.call(f_id_decl, q) + f_main.set_outputs(call) + + package = Package([mod.hugr, mod2.hugr]) + validate(package) + + p = PackagePointer(package) + assert p.package == package + + m = ModulePointer(package, 1) + assert m.module == mod2.hugr + + f = FuncDeclPointer(package, 1, f_id_decl) + assert f.func_decl == mod2.hugr[f_id_decl].op + + f = FuncDefnPointer(package, 0, f_id.to_node()) + + assert f.func_defn == mod.hugr[f_id.to_node()].op + + main = m.to_executable_package() + assert main.entry_point_node == f_main.to_node() diff --git a/hugr-py/tests/test_tracked_dfg.py b/hugr-py/tests/test_tracked_dfg.py index cac9ec0ee..5157e527c 100644 --- a/hugr-py/tests/test_tracked_dfg.py +++ b/hugr-py/tests/test_tracked_dfg.py @@ -2,7 +2,7 @@ from hugr import tys from hugr.build.tracked_dfg import TrackedDfg -from hugr.ext import Package +from hugr.package import Package from hugr.std.float import FLOAT_T, FloatVal from hugr.std.logic import Not