-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create a recursive graph iterator and use it to refactor UnusedFuncti…
…onRemover (#1565) - Create `traversal.py` for graph traversal utilities and implemented `RecursiveGraphIterator`. Expose `traversal` to the `ir` module. Fixes #1556 - Remove `NodeTransformer` because `RecursiveGraphIterator` is more flexible. - Refactor remove_unused_function.py to use `RecursiveGraphIterator`
- Loading branch information
1 parent
a6843da
commit c41ded5
Showing
6 changed files
with
212 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
"""Utilities for traversing the IR graph.""" | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"RecursiveGraphIterator", | ||
] | ||
|
||
from typing import Callable, Iterator, Reversible | ||
|
||
from typing_extensions import Self | ||
|
||
from onnxscript.ir import _core, _enums | ||
|
||
|
||
class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]): | ||
def __init__( | ||
self, | ||
graph: _core.Graph | _core.Function | _core.GraphView, | ||
*, | ||
recursive: Callable[[_core.Node], bool] | None = None, | ||
reverse: bool = False, | ||
): | ||
"""Iterate over the nodes in the graph, recursively visiting subgraphs. | ||
Args: | ||
graph: The graph to traverse. | ||
recursive: A callback that determines whether to recursively visit the subgraphs | ||
contained in a node. If not provided, all nodes in subgraphs are visited. | ||
reverse: Whether to iterate in reverse order. | ||
""" | ||
self._graph = graph | ||
self._recursive = recursive | ||
self._reverse = reverse | ||
self._iterator = self._recursive_node_iter(graph) | ||
|
||
def __iter__(self) -> Self: | ||
self._iterator = self._recursive_node_iter(self._graph) | ||
return self | ||
|
||
def __next__(self) -> _core.Node: | ||
return next(self._iterator) | ||
|
||
def _recursive_node_iter( | ||
self, graph: _core.Graph | _core.Function | _core.GraphView | ||
) -> Iterator[_core.Node]: | ||
iterable = reversed(graph) if self._reverse else graph | ||
for node in iterable: # type: ignore[union-attr] | ||
yield node | ||
if self._recursive is not None and not self._recursive(node): | ||
continue | ||
yield from self._iterate_subgraphs(node) | ||
|
||
def _iterate_subgraphs(self, node: _core.Node): | ||
for attr in node.attributes.values(): | ||
if not isinstance(attr, _core.Attr): | ||
continue | ||
if attr.type == _enums.AttributeType.GRAPH: | ||
yield from RecursiveGraphIterator( | ||
attr.value, | ||
recursive=self._recursive, | ||
reverse=self._reverse, | ||
) | ||
elif attr.type == _enums.AttributeType.GRAPHS: | ||
graphs = reversed(attr.value) if self._reverse else attr.value | ||
for graph in graphs: | ||
yield from RecursiveGraphIterator( | ||
graph, | ||
recursive=self._recursive, | ||
reverse=self._reverse, | ||
) | ||
|
||
def __reversed__(self) -> Iterator[_core.Node]: | ||
return RecursiveGraphIterator( | ||
self._graph, | ||
recursive=self._recursive, | ||
reverse=not self._reverse, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import parameterized | ||
|
||
from onnxscript import ir | ||
from onnxscript.ir import traversal | ||
|
||
|
||
class RecursiveGraphIteratorTest(unittest.TestCase): | ||
def setUp(self): | ||
self.graph = ir.Graph( | ||
[], | ||
[], | ||
nodes=[ | ||
ir.Node("", "Node1", []), | ||
ir.Node("", "Node2", []), | ||
ir.Node( | ||
"", | ||
"If", | ||
[], | ||
attributes=[ | ||
ir.AttrGraph( | ||
"then_branch", | ||
ir.Graph( | ||
[], | ||
[], | ||
nodes=[ir.Node("", "Node3", []), ir.Node("", "Node4", [])], | ||
name="then_graph", | ||
), | ||
), | ||
ir.AttrGraph( | ||
"else_branch", | ||
ir.Graph( | ||
[], | ||
[], | ||
nodes=[ir.Node("", "Node5", []), ir.Node("", "Node6", [])], | ||
name="else_graph", | ||
), | ||
), | ||
], | ||
), | ||
], | ||
name="main_graph", | ||
) | ||
|
||
@parameterized.parameterized.expand( | ||
[ | ||
("forward", False, ("Node1", "Node2", "If", "Node3", "Node4", "Node5", "Node6")), | ||
("reversed", True, ("If", "Node4", "Node3", "Node6", "Node5", "Node2", "Node1")), | ||
] | ||
) | ||
def test_recursive_graph_iterator(self, _: str, reverse: bool, expected: tuple[str, ...]): | ||
iterator = traversal.RecursiveGraphIterator(self.graph) | ||
if reverse: | ||
iterator = reversed(iterator) | ||
nodes = list(iterator) | ||
self.assertEqual(tuple(node.op_type for node in nodes), expected) | ||
|
||
@parameterized.parameterized.expand( | ||
[ | ||
("forward", False, ("Node1", "Node2", "If")), | ||
("reversed", True, ("If", "Node2", "Node1")), | ||
] | ||
) | ||
def test_recursive_graph_iterator_recursive_controls_recursive_behavior( | ||
self, _: str, reverse: bool, expected: list[str] | ||
): | ||
nodes = list( | ||
traversal.RecursiveGraphIterator( | ||
self.graph, recursive=lambda node: node.op_type != "If", reverse=reverse | ||
) | ||
) | ||
self.assertEqual(tuple(node.op_type for node in nodes), expected) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters