Skip to content

Commit

Permalink
[IR] Implement node and num_nodes on Graph (#1516)
Browse files Browse the repository at this point in the history
- `node()` to get node by index or name
- `num_nodes()` to obtain the node counts
  • Loading branch information
justinchuby authored May 8, 2024
1 parent af6afd1 commit 7be2c00
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
49 changes: 47 additions & 2 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,10 @@ def __init__(self, value: str | None) -> None:
value: The value of the dimension. It should not be an int.
"""
if isinstance(value, int):
raise TypeError("The value of a SymbolicDim cannot be an int")
raise TypeError(
"The value of a SymbolicDim cannot be an int. "
"If you are creating a Shape, use int directly instead of SymbolicDim."
)
self._value = value

def __eq__(self, other: object) -> bool:
Expand Down Expand Up @@ -1717,6 +1720,48 @@ def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
node.graph = self
return node

def node(self, index_or_name: int | str, /) -> Node:
"""Get a node by index or name.
This is an O(n) operation. Getting nodes on the ends of the graph (0 or -1) is O(1).
.. note::
If you need repeated random access, consider turning it into a list with ``list(graph)`` .
Or a dictionary for repeated access by name: ``{node.name for node in graph}`` .
When a name is provided and if there are multiple nodes with the same name,
the first node with the name is returned.
Args:
index_or_name: The index or name of the node.
Returns:
The node if found.
Raises:
IndexError: If the index is out of range.
ValueError: If the node with the given name is not found.
"""
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
if isinstance(index_or_name, int):
return self[index_or_name]
for node in self:
if node.name == index_or_name:
return node
raise ValueError(f"Node with name '{index_or_name}' not found.")

def num_nodes(self) -> int:
"""Get the number of nodes in the graph in O(1) time.
Note that this method returns the number of nodes this graph directly contains.
It does not count nodes in subgraphs.
This is an alias for ``len(graph)``. Use this if you prefer a more descriptive
name for readability.
"""
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
return len(self)

# Mutation methods
def append(self, node: Node, /) -> None:
"""Append a node to the graph in O(1) time.
Expand All @@ -1743,7 +1788,7 @@ def extend(self, nodes: Iterable[Node], /) -> None:
self._nodes.extend(nodes)

def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
"""Remove nodes from the graph in O(#num of nodes) time.
"""Remove nodes from the graph in O(#num of nodes to remove) time.
If any errors are raise, to ensure the graph is not left in an inconsistent state,
the graph is not modified.
Expand Down
22 changes: 21 additions & 1 deletion onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,9 @@ class GraphTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = _core.Input(name="v0")
self.v1 = _core.Input(name="v1")
self.node = _core.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1)
self.node = _core.Node(
"", "Add", inputs=(self.v0, self.v1), num_outputs=1, name="node_add"
)
self.graph = _core.Graph(
(self.v0, self.v1),
self.node.outputs,
Expand All @@ -664,6 +666,24 @@ def test_initialize(self):
def test_it_is_iterable_of_nodes(self):
self.assertEqual(list(self.graph), [self.node])

def test_node_returns_node_by_name(self):
self.assertIs(self.graph.node("node_add"), self.node)

def test_node_returns_node_by_index(self):
self.assertIs(self.graph.node(0), self.node)

def test_node_raises_when_node_does_not_exist(self):
with self.assertRaisesRegex(ValueError, "not found"):
self.graph.node("non_existent")

def test_node_raises_when_index_out_of_range(self):
with self.assertRaises(IndexError):
self.graph.node(1)

def test_num_nodes_returns_the_count_of_nodes(self):
self.assertEqual(self.graph.num_nodes(), 1)
self.assertEqual(self.graph.num_nodes(), len(self.graph))

def test_metadata(self):
self.graph.meta["test"] = 1
self.assertEqual(self.graph.meta["test"], 1)
Expand Down

0 comments on commit 7be2c00

Please sign in to comment.