Skip to content

Commit

Permalink
Minor cleanup for multi-output-matcher verbose trace output (#1523)
Browse files Browse the repository at this point in the history
Add `__str__` methods to pattern objects and use them in the
multi-output-matcher trace output.
  • Loading branch information
gramalingam authored May 10, 2024
1 parent 9153dda commit 32bcd06
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 14 deletions.
49 changes: 38 additions & 11 deletions onnxscript/rewriter/generic_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,29 @@ def _to_match_result(pmr: PatternMatchResult) -> orp.MatchResult:
return result


def _value_to_str(value: ir.Value | orp.ValuePattern) -> str:
return value.name if value.name is not None else "anonymous:" + str(id(value))


def _opt_value_to_str(value: ir.Value | orp.ValuePattern | None) -> str:
return _value_to_str(value) if value is not None else "None"


def _node_to_str(node: ir.Node | orp.NodePattern) -> str:
inputs = ", ".join(_opt_value_to_str(input) for input in node.inputs)
outputs = ", ".join(_opt_value_to_str(output) for output in node.outputs)
op_type = node.op_type
domain = str(node.domain)
qualified_op = f"{domain}.{op_type}" if domain else op_type
return f"{outputs} = {qualified_op}({inputs})"


# def _pattern_node_to_str(node: orp.NodePattern) -> str:
# inputs = ", ".join(_opt_value_to_str(input) for input in node.inputs)
# outputs = ", ".join(_opt_value_to_str(output) for output in node.outputs)
# return f"{outputs} = {node.op_type}({inputs})"


class GenericPatternMatcher(orp.PatternMatcher):
"""
Implements a pattern optimization for quick experimentation.
Expand Down Expand Up @@ -178,16 +201,16 @@ def none(
else:
msg2 = ""
print(
f"[{self.__class__.__name__}.match] NONE - line: {lineno}:"
f"[{self.__class__.__name__}.match] Match failed at line: {lineno}:"
f"{os.path.split(self.__class__.__module__)[-1]}, "
f"op_type={node.op_type}{msg}{msg2}"
)
return None

def print_match(self, graph_node: ir.Node, pattern_node: orp.NodePattern) -> str:
s1 = f"{graph_node.op_type}({graph_node.inputs})"
s2 = f"{pattern_node.op_type}({pattern_node.inputs})"
return f"match {s1} with {s2} (pattern)"
s1 = _node_to_str(graph_node)
s2 = _node_to_str(pattern_node)
return f"match {s1} with pattern: {s2}"

def _debug_print(self) -> str:
if not hasattr(self, "_debug"):
Expand All @@ -201,7 +224,7 @@ def _s(s: str) -> str:
def _p(n: ir.Node, full: bool = False) -> str:
if full:
return str(n)
return f"{n.op_type}({', '.join([str(input) for input in n.inputs])})"
return _node_to_str(n)

rows = []
for k, v in sorted(self._debug.items()):
Expand All @@ -221,6 +244,8 @@ def _p(n: ir.Node, full: bool = False) -> str:
if k == "hint":
rows.append(f"--hint--: {v[0]}") # type: ignore[arg-type]
for i in v[1:]:
if isinstance(i, str):
rows.append(" " + i)
if isinstance(i, ir.Node):
rows.append(" " + _p(i, full=True))
continue
Expand Down Expand Up @@ -282,9 +307,9 @@ def _match_backward(
self._hint(
"BACKWARD: different node types",
"--pattern",
pattern_pred,
_node_to_str(pattern_pred),
"-- model",
graph_pred,
_node_to_str(graph_pred),
)
return self.none(starting_node, inspect.currentframe().f_lineno)
# matching backward
Expand Down Expand Up @@ -495,13 +520,15 @@ def match(
return self.none()

if self.verbose > 5:
print(f"[GenericPatternMatcher.match] starts with {node}")
print(
f"[GenericPatternMatcher.match] Matching started at node: {_node_to_str(node)}"
)
if self.verbose >= 10:
print(f"[GenericPatternMatcher.match] match pattern {self!r}")
print(f"[GenericPatternMatcher.match] match pattern {self}")

all_pattern_nodes = set(self.pattern)
matched: dict[ir.Node, ir.Node] = {last_pattern_node: node}
stack: list[ir.Node] = [last_pattern_node]
matched: dict[orp.NodePattern, ir.Node] = {last_pattern_node: node}
stack: list[orp.NodePattern] = [last_pattern_node]
iteration = 0

if self.verbose > 5:
Expand Down
44 changes: 41 additions & 3 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def __init__(self, value: str):
def matches(self, item: str) -> bool:
return item == self._value

def __str__(self) -> str:
return self._value


class PrefixPattern(Pattern[str]):
"""Matches strings with a given prefix."""
Expand All @@ -51,6 +54,9 @@ def __init__(self, value: str) -> None:
def matches(self, value: str) -> bool:
return value.startswith(self._value)

def __str__(self) -> str:
return f"{self._value}*"


class AttrPattern(Pattern[Union[ir.Attr, ir.RefAttr]]):
"""Base class for an attribute pattern. Matches any attribute value by default."""
Expand All @@ -65,6 +71,9 @@ def name(self) -> str | None:
def matches(self, attr: ir.Attr | ir.RefAttr) -> bool:
return True

def __str__(self) -> str:
return self._name if self._name is not None else "anonymous:" + str(id(self))


# TODO: Support tensors. Align with usage elsewhere.
SupportedAttrTypes = Union[
Expand All @@ -91,6 +100,9 @@ def __init__(self, value: SupportedAttrTypes):
def matches(self, attr: ir.Attr | ir.RefAttr) -> bool:
return isinstance(attr, ir.Attr) and attr.value == self._value

def __str__(self) -> str:
return str(self._value)


def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> AttrPattern:
"""Represents promotion of values allowed as keyword-arguments in a pattern-builder call to an AttrPattern."""
Expand Down Expand Up @@ -152,6 +164,9 @@ def submodule(self, name: str) -> OpPatternBuilder:
"""This method is used to match against submodule ops with prefix."""
return OpPatternBuilder(self, PrefixPattern(name))

def __str__(self) -> str:
return str(self._domain_pattern)


onnxop = OpsetPatternBuilder("")

Expand Down Expand Up @@ -396,6 +411,9 @@ def __rtruediv__(self, other):
def __pow__(self, other):
return onnxop.Pow(self, other)

def __str__(self) -> str:
return self._name if self._name is not None else "anonymous:" + str(id(self))


class NodePattern:
"""Represents a pattern that matches against a Node.
Expand Down Expand Up @@ -435,14 +453,22 @@ def __init__(
if value is not None:
value.append_use(self, index)

def __str__(self) -> str:
inputs = ", ".join(str(v) for v in self.inputs)
outputs = ", ".join(str(v) for v in self.outputs)
attributes = ", ".join(f"{k}={v}" for k, v in self.attributes.items())
op = str(self.op)
domain = str(self.domain)
qualified_op = f"{domain}.{op}" if domain else op
inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs
return f"{outputs} = {qualified_op} ({inputs_and_attributes})"

def op_identifier(self) -> Tuple[str, str, str] | None:
return self._op_identifier

@property
def op_type(self) -> str:
if self._op_identifier is not None:
return self._op_identifier[1]
return "unknown" # used primarily for debugging
return str(self.op)

def matches(self, node: ir.Node) -> bool:
"""Matches the pattern represented by self against a node.
Expand Down Expand Up @@ -603,6 +629,9 @@ def matches(self, value: ir.Value):
def commute(self) -> list[ValuePattern]:
return [self]

def __str__(self) -> str:
return str(self._value)


def _nodes_in_pattern(outputs: Sequence[ValuePattern]) -> list[NodePattern]:
"""Returns all nodes used in a pattern, given the outputs of the pattern."""
Expand Down Expand Up @@ -696,6 +725,12 @@ def commute(self) -> Sequence[GraphPattern]:
for n in nodes
]

def __str__(self) -> str:
inputs = ", ".join(str(v) for v in self._inputs)
outputs = ", ".join(str(v) for v in self._outputs)
nodes = "\n ".join(str(n) for n in self._nodes)
return f"pattern ({inputs}) {{\n {nodes}\n return {outputs}\n}}"


def _to_graph_pattern(pattern_constructor: Callable) -> GraphPattern:
"""Convert a pattern-construction function to a GraphPattern.
Expand Down Expand Up @@ -866,6 +901,9 @@ def match(
) -> MatchResult:
pass

def __str__(self) -> str:
return str(self.pattern)


class SimplePatternMatcher(PatternMatcher):
def __init__(self, pattern: GraphPattern) -> None:
Expand Down

0 comments on commit 32bcd06

Please sign in to comment.