diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index d0daf2e06..ac7c4c892 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -298,7 +298,15 @@ def _match_backward( return self.none(starting_node, inspect.currentframe().f_lineno) for graph_input, pattern_input in zip(graph_node.inputs, pattern_node.inputs): - if len(list(graph_input.uses())) != len(list(pattern_input.uses())): + # Intermediate values in the pattern must have the same number of uses + # in the graph for a valid match. By design, this esnures patterns where + # intermediate values have extra uses in the graph do NOT match. + # However, pattern-input-values may have extra uses. This is because + # pattern-inputs will not be removed by pattern-replacement, but + # intermediate values will be removed. + if pattern_input.producer() is not None and len(list(graph_input.uses())) != len( + list(pattern_input.uses()) + ): self._hint( "BACKWARD: one input is used outside the pattern", "-- pattern", @@ -423,12 +431,12 @@ def _match_values_forward( return match_count if len(free) < len(pattern_node_users_not_matched): # Not enough successors to match the remaining patterns. - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) if len(pattern_node_users_not_matched) == len(free) == 1: # Only one option again. graph_node = free[0] if pattern_node_users_not_matched[0].op_identifier() != graph_node.op_identifier(): - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) key = pattern_node_users_not_matched[0] if self.verbose >= 10: @@ -461,11 +469,11 @@ def _match_values_forward( "-- model-matched", pattern_node_users_matched, ) - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) for k, v in ec.items(): if gc[k] < v: # Not enough types to match. - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) # At this stage, we know matching the types is possible. # We first mark whatever is possible.