Skip to content

Commit

Permalink
feat(py): Set node output# for dfgs and cfgs (#1400)
Browse files Browse the repository at this point in the history
Updates the parent node with the output port count as soon as a builder
knows its output.

This lets us do things like
```python
dfg = Dfg(tys.Bool)
dfg.set_outputs(*dfg.inputs())
wires: list[Wire] = list(dfg)
```

---------

Co-authored-by: Seyon Sivarajah <[email protected]>
  • Loading branch information
aborgna-q and ss2165 authored Aug 8, 2024
1 parent e88910b commit c5d1a74
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 2 deletions.
14 changes: 13 additions & 1 deletion hugr-py/src/hugr/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing_extensions import Self

from hugr import ops, val
from hugr import ops, tys, val

from .dfg import _DfBase
from .exceptions import MismatchedExit, NoSiblingAncestor, NotInSameCfg
Expand All @@ -22,6 +22,15 @@
class Block(_DfBase[ops.DataflowBlock]):
"""Builder class for a basic block in a HUGR control flow graph."""

def set_outputs(self, *outputs: Wire) -> None:
super().set_outputs(*outputs)

assert len(outputs) > 0
branching = outputs[0]
branch_type = self.hugr.port_type(branching.out_port())
assert isinstance(branch_type, tys.Sum)
self._set_parent_output_count(len(branch_type.variant_rows))

def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
self.set_outputs(branching, *other_outputs)

Expand Down Expand Up @@ -249,3 +258,6 @@ def branch_exit(self, src: Wire) -> None:
else:
self._exit_op._cfg_outputs = out_types
self.parent_op._outputs = out_types
self.parent_node = self.hugr._update_node_outs(
self.parent_node, len(out_types)
)
13 changes: 12 additions & 1 deletion hugr-py/src/hugr/cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from typing_extensions import Self

from hugr import ops
from hugr.tys import Sum

from .dfg import _DfBase
from .hugr import Hugr, ParentBuilder

if TYPE_CHECKING:
from .node_port import Node, ToNode, Wire
from .tys import Sum, TypeRow
from .tys import TypeRow


class Case(_DfBase[ops.Case]):
Expand Down Expand Up @@ -215,6 +216,16 @@ def __init__(self, just_inputs: TypeRow, rest: TypeRow) -> None:
root_op = ops.TailLoop(just_inputs, rest)
super().__init__(root_op)

def set_outputs(self, *outputs: Wire) -> None:
super().set_outputs(*outputs)

assert len(outputs) > 0
sum_wire = outputs[0]
sum_type = self.hugr.port_type(sum_wire.out_port())
assert isinstance(sum_type, Sum)
assert len(sum_type.variant_rows) == 2
self._set_parent_output_count(len(sum_type.variant_rows[1]) + len(outputs) - 1)

def set_loop_outputs(self, sum_wire: Wire, *rest: Wire) -> None:
"""Set the outputs of the loop body. The first wire must be the sum type
that controls loop termination.
Expand Down
16 changes: 16 additions & 0 deletions hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,18 @@ def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
self.parent_op._set_out_types(self._output_op().types)

def _set_parent_output_count(self, count: int) -> None:
"""Set the final number of output ports on the parent operation.
Args:
count: The number of output ports.
Example:
>>> dfg = Dfg(tys.Bool)
>>> dfg._set_parent_output_count(2)
"""
self.parent_node = self.hugr._update_node_outs(self.parent_node, count)

def add_state_order(self, src: Node, dst: Node) -> None:
"""Add a state order link between two nodes.
Expand Down Expand Up @@ -631,6 +643,10 @@ def __init__(self, *input_types: tys.Type) -> None:
parent_op = ops.DFG(list(input_types), None)
super().__init__(parent_op)

def set_outputs(self, *outputs: Wire) -> None:
super().set_outputs(*outputs)
self._set_parent_output_count(len(outputs))


def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
"""Find the ancestor of `tgt` that is a sibling of `src`, if one exists."""
Expand Down
14 changes: 14 additions & 0 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,20 @@ def _add_node(
self[parent].children.append(node)
return node

def _update_node_outs(self, node: Node, num_outs: int | None) -> Node:
"""Update the number of outgoing ports for a node.
Returns:
The updated node.
"""
self[node]._num_outs = num_outs or 0
node = replace(node, _num_out_ports=num_outs)
parent = self[node].parent
if parent is not None:
pos = self[parent].children.index(node)
self[parent].children[pos] = node
return node

def add_node(
self,
op: Op,
Expand Down

0 comments on commit c5d1a74

Please sign in to comment.