Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: define wrappers around package that point to internals #1573

Merged
merged 3 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 1 addition & 25 deletions hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""HUGR extensions and packages."""
"""HUGR extensions."""

from __future__ import annotations

Expand All @@ -20,7 +20,6 @@
"OpDef",
"ExtensionValue",
"Extension",
"Package",
"Version",
]

Expand Down Expand Up @@ -456,26 +455,3 @@ def get_extension(self, name: ExtensionId) -> Extension:
return self.extensions[name]
except KeyError as e:
raise self.ExtensionNotFound(name) from e


@dataclass
class Package:
"""A package of HUGR modules and extensions.


The HUGRs may refer to the included extensions or those not included.
"""

#: HUGR modules in the package.
modules: list[Hugr]
#: Extensions included in the package.
extensions: list[Extension] = field(default_factory=list)

def _to_serial(self) -> ext_s.Package:
return ext_s.Package(
modules=[m._to_serial() for m in self.modules],
extensions=[e._to_serial() for e in self.extensions],
)

def to_json(self) -> str:
return self._to_serial().model_dump_json()
180 changes: 180 additions & 0 deletions hugr-py/src/hugr/package.py
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""HUGR package and pointed package interfaces."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, TypeVar, cast

import hugr._serialization.extension as ext_s
from hugr.ops import FuncDecl, FuncDefn, Op

if TYPE_CHECKING:
from hugr.ext import Extension
from hugr.hugr.base import Hugr
from hugr.hugr.node_port import Node

__all__ = [
"Package",
"PackagePointer",
"ModulePointer",
"ExtensionPointer",
"NodePointer",
"FuncDeclPointer",
"FuncDefnPointer",
]


@dataclass(frozen=True)
class Package:
"""A package of HUGR modules and extensions.


The HUGRs may refer to the included extensions or those not included.
"""

#: HUGR modules in the package.
modules: list[Hugr]
#: Extensions included in the package.
extensions: list[Extension] = field(default_factory=list)

def _to_serial(self) -> ext_s.Package:
return ext_s.Package(
modules=[m._to_serial() for m in self.modules],
extensions=[e._to_serial() for e in self.extensions],
)

def to_json(self) -> str:
return self._to_serial().model_dump_json()


@dataclass(frozen=True)
class PackagePointer:
"""Classes that point to packages and their inner contents."""

#: Package pointed to.
package: Package


@dataclass(frozen=True)
class ModulePointer(PackagePointer):
"""Pointer to a module in a package.

Args:
package: Package pointed to.
module_index: Index of the module in the package.
"""

#: Index of the module in the package.
module_index: int

@property
def module(self) -> Hugr:
"""Hugr definition of the module."""
return self.package.modules[self.module_index]
Comment on lines +71 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a property (same for all other methods in this PR)


def to_executable_package(self) -> ExecutablePackage:
"""Create an executable package from a module containing a main function.

Raises:
ValueError: If the module does not contain a main function.
"""
module = self.module
try:
main_node = next(
n
for n in module.children()
if isinstance((f_def := module[n].op), FuncDefn)
and f_def.f_name == "main"
)
except StopIteration as e:
msg = "Module does not contain a main function"
raise ValueError(msg) from e

Check warning on line 91 in hugr-py/src/hugr/package.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/package.py#L89-L91

Added lines #L89 - L91 were not covered by tests
return ExecutablePackage(self.package, self.module_index, main_node)


@dataclass(frozen=True)
class ExtensionPointer(PackagePointer):
"""Pointer to an extension in a package.

Args:
package: Package pointed to.
extension_index: Index of the extension in the package.
"""

#: Index of the extension in the package.
extension_index: int

@property
def extension(self) -> Extension:
"""Extension definition."""
return self.package.extensions[self.extension_index]

Check warning on line 110 in hugr-py/src/hugr/package.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/package.py#L110

Added line #L110 was not covered by tests


OpType = TypeVar("OpType", bound=Op)


@dataclass(frozen=True)
class NodePointer(Generic[OpType], ModulePointer):
"""Pointer to a node in a module.

Args:
package: Package pointed to.
module_index: Index of the module in the package.
node: Node pointed to
"""

#: Node pointed to.
node: Node

@property
def node_op(self) -> OpType:
"""Get the operation of the node."""
return cast(OpType, self.module[self.node].op)


@dataclass(frozen=True)
class FuncDeclPointer(NodePointer[FuncDecl]):
"""Pointer to a function declaration in a module.

Args:
package: Package pointed to.
module_index: Index of the module in the package.
node: Node containing the function declaration.
"""

@property
def func_decl(self) -> FuncDecl:
"""Function declaration."""
return self.node_op


@dataclass(frozen=True)
class FuncDefnPointer(NodePointer[FuncDefn]):
"""Pointer to a function definition in a module.

Args:
package: Package pointed to.
module_index: Index of the module in the package.
node: Node containing the function definition
"""

@property
def func_defn(self) -> FuncDefn:
"""Function definition."""
return self.node_op


@dataclass(frozen=True)
class ExecutablePackage(FuncDefnPointer):
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
"""PackagePointer with a defined entrypoint node.

Args:
package: Package pointed to.
module_index: Index of the module in the package.
node: Node containing the entry point function definition.
"""

@property
def entry_point_node(self) -> Node:
"""Get the entry point node of the package."""
return self.node
3 changes: 2 additions & 1 deletion hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from syrupy.assertion import SnapshotAssertion

from hugr.ops import ComWire
from hugr.package import Package

QUANTUM_EXT = ext.Extension("pytest.quantum,", ext.Version(0, 1, 0))
QUANTUM_EXT.add_op_def(
Expand Down Expand Up @@ -133,7 +134,7 @@ def mermaid(h: Hugr):


def validate(
h: Hugr | ext.Package,
h: Hugr | Package,
*,
roundtrip: bool = True,
snap: SnapshotAssertion | None = None,
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/tests/test_cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hugr import ops, tys, val
from hugr.build.cond_loop import Conditional, ConditionalError, TailLoop
from hugr.build.dfg import Dfg
from hugr.ext import Package
from hugr.package import Package
from hugr.std.int import INT_T, IntVal

from .conftest import QUANTUM_EXT, H, Measure, validate
Expand Down
3 changes: 2 additions & 1 deletion hugr-py/tests/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from hugr.build.dfg import Dfg
from hugr.hugr import Hugr, Node
from hugr.ops import AsExtOp, Custom, ExtOp
from hugr.package import Package
from hugr.std.float import FLOAT_T
from hugr.std.float import FLOAT_TYPES_EXTENSION as FLOAT_EXT
from hugr.std.int import INT_OPS_EXTENSION, INT_TYPES_EXTENSION, DivMod, int_t
Expand Down Expand Up @@ -56,7 +57,7 @@ def test_stringly_typed():
n = dfg.add(StringlyOp("world")())
dfg.set_outputs()
assert dfg.hugr[n].op == StringlyOp("world")
validate(ext.Package([dfg.hugr], [STRINGLY_EXT]))
validate(Package([dfg.hugr], [STRINGLY_EXT]))

new_h = Hugr._from_serial(dfg.hugr._to_serial())

Expand Down
45 changes: 45 additions & 0 deletions hugr-py/tests/test_package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from hugr import tys
from hugr.build.function import Module
from hugr.package import (
FuncDeclPointer,
FuncDefnPointer,
ModulePointer,
Package,
PackagePointer,
)

from .conftest import validate


def test_package():
mod = Module()
f_id = mod.define_function("id", [tys.Qubit])
f_id.set_outputs(f_id.input_node[0])

mod2 = Module()
f_id_decl = mod2.declare_function(
"id", tys.PolyFuncType([], tys.FunctionType([tys.Qubit], [tys.Qubit]))
)
f_main = mod2.define_main([tys.Qubit])
q = f_main.input_node[0]
call = f_main.call(f_id_decl, q)
f_main.set_outputs(call)

package = Package([mod.hugr, mod2.hugr])
validate(package)

p = PackagePointer(package)
assert p.package == package

m = ModulePointer(package, 1)
assert m.module == mod2.hugr

f = FuncDeclPointer(package, 1, f_id_decl)
assert f.func_decl == mod2.hugr[f_id_decl].op

f = FuncDefnPointer(package, 0, f_id.to_node())

assert f.func_defn == mod.hugr[f_id.to_node()].op

main = m.to_executable_package()
assert main.entry_point_node == f_main.to_node()
2 changes: 1 addition & 1 deletion hugr-py/tests/test_tracked_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from hugr import tys
from hugr.build.tracked_dfg import TrackedDfg
from hugr.ext import Package
from hugr.package import Package
from hugr.std.float import FLOAT_T, FloatVal
from hugr.std.logic import Not

Expand Down
Loading