Skip to content

Commit

Permalink
fix: Update number of ports for PartialOps, and sanitize orderd edges
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Nov 5, 2024
1 parent c70d68a commit ed74dad
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
6 changes: 6 additions & 0 deletions hugr-py/src/hugr/build/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,12 @@ def _wire_up(self, node: Node, ports: Iterable[Wire]) -> tys.TypeRow:
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops._PartialOp):
op._set_in_types(tys)
if isinstance(op, ops.DataflowOp):
# Update the node's input and output port count
sig = op.outer_signature()
self.hugr._update_port_count(
node, num_inps=len(sig.input), num_outs=len(sig.output)
)
return tys

def _get_dataflow_type(self, wire: Wire) -> tys.Type:
Expand Down
44 changes: 30 additions & 14 deletions hugr-py/src/hugr/hugr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,31 @@ def _update_node_outs(self, node: Node, num_outs: int | None) -> Node:
Returns:
The updated node.
"""
self[node]._num_outs = num_outs or 0
node = replace(node, _num_out_ports=num_outs)
parent = self[node].parent
if parent is not None:
pos = self[parent].children.index(node)
self[parent].children[pos] = node
return self._update_port_count(node, num_outs=num_outs)

def _update_port_count(
self, node: Node, *, num_inps: int | None = None, num_outs: int | None
) -> Node:
"""Update the number of incoming and outgoing ports for a node.
If `num_inps` or `num_outs` is None, the corresponding count is not updated.
Returns:
The updated node.
"""
if num_inps is None and num_outs is None:
return node

if num_inps is not None:
self[node]._num_inps = num_inps
if num_outs is not None:
self[node]._num_outs = num_outs
node = replace(node, _num_out_ports=num_outs)
parent = self[node].parent
if parent is not None:
pos = self[parent].children.index(node)
self[parent].children[pos] = node

return node

def add_node(
Expand Down Expand Up @@ -297,7 +316,7 @@ def has_link(self, src: OutPort, dst: InPort) -> bool:
Examples:
>>> df = dfg.Dfg(tys.Bool)
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0))
>>> df.hugr.is_linked(df.input_node.out(0), df.output_node.inp(0))
>>> df.hugr.has_link(df.input_node.out(0), df.output_node.inp(0))
True
"""
return dst in self.linked_ports(src)
Expand Down Expand Up @@ -622,14 +641,11 @@ def _serialize_link(
)

def _constrain_offset(self, p: P) -> PortOffset:
# negative offsets are used to refer to the last port
# An offset of -1 is a special case, indicating an order edge,
# not counted in the number of ports.
if p.offset < 0:
match p.direction:
case Direction.INCOMING:
current = self.num_incoming(p.node)
case Direction.OUTGOING:
current = self.num_outgoing(p.node)
offset = current + p.offset + 1
assert p.offset == -1, "Only order edges are allowed with offset < 0"
offset = self.num_ports(p.node, p.direction)
else:
offset = p.offset

Expand Down
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def name(self) -> str:
return str(self)


def _sig_port_type(sig: tys.FunctionType, port: InPort | OutPort) -> tys.Type | None:
def _sig_port_type(sig: tys.FunctionType, port: InPort | OutPort) -> tys.Type:
"""Get the type of the given dataflow port given the signature of the operation."""
if port.offset == -1:
# Order port
Expand Down

0 comments on commit ed74dad

Please sign in to comment.