Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Jul 24, 2024
1 parent ee10468 commit 1b8a71e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion onnxscript/rewriter/generic_pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def apply_pattern(op, x, pos_ids, axis, **_):
self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node])
out = buffer.getvalue()
# TODO(Rama): What is this assertion testing? Is it to check that `verbose` is working?
if self.matcher_algo == generic_pattern.GenericPatternMatcher:
if self.matcher_algo is generic_pattern.GenericPatternMatcher:
self.assertIn("[GenericPatternMatcher.match", out)

def test_rotary_embedding_onnxscript(self):
Expand Down
21 changes: 16 additions & 5 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
MutableSequence,
Expand Down Expand Up @@ -906,7 +907,7 @@ def match(
node: ir.Node,
verbose: int = 0,
) -> MatchResult:
pass
"""Match the pattern against the subgraph ending at the given node."""

def __str__(self) -> str:
return str(self.pattern)
Expand Down Expand Up @@ -1034,11 +1035,13 @@ def _match_node_output(self, pattern_value: NodeOutputPattern, value: ir.Value)
return self._match_node(pattern_value.producer(), node)

def _init_match(self, verbose: int) -> None:
"""Initialize the match state. Invoked before starting a new match."""
self._verbose = verbose
self._matched: dict[NodePattern, ir.Node] = {}
self._match: MatchResult = MatchResult()

def _get_output_values(self) -> list[ir.Value] | None:
"""Get values bound to the output variables of the pattern."""
output_values: list[ir.Value] = []
unbound_values: list[str] = []
for j, value_pattern in enumerate(self.pattern.outputs):
Expand Down Expand Up @@ -1095,9 +1098,9 @@ def _match_single_output_node(
match.outputs.extend(output_values)
return match

def _multi_match(self, candidate: dict[NodePattern, ir.Node]) -> MatchResult:
def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult:
match = self._match
for pattern_node, node in candidate.items():
for pattern_node, node in zip(self.pattern.output_nodes, candidate):
if not self._match_node(pattern_node, node):
return match
output_values = self._get_output_values()
Expand All @@ -1121,6 +1124,15 @@ def match(
self._init_match(verbose)
return self._match_single_output_node(model, graph_or_function, node)
else:
# Note: This is a potentially expensive algorithm for matching patterns with
# multiple output nodes. For patterns with N output nodes, we try all possible
# combinations of N nodes from the graph, and check if they match the pattern.
# The first node is fixed to the node argument in this method call. We do
# some simple filtering by restricting the candidates for each remaining
# output nodes to graph nodes with the same op_type as the corresponding pattern
# node. For now, this is intended to be a simple, but robust, implementation
# that can be used for debugging and testing. The GenericPatternMatcher is a
# more sophisticated implementation, but incomplete.
pattern_output_nodes = self.pattern.output_nodes
op_to_nodes: dict[tuple[str, str, str], list[ir.Node]] = {}
for n in graph_or_function:
Expand All @@ -1136,9 +1148,8 @@ def get_nodes(pattern_node):
candidates = [iter([node])] + [get_nodes(pn) for pn in pattern_output_nodes[1:]]
match = None
for combination in itertools.product(*candidates):
candidate = dict(zip(pattern_output_nodes, combination))
self._init_match(verbose)
match = self._multi_match(candidate)
match = self._multi_match(combination)
if match:
return match
if match is None:
Expand Down

0 comments on commit 1b8a71e

Please sign in to comment.