Skip to content

Commit

Permalink
Address ruff warning
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Jul 18, 2024
1 parent 56fe8e9 commit 35c3b40
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions onnxscript/rewriter/generic_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,15 @@ def _match_backward(
return self.none(starting_node, inspect.currentframe().f_lineno)

for graph_input, pattern_input in zip(graph_node.inputs, pattern_node.inputs):
if len(list(graph_input.uses())) != len(list(pattern_input.uses())):
# Intermediate values in the pattern must have the same number of uses
# in the graph for a valid match. By design, this esnures patterns where
# intermediate values have extra uses in the graph do NOT match.
# However, pattern-input-values may have extra uses. This is because
# pattern-inputs will not be removed by pattern-replacement, but
# intermediate values will be removed.
if pattern_input.producer() is not None and len(list(graph_input.uses())) != len(
list(pattern_input.uses())
):
self._hint(
"BACKWARD: one input is used outside the pattern",
"-- pattern",
Expand Down Expand Up @@ -423,12 +431,12 @@ def _match_values_forward(
return match_count
if len(free) < len(pattern_node_users_not_matched):
# Not enough successors to match the remaining patterns.
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)

Check warning on line 434 in onnxscript/rewriter/generic_pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/generic_pattern.py#L434

Added line #L434 was not covered by tests
if len(pattern_node_users_not_matched) == len(free) == 1:
# Only one option again.
graph_node = free[0]
if pattern_node_users_not_matched[0].op_identifier() != graph_node.op_identifier():
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)

Check warning on line 439 in onnxscript/rewriter/generic_pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/generic_pattern.py#L439

Added line #L439 was not covered by tests

key = pattern_node_users_not_matched[0]
if self.verbose >= 10:
Expand Down Expand Up @@ -461,11 +469,11 @@ def _match_values_forward(
"-- model-matched",
pattern_node_users_matched,
)
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)

Check warning on line 472 in onnxscript/rewriter/generic_pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/generic_pattern.py#L472

Added line #L472 was not covered by tests
for k, v in ec.items():
if gc[k] < v:
# Not enough types to match.
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)

Check warning on line 476 in onnxscript/rewriter/generic_pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/generic_pattern.py#L476

Added line #L476 was not covered by tests

# At this stage, we know matching the types is possible.
# We first mark whatever is possible.
Expand Down

0 comments on commit 35c3b40

Please sign in to comment.