Skip to content

Commit

Permalink
Support optional attribute checking in matcher (#1629)
Browse files Browse the repository at this point in the history
Extend matcher to allow users to specify whether all attributes must be
exactly as in pattern. Change default-value to allow extra-attributes in
actual node, not specified in pattern.
  • Loading branch information
gramalingam authored Jul 2, 2024
1 parent 3244e92 commit e824285
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 9 deletions.
3 changes: 2 additions & 1 deletion onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def convert_attributes(
"""
attributes: list[_core.Attr | _core.RefAttr] = []
for name, attr in attrs.items():
attributes.append(convert_attribute(name, attr))
if attr is not None:
attributes.append(convert_attribute(name, attr))
return attributes


Expand Down
51 changes: 43 additions & 8 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __call__(
domain: str | None = None,
version: int | None = None,
outputs: int | list[str | None] = 1,
_allow_other_attributes: bool | None = None,
**kwargs,
):
if version is not None:
Expand All @@ -228,7 +229,9 @@ def __call__(
raise ValueError("outputs must be an int or a list[str|None].")
inputs = [_to_value_pattern(x) for x in args]
attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()}
node_pattern = NodePattern(opset_pattern, self.op_name, inputs, attributes, outputs)
node_pattern = NodePattern(
opset_pattern, self.op_name, inputs, attributes, outputs, _allow_other_attributes
)
output_values = node_pattern.outputs
# Unpack outputs if there is only one output, the common case.
if len(output_values) == 1:
Expand Down Expand Up @@ -424,6 +427,15 @@ class NodePattern:
This differs from a NodeOutputPattern in that it matches against a node (which
may produce 1 or more outputs), whereas a NodeOutputPattern matches against
a specific output of a node.
Args:
domain: pattern to match against the domain of the node.
op: pattern or string constant to match against the op_type of the node.
inputs: sequence of ValuePatterns (or constants) to match against the inputs of the node.
attributes: dictionary of attribute patterns to match against the attributes of the node.
outputs: specifies pattern-variable-name for outputs (or None)
allow_other_attributes: specifies whether other attributes (not mentioned in `attributes`)
are allowed in the node.
"""

def __init__(
Expand All @@ -433,11 +445,16 @@ def __init__(
inputs: Sequence[int | float | ValuePattern | None],
attributes: dict[str, AttrPattern],
outputs: Sequence[str | None],
allow_other_attributes: bool | None,
):
if allow_other_attributes is None:
# Default behavior: allow other unmatched attributes in the node.
allow_other_attributes = True
self.domain = domain
self.op = StringConstantPattern(op) if isinstance(op, str) else op
self.inputs = [_to_value_pattern(x) for x in inputs]
self.attributes = attributes
self.allow_other_attributes = allow_other_attributes
# In the common case, domain and op are constants, which can be used to optimize matching.
if isinstance(op, str) and domain.domain_name is not None:
# TODO(rama): support overloaded operators.
Expand Down Expand Up @@ -497,10 +514,11 @@ def matches(self, node: ir.Node, match: MatchResult) -> MatchResult:
if not match.bind(attr_pattern.name, attr_value):
return match

for name in node.attributes:
# TODO: Support matching default nodes for attributes.
if name not in self.attributes:
return match.fail(f"Attribute {name} not expected in node.")
if not self.allow_other_attributes:
for name in node.attributes:
# TODO: Support matching default nodes for attributes.
if name not in self.attributes:
return match.fail(f"Attribute {name} not expected in node.")

return match

Expand All @@ -524,7 +542,14 @@ def enumerate_inputs(inputs, index):
inputs.extend(swapped)
outputs = [value.name for value in self.outputs]
return [
NodePattern(self.domain, self.op, input, self.attributes, outputs)
NodePattern(
self.domain,
self.op,
input,
self.attributes,
outputs,
self.allow_other_attributes,
)
for input in inputs
]

Expand Down Expand Up @@ -961,11 +986,15 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool:
if not self._match_value(previous_node_output_pattern, arg_value):
return False

for i, output_value_pattern in enumerate(pattern_node.outputs):
if not self._bind_value(output_value_pattern, node.outputs[i]):
return False

match.nodes.append(node)
return True

def _match_value(self, pattern_value: ValuePattern, value: ir.Value) -> bool:
"""Match an IR value against a ValuePattern instance."""
def _bind_value(self, pattern_value: ValuePattern, value: ir.Value) -> bool:
"""Bind a ValuePattern var to ir Value."""
if pattern_value.name is not None:
match = self._match
if pattern_value.name in match.bindings:
Expand All @@ -974,6 +1003,12 @@ def _match_value(self, pattern_value: ValuePattern, value: ir.Value) -> bool:
return True
return self.fail(f"Variable {pattern_value.name} is bound to multiple values.")
match.bindings[pattern_value.name] = value
return True

def _match_value(self, pattern_value: ValuePattern, value: ir.Value) -> bool:
"""Match an IR value against a ValuePattern instance."""
if not self._bind_value(pattern_value, value):
return False

if isinstance(pattern_value, NodeOutputPattern):
return self._match_node_output(pattern_value, value)
Expand Down
52 changes: 52 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,58 @@ def double(op, x):
)
onnx.checker.check_model(ir.serde.serialize_model(model))

def test_optional_attribute(self):
"""Test rules with optional attributes."""

def concat_pattern(op, x, y):
seq = op.SequenceConstruct(x, y)
result = op.ConcatFromSequence(seq, outputs=["result"])
return result

def concat(op, x, y, result: ir.Value):
node = result.producer()
assert node is not None
axis = node.attributes.get("axis", None)
return op.Concat(x, y, axis=axis)

rule = pattern.RewriteRule(concat_pattern, concat)

# Case 1: a model with attribute axis present
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] x, float[N] y) => (float[M] z)
{
t = SequenceConstruct (x, y)
z = ConcatFromSequence <axis=0> (t)
}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = rule.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 1)
self.assertEqual(model.graph[0].op_type, "Concat")
self.assertEqual(model.graph[0].attributes["axis"].value, 0)

# Case 2: a model with attribute axis absent
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] x, float[N] y) => (float[M] z)
{
t = SequenceConstruct (x, y)
z = ConcatFromSequence (t)
}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = rule.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 1)
self.assertEqual(model.graph[0].op_type, "Concat")
self.assertNotIn("axis", model.graph[0].attributes)


if __name__ == "__main__":
unittest.main()

0 comments on commit e824285

Please sign in to comment.