Skip to content

Commit

Permalink
Refactor builder out as an utility
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Aug 2, 2024
1 parent 14f88d3 commit 33471a0
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 56 deletions.
45 changes: 44 additions & 1 deletion onnxscript/rewriter/_tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from typing import Iterable, Mapping, Sequence
from typing import Any, Iterable, Mapping, Optional, Sequence, Tuple

from onnxscript import ir
from onnxscript.ir import _convenience
Expand Down Expand Up @@ -59,3 +59,46 @@ def op_multi_output(
self._nodes.append(node)

return node.outputs


# A type representing the domains/versions used in creating nodes in IR.
UsedOpsets = Sequence[Tuple[str, Optional[int]]]


class Builder(Tape):
"""An extension of the tape that provides a more convenient API for constructing the IR."""

def __init__(self):
super().__init__()
self._used_opsets: UsedOpsets = []

def __getattr__(self, op_type: str) -> Any:
return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)

def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]):
domain = kwargs.pop("_domain", "")
version = kwargs.pop("_version", None)
outputs = kwargs.pop("_outputs", 1)
if isinstance(outputs, Sequence):
num_outputs = len(outputs)
else:
assert isinstance(outputs, int)
num_outputs = outputs

self._used_opsets.append((domain, version))

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

"Sequence[tuple[str, int | None]]" has no attribute "append" To disable, use # type: ignore[attr-defined]
if num_outputs == 1:
value = super().op(op_type, inputs=inputs, attributes=kwargs, domain=domain)
if isinstance(outputs, Sequence):
value.name = outputs[0]
return value
values = super().op_multi_output(
op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs
)
if isinstance(outputs, Sequence):
for value, name in zip(values, outputs):
value.name = name
return values

@property
def used_opsets(self) -> UsedOpsets:
return self._used_opsets
57 changes: 2 additions & 55 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
Callable,
Iterable,
Iterator,
List,
MutableSequence,
Optional,
Protocol,
Sequence,
Tuple,
Expand Down Expand Up @@ -818,58 +816,7 @@ def _valid_to_replace(
return True


# A type representing the domains/versions used in creating a replacement subgraph
UsedOpsets = List[Tuple[str, Optional[int]]]


class RewriterContext:
"""Context parameter used to build the replacement pattern."""

# TODO(justinchuby): Merge with the rest of pattern building methods
def __init__(self):
self._tape = _tape.Tape()
self._used_opsets: UsedOpsets = []

def __getattr__(self, op_type: str) -> Any:
return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)

def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]):
# TODO(rama): some of the following logic should move into the tape.
domain = kwargs.pop("_domain", "")
version = kwargs.pop("_version", None)
outputs = kwargs.pop("_outputs", 1)
if isinstance(outputs, Sequence):
num_outputs = len(outputs)
else:
assert isinstance(outputs, int)
num_outputs = outputs

self._used_opsets.append((domain, version))
if num_outputs == 1:
value = self._tape.op(op_type, inputs=inputs, attributes=kwargs, domain=domain)
if isinstance(outputs, Sequence):
value.name = outputs[0]
return value
values = self._tape.op_multi_output(
op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs
)
if isinstance(outputs, Sequence):
for value, name in zip(values, outputs):
value.name = name
return values

@property
def nodes(self) -> Sequence[ir.Node]:
# TODO(rama): The current tape-based implementation will not track nodes added
# via overloaded operators, eg., `x + y`. One possible way to fix this is to
# have values/nodes know which tape they belong to (instead of a graph/function).
# However, it is unclear we need this feature for rewriting: we could also
# identify the nodes to be inserted from the replacement values (by tracing back).
return self._tape.nodes

@property
def used_opsets(self) -> UsedOpsets:
return self._used_opsets
RewriterContext = _tape.Builder


@dataclasses.dataclass
Expand All @@ -879,7 +826,7 @@ class ReplacementSubgraph:
match: MatchResult
new_outputs: Sequence[ir.Value]
new_nodes: Sequence[ir.Node]
used_opsets: UsedOpsets
used_opsets: _tape.UsedOpsets


def always_true(*args, **kwargs) -> bool:
Expand Down

0 comments on commit 33471a0

Please sign in to comment.