diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index 71d650e1b..51957ff47 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -300,7 +300,13 @@ def _match_backward( # TODO(rama): Handle constant-pattern pattern_pred = pattern_value.producer() if pattern_pred is None: - # pattern_pred is None means the pattern ends here. + # pattern_pred is None means the pattern backward search ends here. + result = self._match_values_forward( + starting_node, matched, stack, graph_value, pattern_value + ) + if result is None: + return result + match_count += result continue graph_pred = graph_value.producer() if graph_pred is None: @@ -328,6 +334,158 @@ def _match_backward( print(f"[GenericPatternMatcher._match_backward] add {match_count} nodes") return match_count + def _match_values_forward( + self, + starting_node: ir.Node, + matched: dict[orp.NodePattern, ir.Node], + stack: list[orp.NodePattern], + graph_value: ir.Value, + pattern_value: orp.ValuePattern, + ) -> int | None: + """ + Matches forward. + + Args: + starting_node: root node (the node the match begins with, used only for debugging) + matched: nodes of the pattern matched as already matched + stack: next node to look into + graph_value: value coming from the graph + pattern_value: pattern value coming from the pattern + + Returns: + number of matched nodes to continue, None or False to indicate a failed match + """ + match_count = 0 + graph_node_users = [user for user, _ in graph_value.uses()] + pattern_node_users = [user for user, _ in pattern_value.uses()] + if not pattern_node_users: + # The pattern has no node forward, the matching stops. + return match_count + if len(graph_node_users) < len(pattern_node_users): + # Not enough node in the graph to match the pattern. A match is not possible + return self.none(starting_node, inspect.currentframe().f_lineno) + + # Here comes the fun part, there is the same number of successors or more + # nodes in the graph to match with the pattern. + # And we have to handle the nodes already matched as found. + # Hopefully, there is only one option. + + if len(graph_node_users) == len(pattern_node_users) == 1: + # Let's deal with the simple case + if graph_node_users[0].op_identifier() != pattern_node_users[0].op_identifier(): + return self.none(starting_node, inspect.currentframe().f_lineno) + + node = pattern_node_users[0] + if node not in matched: + if self.verbose >= 10: + print( + f"[GenericPatternMatcher._match_values_forward]{self.print_match(graph_node_users[0], pattern_node_users[0])}" + ) + matched[node] = graph_node_users[0] + stack.append(node) + match_count += 1 + return match_count + + # Let's remove the nodes already matched. + pattern_node_users_not_matched = [ + unmatched_node + for unmatched_node in pattern_node_users + if unmatched_node not in matched + ] + pattern_node_users_matched = [ + matched[matched_node] + for matched_node in pattern_node_users + if matched_node in matched + ] + assert len(pattern_node_users_matched) + len(pattern_node_users_not_matched) == len( + pattern_node_users + ), ( + f"pattern_node_users_not_matched={pattern_node_users_not_matched}, " + f"pattern_node_users_matched={pattern_node_users_matched}, " + f"pattern_node_users={pattern_node_users}, " + f"matched={matched}" + ) + free = list(set(graph_node_users) - set(pattern_node_users_matched)) + if not pattern_node_users_not_matched: + # Everything is already matched. + return match_count + if len(free) < len(pattern_node_users_not_matched): + # Not enough successors to match the remaining patterns. + return self.none(node, inspect.currentframe().f_lineno) + if len(pattern_node_users_not_matched) == len(free) == 1: + # Only one option again. + graph_node = free[0] + if pattern_node_users_not_matched[0].op_identifier() != graph_node.op_identifier(): + return self.none(node, inspect.currentframe().f_lineno) + + key = pattern_node_users_not_matched[0] + if self.verbose >= 10: + print( + f"[GenericPatternMatcher._match_values_forward] {self.print_match(graph_node, pattern_node_users_not_matched[0])}" + ) + matched[key] = graph_node + stack.append(key) + match_count += 1 + return match_count + + # And now another fun part, let's try to handle the case when + # there is only one option, matching on node type only returns one + # option. + expected_op_type = [_.op_identifier() for _ in pattern_node_users_not_matched] + got_op_type = [_.op_identifier() for _ in free] + + ec = collections.Counter(expected_op_type) + gc = collections.Counter(got_op_type) + if len(ec) != len(gc) or set(ec) != set(gc): + # unique operator types is different. + self._hint( + "FORWARD: unique operator types are different", + "-- pattern", + ec, + pattern_value, + "-- model", + gc, + graph_value, + "-- model-matched", + pattern_node_users_matched, + ) + return self.none(node, inspect.currentframe().f_lineno) + for k, v in ec.items(): + if gc[k] < v: + # Not enough types to match. + return self.none(node, inspect.currentframe().f_lineno) + + # At this stage, we know matching the types is possible. + # We first mark whatever is possible. + ptype_to_node = {_.op_identifier(): _ for _ in pattern_node_users_not_matched} + gtype_to_node = {_.op_identifier(): _ for _ in free} + missing = [] + for k, v in ec.items(): + if gc[k] == v == 1: + key = id(ptype_to_node[k]) + if key not in matched: + if self.verbose >= 10: + print( + f"[GenericPatternMatcher._match_values_forward] match " + f"{self.print_match(gtype_to_node[k], ptype_to_node[k])}" + ) + matched[key] = gtype_to_node[k] + stack.append(key) + match_count += 1 + else: + missing.append(k) + + if not missing: + return match_count + + # At this stage, there are mutiple options for matching. We can: + # 1. make assumptions and continue + # 2. mark the node as incomplete matching, we could end up stuck anyway. + raise NotImplementedError( + f"There are more than one option, this will be implemented later, " + f"ec={ec}, gc={gc}" + ) + def _match_forward( self, starting_node: ir.Node, @@ -364,141 +522,13 @@ def _match_forward( return self.none(starting_node, inspect.currentframe().f_lineno) for graph_output, pattern_output in zip(graph_node.outputs, pattern_node.outputs): - graph_node_users = [user for user, _ in graph_output.uses()] - pattern_node_users = [user for user, _ in pattern_output.uses()] - if not pattern_node_users: - # The pattern has no node forward, the matching stops. - continue - if len(graph_node_users) < len(pattern_node_users): - # Not enough node in the graph to match the pattern. A match is not possible - return self.none(starting_node, inspect.currentframe().f_lineno) - - # Here comes the fun part, there is the same number of successors or more - # nodes in the graph to match with the pattern. - # And we have to handle the nodes already matched as found. - # Hopefully, there is only one option. - - if len(graph_node_users) == len(pattern_node_users) == 1: - # Let's deal with the simple case - if ( - graph_node_users[0].op_identifier() - != pattern_node_users[0].op_identifier() - ): - return self.none(starting_node, inspect.currentframe().f_lineno) - - node = pattern_node_users[0] - if node not in matched: - if self.verbose >= 10: - print( - f"[GenericPatternMatcher._match_forward]{self.print_match(graph_node_users[0], pattern_node_users[0])}" - ) - matched[node] = graph_node_users[0] - stack.append(node) - match_count += 1 - continue - - # Let's remove the nodes already matched. - pattern_node_users_not_matched = [ - unmatched_node - for unmatched_node in pattern_node_users - if unmatched_node not in matched - ] - pattern_node_users_matched = [ - matched[matched_node] - for matched_node in pattern_node_users - if matched_node in matched - ] - assert len(pattern_node_users_matched) + len( - pattern_node_users_not_matched - ) == len(pattern_node_users), ( - f"pattern_node_users_not_matched={pattern_node_users_not_matched}, " - f"pattern_node_users_matched={pattern_node_users_matched}, " - f"pattern_node_users={pattern_node_users}, " - f"matched={matched}" + result = self._match_values_forward( + starting_node, matched, stack, graph_output, pattern_output ) - free = list(set(graph_node_users) - set(pattern_node_users_matched)) - if not pattern_node_users_not_matched: - # Everything is already matched. - continue - if len(free) < len(pattern_node_users_not_matched): - # Not enough successors to match the remaining patterns. - return self.none(node, inspect.currentframe().f_lineno) - if len(pattern_node_users_not_matched) == len(free) == 1: - # Only one option again. - graph_node = free[0] - if ( - pattern_node_users_not_matched[0].op_identifier() - != graph_node.op_identifier() - ): - return self.none(node, inspect.currentframe().f_lineno) - - key = pattern_node_users_not_matched[0] - if self.verbose >= 10: - print( - f"[GenericPatternMatcher._match_forward] {self.print_match(graph_node, pattern_node_users_not_matched[0])}" - ) - matched[key] = graph_node - stack.append(key) - match_count += 1 - continue - - # And now another fun part, let's try to handle the case when - # there is only one option, matching on node type only returns one - # option. - expected_op_type = [_.op_identifier() for _ in pattern_node_users_not_matched] - got_op_type = [_.op_identifier() for _ in free] - - ec = collections.Counter(expected_op_type) - gc = collections.Counter(got_op_type) - if len(ec) != len(gc) or set(ec) != set(gc): - # unique operator types is different. - self._hint( - "FORWARD: unique operator types are different", - "-- pattern", - ec, - pattern_node, - "-- model", - gc, - graph_node, - "-- model-matched", - pattern_node_users_matched, - ) - return self.none(node, inspect.currentframe().f_lineno) - for k, v in ec.items(): - if gc[k] < v: - # Not enough types to match. - return self.none(node, inspect.currentframe().f_lineno) - - # At this stage, we know matching the types is possible. - # We first mark whatever is possible. - ptype_to_node = {_.op_identifier(): _ for _ in pattern_node_users_not_matched} - gtype_to_node = {_.op_identifier(): _ for _ in free} - missing = [] - for k, v in ec.items(): - if gc[k] == v == 1: - key = id(ptype_to_node[k]) - if key not in matched: - if self.verbose >= 10: - print( - f"[GenericPatternMatcher._match_forward] match " - f"{self.print_match(gtype_to_node[k], ptype_to_node[k])}" - ) - matched[key] = gtype_to_node[k] - stack.append(key) - match_count += 1 - else: - missing.append(k) - - if not missing: - continue + if result is None: + return result + match_count += result - # At this stage, there are mutiple options for matching. We can: - # 1. make assumptions and continue - # 2. mark the node as incomplete matching, we could end up stuck anyway. - raise NotImplementedError( - f"There are more than one option, this will be implemented later, " - f"ec={ec}, gc={gc}" - ) if self.verbose > 5 and match_count > 0: print(f"[GenericPatternMatcher._match_forward] add {match_count} nodes") return match_count diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index 174468cda..04a7f4f69 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -9,6 +9,7 @@ import numpy as np import onnx +import onnx.parser import onnx.reference import onnxruntime as ort @@ -246,6 +247,40 @@ def get_rotary_model(self): ) return model + def test_shared_root_value_test(self): + def match_pattern(op, x): + t1 = op.Sin(x) + t2 = op.Cos(x) + return t1, t2 + + def apply_pattern(op, x, **_): + return op.SinCos(x, domain="com.microsoft", outputs=2) + + rule = pattern.RewriteRule( + match_pattern, + apply_pattern, + matcher=generic_pattern.GenericPatternMatcher, + ) + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N] y) => (float[N] z) + { + temp1 = Sin(y) + temp2 = Cos(y) + z = Add(temp1, temp2) + } + """ + ) + onnx.checker.check_model(model_proto) + model = onnx.shape_inference.infer_shapes(model_proto) + ir_model = ir.serde.deserialize_model(model) + rule.apply_to_model(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + graph = rewritten_model.graph + self.assertEqual(len(graph.node), 2) + self.assertEqual(graph.node[0].op_type, "SinCos") + def test_rotary_embedding(self): # The test work on a model if it has the expected name. # A dummy model is used if not present (not implemented yet).