Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed May 2, 2024
1 parent 98df760 commit 0b39aff
Showing 1 changed file with 9 additions and 27 deletions.
36 changes: 9 additions & 27 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
class Pattern(Protocol, Generic[T]):

Check failure

Code scanning / lintrunner

MYPY/misc Error

Invariant type variable "T" used in protocol where contravariant one is expected To disable, use # type: ignore[misc]
"""This is essentially a Predicate[T], that is, a Callable[[T], bool] bound to the name "matches"."""

def matches(self, item: T) -> bool: ...
def matches(self, item: T) -> bool:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

Check warning on line 34 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L34

Added line #L34 was not covered by tests


class StringConstantPattern(Pattern[str]):
Expand Down Expand Up @@ -90,6 +91,7 @@ def matches(self, attr: ir.Attr | ir.RefAttr) -> bool:


def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> AttrPattern:
"""Represents promotion of values allowed as keyword-arguments in a pattern-builder call to an AttrPattern."""
if isinstance(value, AttrPattern):
return value

Check warning on line 96 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L96

Added line #L96 was not covered by tests
if type(value) == ValuePattern:
Expand Down Expand Up @@ -138,10 +140,10 @@ def domain_prefix(cls, domain: str) -> OpsetPatternBuilder:
def matches(self, domain):
return self.domain_pattern.matches(domain)

def __getattr__(self, name: str) -> Any:
def __getattr__(self, name: str) -> OpPatternBuilder:
return OpPatternBuilder(self, StringConstantPattern(name))

def submodule(self, name: str) -> Any:
def submodule(self, name: str) -> OpPatternBuilder:
"""This method is used to match against submodule ops with prefix."""
return OpPatternBuilder(self, PrefixPattern(name))

Expand Down Expand Up @@ -264,6 +266,10 @@ def nodes(self) -> MutableSequence[ir.Node]:
return self.matched_nodes

def bind(self, var: str, value: Any) -> bool:
"""Binds a pattern variable name to a value from the matched IR.
Returns True if the binding is successful, False otherwise (when the binding is inconsistent).
"""
if var in self.bindings:
# TODO(rama): Use appropriate equality-check here.
if self.bindings[var] == value:
Expand Down Expand Up @@ -369,30 +375,6 @@ def __init__(
self.inputs = [_to_value_pattern(x) for x in inputs]
self.attributes = attributes

def matches(self, node: ir.Node) -> bool:
"""Examine if the IR node matches the self pattern."""
if not self.domain.matches(node.domain):
return False
if not self.op.matches(node.op_type):
return False
match = MatchResult(success=True)

# Sub-graphs not handled.
for name, attr_pattern in self.attributes.items():
attr_value = node.attributes.get(name)
if attr_value is None:
return False
if not attr_pattern.matches(attr_value):
return False
if attr_pattern.name is not None:
if not match.bind(attr_pattern.name, attr_value):
return match
for name in node.attributes:
# TODO: Support matching default nodes for attributes.
if name not in self.attributes:
return False
return True

def matches_node(self, node: ir.Node) -> MatchResult:
"""Examine if the IR node matches the self pattern."""
if not self.domain.matches(node.domain):
Expand Down

0 comments on commit 0b39aff

Please sign in to comment.