diff --git a/onnxscript/rewriter/_tape.py b/onnxscript/rewriter/_tape.py index 8ebed05fa..d757ec45e 100644 --- a/onnxscript/rewriter/_tape.py +++ b/onnxscript/rewriter/_tape.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Iterable, Mapping, Sequence +from typing import Any, Iterable, Mapping, Optional, Sequence, Tuple from onnxscript import ir from onnxscript.ir import _convenience @@ -59,3 +59,46 @@ def op_multi_output( self._nodes.append(node) return node.outputs + + +# A type representing the domains/versions used in creating nodes in IR. +UsedOpsets = Sequence[Tuple[str, Optional[int]]] + + +class Builder(Tape): + """An extension of the tape that provides a more convenient API for constructing the IR.""" + + def __init__(self): + super().__init__() + self._used_opsets: UsedOpsets = [] + + def __getattr__(self, op_type: str) -> Any: + return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) + + def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): + domain = kwargs.pop("_domain", "") + version = kwargs.pop("_version", None) + outputs = kwargs.pop("_outputs", 1) + if isinstance(outputs, Sequence): + num_outputs = len(outputs) + else: + assert isinstance(outputs, int) + num_outputs = outputs + + self._used_opsets.append((domain, version)) + if num_outputs == 1: + value = super().op(op_type, inputs=inputs, attributes=kwargs, domain=domain) + if isinstance(outputs, Sequence): + value.name = outputs[0] + return value + values = super().op_multi_output( + op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs + ) + if isinstance(outputs, Sequence): + for value, name in zip(values, outputs): + value.name = name + return values + + @property + def used_opsets(self) -> UsedOpsets: + return self._used_opsets diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 87544874d..454c7419f 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -13,9 +13,7 @@ Callable, Iterable, Iterator, - List, MutableSequence, - Optional, Protocol, Sequence, Tuple, @@ -818,58 +816,7 @@ def _valid_to_replace( return True -# A type representing the domains/versions used in creating a replacement subgraph -UsedOpsets = List[Tuple[str, Optional[int]]] - - -class RewriterContext: - """Context parameter used to build the replacement pattern.""" - - # TODO(justinchuby): Merge with the rest of pattern building methods - def __init__(self): - self._tape = _tape.Tape() - self._used_opsets: UsedOpsets = [] - - def __getattr__(self, op_type: str) -> Any: - return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) - - def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): - # TODO(rama): some of the following logic should move into the tape. - domain = kwargs.pop("_domain", "") - version = kwargs.pop("_version", None) - outputs = kwargs.pop("_outputs", 1) - if isinstance(outputs, Sequence): - num_outputs = len(outputs) - else: - assert isinstance(outputs, int) - num_outputs = outputs - - self._used_opsets.append((domain, version)) - if num_outputs == 1: - value = self._tape.op(op_type, inputs=inputs, attributes=kwargs, domain=domain) - if isinstance(outputs, Sequence): - value.name = outputs[0] - return value - values = self._tape.op_multi_output( - op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs - ) - if isinstance(outputs, Sequence): - for value, name in zip(values, outputs): - value.name = name - return values - - @property - def nodes(self) -> Sequence[ir.Node]: - # TODO(rama): The current tape-based implementation will not track nodes added - # via overloaded operators, eg., `x + y`. One possible way to fix this is to - # have values/nodes know which tape they belong to (instead of a graph/function). - # However, it is unclear we need this feature for rewriting: we could also - # identify the nodes to be inserted from the replacement values (by tracing back). - return self._tape.nodes - - @property - def used_opsets(self) -> UsedOpsets: - return self._used_opsets +RewriterContext = _tape.Builder @dataclasses.dataclass @@ -879,7 +826,7 @@ class ReplacementSubgraph: match: MatchResult new_outputs: Sequence[ir.Value] new_nodes: Sequence[ir.Node] - used_opsets: UsedOpsets + used_opsets: _tape.UsedOpsets def always_true(*args, **kwargs) -> bool: