Skip to content

Commit

Permalink
fix(py): Invalid node indexing (#1457)
Browse files Browse the repository at this point in the history
Fixes #1454, and other failing cases
  • Loading branch information
aborgna-q authored Aug 21, 2024
1 parent 26f05b5 commit d6edcd7
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 4 deletions.
44 changes: 40 additions & 4 deletions hugr-py/src/hugr/node_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,24 +164,60 @@ def _index(
) -> OutPort | Iterator[OutPort]:
match index:
case PortOffset(index):
if self._num_out_ports is not None and index >= self._num_out_ports:
msg = "Index out of range"
raise IndexError(msg)
index = self._normalize_index(index)
return self.out(index)
case slice():
start = index.start or 0
stop = index.stop or self._num_out_ports
stop = index.stop if index.stop is not None else self._num_out_ports
if stop is None:
msg = (
f"{self} does not have a fixed number of output ports. "
"Iterating over all output ports is not supported."
)
raise ValueError(msg)

start = self._normalize_index(start)
stop = self._normalize_index(stop, allow_eq_len=True)
step = index.step or 1

return (self[i] for i in range(start, stop, step))
case tuple(xs):
return (self[i] for i in xs)

def _normalize_index(self, index: int, allow_eq_len: bool = False) -> int:
"""Given an index passed to `__getitem__`, normalize it to be within the
range of output ports.
Args:
index: index to normalize.
allow_eq_len: whether to allow the index to be equal to the number of
output ports.
Returns:
Normalized index.
Raises:
IndexError: if the index is out of range.
"""
msg = f"Index {index} out of range"

if self._num_out_ports is not None:
if index > self._num_out_ports:
raise IndexError(msg)
if index == self._num_out_ports and not allow_eq_len:
raise IndexError(msg)
if index < -self._num_out_ports:
raise IndexError(msg)
else:
if index < 0:
raise IndexError(msg)

if index >= 0:
return index
else:
assert self._num_out_ports is not None
return self._num_out_ports + index

def to_node(self) -> Node:
return self

Expand Down
37 changes: 37 additions & 0 deletions hugr-py/tests/test_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from hugr.node_port import Node, OutPort


def test_index():
n = Node(0, _num_out_ports=3)
assert n[0] == OutPort(n, 0)
assert n[1] == OutPort(n, 1)
assert n[2] == OutPort(n, 2)
assert n[-1] == OutPort(n, 2)

with pytest.raises(IndexError, match="Index 3 out of range"):
_ = n[3]

with pytest.raises(IndexError, match="Index -8 out of range"):
_ = n[-8]


def test_slices():
n = Node(0, _num_out_ports=3)
all_ports = [OutPort(n, i) for i in range(3)]

assert list(n) == all_ports
assert list(n[:0]) == []
assert list(n[0:0]) == []
assert list(n[0:1]) == [OutPort(n, 0)]
assert list(n[1:2]) == [OutPort(n, 1)]
assert list(n[:]) == all_ports
assert list(n[0:]) == all_ports
assert list(n[:3]) == all_ports
assert list(n[0:3]) == all_ports
assert list(n[-1:]) == [OutPort(n, 2)]
assert list(n[-3:]) == all_ports

with pytest.raises(IndexError, match="Index -4 out of range"):
_ = n[-4:]

0 comments on commit d6edcd7

Please sign in to comment.