From 7be2c00b60831bd69dba3f7a02a4f69c4a17dab3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 8 May 2024 11:56:58 -0700 Subject: [PATCH] [IR] Implement `node` and `num_nodes` on Graph (#1516) - `node()` to get node by index or name - `num_nodes()` to obtain the node counts --- onnxscript/ir/_core.py | 49 +++++++++++++++++++++++++++++++++++-- onnxscript/ir/_core_test.py | 22 ++++++++++++++++- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index d788ec51a..6f81598e1 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -721,7 +721,10 @@ def __init__(self, value: str | None) -> None: value: The value of the dimension. It should not be an int. """ if isinstance(value, int): - raise TypeError("The value of a SymbolicDim cannot be an int") + raise TypeError( + "The value of a SymbolicDim cannot be an int. " + "If you are creating a Shape, use int directly instead of SymbolicDim." + ) self._value = value def __eq__(self, other: object) -> bool: @@ -1717,6 +1720,48 @@ def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node: node.graph = self return node + def node(self, index_or_name: int | str, /) -> Node: + """Get a node by index or name. + + This is an O(n) operation. Getting nodes on the ends of the graph (0 or -1) is O(1). + + .. note:: + If you need repeated random access, consider turning it into a list with ``list(graph)`` . + Or a dictionary for repeated access by name: ``{node.name for node in graph}`` . + + When a name is provided and if there are multiple nodes with the same name, + the first node with the name is returned. + + Args: + index_or_name: The index or name of the node. + + Returns: + The node if found. + + Raises: + IndexError: If the index is out of range. + ValueError: If the node with the given name is not found. + """ + # NOTE: This is a method specific to Graph, not required by the protocol unless proven + if isinstance(index_or_name, int): + return self[index_or_name] + for node in self: + if node.name == index_or_name: + return node + raise ValueError(f"Node with name '{index_or_name}' not found.") + + def num_nodes(self) -> int: + """Get the number of nodes in the graph in O(1) time. + + Note that this method returns the number of nodes this graph directly contains. + It does not count nodes in subgraphs. + + This is an alias for ``len(graph)``. Use this if you prefer a more descriptive + name for readability. + """ + # NOTE: This is a method specific to Graph, not required by the protocol unless proven + return len(self) + # Mutation methods def append(self, node: Node, /) -> None: """Append a node to the graph in O(1) time. @@ -1743,7 +1788,7 @@ def extend(self, nodes: Iterable[Node], /) -> None: self._nodes.extend(nodes) def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None: - """Remove nodes from the graph in O(#num of nodes) time. + """Remove nodes from the graph in O(#num of nodes to remove) time. If any errors are raise, to ensure the graph is not left in an inconsistent state, the graph is not modified. diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 99e88d65d..07c3301c0 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -646,7 +646,9 @@ class GraphTest(unittest.TestCase): def setUp(self) -> None: self.v0 = _core.Input(name="v0") self.v1 = _core.Input(name="v1") - self.node = _core.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1) + self.node = _core.Node( + "", "Add", inputs=(self.v0, self.v1), num_outputs=1, name="node_add" + ) self.graph = _core.Graph( (self.v0, self.v1), self.node.outputs, @@ -664,6 +666,24 @@ def test_initialize(self): def test_it_is_iterable_of_nodes(self): self.assertEqual(list(self.graph), [self.node]) + def test_node_returns_node_by_name(self): + self.assertIs(self.graph.node("node_add"), self.node) + + def test_node_returns_node_by_index(self): + self.assertIs(self.graph.node(0), self.node) + + def test_node_raises_when_node_does_not_exist(self): + with self.assertRaisesRegex(ValueError, "not found"): + self.graph.node("non_existent") + + def test_node_raises_when_index_out_of_range(self): + with self.assertRaises(IndexError): + self.graph.node(1) + + def test_num_nodes_returns_the_count_of_nodes(self): + self.assertEqual(self.graph.num_nodes(), 1) + self.assertEqual(self.graph.num_nodes(), len(self.graph)) + def test_metadata(self): self.graph.meta["test"] = 1 self.assertEqual(self.graph.meta["test"], 1)