diff --git a/hugr-py/src/hugr/build/dfg.py b/hugr-py/src/hugr/build/dfg.py index 8af4ae1ad..48a710df7 100644 --- a/hugr-py/src/hugr/build/dfg.py +++ b/hugr-py/src/hugr/build/dfg.py @@ -510,7 +510,10 @@ def add_state_order(self, src: Node, dst: Node) -> None: [Node(2)] """ # adds edge to the right of all existing edges - self.hugr.add_link(src.out(-1), dst.inp(-1)) + source = src.out(-1) + target = dst.inp(-1) + if not self.hugr.has_link(source, target): + self.hugr.add_link(source, target) def load( self, const: ToNode | val.Value, const_parent: ToNode | None = None @@ -616,6 +619,12 @@ def _wire_up(self, node: Node, ports: Iterable[Wire]) -> tys.TypeRow: tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops._PartialOp): op._set_in_types(tys) + if isinstance(op, ops.DataflowOp): + # Update the node's input and output port count + sig = op.outer_signature() + self.hugr._update_port_count( + node, num_inps=len(sig.input), num_outs=len(sig.output) + ) return tys def _get_dataflow_type(self, wire: Wire) -> tys.Type: diff --git a/hugr-py/src/hugr/hugr/base.py b/hugr-py/src/hugr/hugr/base.py index 979021057..14e40e795 100644 --- a/hugr-py/src/hugr/hugr/base.py +++ b/hugr-py/src/hugr/hugr/base.py @@ -182,12 +182,31 @@ def _update_node_outs(self, node: Node, num_outs: int | None) -> Node: Returns: The updated node. """ - self[node]._num_outs = num_outs or 0 - node = replace(node, _num_out_ports=num_outs) - parent = self[node].parent - if parent is not None: - pos = self[parent].children.index(node) - self[parent].children[pos] = node + return self._update_port_count(node, num_outs=num_outs) + + def _update_port_count( + self, node: Node, *, num_inps: int | None = None, num_outs: int | None + ) -> Node: + """Update the number of incoming and outgoing ports for a node. + + If `num_inps` or `num_outs` is None, the corresponding count is not updated. + + Returns: + The updated node. + """ + if num_inps is None and num_outs is None: + return node + + if num_inps is not None: + self[node]._num_inps = num_inps + if num_outs is not None: + self[node]._num_outs = num_outs + node = replace(node, _num_out_ports=num_outs) + parent = self[node].parent + if parent is not None: + pos = self[parent].children.index(node) + self[parent].children[pos] = node + return node def add_node( @@ -284,6 +303,24 @@ def _unused_sub_offset(self, port: P) -> _SubPort[P]: sub_port = sub_port.next_sub_offset() return sub_port + def has_link(self, src: OutPort, dst: InPort) -> bool: + """Check if there is a link between two ports. + + Args: + src: Source port. + dst: Destination port. + + Returns: + True if there is a link, False otherwise. + + Examples: + >>> df = dfg.Dfg(tys.Bool) + >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) + >>> df.hugr.has_link(df.input_node.out(0), df.output_node.inp(0)) + True + """ + return dst in self.linked_ports(src) + def add_link(self, src: OutPort, dst: InPort) -> None: """Add a link (edge) between two nodes to the HUGR, from an outgoing port to an incoming port. @@ -604,14 +641,11 @@ def _serialize_link( ) def _constrain_offset(self, p: P) -> PortOffset: - # negative offsets are used to refer to the last port + # An offset of -1 is a special case, indicating an order edge, + # not counted in the number of ports. if p.offset < 0: - match p.direction: - case Direction.INCOMING: - current = self.num_incoming(p.node) - case Direction.OUTGOING: - current = self.num_outgoing(p.node) - offset = current + p.offset + 1 + assert p.offset == -1, "Only order edges are allowed with offset < 0" + offset = self.num_ports(p.node, p.direction) else: offset = p.offset diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index a97470adf..8c5b845a8 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -72,6 +72,11 @@ def name(self) -> str: def _sig_port_type(sig: tys.FunctionType, port: InPort | OutPort) -> tys.Type: + """Get the type of the given dataflow port given the signature of the operation.""" + if port.offset == -1: + # Order port + msg = "Order port has no type." + raise ValueError(msg) if port.direction == Direction.INCOMING: return sig.input[port.offset] return sig.output[port.offset] diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 3f7145a87..2607a4773 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -316,3 +316,17 @@ def test_alias() -> None: _dcl = mod.add_alias_decl("my_bool", tys.TypeBound.Copyable) validate(mod.hugr) + + +# https://github.com/CQCL/hugr/issues/1625 +def test_dfg_unpack() -> None: + dfg = Dfg(tys.Tuple(tys.Bool, tys.Bool)) + bool1, _unused_bool2 = dfg.add_op(ops.UnpackTuple(), *dfg.inputs()) + cond = dfg.add_conditional(bool1) + with cond.add_case(0) as case: + case.set_outputs(bool1) + with cond.add_case(1) as case: + case.set_outputs(bool1) + dfg.set_outputs(*cond.outputs()) + + validate(dfg.hugr)