diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index d0daf2e06..2926f5964 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -298,7 +298,7 @@ def _match_backward( return self.none(starting_node, inspect.currentframe().f_lineno) for graph_input, pattern_input in zip(graph_node.inputs, pattern_node.inputs): - if len(list(graph_input.uses())) != len(list(pattern_input.uses())): + if len(graph_input.uses()) != len(pattern_input.uses()): self._hint( "BACKWARD: one input is used outside the pattern", "-- pattern", @@ -423,12 +423,12 @@ def _match_values_forward( return match_count if len(free) < len(pattern_node_users_not_matched): # Not enough successors to match the remaining patterns. - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) if len(pattern_node_users_not_matched) == len(free) == 1: # Only one option again. graph_node = free[0] if pattern_node_users_not_matched[0].op_identifier() != graph_node.op_identifier(): - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) key = pattern_node_users_not_matched[0] if self.verbose >= 10: @@ -461,11 +461,11 @@ def _match_values_forward( "-- model-matched", pattern_node_users_matched, ) - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) for k, v in ec.items(): if gc[k] < v: # Not enough types to match. - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) # At this stage, we know matching the types is possible. # We first mark whatever is possible. diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index d65f01c8d..db0e2a638 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -12,6 +12,7 @@ import onnx.parser import onnx.reference import onnxruntime as ort +import parameterized from onnxscript import ir from onnxscript.rewriter import generic_pattern, pattern @@ -19,6 +20,13 @@ FLOAT = onnx.TensorProto.FLOAT +@parameterized.parameterized_class( + ("matcher_algo",), + [ + (generic_pattern.GenericPatternMatcher,), + (pattern.SimplePatternMatcher,), + ], +) class GenericPatternTest(unittest.TestCase): def _range(self, *shape, bias: float | None = None): n = np.prod(shape) @@ -48,7 +56,7 @@ def validate_mapping(context, x, y, z, **_) -> bool: match_pattern, apply_pattern, validate_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, ) class AddAdd(onnx.reference.op_run.OpRun): @@ -128,7 +136,7 @@ def validate_mapping(context, **_) -> bool: match_pattern, apply_pattern, validate_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -256,11 +264,7 @@ def match_pattern(op, x): def apply_pattern(op, x, **_): return op.SinCos(x, domain="com.microsoft", outputs=2) - rule = pattern.RewriteRule( - match_pattern, - apply_pattern, - matcher=generic_pattern.GenericPatternMatcher, - ) + rule = pattern.RewriteRule(match_pattern, apply_pattern, matcher=self.matcher_algo) model_proto = onnx.parser.parse_model( """ @@ -281,8 +285,10 @@ def apply_pattern(op, x, **_): self.assertEqual(len(graph.node), 2) self.assertEqual(graph.node[0].op_type, "SinCos") - @unittest.skip("Input variable reuse not supported yet") def test_shared_root_value_extra_use(self): + if self.matcher_algo is generic_pattern.GenericPatternMatcher: + raise unittest.SkipTest("GenericPatternMatcher does not support extra uses yet.") + def match_pattern(op, x): t1 = op.Sin(x) t2 = op.Cos(x) @@ -294,7 +300,7 @@ def apply_pattern(op, x, **_): rule = pattern.RewriteRule( match_pattern, apply_pattern, - matcher=generic_pattern.GenericPatternMatcher, + matcher=self.matcher_algo, ) model_proto = onnx.parser.parse_model( """ @@ -314,7 +320,7 @@ def apply_pattern(op, x, **_): rule.apply_to_model(ir_model) graph = ir_model.graph self.assertEqual(len(graph), 3) - self.assertEqual(graph.node[0].op_type, "SinCos") + self.assertEqual(graph.node(0).op_type, "SinCos") def test_rotary_embedding(self): # The test work on a model if it has the expected name. @@ -367,7 +373,7 @@ def apply_pattern(op, x, pos_ids, axis, **_): match_pattern, apply_pattern, validate_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -389,7 +395,8 @@ def apply_pattern(op, x, pos_ids, axis, **_): self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) out = buffer.getvalue() # TODO(Rama): What is this assertion testing? Is it to check that `verbose` is working? - self.assertIn("[GenericPatternMatcher.match", out) + if self.matcher_algo is generic_pattern.GenericPatternMatcher: + self.assertIn("[GenericPatternMatcher.match", out) def test_rotary_embedding_onnxscript(self): # The test work on a model if it has the expected name. @@ -432,7 +439,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_): rotary_match_pattern, rotary_apply_pattern, validate_rotary_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -454,7 +461,8 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_): self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) out = buffer.getvalue() # TODO(justinchuby): Remove this assert - capturing stdout is not robust - self.assertIn("[GenericPatternMatcher.match", out) + if self.matcher_algo is generic_pattern.GenericPatternMatcher: + self.assertIn("[GenericPatternMatcher.match", out) def test_rotary_emb_file_onnxscript(self): # The test work on a model if it has the expected name. @@ -504,7 +512,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis): rotary_match_pattern, rotary_apply_pattern, validate_rotary_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -561,7 +569,7 @@ def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_): transpose_transpose_pattern, transpose_transpose_apply_pattern, transpose_transpose_check, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=0, ) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 04c1ffd13..4c388c6ae 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -10,6 +10,7 @@ from typing import ( Any, Callable, + Iterable, Iterator, List, MutableSequence, @@ -665,21 +666,25 @@ def __init__( self._nodes = _nodes_in_pattern(outputs) # Check if all outputs are produced by the same node. - output_node = None - for i, value_pattern in enumerate(outputs): + output_nodes: set[NodePattern] = set() + for value_pattern in outputs: if not isinstance(value_pattern, ValuePattern): raise TypeError( f"Invalid type {type(value_pattern)} for graph pattern output." ) - if not isinstance(value_pattern, NodeOutputPattern) or ( - value_pattern.output_index != i - ): - output_node = None - elif i == 0: - output_node = value_pattern.producer() - elif value_pattern.producer() is not output_node: - output_node = None - self._output_node = output_node + if isinstance(value_pattern, Constant): + raise NotImplementedError( + "Constant values are not allowed as graph pattern outputs." + ) + if isinstance(value_pattern, NodeOutputPattern): + output_nodes.add(value_pattern.producer()) + self.output_nodes: list[NodePattern] = list(output_nodes) + + @property + def output_node(self) -> NodePattern: + if len(self.output_nodes) != 1: + raise ValueError("GraphPattern does not have unique output node.") + return self.output_nodes[0] def node(self, index: int) -> NodePattern: return self._nodes[index] @@ -706,18 +711,18 @@ def __reversed__(self) -> Iterator[NodePattern]: @property def has_single_output_node(self) -> bool: - return self._output_node is not None + return len(self.output_nodes) == 1 @property def num_outputs(self) -> int: return len(self._outputs) def commute(self) -> Sequence[GraphPattern]: - if self._output_node is None: + if not self.has_single_output_node: raise NotImplementedError( "Cannot commute a graph pattern with multiple output nodes." ) - nodes = self._output_node.commute() + nodes = self.output_node.commute() return [ GraphPattern( self._inputs, [NodeOutputPattern(n, i) for i in range(self.num_outputs)] @@ -762,15 +767,18 @@ def pattern(op, x: Var, shape1: Var, shape2: Var): return GraphPattern(pattern_inputs, pattern_outputs) -def _valid_to_replace(matched_nodes: Sequence[ir.Node]) -> bool: - """Check that values computed by the matched_nodes, except for the last one, are used only by the matched_nodes.""" +def _valid_to_replace( + matched_nodes: Sequence[ir.Node], output_values: Sequence[ir.Value] +) -> bool: + """Check that values computed by the matched_nodes, except for output_values, are used only by the matched_nodes.""" # * Must check that all values matched by pattern are used only by pattern, # except for the value that is replaced. # * Must ensure that replacement subgraph does not use any of the deleted # (intermediate) values. (Not necessary for now. Guaranteed.) - deleted_nodes = matched_nodes[:-1] - for n in deleted_nodes: + for n in matched_nodes: for v in n.outputs: + if v in output_values: + continue if v.is_graph_output(): # value is an output-value of the graph/function. return False @@ -899,7 +907,7 @@ def match( node: ir.Node, verbose: int = 0, ) -> MatchResult: - pass + """Match the pattern against the subgraph ending at the given node.""" def __str__(self) -> str: return str(self.pattern) @@ -907,9 +915,6 @@ def __str__(self) -> str: class SimplePatternMatcher(PatternMatcher): def __init__(self, pattern: GraphPattern) -> None: - assert ( - pattern.has_single_output_node - ), "SimplePatternMatcher only supports patterns with a single output node." super().__init__(pattern) def fail(self, reason: str) -> bool: @@ -1029,37 +1034,152 @@ def _match_node_output(self, pattern_value: NodeOutputPattern, value: ir.Value) ) return self._match_node(pattern_value.producer(), node) - def match( + def _init_match(self, verbose: int) -> None: + """Initialize the match state. Invoked before starting a new match.""" + self._verbose = verbose + self._matched: dict[NodePattern, ir.Node] = {} + self._match: MatchResult = MatchResult() + + def _get_output_values(self) -> list[ir.Value] | None: + """Get values bound to the output variables of the pattern.""" + output_values: list[ir.Value] = [] + unbound_values: list[str] = [] + for j, value_pattern in enumerate(self.pattern.outputs): + if value_pattern.name is not None: + if value_pattern.name in self._match.bindings: + output_values.append(self._match.bindings[value_pattern.name]) + else: + unbound_values.append(value_pattern.name) + elif isinstance(value_pattern, NodeOutputPattern): + i = value_pattern.output_index + node = value_pattern.producer() + if node in self._matched: + output_values.append(self._matched[node].outputs[i]) + else: + unbound_values.append(f"output_{j}") + elif isinstance(value_pattern, Constant): + raise NotImplementedError("Constant values as return-values not supported.") + if unbound_values: + self._match.fail(f"Error: Output values not found: {unbound_values}") + return None + return output_values + + def _match_single_output_node( self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, - verbose: int = 0, ) -> MatchResult: del model del graph_or_function - self._verbose = verbose - self._matched: dict[NodePattern, ir.Node] = {} - self._match: MatchResult = MatchResult() pattern = self.pattern match = self._match - if len(node.outputs) != pattern.num_outputs: + + if not pattern.has_single_output_node: return match.fail( - f"Number of node outputs mismatch: expected {pattern.num_outputs}, got {len(node.outputs)}." + "Internal Error: SimplePatternMatcher should not be used for patterns with multiple output nodes." ) - if pattern._output_node is None: + + if not self._match_node(pattern.output_node, node): + return match + + output_values = self._get_output_values() + if output_values is None: + return match + if not _valid_to_replace(match.nodes, output_values): + return match.fail("Matched nodes have other uses preventing replacement.") + + if len(node.outputs) != pattern.num_outputs: return match.fail( - "Internal Error: SimplePatternMatcher should not be used for patterns with multiple output nodes." + f"Number of node outputs mismatch: expected {pattern.num_outputs}, got {len(node.outputs)}." ) - if self._match_node(pattern._output_node, node): - if not _valid_to_replace(match.nodes): - return match.fail("Matched nodes have other uses preventing replacement.") + match.outputs.extend(output_values) + return match + + def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: + """Find a match for a pattern with multiple output nodes. + + For a pattern with K output nodes, the input candidate should specify K nodes + in the graph that will be matched against the pattern output nodes. - match.outputs.extend(node.outputs) + Args: + candidate: An iterable of nodes that will be matched against the pattern output nodes. + """ + match = self._match + for pattern_node, node in zip(self.pattern.output_nodes, candidate): + if not self._match_node(pattern_node, node): + return match + output_values = self._get_output_values() + if output_values is None: + return match + + if not _valid_to_replace(match.nodes, output_values): + return match.fail("Matched nodes have other uses preventing replacement.") + + match.outputs.extend(output_values) return match + def match( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + node: ir.Node, + verbose: int = 0, + ) -> MatchResult: + """Match the pattern against the subgraph ending at the given node. + + For patterns with multiple output nodes, the given node is matched + against the first output node in the pattern. For the remaining + output nodes in the pattern, we use a brute-force algorithm that + enumerates all possible combinations of nodes from the graph (with + a filter based on op-type). + + TODO: Consider omitting parameters model and graph_or_function. With + the new IR, the graph can be obtained from the node, and the model is + not used. But this is a shared abstract method of the Matcher interface, + so other matcher implementation also needs to be updated. More importantly, + matching in the presence of subgraphs (control-flow) can introduce some + complications which require careful consideration. + """ + + if self.pattern.has_single_output_node: + self._init_match(verbose) + return self._match_single_output_node(model, graph_or_function, node) + else: + # Note: This is a potentially expensive algorithm for matching patterns with + # multiple output nodes. For patterns with N output nodes, we try all possible + # combinations of N nodes from the graph, and check if they match the pattern. + # The first node is fixed to the node argument in this method call. We do + # some simple filtering by restricting the candidates for each remaining + # output nodes to graph nodes with the same op_type as the corresponding pattern + # node. For now, this is intended to be a simple, but robust, implementation + # that can be used for debugging and testing. The GenericPatternMatcher is a + # more sophisticated implementation, but incomplete. + pattern_output_nodes = self.pattern.output_nodes + op_to_nodes: dict[tuple[str, str, str], list[ir.Node]] = {} + for n in graph_or_function: + op_to_nodes.setdefault(n.op_identifier(), []).append(n) + all_nodes = iter(graph_or_function) + + def get_nodes(pattern_node): + id = pattern_node.op_identifier() + if id is None: + return all_nodes + return op_to_nodes.get(id, []) + + candidates = [iter([node])] + [get_nodes(pn) for pn in pattern_output_nodes[1:]] + match = None + for combination in itertools.product(*candidates): + self._init_match(verbose) + match = self._multi_match(combination) + if match: + return match + if match is None: + return MatchResult().fail("No match found.") + return match + class RewriteRule: def __init__(