diff --git a/hugr-py/src/hugr/cfg.py b/hugr-py/src/hugr/cfg.py index 1d9b364a0..7c46c9abd 100644 --- a/hugr-py/src/hugr/cfg.py +++ b/hugr-py/src/hugr/cfg.py @@ -8,7 +8,7 @@ from typing_extensions import Self -from hugr import ops, val +from hugr import ops, tys, val from .dfg import _DfBase from .exceptions import MismatchedExit, NoSiblingAncestor, NotInSameCfg @@ -22,6 +22,15 @@ class Block(_DfBase[ops.DataflowBlock]): """Builder class for a basic block in a HUGR control flow graph.""" + def set_outputs(self, *outputs: Wire) -> None: + super().set_outputs(*outputs) + + assert len(outputs) > 0 + branching = outputs[0] + branch_type = self.hugr.port_type(branching.out_port()) + assert isinstance(branch_type, tys.Sum) + self._set_parent_output_count(len(branch_type.variant_rows)) + def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: self.set_outputs(branching, *other_outputs) @@ -249,3 +258,6 @@ def branch_exit(self, src: Wire) -> None: else: self._exit_op._cfg_outputs = out_types self.parent_op._outputs = out_types + self.parent_node = self.hugr._update_node_outs( + self.parent_node, len(out_types) + ) diff --git a/hugr-py/src/hugr/cond_loop.py b/hugr-py/src/hugr/cond_loop.py index 171de3a0e..a62f89c0a 100644 --- a/hugr-py/src/hugr/cond_loop.py +++ b/hugr-py/src/hugr/cond_loop.py @@ -11,13 +11,14 @@ from typing_extensions import Self from hugr import ops +from hugr.tys import Sum from .dfg import _DfBase from .hugr import Hugr, ParentBuilder if TYPE_CHECKING: from .node_port import Node, ToNode, Wire - from .tys import Sum, TypeRow + from .tys import TypeRow class Case(_DfBase[ops.Case]): @@ -215,6 +216,16 @@ def __init__(self, just_inputs: TypeRow, rest: TypeRow) -> None: root_op = ops.TailLoop(just_inputs, rest) super().__init__(root_op) + def set_outputs(self, *outputs: Wire) -> None: + super().set_outputs(*outputs) + + assert len(outputs) > 0 + sum_wire = outputs[0] + sum_type = self.hugr.port_type(sum_wire.out_port()) + assert isinstance(sum_type, Sum) + assert len(sum_type.variant_rows) == 2 + self._set_parent_output_count(len(sum_type.variant_rows[1]) + len(outputs) - 1) + def set_loop_outputs(self, sum_wire: Wire, *rest: Wire) -> None: """Set the outputs of the loop body. The first wire must be the sum type that controls loop termination. diff --git a/hugr-py/src/hugr/dfg.py b/hugr-py/src/hugr/dfg.py index a06df03c0..f1ea17cd0 100644 --- a/hugr-py/src/hugr/dfg.py +++ b/hugr-py/src/hugr/dfg.py @@ -472,6 +472,18 @@ def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) self.parent_op._set_out_types(self._output_op().types) + def _set_parent_output_count(self, count: int) -> None: + """Set the final number of output ports on the parent operation. + + Args: + count: The number of output ports. + + Example: + >>> dfg = Dfg(tys.Bool) + >>> dfg._set_parent_output_count(2) + """ + self.parent_node = self.hugr._update_node_outs(self.parent_node, count) + def add_state_order(self, src: Node, dst: Node) -> None: """Add a state order link between two nodes. @@ -631,6 +643,10 @@ def __init__(self, *input_types: tys.Type) -> None: parent_op = ops.DFG(list(input_types), None) super().__init__(parent_op) + def set_outputs(self, *outputs: Wire) -> None: + super().set_outputs(*outputs) + self._set_parent_output_count(len(outputs)) + def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None: """Find the ancestor of `tgt` that is a sibling of `src`, if one exists.""" diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py index f1eda8e8a..99392e15e 100644 --- a/hugr-py/src/hugr/hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -173,6 +173,20 @@ def _add_node( self[parent].children.append(node) return node + def _update_node_outs(self, node: Node, num_outs: int | None) -> Node: + """Update the number of outgoing ports for a node. + + Returns: + The updated node. + """ + self[node]._num_outs = num_outs or 0 + node = replace(node, _num_out_ports=num_outs) + parent = self[node].parent + if parent is not None: + pos = self[parent].children.index(node) + self[parent].children[pos] = node + return node + def add_node( self, op: Op,