diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index a1ed1417c..bde255b32 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -6,7 +6,8 @@ from ._dfg import _DfBase from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit -from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire +from ._hugr import Hugr, ParentBuilder +from ._node_port import Node, Wire, ToNode from ._tys import TypeRow, Type import hugr._val as val diff --git a/hugr-py/src/hugr/_cond_loop.py b/hugr-py/src/hugr/_cond_loop.py index ebf7f3f14..6433a0f21 100644 --- a/hugr-py/src/hugr/_cond_loop.py +++ b/hugr-py/src/hugr/_cond_loop.py @@ -5,7 +5,9 @@ import hugr._ops as ops from ._dfg import _DfBase -from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire +from ._hugr import Hugr, ParentBuilder +from ._node_port import Node, Wire, ToNode + from ._tys import Sum, TypeRow diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index b47c81993..2f3a10eec 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -12,10 +12,11 @@ import hugr._ops as ops import hugr._val as val -from hugr._tys import Type, TypeRow, get_first_sum +from hugr._tys import Type, TypeRow, get_first_sum, FunctionType, TypeArg, FunctionKind from ._exceptions import NoSiblingAncestor -from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire +from ._hugr import Hugr, ParentBuilder +from ._node_port import Node, OutPort, Wire, ToNode if TYPE_CHECKING: from ._cfg import Cfg @@ -164,6 +165,28 @@ def load(self, const: ToNode | val.Value) -> Node: return load + def call( + self, + func: ToNode, + *args: Wire, + instantiation: FunctionType | None = None, + type_args: list[TypeArg] | None = None, + ) -> Node: + f_op = self.hugr[func] + f_kind = f_op.op.port_kind(func.out(0)) + match f_kind: + case FunctionKind(sig): + signature = sig + case _: + raise ValueError("Expected 'func' to be a function") + call_op = ops.Call(signature, instantiation, type_args) + call_n = self.hugr.add_node(call_op, self.parent_node, call_op.num_out) + self.hugr.add_link(func.out(0), call_n.inp(call_op.function_port_offset())) + + self._wire_up(call_n, args) + + return call_n + def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow: tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops.PartialOp): diff --git a/hugr-py/src/hugr/_function.py b/hugr-py/src/hugr/_function.py new file mode 100644 index 000000000..d8ffa578b --- /dev/null +++ b/hugr-py/src/hugr/_function.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import hugr._ops as ops +import hugr._val as val + +from ._dfg import _DfBase +from hugr._node_port import Node +from ._hugr import Hugr +from ._tys import TypeRow, TypeParam, PolyFuncType + + +@dataclass +class Function(_DfBase[ops.FuncDefn]): + def __init__( + self, + name: str, + input_types: TypeRow, + type_params: list[TypeParam] | None = None, + ) -> None: + root_op = ops.FuncDefn(name, input_types, type_params or []) + super().__init__(root_op) + + +@dataclass +class Module: + hugr: Hugr + + def __init__(self) -> None: + self.hugr = Hugr(ops.Module()) + + def define_function( + self, + name: str, + input_types: TypeRow, + type_params: list[TypeParam] | None = None, + ) -> Function: + parent_op = ops.FuncDefn(name, input_types, type_params or []) + return Function.new_nested(parent_op, self.hugr) + + def define_main(self, input_types: TypeRow) -> Function: + return self.define_function("main", input_types) + + def declare_function(self, name: str, signature: PolyFuncType) -> Node: + return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.root) + + def add_const(self, value: val.Value) -> Node: + return self.hugr.add_node(ops.Const(value), self.hugr.root) diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index a13cb3d1d..7ee6831fa 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -2,12 +2,9 @@ from collections.abc import Mapping from dataclasses import dataclass, field, replace -from enum import Enum from typing import ( - ClassVar, Generic, Iterable, - Iterator, Protocol, TypeVar, cast, @@ -15,11 +12,11 @@ Type as PyType, ) -from typing_extensions import Self -from hugr._ops import Op, DataflowOp, Const -from hugr._tys import Type, Kind +from hugr._ops import Op, DataflowOp, Const, Call +from hugr._tys import Type, Kind, ValueKind from hugr._val import Value +from hugr._node_port import Direction, InPort, OutPort, ToNode, Node, _SubPort from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.utils import BiMap @@ -27,96 +24,6 @@ from ._exceptions import ParentBeforeChild -class Direction(Enum): - INCOMING = 0 - OUTGOING = 1 - - -@dataclass(frozen=True, eq=True, order=True) -class _Port: - node: Node - offset: int - direction: ClassVar[Direction] - - -@dataclass(frozen=True, eq=True, order=True) -class InPort(_Port): - direction: ClassVar[Direction] = Direction.INCOMING - - -class Wire(Protocol): - def out_port(self) -> OutPort: ... - - -@dataclass(frozen=True, eq=True, order=True) -class OutPort(_Port, Wire): - direction: ClassVar[Direction] = Direction.OUTGOING - - def out_port(self) -> OutPort: - return self - - -class ToNode(Wire, Protocol): - def to_node(self) -> Node: ... - - @overload - def __getitem__(self, index: int) -> OutPort: ... - @overload - def __getitem__(self, index: slice) -> Iterator[OutPort]: ... - @overload - def __getitem__(self, index: tuple[int, ...]) -> Iterator[OutPort]: ... - - def __getitem__( - self, index: int | slice | tuple[int, ...] - ) -> OutPort | Iterator[OutPort]: - return self.to_node()._index(index) - - def out_port(self) -> "OutPort": - return OutPort(self.to_node(), 0) - - def inp(self, offset: int) -> InPort: - return InPort(self.to_node(), offset) - - def out(self, offset: int) -> OutPort: - return OutPort(self.to_node(), offset) - - def port(self, offset: int, direction: Direction) -> InPort | OutPort: - if direction == Direction.INCOMING: - return self.inp(offset) - else: - return self.out(offset) - - -@dataclass(frozen=True, eq=True, order=True) -class Node(ToNode): - idx: int - _num_out_ports: int | None = field(default=None, compare=False) - - def _index( - self, index: int | slice | tuple[int, ...] - ) -> OutPort | Iterator[OutPort]: - match index: - case int(index): - if self._num_out_ports is not None: - if index >= self._num_out_ports: - raise IndexError("Index out of range") - return self.out(index) - case slice(): - start = index.start or 0 - stop = index.stop or self._num_out_ports - if stop is None: - raise ValueError( - "Stop must be specified when number of outputs unknown" - ) - step = index.step or 1 - return (self[i] for i in range(start, stop, step)) - case tuple(xs): - return (self[i] for i in xs) - - def to_node(self) -> Node: - return self - - @dataclass() class NodeData: op: Op @@ -131,25 +38,15 @@ def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: return SerialOp(root=o) # type: ignore[arg-type] +_SO = _SubPort[OutPort] +_SI = _SubPort[InPort] + P = TypeVar("P", InPort, OutPort) K = TypeVar("K", InPort, OutPort) OpVar = TypeVar("OpVar", bound=Op) OpVar2 = TypeVar("OpVar2", bound=Op) -@dataclass(frozen=True, eq=True, order=True) -class _SubPort(Generic[P]): - port: P - sub_offset: int = 0 - - def next_sub_offset(self) -> Self: - return replace(self, sub_offset=self.sub_offset + 1) - - -_SO = _SubPort[OutPort] -_SI = _SubPort[InPort] - - class ParentBuilder(ToNode, Protocol[OpVar]): hugr: Hugr[OpVar] parent_node: Node @@ -360,6 +257,10 @@ def port_type(self, port: InPort | OutPort) -> Type | None: op = self[port.node].op if isinstance(op, DataflowOp): return op.port_type(port) + if isinstance(op, Call) and isinstance(port, OutPort): + kind = self.port_kind(port) + if isinstance(kind, ValueKind): + return kind.ty return None def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, Node]: diff --git a/hugr-py/src/hugr/_node_port.py b/hugr-py/src/hugr/_node_port.py new file mode 100644 index 000000000..23e10291a --- /dev/null +++ b/hugr-py/src/hugr/_node_port.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from enum import Enum +from typing import ( + ClassVar, + Iterator, + Protocol, + overload, + TypeVar, + Generic, +) +from typing_extensions import Self + + +class Direction(Enum): + INCOMING = 0 + OUTGOING = 1 + + +@dataclass(frozen=True, eq=True, order=True) +class _Port: + node: Node + offset: int + direction: ClassVar[Direction] + + +@dataclass(frozen=True, eq=True, order=True) +class InPort(_Port): + direction: ClassVar[Direction] = Direction.INCOMING + + +class Wire(Protocol): + def out_port(self) -> OutPort: ... + + +@dataclass(frozen=True, eq=True, order=True) +class OutPort(_Port, Wire): + direction: ClassVar[Direction] = Direction.OUTGOING + + def out_port(self) -> OutPort: + return self + + +class ToNode(Wire, Protocol): + def to_node(self) -> Node: ... + + @overload + def __getitem__(self, index: int) -> OutPort: ... + @overload + def __getitem__(self, index: slice) -> Iterator[OutPort]: ... + @overload + def __getitem__(self, index: tuple[int, ...]) -> Iterator[OutPort]: ... + + def __getitem__( + self, index: int | slice | tuple[int, ...] + ) -> OutPort | Iterator[OutPort]: + return self.to_node()._index(index) + + def out_port(self) -> "OutPort": + return OutPort(self.to_node(), 0) + + def inp(self, offset: int) -> InPort: + return InPort(self.to_node(), offset) + + def out(self, offset: int) -> OutPort: + return OutPort(self.to_node(), offset) + + def port(self, offset: int, direction: Direction) -> InPort | OutPort: + if direction == Direction.INCOMING: + return self.inp(offset) + else: + return self.out(offset) + + +@dataclass(frozen=True, eq=True, order=True) +class Node(ToNode): + idx: int + _num_out_ports: int | None = field(default=None, compare=False) + + def _index( + self, index: int | slice | tuple[int, ...] + ) -> OutPort | Iterator[OutPort]: + match index: + case int(index): + if self._num_out_ports is not None: + if index >= self._num_out_ports: + raise IndexError("Index out of range") + return self.out(index) + case slice(): + start = index.start or 0 + stop = index.stop or self._num_out_ports + if stop is None: + raise ValueError( + "Stop must be specified when number of outputs unknown" + ) + step = index.step or 1 + return (self[i] for i in range(start, stop, step)) + case tuple(xs): + return (self[i] for i in xs) + + def to_node(self) -> Node: + return self + + +P = TypeVar("P", InPort, OutPort) + + +@dataclass(frozen=True, eq=True, order=True) +class _SubPort(Generic[P]): + port: P + sub_offset: int = 0 + + def next_sub_offset(self) -> Self: + return replace(self, sub_offset=self.sub_offset + 1) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 2241eff1d..5bc63c79a 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -6,11 +6,21 @@ import hugr.serialization.ops as sops from hugr.utils import ser_it import hugr._tys as tys +from hugr._node_port import Node, InPort, OutPort, Wire import hugr._val as val from ._exceptions import IncompleteOp if TYPE_CHECKING: - from hugr._hugr import Hugr, Node, Wire, InPort, OutPort + from hugr._hugr import Hugr + + +@dataclass +class InvalidPort(Exception): + port: InPort | OutPort + + @property + def msg(self) -> str: + return f"Invalid port {self.port}" @runtime_checkable @@ -24,6 +34,14 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ... def port_kind(self, port: InPort | OutPort) -> tys.Kind: ... +def _sig_port_type(sig: tys.FunctionType, port: InPort | OutPort) -> tys.Type: + from hugr._hugr import Direction + + if port.direction == Direction.INCOMING: + return sig.input[port.offset] + return sig.output[port.offset] + + @runtime_checkable class DataflowOp(Op, Protocol): def outer_signature(self) -> tys.FunctionType: ... @@ -34,12 +52,7 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: return tys.ValueKind(self.port_type(port)) def port_type(self, port: InPort | OutPort) -> tys.Type: - from hugr._hugr import Direction - - sig = self.outer_signature() - if port.direction == Direction.INCOMING: - return sig.input[port.offset] - return sig.output[port.offset] + return _sig_port_type(self.outer_signature(), port) def __call__(self, *args) -> Command: return Command(self, list(args)) @@ -358,7 +371,11 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Const: ) def port_kind(self, port: InPort | OutPort) -> tys.Kind: - return tys.ConstKind(self.val.type_()) + match port: + case OutPort(_, 0): + return tys.ConstKind(self.val.type_()) + case _: + raise InvalidPort(port) @dataclass @@ -377,6 +394,15 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadConstant: def outer_signature(self) -> tys.FunctionType: return tys.FunctionType(input=[], output=[self.type_()]) + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + match port: + case InPort(_, 0): + return tys.ConstKind(self.type_()) + case OutPort(_, 0): + return tys.ValueKind(self.type_()) + case _: + raise InvalidPort(port) + @dataclass() class Conditional(DataflowOp): @@ -416,15 +442,12 @@ def nth_inputs(self, n: int) -> tys.TypeRow: class Case(DfParentOp): inputs: tys.TypeRow _outputs: tys.TypeRow | None = None + num_out: int | None = 0 @property def outputs(self) -> tys.TypeRow: return _check_complete(self._outputs) - @property - def num_out(self) -> int | None: - return 0 - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Case: return sops.Case( parent=parent.idx, signature=self.inner_signature().to_serial() @@ -434,7 +457,7 @@ def inner_signature(self) -> tys.FunctionType: return tys.FunctionType(self.inputs, self.outputs) def port_kind(self, port: InPort | OutPort) -> tys.Kind: - raise NotImplementedError("Case nodes have no external ports.") + raise InvalidPort(port) def _set_out_types(self, types: tys.TypeRow) -> None: self._outputs = types @@ -485,3 +508,132 @@ def _set_out_types(self, types: tys.TypeRow) -> None: def _inputs(self) -> tys.TypeRow: return self.just_inputs + self.rest + + +@dataclass +class FuncDefn(DfParentOp): + name: str + inputs: tys.TypeRow + params: list[tys.TypeParam] = field(default_factory=list) + _outputs: tys.TypeRow | None = None + num_out: int | None = 1 + + @property + def outputs(self) -> tys.TypeRow: + return _check_complete(self._outputs) + + @property + def signature(self) -> tys.PolyFuncType: + return tys.PolyFuncType( + self.params, tys.FunctionType(self.inputs, self.outputs) + ) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.FuncDefn: + return sops.FuncDefn( + parent=parent.idx, + name=self.name, + signature=self.signature.to_serial(), + ) + + def inner_signature(self) -> tys.FunctionType: + return self.signature.body + + def _set_out_types(self, types: tys.TypeRow) -> None: + self._outputs = types + + def _inputs(self) -> tys.TypeRow: + return self.inputs + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + match port: + case OutPort(_, 0): + return tys.FunctionKind(self.signature) + case _: + raise InvalidPort(port) + + +@dataclass +class FuncDecl(Op): + name: str + signature: tys.PolyFuncType + num_out: int | None = 0 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.FuncDecl: + return sops.FuncDecl( + parent=parent.idx, + name=self.name, + signature=self.signature.to_serial(), + ) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + match port: + case OutPort(_, 0): + return tys.FunctionKind(self.signature) + case _: + raise InvalidPort(port) + + +@dataclass +class Module(Op): + num_out: int | None = 0 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Module: + return sops.Module(parent=parent.idx) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + raise InvalidPort(port) + + +class NoConcreteFunc(Exception): + pass + + +@dataclass +class Call(Op): + signature: tys.PolyFuncType + instantiation: tys.FunctionType + type_args: list[tys.TypeArg] + + def __init__( + self, + signature: tys.PolyFuncType, + instantiation: tys.FunctionType | None = None, + type_args: list[tys.TypeArg] | None = None, + ) -> None: + self.signature = signature + if len(signature.params) == 0: + self.instantiation = signature.body + self.type_args = [] + + else: + # TODO substitute type args into signature to get instantiation + if instantiation is None: + raise NoConcreteFunc("Missing instantiation for polymorphic function.") + type_args = type_args or [] + + if len(signature.params) != len(type_args): + raise NoConcreteFunc("Mismatched number of type arguments.") + self.instantiation = instantiation + self.type_args = type_args + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Call: + return sops.Call( + parent=parent.idx, + func_sig=self.signature.to_serial(), + type_args=ser_it(self.type_args), + instantiation=self.instantiation.to_serial(), + ) + + @property + def num_out(self) -> int | None: + return len(self.signature.body.output) + + def function_port_offset(self) -> int: + return len(self.signature.body.input) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + match port: + case InPort(_, offset) if offset == self.function_port_offset(): + return tys.FunctionKind(self.signature) + case _: + return tys.ValueKind(_sig_port_type(self.instantiation, port)) diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 48dd43b08..6e2e3584e 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -36,6 +36,9 @@ def to_serial(self) -> stys.BaseType: ... def to_serial_root(self) -> stys.Type: return stys.Type(root=self.to_serial()) # type: ignore[arg-type] + def type_arg(self) -> TypeTypeArg: + return TypeTypeArg(self) + TypeRow = list[Type] diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index d2fedf027..a81ca95ba 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -58,6 +58,9 @@ class Module(BaseOp): op: Literal["Module"] = "Module" + def deserialize(self) -> _ops.Module: + return _ops.Module() + class FuncDefn(BaseOp): """A function definition. Children nodes are the body of the definition.""" @@ -67,6 +70,12 @@ class FuncDefn(BaseOp): name: str signature: PolyFuncType + def deserialize(self) -> _ops.FuncDefn: + poly_func = self.signature.deserialize() + return _ops.FuncDefn( + self.name, inputs=poly_func.body.input, _outputs=poly_func.body.output + ) + class FuncDecl(BaseOp): """External function declaration, linked at runtime.""" @@ -75,6 +84,9 @@ class FuncDecl(BaseOp): name: str signature: PolyFuncType + def deserialize(self) -> _ops.FuncDecl: + return _ops.FuncDecl(self.name, self.signature.deserialize()) + class CustomConst(ConfiguredBaseModel): c: str @@ -298,6 +310,13 @@ class Call(DataflowOp): } ) + def deserialize(self) -> _ops.Call: + return _ops.Call( + self.func_sig.deserialize(), + self.instantiation.deserialize(), + deser_it(self.type_args), + ) + class CallIndirect(DataflowOp): """Call a function indirectly. diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 14a44d12b..cad65436d 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -3,13 +3,16 @@ import subprocess import os import pathlib -from hugr._hugr import Hugr, Node, Wire, _SubPort +from hugr._node_port import Node, Wire, _SubPort + +from hugr._hugr import Hugr from hugr._dfg import Dfg, _ancestral_sibling -from hugr._ops import Custom, Command +from hugr._ops import Custom, Command, NoConcreteFunc import hugr._ops as ops from hugr.serialization import SerialHugr import hugr._tys as tys import hugr._val as val +from hugr._function import Module import pytest import json @@ -313,3 +316,43 @@ def test_vals(val: val.Value): d.set_outputs(d.load(val)) _validate(d.hugr) + + +def test_poly_function() -> None: + mod = Module() + f_id = mod.declare_function( + "id", + tys.PolyFuncType( + [tys.TypeTypeParam(tys.TypeBound.Any)], + tys.FunctionType.endo([tys.Variable(0, tys.TypeBound.Any)]), + ), + ) + + f_main = mod.define_main([tys.Qubit]) + q = f_main.input_node[0] + with pytest.raises(NoConcreteFunc, match="Missing instantiation"): + f_main.call(f_id, q) + call = f_main.call( + f_id, + q, + # for now concrete instantiations have to be provided. + instantiation=tys.FunctionType.endo([tys.Qubit]), + type_args=[tys.Qubit.type_arg()], + ) + f_main.set_outputs(call) + + _validate(mod.hugr, True) + + +def test_mono_function() -> None: + mod = Module() + f_id = mod.define_function("id", [tys.Qubit]) + f_id.set_outputs(f_id.input_node[0]) + + f_main = mod.define_main([tys.Qubit]) + q = f_main.input_node[0] + # monomorphic functions don't need instantiation specified + call = f_main.call(f_id, q) + f_main.set_outputs(call) + + _validate(mod.hugr, True)