Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend basic matcher to handle multiple-output-nodes #1734

Merged
merged 14 commits into from
Jul 25, 2024
Merged
10 changes: 5 additions & 5 deletions onnxscript/rewriter/generic_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@
return self.none(starting_node, inspect.currentframe().f_lineno)

for graph_input, pattern_input in zip(graph_node.inputs, pattern_node.inputs):
if len(list(graph_input.uses())) != len(list(pattern_input.uses())):
if len(graph_input.uses()) != len(pattern_input.uses()):
self._hint(
"BACKWARD: one input is used outside the pattern",
"-- pattern",
Expand Down Expand Up @@ -423,12 +423,12 @@
return match_count
if len(free) < len(pattern_node_users_not_matched):
# Not enough successors to match the remaining patterns.
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)

Check warning on line 426 in onnxscript/rewriter/generic_pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/generic_pattern.py#L426

Added line #L426 was not covered by tests
if len(pattern_node_users_not_matched) == len(free) == 1:
# Only one option again.
graph_node = free[0]
if pattern_node_users_not_matched[0].op_identifier() != graph_node.op_identifier():
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)

Check warning on line 431 in onnxscript/rewriter/generic_pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/generic_pattern.py#L431

Added line #L431 was not covered by tests

key = pattern_node_users_not_matched[0]
if self.verbose >= 10:
Expand Down Expand Up @@ -461,11 +461,11 @@
"-- model-matched",
pattern_node_users_matched,
)
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)

Check warning on line 464 in onnxscript/rewriter/generic_pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/generic_pattern.py#L464

Added line #L464 was not covered by tests
for k, v in ec.items():
if gc[k] < v:
# Not enough types to match.
return self.none(node, inspect.currentframe().f_lineno)
return self.none(starting_node, inspect.currentframe().f_lineno)

Check warning on line 468 in onnxscript/rewriter/generic_pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/generic_pattern.py#L468

Added line #L468 was not covered by tests

# At this stage, we know matching the types is possible.
# We first mark whatever is possible.
Expand Down
40 changes: 24 additions & 16 deletions onnxscript/rewriter/generic_pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,21 @@
import onnx.parser
import onnx.reference
import onnxruntime as ort
import parameterized

from onnxscript import ir
from onnxscript.rewriter import generic_pattern, pattern

FLOAT = onnx.TensorProto.FLOAT


@parameterized.parameterized_class(
("matcher_algo",),
[
(generic_pattern.GenericPatternMatcher,),
(pattern.SimplePatternMatcher,),
],
)
class GenericPatternTest(unittest.TestCase):
def _range(self, *shape, bias: float | None = None):
n = np.prod(shape)
Expand Down Expand Up @@ -48,7 +56,7 @@ def validate_mapping(context, x, y, z, **_) -> bool:
match_pattern,
apply_pattern,
validate_mapping,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
)

class AddAdd(onnx.reference.op_run.OpRun):
Expand Down Expand Up @@ -128,7 +136,7 @@ def validate_mapping(context, **_) -> bool:
match_pattern,
apply_pattern,
validate_mapping,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
verbose=10,
)

Expand Down Expand Up @@ -256,11 +264,7 @@ def match_pattern(op, x):
def apply_pattern(op, x, **_):
return op.SinCos(x, domain="com.microsoft", outputs=2)

rule = pattern.RewriteRule(
match_pattern,
apply_pattern,
matcher=generic_pattern.GenericPatternMatcher,
)
rule = pattern.RewriteRule(match_pattern, apply_pattern, matcher=self.matcher_algo)
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
Expand All @@ -281,8 +285,10 @@ def apply_pattern(op, x, **_):
self.assertEqual(len(graph.node), 2)
self.assertEqual(graph.node[0].op_type, "SinCos")

@unittest.skip("Input variable reuse not supported yet")
def test_shared_root_value_extra_use(self):
if self.matcher_algo is generic_pattern.GenericPatternMatcher:
raise unittest.SkipTest("GenericPatternMatcher does not support extra uses yet.")

def match_pattern(op, x):
t1 = op.Sin(x)
t2 = op.Cos(x)
Expand All @@ -294,7 +300,7 @@ def apply_pattern(op, x, **_):
rule = pattern.RewriteRule(
match_pattern,
apply_pattern,
matcher=generic_pattern.GenericPatternMatcher,
matcher=self.matcher_algo,
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
)
model_proto = onnx.parser.parse_model(
"""
Expand All @@ -314,7 +320,7 @@ def apply_pattern(op, x, **_):
rule.apply_to_model(ir_model)
graph = ir_model.graph
self.assertEqual(len(graph), 3)
self.assertEqual(graph.node[0].op_type, "SinCos")
self.assertEqual(graph.node(0).op_type, "SinCos")

def test_rotary_embedding(self):
# The test work on a model if it has the expected name.
Expand Down Expand Up @@ -367,7 +373,7 @@ def apply_pattern(op, x, pos_ids, axis, **_):
match_pattern,
apply_pattern,
validate_mapping,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
verbose=10,
)

Expand All @@ -389,7 +395,8 @@ def apply_pattern(op, x, pos_ids, axis, **_):
self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node])
out = buffer.getvalue()
# TODO(Rama): What is this assertion testing? Is it to check that `verbose` is working?
self.assertIn("[GenericPatternMatcher.match", out)
if self.matcher_algo is generic_pattern.GenericPatternMatcher:
self.assertIn("[GenericPatternMatcher.match", out)

def test_rotary_embedding_onnxscript(self):
# The test work on a model if it has the expected name.
Expand Down Expand Up @@ -432,7 +439,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_):
rotary_match_pattern,
rotary_apply_pattern,
validate_rotary_mapping,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
verbose=10,
)

Expand All @@ -454,7 +461,8 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_):
self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node])
out = buffer.getvalue()
# TODO(justinchuby): Remove this assert - capturing stdout is not robust
self.assertIn("[GenericPatternMatcher.match", out)
if self.matcher_algo is generic_pattern.GenericPatternMatcher:
self.assertIn("[GenericPatternMatcher.match", out)

def test_rotary_emb_file_onnxscript(self):
# The test work on a model if it has the expected name.
Expand Down Expand Up @@ -504,7 +512,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
rotary_match_pattern,
rotary_apply_pattern,
validate_rotary_mapping,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
verbose=10,
)

Expand Down Expand Up @@ -561,7 +569,7 @@ def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_):
transpose_transpose_pattern,
transpose_transpose_apply_pattern,
transpose_transpose_check,
generic_pattern.GenericPatternMatcher,
self.matcher_algo,
verbose=0,
)

Expand Down
Loading
Loading