From 19f1126af9697e7917f10e1dec4fe86dd209a34d Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 25 Jul 2024 16:17:45 -0700 Subject: [PATCH] Extend basic matcher to handle multiple-output-nodes (#1734) This PR extends the basic matcher to handle multiple output nodes. This provides an alternative to the generic-matcher algorithm, which is incomplete and fails in some circumstances. This can also be useful in debugging match-failures (when it is unclear if the failure is valid or due to limitations of the matching algorithm). The drawback is that this algorithm can, in some cases, be expensive, especially when the number of output-nodes is large and the graph size is large. (So far, however, we haven't encountered patterns with more than 2 output-nodes.) --- onnxscript/rewriter/generic_pattern.py | 10 +- onnxscript/rewriter/generic_pattern_test.py | 40 +++-- onnxscript/rewriter/pattern.py | 190 ++++++++++++++++---- 3 files changed, 184 insertions(+), 56 deletions(-) diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index d0daf2e06..2926f5964 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -298,7 +298,7 @@ def _match_backward( return self.none(starting_node, inspect.currentframe().f_lineno) for graph_input, pattern_input in zip(graph_node.inputs, pattern_node.inputs): - if len(list(graph_input.uses())) != len(list(pattern_input.uses())): + if len(graph_input.uses()) != len(pattern_input.uses()): self._hint( "BACKWARD: one input is used outside the pattern", "-- pattern", @@ -423,12 +423,12 @@ def _match_values_forward( return match_count if len(free) < len(pattern_node_users_not_matched): # Not enough successors to match the remaining patterns. - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) if len(pattern_node_users_not_matched) == len(free) == 1: # Only one option again. graph_node = free[0] if pattern_node_users_not_matched[0].op_identifier() != graph_node.op_identifier(): - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) key = pattern_node_users_not_matched[0] if self.verbose >= 10: @@ -461,11 +461,11 @@ def _match_values_forward( "-- model-matched", pattern_node_users_matched, ) - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) for k, v in ec.items(): if gc[k] < v: # Not enough types to match. - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) # At this stage, we know matching the types is possible. # We first mark whatever is possible. diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index d65f01c8d..db0e2a638 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -12,6 +12,7 @@ import onnx.parser import onnx.reference import onnxruntime as ort +import parameterized from onnxscript import ir from onnxscript.rewriter import generic_pattern, pattern @@ -19,6 +20,13 @@ FLOAT = onnx.TensorProto.FLOAT +@parameterized.parameterized_class( + ("matcher_algo",), + [ + (generic_pattern.GenericPatternMatcher,), + (pattern.SimplePatternMatcher,), + ], +) class GenericPatternTest(unittest.TestCase): def _range(self, *shape, bias: float | None = None): n = np.prod(shape) @@ -48,7 +56,7 @@ def validate_mapping(context, x, y, z, **_) -> bool: match_pattern, apply_pattern, validate_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, ) class AddAdd(onnx.reference.op_run.OpRun): @@ -128,7 +136,7 @@ def validate_mapping(context, **_) -> bool: match_pattern, apply_pattern, validate_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -256,11 +264,7 @@ def match_pattern(op, x): def apply_pattern(op, x, **_): return op.SinCos(x, domain="com.microsoft", outputs=2) - rule = pattern.RewriteRule( - match_pattern, - apply_pattern, - matcher=generic_pattern.GenericPatternMatcher, - ) + rule = pattern.RewriteRule(match_pattern, apply_pattern, matcher=self.matcher_algo) model_proto = onnx.parser.parse_model( """ @@ -281,8 +285,10 @@ def apply_pattern(op, x, **_): self.assertEqual(len(graph.node), 2) self.assertEqual(graph.node[0].op_type, "SinCos") - @unittest.skip("Input variable reuse not supported yet") def test_shared_root_value_extra_use(self): + if self.matcher_algo is generic_pattern.GenericPatternMatcher: + raise unittest.SkipTest("GenericPatternMatcher does not support extra uses yet.") + def match_pattern(op, x): t1 = op.Sin(x) t2 = op.Cos(x) @@ -294,7 +300,7 @@ def apply_pattern(op, x, **_): rule = pattern.RewriteRule( match_pattern, apply_pattern, - matcher=generic_pattern.GenericPatternMatcher, + matcher=self.matcher_algo, ) model_proto = onnx.parser.parse_model( """ @@ -314,7 +320,7 @@ def apply_pattern(op, x, **_): rule.apply_to_model(ir_model) graph = ir_model.graph self.assertEqual(len(graph), 3) - self.assertEqual(graph.node[0].op_type, "SinCos") + self.assertEqual(graph.node(0).op_type, "SinCos") def test_rotary_embedding(self): # The test work on a model if it has the expected name. @@ -367,7 +373,7 @@ def apply_pattern(op, x, pos_ids, axis, **_): match_pattern, apply_pattern, validate_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -389,7 +395,8 @@ def apply_pattern(op, x, pos_ids, axis, **_): self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) out = buffer.getvalue() # TODO(Rama): What is this assertion testing? Is it to check that `verbose` is working? - self.assertIn("[GenericPatternMatcher.match", out) + if self.matcher_algo is generic_pattern.GenericPatternMatcher: + self.assertIn("[GenericPatternMatcher.match", out) def test_rotary_embedding_onnxscript(self): # The test work on a model if it has the expected name. @@ -432,7 +439,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_): rotary_match_pattern, rotary_apply_pattern, validate_rotary_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -454,7 +461,8 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_): self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) out = buffer.getvalue() # TODO(justinchuby): Remove this assert - capturing stdout is not robust - self.assertIn("[GenericPatternMatcher.match", out) + if self.matcher_algo is generic_pattern.GenericPatternMatcher: + self.assertIn("[GenericPatternMatcher.match", out) def test_rotary_emb_file_onnxscript(self): # The test work on a model if it has the expected name. @@ -504,7 +512,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis): rotary_match_pattern, rotary_apply_pattern, validate_rotary_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -561,7 +569,7 @@ def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_): transpose_transpose_pattern, transpose_transpose_apply_pattern, transpose_transpose_check, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=0, ) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 04c1ffd13..4c388c6ae 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -10,6 +10,7 @@ from typing import ( Any, Callable, + Iterable, Iterator, List, MutableSequence, @@ -665,21 +666,25 @@ def __init__( self._nodes = _nodes_in_pattern(outputs) # Check if all outputs are produced by the same node. - output_node = None - for i, value_pattern in enumerate(outputs): + output_nodes: set[NodePattern] = set() + for value_pattern in outputs: if not isinstance(value_pattern, ValuePattern): raise TypeError( f"Invalid type {type(value_pattern)} for graph pattern output." ) - if not isinstance(value_pattern, NodeOutputPattern) or ( - value_pattern.output_index != i - ): - output_node = None - elif i == 0: - output_node = value_pattern.producer() - elif value_pattern.producer() is not output_node: - output_node = None - self._output_node = output_node + if isinstance(value_pattern, Constant): + raise NotImplementedError( + "Constant values are not allowed as graph pattern outputs." + ) + if isinstance(value_pattern, NodeOutputPattern): + output_nodes.add(value_pattern.producer()) + self.output_nodes: list[NodePattern] = list(output_nodes) + + @property + def output_node(self) -> NodePattern: + if len(self.output_nodes) != 1: + raise ValueError("GraphPattern does not have unique output node.") + return self.output_nodes[0] def node(self, index: int) -> NodePattern: return self._nodes[index] @@ -706,18 +711,18 @@ def __reversed__(self) -> Iterator[NodePattern]: @property def has_single_output_node(self) -> bool: - return self._output_node is not None + return len(self.output_nodes) == 1 @property def num_outputs(self) -> int: return len(self._outputs) def commute(self) -> Sequence[GraphPattern]: - if self._output_node is None: + if not self.has_single_output_node: raise NotImplementedError( "Cannot commute a graph pattern with multiple output nodes." ) - nodes = self._output_node.commute() + nodes = self.output_node.commute() return [ GraphPattern( self._inputs, [NodeOutputPattern(n, i) for i in range(self.num_outputs)] @@ -762,15 +767,18 @@ def pattern(op, x: Var, shape1: Var, shape2: Var): return GraphPattern(pattern_inputs, pattern_outputs) -def _valid_to_replace(matched_nodes: Sequence[ir.Node]) -> bool: - """Check that values computed by the matched_nodes, except for the last one, are used only by the matched_nodes.""" +def _valid_to_replace( + matched_nodes: Sequence[ir.Node], output_values: Sequence[ir.Value] +) -> bool: + """Check that values computed by the matched_nodes, except for output_values, are used only by the matched_nodes.""" # * Must check that all values matched by pattern are used only by pattern, # except for the value that is replaced. # * Must ensure that replacement subgraph does not use any of the deleted # (intermediate) values. (Not necessary for now. Guaranteed.) - deleted_nodes = matched_nodes[:-1] - for n in deleted_nodes: + for n in matched_nodes: for v in n.outputs: + if v in output_values: + continue if v.is_graph_output(): # value is an output-value of the graph/function. return False @@ -899,7 +907,7 @@ def match( node: ir.Node, verbose: int = 0, ) -> MatchResult: - pass + """Match the pattern against the subgraph ending at the given node.""" def __str__(self) -> str: return str(self.pattern) @@ -907,9 +915,6 @@ def __str__(self) -> str: class SimplePatternMatcher(PatternMatcher): def __init__(self, pattern: GraphPattern) -> None: - assert ( - pattern.has_single_output_node - ), "SimplePatternMatcher only supports patterns with a single output node." super().__init__(pattern) def fail(self, reason: str) -> bool: @@ -1029,37 +1034,152 @@ def _match_node_output(self, pattern_value: NodeOutputPattern, value: ir.Value) ) return self._match_node(pattern_value.producer(), node) - def match( + def _init_match(self, verbose: int) -> None: + """Initialize the match state. Invoked before starting a new match.""" + self._verbose = verbose + self._matched: dict[NodePattern, ir.Node] = {} + self._match: MatchResult = MatchResult() + + def _get_output_values(self) -> list[ir.Value] | None: + """Get values bound to the output variables of the pattern.""" + output_values: list[ir.Value] = [] + unbound_values: list[str] = [] + for j, value_pattern in enumerate(self.pattern.outputs): + if value_pattern.name is not None: + if value_pattern.name in self._match.bindings: + output_values.append(self._match.bindings[value_pattern.name]) + else: + unbound_values.append(value_pattern.name) + elif isinstance(value_pattern, NodeOutputPattern): + i = value_pattern.output_index + node = value_pattern.producer() + if node in self._matched: + output_values.append(self._matched[node].outputs[i]) + else: + unbound_values.append(f"output_{j}") + elif isinstance(value_pattern, Constant): + raise NotImplementedError("Constant values as return-values not supported.") + if unbound_values: + self._match.fail(f"Error: Output values not found: {unbound_values}") + return None + return output_values + + def _match_single_output_node( self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, - verbose: int = 0, ) -> MatchResult: del model del graph_or_function - self._verbose = verbose - self._matched: dict[NodePattern, ir.Node] = {} - self._match: MatchResult = MatchResult() pattern = self.pattern match = self._match - if len(node.outputs) != pattern.num_outputs: + + if not pattern.has_single_output_node: return match.fail( - f"Number of node outputs mismatch: expected {pattern.num_outputs}, got {len(node.outputs)}." + "Internal Error: SimplePatternMatcher should not be used for patterns with multiple output nodes." ) - if pattern._output_node is None: + + if not self._match_node(pattern.output_node, node): + return match + + output_values = self._get_output_values() + if output_values is None: + return match + if not _valid_to_replace(match.nodes, output_values): + return match.fail("Matched nodes have other uses preventing replacement.") + + if len(node.outputs) != pattern.num_outputs: return match.fail( - "Internal Error: SimplePatternMatcher should not be used for patterns with multiple output nodes." + f"Number of node outputs mismatch: expected {pattern.num_outputs}, got {len(node.outputs)}." ) - if self._match_node(pattern._output_node, node): - if not _valid_to_replace(match.nodes): - return match.fail("Matched nodes have other uses preventing replacement.") + match.outputs.extend(output_values) + return match + + def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: + """Find a match for a pattern with multiple output nodes. + + For a pattern with K output nodes, the input candidate should specify K nodes + in the graph that will be matched against the pattern output nodes. - match.outputs.extend(node.outputs) + Args: + candidate: An iterable of nodes that will be matched against the pattern output nodes. + """ + match = self._match + for pattern_node, node in zip(self.pattern.output_nodes, candidate): + if not self._match_node(pattern_node, node): + return match + output_values = self._get_output_values() + if output_values is None: + return match + + if not _valid_to_replace(match.nodes, output_values): + return match.fail("Matched nodes have other uses preventing replacement.") + + match.outputs.extend(output_values) return match + def match( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + node: ir.Node, + verbose: int = 0, + ) -> MatchResult: + """Match the pattern against the subgraph ending at the given node. + + For patterns with multiple output nodes, the given node is matched + against the first output node in the pattern. For the remaining + output nodes in the pattern, we use a brute-force algorithm that + enumerates all possible combinations of nodes from the graph (with + a filter based on op-type). + + TODO: Consider omitting parameters model and graph_or_function. With + the new IR, the graph can be obtained from the node, and the model is + not used. But this is a shared abstract method of the Matcher interface, + so other matcher implementation also needs to be updated. More importantly, + matching in the presence of subgraphs (control-flow) can introduce some + complications which require careful consideration. + """ + + if self.pattern.has_single_output_node: + self._init_match(verbose) + return self._match_single_output_node(model, graph_or_function, node) + else: + # Note: This is a potentially expensive algorithm for matching patterns with + # multiple output nodes. For patterns with N output nodes, we try all possible + # combinations of N nodes from the graph, and check if they match the pattern. + # The first node is fixed to the node argument in this method call. We do + # some simple filtering by restricting the candidates for each remaining + # output nodes to graph nodes with the same op_type as the corresponding pattern + # node. For now, this is intended to be a simple, but robust, implementation + # that can be used for debugging and testing. The GenericPatternMatcher is a + # more sophisticated implementation, but incomplete. + pattern_output_nodes = self.pattern.output_nodes + op_to_nodes: dict[tuple[str, str, str], list[ir.Node]] = {} + for n in graph_or_function: + op_to_nodes.setdefault(n.op_identifier(), []).append(n) + all_nodes = iter(graph_or_function) + + def get_nodes(pattern_node): + id = pattern_node.op_identifier() + if id is None: + return all_nodes + return op_to_nodes.get(id, []) + + candidates = [iter([node])] + [get_nodes(pn) for pn in pattern_output_nodes[1:]] + match = None + for combination in itertools.product(*candidates): + self._init_match(verbose) + match = self._multi_match(combination) + if match: + return match + if match is None: + return MatchResult().fail("No match found.") + return match + class RewriteRule: def __init__(