Skip to content

Commit

Permalink
Fix multi-output pattern-matcher bug (#1620)
Browse files Browse the repository at this point in the history
See unit test below for example of pattern not handled by matcher.
Basically, match-forward needs to be done for "values" as well, in
addition to, "nodes".
  • Loading branch information
gramalingam authored Jun 18, 2024
1 parent 7f7fd74 commit 1108e7d
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 134 deletions.
298 changes: 164 additions & 134 deletions onnxscript/rewriter/generic_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,13 @@ def _match_backward(
# TODO(rama): Handle constant-pattern
pattern_pred = pattern_value.producer()
if pattern_pred is None:
# pattern_pred is None means the pattern ends here.
# pattern_pred is None means the pattern backward search ends here.
result = self._match_values_forward(
starting_node, matched, stack, graph_value, pattern_value
)
if result is None:
return result
match_count += result
continue
graph_pred = graph_value.producer()
if graph_pred is None:
Expand Down Expand Up @@ -328,6 +334,158 @@ def _match_backward(
print(f"[GenericPatternMatcher._match_backward] add {match_count} nodes")
return match_count

def _match_values_forward(
self,
starting_node: ir.Node,
matched: dict[orp.NodePattern, ir.Node],
stack: list[orp.NodePattern],
graph_value: ir.Value,
pattern_value: orp.ValuePattern,
) -> int | None:
"""
Matches forward.
Args:
starting_node: root node (the node the match begins with, used only for debugging)
matched: nodes of the pattern matched as already matched
stack: next node to look into
graph_value: value coming from the graph
pattern_value: pattern value coming from the pattern
Returns:
number of matched nodes to continue, None or False to indicate a failed match
"""
match_count = 0
graph_node_users = [user for user, _ in graph_value.uses()]
pattern_node_users = [user for user, _ in pattern_value.uses()]
if not pattern_node_users:
# The pattern has no node forward, the matching stops.
return match_count
if len(graph_node_users) < len(pattern_node_users):
# Not enough node in the graph to match the pattern. A match is not possible
return self.none(starting_node, inspect.currentframe().f_lineno)

# Here comes the fun part, there is the same number of successors or more
# nodes in the graph to match with the pattern.
# And we have to handle the nodes already matched as found.
# Hopefully, there is only one option.

if len(graph_node_users) == len(pattern_node_users) == 1:
# Let's deal with the simple case
if graph_node_users[0].op_identifier() != pattern_node_users[0].op_identifier():
return self.none(starting_node, inspect.currentframe().f_lineno)

node = pattern_node_users[0]
if node not in matched:
if self.verbose >= 10:
print(
f"[GenericPatternMatcher._match_values_forward]{self.print_match(graph_node_users[0], pattern_node_users[0])}"
)
matched[node] = graph_node_users[0]
stack.append(node)
match_count += 1
return match_count

# Let's remove the nodes already matched.
pattern_node_users_not_matched = [
unmatched_node
for unmatched_node in pattern_node_users
if unmatched_node not in matched
]
pattern_node_users_matched = [
matched[matched_node]
for matched_node in pattern_node_users
if matched_node in matched
]
assert len(pattern_node_users_matched) + len(pattern_node_users_not_matched) == len(
pattern_node_users
), (
f"pattern_node_users_not_matched={pattern_node_users_not_matched}, "
f"pattern_node_users_matched={pattern_node_users_matched}, "
f"pattern_node_users={pattern_node_users}, "
f"matched={matched}"
)
free = list(set(graph_node_users) - set(pattern_node_users_matched))
if not pattern_node_users_not_matched:
# Everything is already matched.
return match_count
if len(free) < len(pattern_node_users_not_matched):
# Not enough successors to match the remaining patterns.
return self.none(node, inspect.currentframe().f_lineno)
if len(pattern_node_users_not_matched) == len(free) == 1:
# Only one option again.
graph_node = free[0]
if pattern_node_users_not_matched[0].op_identifier() != graph_node.op_identifier():
return self.none(node, inspect.currentframe().f_lineno)

key = pattern_node_users_not_matched[0]
if self.verbose >= 10:
print(
f"[GenericPatternMatcher._match_values_forward] {self.print_match(graph_node, pattern_node_users_not_matched[0])}"
)
matched[key] = graph_node
stack.append(key)
match_count += 1
return match_count

# And now another fun part, let's try to handle the case when
# there is only one option, matching on node type only returns one
# option.
expected_op_type = [_.op_identifier() for _ in pattern_node_users_not_matched]
got_op_type = [_.op_identifier() for _ in free]

ec = collections.Counter(expected_op_type)
gc = collections.Counter(got_op_type)
if len(ec) != len(gc) or set(ec) != set(gc):
# unique operator types is different.
self._hint(
"FORWARD: unique operator types are different",
"-- pattern",
ec,
pattern_value,
"-- model",
gc,
graph_value,
"-- model-matched",
pattern_node_users_matched,
)
return self.none(node, inspect.currentframe().f_lineno)
for k, v in ec.items():
if gc[k] < v:
# Not enough types to match.
return self.none(node, inspect.currentframe().f_lineno)

# At this stage, we know matching the types is possible.
# We first mark whatever is possible.
ptype_to_node = {_.op_identifier(): _ for _ in pattern_node_users_not_matched}
gtype_to_node = {_.op_identifier(): _ for _ in free}
missing = []
for k, v in ec.items():
if gc[k] == v == 1:
key = id(ptype_to_node[k])
if key not in matched:
if self.verbose >= 10:
print(
f"[GenericPatternMatcher._match_values_forward] match "
f"{self.print_match(gtype_to_node[k], ptype_to_node[k])}"
)
matched[key] = gtype_to_node[k]
stack.append(key)
match_count += 1
else:
missing.append(k)

if not missing:
return match_count

# At this stage, there are mutiple options for matching. We can:
# 1. make assumptions and continue
# 2. mark the node as incomplete matching, we could end up stuck anyway.
raise NotImplementedError(
f"There are more than one option, this will be implemented later, "
f"ec={ec}, gc={gc}"
)

def _match_forward(
self,
starting_node: ir.Node,
Expand Down Expand Up @@ -364,141 +522,13 @@ def _match_forward(
return self.none(starting_node, inspect.currentframe().f_lineno)

for graph_output, pattern_output in zip(graph_node.outputs, pattern_node.outputs):
graph_node_users = [user for user, _ in graph_output.uses()]
pattern_node_users = [user for user, _ in pattern_output.uses()]
if not pattern_node_users:
# The pattern has no node forward, the matching stops.
continue
if len(graph_node_users) < len(pattern_node_users):
# Not enough node in the graph to match the pattern. A match is not possible
return self.none(starting_node, inspect.currentframe().f_lineno)

# Here comes the fun part, there is the same number of successors or more
# nodes in the graph to match with the pattern.
# And we have to handle the nodes already matched as found.
# Hopefully, there is only one option.

if len(graph_node_users) == len(pattern_node_users) == 1:
# Let's deal with the simple case
if (
graph_node_users[0].op_identifier()
!= pattern_node_users[0].op_identifier()
):
return self.none(starting_node, inspect.currentframe().f_lineno)

node = pattern_node_users[0]
if node not in matched:
if self.verbose >= 10:
print(
f"[GenericPatternMatcher._match_forward]{self.print_match(graph_node_users[0], pattern_node_users[0])}"
)
matched[node] = graph_node_users[0]
stack.append(node)
match_count += 1
continue

# Let's remove the nodes already matched.
pattern_node_users_not_matched = [
unmatched_node
for unmatched_node in pattern_node_users
if unmatched_node not in matched
]
pattern_node_users_matched = [
matched[matched_node]
for matched_node in pattern_node_users
if matched_node in matched
]
assert len(pattern_node_users_matched) + len(
pattern_node_users_not_matched
) == len(pattern_node_users), (
f"pattern_node_users_not_matched={pattern_node_users_not_matched}, "
f"pattern_node_users_matched={pattern_node_users_matched}, "
f"pattern_node_users={pattern_node_users}, "
f"matched={matched}"
result = self._match_values_forward(
starting_node, matched, stack, graph_output, pattern_output
)
free = list(set(graph_node_users) - set(pattern_node_users_matched))
if not pattern_node_users_not_matched:
# Everything is already matched.
continue
if len(free) < len(pattern_node_users_not_matched):
# Not enough successors to match the remaining patterns.
return self.none(node, inspect.currentframe().f_lineno)
if len(pattern_node_users_not_matched) == len(free) == 1:
# Only one option again.
graph_node = free[0]
if (
pattern_node_users_not_matched[0].op_identifier()
!= graph_node.op_identifier()
):
return self.none(node, inspect.currentframe().f_lineno)

key = pattern_node_users_not_matched[0]
if self.verbose >= 10:
print(
f"[GenericPatternMatcher._match_forward] {self.print_match(graph_node, pattern_node_users_not_matched[0])}"
)
matched[key] = graph_node
stack.append(key)
match_count += 1
continue

# And now another fun part, let's try to handle the case when
# there is only one option, matching on node type only returns one
# option.
expected_op_type = [_.op_identifier() for _ in pattern_node_users_not_matched]
got_op_type = [_.op_identifier() for _ in free]

ec = collections.Counter(expected_op_type)
gc = collections.Counter(got_op_type)
if len(ec) != len(gc) or set(ec) != set(gc):
# unique operator types is different.
self._hint(
"FORWARD: unique operator types are different",
"-- pattern",
ec,
pattern_node,
"-- model",
gc,
graph_node,
"-- model-matched",
pattern_node_users_matched,
)
return self.none(node, inspect.currentframe().f_lineno)
for k, v in ec.items():
if gc[k] < v:
# Not enough types to match.
return self.none(node, inspect.currentframe().f_lineno)

# At this stage, we know matching the types is possible.
# We first mark whatever is possible.
ptype_to_node = {_.op_identifier(): _ for _ in pattern_node_users_not_matched}
gtype_to_node = {_.op_identifier(): _ for _ in free}
missing = []
for k, v in ec.items():
if gc[k] == v == 1:
key = id(ptype_to_node[k])
if key not in matched:
if self.verbose >= 10:
print(
f"[GenericPatternMatcher._match_forward] match "
f"{self.print_match(gtype_to_node[k], ptype_to_node[k])}"
)
matched[key] = gtype_to_node[k]
stack.append(key)
match_count += 1
else:
missing.append(k)

if not missing:
continue
if result is None:
return result
match_count += result

# At this stage, there are mutiple options for matching. We can:
# 1. make assumptions and continue
# 2. mark the node as incomplete matching, we could end up stuck anyway.
raise NotImplementedError(
f"There are more than one option, this will be implemented later, "
f"ec={ec}, gc={gc}"
)
if self.verbose > 5 and match_count > 0:
print(f"[GenericPatternMatcher._match_forward] add {match_count} nodes")
return match_count
Expand Down
35 changes: 35 additions & 0 deletions onnxscript/rewriter/generic_pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import onnx
import onnx.parser
import onnx.reference
import onnxruntime as ort

Expand Down Expand Up @@ -246,6 +247,40 @@ def get_rotary_model(self):
)
return model

def test_shared_root_value_test(self):
def match_pattern(op, x):
t1 = op.Sin(x)
t2 = op.Cos(x)
return t1, t2

def apply_pattern(op, x, **_):
return op.SinCos(x, domain="com.microsoft", outputs=2)

rule = pattern.RewriteRule(
match_pattern,
apply_pattern,
matcher=generic_pattern.GenericPatternMatcher,
)
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] y) => (float[N] z)
{
temp1 = Sin(y)
temp2 = Cos(y)
z = Add(temp1, temp2)
}
"""
)
onnx.checker.check_model(model_proto)
model = onnx.shape_inference.infer_shapes(model_proto)
ir_model = ir.serde.deserialize_model(model)
rule.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)
graph = rewritten_model.graph
self.assertEqual(len(graph.node), 2)
self.assertEqual(graph.node[0].op_type, "SinCos")

def test_rotary_embedding(self):
# The test work on a model if it has the expected name.
# A dummy model is used if not present (not implemented yet).
Expand Down

0 comments on commit 1108e7d

Please sign in to comment.