Skip to content

Commit

Permalink
Extend basic matcher to handle multiple-output-nodes (#1734)
Browse files Browse the repository at this point in the history
This PR extends the basic matcher to handle multiple output nodes. This
provides an alternative to the generic-matcher algorithm, which is
incomplete and fails in some circumstances. This can also be useful in
debugging match-failures (when it is unclear if the failure is valid or
due to limitations of the matching algorithm). The drawback is that this
algorithm can, in some cases, be expensive, especially when the number
of output-nodes is large and the graph size is large. (So far, however,
we haven't encountered patterns with more than 2 output-nodes.)
  • Loading branch information
gramalingam authored Jul 25, 2024
1 parent 937558f commit 19f1126
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 56 deletions.
10 changes: 5 additions & 5 deletions onnxscript/rewriter/generic_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _match_backward(
return self.none(starting_node, inspect.currentframe().f_lineno)

for graph_input, pattern_input in zip(graph_node.inputs, pattern_node.inputs):
if len(list(graph_input.uses())) != len(list(pattern_input.uses())):
if len(graph_input.uses()) != len(pattern_input.uses()):
self._hint(
"BACKWARD: one input is used outside the pattern",
"-- pattern",
Expand Down Expand Up @@ -423,12 +423,12 @@ def _match_values_forward(
return match_count
if len(free) < len(pattern_node_users_not_matched):
# Not enough successors to match the remaining patterns.
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)
if len(pattern_node_users_not_matched) == len(free) == 1:
# Only one option again.
graph_node = free[0]
if pattern_node_users_not_matched[0].op_identifier() != graph_node.op_identifier():
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)

key = pattern_node_users_not_matched[0]
if self.verbose >= 10:
Expand Down Expand Up @@ -461,11 +461,11 @@ def _match_values_forward(
"-- model-matched",
pattern_node_users_matched,
)
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)
for k, v in ec.items():
if gc[k] < v:
# Not enough types to match.
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)

# At this stage, we know matching the types is possible.
# We first mark whatever is possible.
Expand Down
40 changes: 24 additions & 16 deletions onnxscript/rewriter/generic_pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,21 @@
import onnx.parser
import onnx.reference
import onnxruntime as ort
import parameterized

from onnxscript import ir
from onnxscript.rewriter import generic_pattern, pattern

FLOAT = onnx.TensorProto.FLOAT


@parameterized.parameterized_class(
("matcher_algo",),
[
(generic_pattern.GenericPatternMatcher,),
(pattern.SimplePatternMatcher,),
],
)
class GenericPatternTest(unittest.TestCase):
def _range(self, *shape, bias: float | None = None):
n = np.prod(shape)
Expand Down Expand Up @@ -48,7 +56,7 @@ def validate_mapping(context, x, y, z, **_) -> bool:
match_pattern,
apply_pattern,
validate_mapping,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
)

class AddAdd(onnx.reference.op_run.OpRun):
Expand Down Expand Up @@ -128,7 +136,7 @@ def validate_mapping(context, **_) -> bool:
match_pattern,
apply_pattern,
validate_mapping,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
verbose=10,
)

Expand Down Expand Up @@ -256,11 +264,7 @@ def match_pattern(op, x):
def apply_pattern(op, x, **_):
return op.SinCos(x, domain="com.microsoft", outputs=2)

rule = pattern.RewriteRule(
match_pattern,
apply_pattern,
matcher=generic_pattern.GenericPatternMatcher,
)
rule = pattern.RewriteRule(match_pattern, apply_pattern, matcher=self.matcher_algo)
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
Expand All @@ -281,8 +285,10 @@ def apply_pattern(op, x, **_):
self.assertEqual(len(graph.node), 2)
self.assertEqual(graph.node[0].op_type, "SinCos")

@unittest.skip("Input variable reuse not supported yet")
def test_shared_root_value_extra_use(self):
if self.matcher_algo is generic_pattern.GenericPatternMatcher:
raise unittest.SkipTest("GenericPatternMatcher does not support extra uses yet.")

def match_pattern(op, x):
t1 = op.Sin(x)
t2 = op.Cos(x)
Expand All @@ -294,7 +300,7 @@ def apply_pattern(op, x, **_):
rule = pattern.RewriteRule(
match_pattern,
apply_pattern,
matcher=generic_pattern.GenericPatternMatcher,
matcher=self.matcher_algo,
)
model_proto = onnx.parser.parse_model(
"""
Expand All @@ -314,7 +320,7 @@ def apply_pattern(op, x, **_):
rule.apply_to_model(ir_model)
graph = ir_model.graph
self.assertEqual(len(graph), 3)
self.assertEqual(graph.node[0].op_type, "SinCos")
self.assertEqual(graph.node(0).op_type, "SinCos")

def test_rotary_embedding(self):
# The test work on a model if it has the expected name.
Expand Down Expand Up @@ -367,7 +373,7 @@ def apply_pattern(op, x, pos_ids, axis, **_):
match_pattern,
apply_pattern,
validate_mapping,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
verbose=10,
)

Expand All @@ -389,7 +395,8 @@ def apply_pattern(op, x, pos_ids, axis, **_):
self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node])
out = buffer.getvalue()
# TODO(Rama): What is this assertion testing? Is it to check that `verbose` is working?
self.assertIn("[GenericPatternMatcher.match", out)
if self.matcher_algo is generic_pattern.GenericPatternMatcher:
self.assertIn("[GenericPatternMatcher.match", out)

def test_rotary_embedding_onnxscript(self):
# The test work on a model if it has the expected name.
Expand Down Expand Up @@ -432,7 +439,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_):
rotary_match_pattern,
rotary_apply_pattern,
validate_rotary_mapping,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
verbose=10,
)

Expand All @@ -454,7 +461,8 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_):
self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node])
out = buffer.getvalue()
# TODO(justinchuby): Remove this assert - capturing stdout is not robust
self.assertIn("[GenericPatternMatcher.match", out)
if self.matcher_algo is generic_pattern.GenericPatternMatcher:
self.assertIn("[GenericPatternMatcher.match", out)

def test_rotary_emb_file_onnxscript(self):
# The test work on a model if it has the expected name.
Expand Down Expand Up @@ -504,7 +512,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
rotary_match_pattern,
rotary_apply_pattern,
validate_rotary_mapping,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
verbose=10,
)

Expand Down Expand Up @@ -561,7 +569,7 @@ def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_):
transpose_transpose_pattern,
transpose_transpose_apply_pattern,
transpose_transpose_check,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
verbose=0,
)

Expand Down
Loading

0 comments on commit 19f1126

Please sign in to comment.