Skip to content

Commit

Permalink
feat(hugr-py): builder for function definition/declaration and call (#…
Browse files Browse the repository at this point in the history
…1212)

Closes #1211 
review commits separately, first is just a trivial refactor moving code
in to its own file

---------

Co-authored-by: Agustín Borgna <[email protected]>
  • Loading branch information
ss2165 and aborgna-q authored Jun 21, 2024
1 parent 43569a4 commit af062ea
Show file tree
Hide file tree
Showing 10 changed files with 436 additions and 128 deletions.
3 changes: 2 additions & 1 deletion hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from ._dfg import _DfBase
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._hugr import Hugr, ParentBuilder
from ._node_port import Node, Wire, ToNode
from ._tys import TypeRow, Type
import hugr._val as val

Expand Down
4 changes: 3 additions & 1 deletion hugr-py/src/hugr/_cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import hugr._ops as ops

from ._dfg import _DfBase
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._hugr import Hugr, ParentBuilder
from ._node_port import Node, Wire, ToNode

from ._tys import Sum, TypeRow


Expand Down
27 changes: 25 additions & 2 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

import hugr._ops as ops
import hugr._val as val
from hugr._tys import Type, TypeRow, get_first_sum
from hugr._tys import Type, TypeRow, get_first_sum, FunctionType, TypeArg, FunctionKind

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire
from ._hugr import Hugr, ParentBuilder
from ._node_port import Node, OutPort, Wire, ToNode

if TYPE_CHECKING:
from ._cfg import Cfg
Expand Down Expand Up @@ -164,6 +165,28 @@ def load(self, const: ToNode | val.Value) -> Node:

return load

def call(
self,
func: ToNode,
*args: Wire,
instantiation: FunctionType | None = None,
type_args: list[TypeArg] | None = None,
) -> Node:
f_op = self.hugr[func]
f_kind = f_op.op.port_kind(func.out(0))
match f_kind:
case FunctionKind(sig):
signature = sig
case _:
raise ValueError("Expected 'func' to be a function")
call_op = ops.Call(signature, instantiation, type_args)
call_n = self.hugr.add_node(call_op, self.parent_node, call_op.num_out)
self.hugr.add_link(func.out(0), call_n.inp(call_op.function_port_offset()))

self._wire_up(call_n, args)

return call_n

def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow:
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
Expand Down
49 changes: 49 additions & 0 deletions hugr-py/src/hugr/_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from dataclasses import dataclass

import hugr._ops as ops
import hugr._val as val

from ._dfg import _DfBase
from hugr._node_port import Node
from ._hugr import Hugr
from ._tys import TypeRow, TypeParam, PolyFuncType


@dataclass
class Function(_DfBase[ops.FuncDefn]):
def __init__(
self,
name: str,
input_types: TypeRow,
type_params: list[TypeParam] | None = None,
) -> None:
root_op = ops.FuncDefn(name, input_types, type_params or [])
super().__init__(root_op)


@dataclass
class Module:
hugr: Hugr

def __init__(self) -> None:
self.hugr = Hugr(ops.Module())

def define_function(
self,
name: str,
input_types: TypeRow,
type_params: list[TypeParam] | None = None,
) -> Function:
parent_op = ops.FuncDefn(name, input_types, type_params or [])
return Function.new_nested(parent_op, self.hugr)

def define_main(self, input_types: TypeRow) -> Function:
return self.define_function("main", input_types)

def declare_function(self, name: str, signature: PolyFuncType) -> Node:
return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.root)

def add_const(self, value: val.Value) -> Node:
return self.hugr.add_node(ops.Const(value), self.hugr.root)
119 changes: 10 additions & 109 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,121 +2,28 @@

from collections.abc import Mapping
from dataclasses import dataclass, field, replace
from enum import Enum
from typing import (
ClassVar,
Generic,
Iterable,
Iterator,
Protocol,
TypeVar,
cast,
overload,
Type as PyType,
)

from typing_extensions import Self

from hugr._ops import Op, DataflowOp, Const
from hugr._tys import Type, Kind
from hugr._ops import Op, DataflowOp, Const, Call
from hugr._tys import Type, Kind, ValueKind
from hugr._val import Value
from hugr._node_port import Direction, InPort, OutPort, ToNode, Node, _SubPort
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.utils import BiMap

from ._exceptions import ParentBeforeChild


class Direction(Enum):
INCOMING = 0
OUTGOING = 1


@dataclass(frozen=True, eq=True, order=True)
class _Port:
node: Node
offset: int
direction: ClassVar[Direction]


@dataclass(frozen=True, eq=True, order=True)
class InPort(_Port):
direction: ClassVar[Direction] = Direction.INCOMING


class Wire(Protocol):
def out_port(self) -> OutPort: ...


@dataclass(frozen=True, eq=True, order=True)
class OutPort(_Port, Wire):
direction: ClassVar[Direction] = Direction.OUTGOING

def out_port(self) -> OutPort:
return self


class ToNode(Wire, Protocol):
def to_node(self) -> Node: ...

@overload
def __getitem__(self, index: int) -> OutPort: ...
@overload
def __getitem__(self, index: slice) -> Iterator[OutPort]: ...
@overload
def __getitem__(self, index: tuple[int, ...]) -> Iterator[OutPort]: ...

def __getitem__(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
return self.to_node()._index(index)

def out_port(self) -> "OutPort":
return OutPort(self.to_node(), 0)

def inp(self, offset: int) -> InPort:
return InPort(self.to_node(), offset)

def out(self, offset: int) -> OutPort:
return OutPort(self.to_node(), offset)

def port(self, offset: int, direction: Direction) -> InPort | OutPort:
if direction == Direction.INCOMING:
return self.inp(offset)
else:
return self.out(offset)


@dataclass(frozen=True, eq=True, order=True)
class Node(ToNode):
idx: int
_num_out_ports: int | None = field(default=None, compare=False)

def _index(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
match index:
case int(index):
if self._num_out_ports is not None:
if index >= self._num_out_ports:
raise IndexError("Index out of range")
return self.out(index)
case slice():
start = index.start or 0
stop = index.stop or self._num_out_ports
if stop is None:
raise ValueError(
"Stop must be specified when number of outputs unknown"
)
step = index.step or 1
return (self[i] for i in range(start, stop, step))
case tuple(xs):
return (self[i] for i in xs)

def to_node(self) -> Node:
return self


@dataclass()
class NodeData:
op: Op
Expand All @@ -131,25 +38,15 @@ def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
return SerialOp(root=o) # type: ignore[arg-type]


_SO = _SubPort[OutPort]
_SI = _SubPort[InPort]

P = TypeVar("P", InPort, OutPort)
K = TypeVar("K", InPort, OutPort)
OpVar = TypeVar("OpVar", bound=Op)
OpVar2 = TypeVar("OpVar2", bound=Op)


@dataclass(frozen=True, eq=True, order=True)
class _SubPort(Generic[P]):
port: P
sub_offset: int = 0

def next_sub_offset(self) -> Self:
return replace(self, sub_offset=self.sub_offset + 1)


_SO = _SubPort[OutPort]
_SI = _SubPort[InPort]


class ParentBuilder(ToNode, Protocol[OpVar]):
hugr: Hugr[OpVar]
parent_node: Node
Expand Down Expand Up @@ -360,6 +257,10 @@ def port_type(self, port: InPort | OutPort) -> Type | None:
op = self[port.node].op
if isinstance(op, DataflowOp):
return op.port_type(port)
if isinstance(op, Call) and isinstance(port, OutPort):
kind = self.port_kind(port)
if isinstance(kind, ValueKind):
return kind.ty
return None

def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, Node]:
Expand Down
115 changes: 115 additions & 0 deletions hugr-py/src/hugr/_node_port.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

from dataclasses import dataclass, field, replace
from enum import Enum
from typing import (
ClassVar,
Iterator,
Protocol,
overload,
TypeVar,
Generic,
)
from typing_extensions import Self


class Direction(Enum):
INCOMING = 0
OUTGOING = 1


@dataclass(frozen=True, eq=True, order=True)
class _Port:
node: Node
offset: int
direction: ClassVar[Direction]


@dataclass(frozen=True, eq=True, order=True)
class InPort(_Port):
direction: ClassVar[Direction] = Direction.INCOMING


class Wire(Protocol):
def out_port(self) -> OutPort: ...


@dataclass(frozen=True, eq=True, order=True)
class OutPort(_Port, Wire):
direction: ClassVar[Direction] = Direction.OUTGOING

def out_port(self) -> OutPort:
return self


class ToNode(Wire, Protocol):
def to_node(self) -> Node: ...

@overload
def __getitem__(self, index: int) -> OutPort: ...
@overload
def __getitem__(self, index: slice) -> Iterator[OutPort]: ...
@overload
def __getitem__(self, index: tuple[int, ...]) -> Iterator[OutPort]: ...

def __getitem__(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
return self.to_node()._index(index)

def out_port(self) -> "OutPort":
return OutPort(self.to_node(), 0)

def inp(self, offset: int) -> InPort:
return InPort(self.to_node(), offset)

def out(self, offset: int) -> OutPort:
return OutPort(self.to_node(), offset)

def port(self, offset: int, direction: Direction) -> InPort | OutPort:
if direction == Direction.INCOMING:
return self.inp(offset)
else:
return self.out(offset)


@dataclass(frozen=True, eq=True, order=True)
class Node(ToNode):
idx: int
_num_out_ports: int | None = field(default=None, compare=False)

def _index(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
match index:
case int(index):
if self._num_out_ports is not None:
if index >= self._num_out_ports:
raise IndexError("Index out of range")
return self.out(index)
case slice():
start = index.start or 0
stop = index.stop or self._num_out_ports
if stop is None:
raise ValueError(
"Stop must be specified when number of outputs unknown"
)
step = index.step or 1
return (self[i] for i in range(start, stop, step))
case tuple(xs):
return (self[i] for i in xs)

def to_node(self) -> Node:
return self


P = TypeVar("P", InPort, OutPort)


@dataclass(frozen=True, eq=True, order=True)
class _SubPort(Generic[P]):
port: P
sub_offset: int = 0

def next_sub_offset(self) -> Self:
return replace(self, sub_offset=self.sub_offset + 1)
Loading

0 comments on commit af062ea

Please sign in to comment.