diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/cast_constant_of_shape.py index bd58af933..34656ff19 100644 --- a/onnxscript/rewriter/cast_constant_of_shape.py +++ b/onnxscript/rewriter/cast_constant_of_shape.py @@ -9,7 +9,6 @@ from onnxscript import ir from onnxscript.rewriter import pattern -op = pattern.onnxop logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/gemm_to_matmul_add.py index 0b9ee373b..bff77839f 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/gemm_to_matmul_add.py @@ -3,8 +3,6 @@ from onnxscript.rewriter import pattern from onnxscript.rewriter.broadcast_to_matmul import check_if_not_need_reshape -op = pattern.onnxop - # Pattern to match against def reshape_gemm_reshape_pattern(op, input_a, input_b, input_c, shape_a, shape_c): diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index 1adb03e16..0d163d0a2 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -11,8 +11,6 @@ import onnxscript.rewriter.no_op as no_op import onnxscript.rewriter.pattern as orp -op = orp.onnxop - class CastIdentity(orp.RewriteRuleAsClass): """Replaces ``Cast(., to=to)`` by ``Identity`` if possible.""" diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py index 95c3e2434..7a4b00798 100644 --- a/onnxscript/rewriter/no_op.py +++ b/onnxscript/rewriter/no_op.py @@ -2,8 +2,6 @@ # Licensed under the MIT License. from onnxscript.rewriter import pattern -op = pattern.onnxop - # TODO: Support 1-D constant tensors # https://github.com/microsoft/onnx-rewriter/issues/186 diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py index 3a4444dbb..65496ec8b 100644 --- a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py @@ -6,8 +6,6 @@ import onnxscript.rewriter.pattern as orp -op = orp.onnxop - class FusedMatMulDiv1(orp.RewriteRuleAsClass): """Replaces ``MatMul + Div`` by FusedMatMul.""" diff --git a/onnxscript/rewriter/onnxruntime/softmax.py b/onnxscript/rewriter/onnxruntime/softmax.py index 12ad97672..f1d6df7b6 100644 --- a/onnxscript/rewriter/onnxruntime/softmax.py +++ b/onnxscript/rewriter/onnxruntime/softmax.py @@ -9,7 +9,6 @@ from onnxscript import ir from onnxscript.rewriter import pattern -op = pattern.onnxop logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 6f3613e5f..87544874d 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +import contextlib import dataclasses import inspect import itertools @@ -35,7 +36,19 @@ class Pattern(Protocol[T]): # type: ignore[misc] def matches(self, item: T) -> bool: ... -class StringConstantPattern(Pattern[str]): +class StringPattern(abc.ABC, Pattern[str]): + """Abstract base class for string patterns.""" + + @abc.abstractmethod + def matches(self, item: str) -> bool: + pass + + @abc.abstractmethod + def __str__(self) -> str: + pass + + +class StringConstantPattern(StringPattern): """Matches strings with given value.""" def __init__(self, value: str): @@ -47,8 +60,11 @@ def matches(self, item: str) -> bool: def __str__(self) -> str: return self._value + def value(self) -> str: + return self._value -class PrefixPattern(Pattern[str]): + +class PrefixPattern(StringPattern): """Matches strings with a given prefix.""" def __init__(self, value: str) -> None: @@ -129,8 +145,8 @@ def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> raise TypeError(f"Cannot convert {type(value)} to AttrPattern") -class OpsetPatternBuilder(Pattern[str]): - """Represents an opset pattern. +class OpsetPatternBuilder: + """Represents an opset pattern and a pattern builder. (i) It is used to create a NodePattern (via OpPatternBuilder). Example usage: @@ -141,24 +157,21 @@ class OpsetPatternBuilder(Pattern[str]): Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance of OpPatternBuilder, and `op.Matmul(x, y)` is an instance of NodePattern. - (ii) An opset pattern is also matched against the actual opset domain used in the + (ii) It contains a domain pattern matched against the actual opset domain used in the input model. """ - def __init__(self, domain: Pattern[str] | str) -> None: + def __init__(self, domain: StringPattern | str, record: bool = False) -> None: if isinstance(domain, str): - self._domain_name: str | None = domain - self._domain_pattern: Pattern[str] = StringConstantPattern(domain) + domain = StringConstantPattern(domain) + self._domain_pattern = domain + if record: + self._nodes: list[NodePattern] | None = [] else: - self._domain_name = None - self._domain_pattern = domain - - @property - def domain_name(self) -> str | None: - return self._domain_name + self._nodes = None - def matches(self, domain): - return self._domain_pattern.matches(domain) + def domain_pattern(self) -> StringPattern: + return self._domain_pattern def __getattr__(self, op_name: str) -> OpPatternBuilder: return OpPatternBuilder(self, op_name) @@ -170,10 +183,17 @@ def submodule(self, name: str) -> OpPatternBuilder: def __str__(self) -> str: return str(self._domain_pattern) + def add_node(self, node: NodePattern) -> None: + if self._nodes is not None: + self._nodes.append(node) -onnxop = OpsetPatternBuilder("") + def nodes(self) -> Sequence[NodePattern]: + if self._nodes is None: + raise ValueError("Nodes were not recorded.") + return self._nodes -msft_op = OpsetPatternBuilder("com.microsoft") + +onnxop = OpsetPatternBuilder("") torch_module_op = OpsetPatternBuilder(PrefixPattern("pkg.torch")) @@ -194,10 +214,10 @@ class OpPatternBuilder: def __init__( self, - opset_pattern: OpsetPatternBuilder, + pattern_builder: OpsetPatternBuilder, op_name: str | Pattern[str], ) -> None: - self.opset_pattern = opset_pattern + self.pattern_builder = pattern_builder self.op_name = op_name def __call__( @@ -215,9 +235,9 @@ def __call__( "Version restrictions should be handled by rewrite rules." ) if _domain is None: - opset_pattern = self.opset_pattern + opset_pattern = self.pattern_builder.domain_pattern() elif isinstance(_domain, str): - opset_pattern = OpsetPatternBuilder(_domain) + opset_pattern = StringConstantPattern(_domain) else: # TODO(rama): allow OpsetPatternBuilder as _domain. raise TypeError("_domain must be a string.") @@ -233,6 +253,7 @@ def __call__( node_pattern = NodePattern( opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes ) + self.pattern_builder.add_node(node_pattern) output_values = node_pattern.outputs # Unpack outputs if there is only one output, the common case. if len(output_values) == 1: @@ -354,6 +375,18 @@ def extend(self, other: MatchResult | bool): self._matched_nodes.extend(other._matched_nodes) # type: ignore[attr-defined] +_pattern_builder: OpsetPatternBuilder = onnxop + + +@contextlib.contextmanager +def pattern_builder(builder: OpsetPatternBuilder): + global _pattern_builder + prev_builder = _pattern_builder + _pattern_builder = builder + yield + _pattern_builder = prev_builder + + class ValuePattern: """Base class for all patterns that match against IR values. @@ -366,6 +399,10 @@ def __init__(self, name: str | None) -> None: # Note: uses will be computed only when the full graph-pattern is constructed. self._uses: list[tuple[NodePattern, int]] = [] + def clone(self, node_map: dict[NodePattern, NodePattern]) -> ValuePattern: + del node_map + return ValuePattern(self._name) + @property def name(self) -> str | None: return self._name @@ -382,41 +419,32 @@ def append_use(self, node: NodePattern, index: int): def __repr__(self) -> str: return f"ValuePattern({self._name!r})" - def commute(self) -> Sequence[ValuePattern]: - """Return a list of commuted patterns. - - This is used to handle commutative operations like addition and multiplication. - A single pattern is converted into a list of equivalent patterns by swapping - the parameters of commutative operations. - """ - return [self] - def __add__(self, other): - return onnxop.Add(self, other) + return _pattern_builder.Add(self, other) def __radd__(self, other): - return onnxop.Add(other, self) + return _pattern_builder.Add(other, self) def __sub__(self, other): - return onnxop.Sub(self, other) + return _pattern_builder.Sub(self, other) def __rsub__(self, other): - return onnxop.Sub(other, self) + return _pattern_builder.Sub(other, self) def __mul__(self, other): - return onnxop.Mul(self, other) + return _pattern_builder.Mul(self, other) def __rmul__(self, other): - return onnxop.Mul(other, self) + return _pattern_builder.Mul(other, self) def __truediv__(self, other): - return onnxop.Div(self, other) + return _pattern_builder.Div(self, other) def __rtruediv__(self, other): - return onnxop.Div(other, self) + return _pattern_builder.Div(other, self) def __pow__(self, other): - return onnxop.Pow(self, other) + return _pattern_builder.Pow(self, other) def __str__(self) -> str: return self._name if self._name is not None else "anonymous:" + str(id(self)) @@ -441,7 +469,7 @@ class NodePattern: def __init__( self, - domain: OpsetPatternBuilder, + domain: StringPattern, op: str | Pattern[str], inputs: Sequence[int | float | ValuePattern | None], attributes: dict[str, AttrPattern], @@ -457,11 +485,11 @@ def __init__( self.attributes = attributes self.allow_other_attributes = allow_other_attributes # In the common case, domain and op are constants, which can be used to optimize matching. - if isinstance(op, str) and domain.domain_name is not None: + if isinstance(op, str) and isinstance(domain, StringConstantPattern): # TODO(rama): support overloaded operators. overload = "" self._op_identifier: tuple[str, str, str] | None = ( - domain.domain_name, + domain.value(), op, overload, ) @@ -523,36 +551,19 @@ def matches(self, node: ir.Node, match: MatchResult) -> MatchResult: return match - def commute(self) -> Sequence[NodePattern]: - list_of_lists = [ - [None] if pattern is None else pattern.commute() for pattern in self.inputs - ] # type: ignore[attr-defined] - - def enumerate_inputs(inputs, index): - if index >= len(inputs): - yield [] - else: - for pattern in inputs[index]: - for rest in enumerate_inputs(inputs, index + 1): - yield [pattern, *rest] - - inputs = list(enumerate_inputs(list_of_lists, 0)) - if self.domain.matches("") and (self.op.matches("Add") or self.op.matches("Mul")): - # TODO: handle cases where number of inputs is not 2. - swapped = [[x[1], x[0]] for x in inputs] - inputs.extend(swapped) + def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePattern: + inputs = [(v.clone(node_map) if v is not None else None) for v in self.inputs] + if swap: + assert ( + len(inputs) == 2 + ), "Internal error: commutative swap applies only to binary ops." + inputs = [inputs[1], inputs[0]] outputs = [value.name for value in self.outputs] - return [ - NodePattern( - self.domain, - self.op, - input, - self.attributes, - outputs, - self.allow_other_attributes, - ) - for input in inputs - ] + copied = NodePattern( + self.domain, self.op, inputs, self.attributes, outputs, self.allow_other_attributes + ) + node_map[self] = copied + return copied class NodeOutputPattern(ValuePattern): @@ -569,17 +580,14 @@ def __init__( self._producer = producer self._output_index = output_index + def clone(self, node_map: dict[NodePattern, NodePattern]) -> NodeOutputPattern: + return node_map[self._producer].outputs[self._output_index] + # return NodeOutputPattern(node_map[self._producer], self._output_index, self._name) + @property def output_index(self) -> int: return self._output_index - def commute(self) -> Sequence[ValuePattern]: - # TODO - return [ - NodeOutputPattern(pattern, self._output_index, self.name) - for pattern in self._producer.commute() - ] - def producer(self) -> NodePattern: return self._producer @@ -598,6 +606,10 @@ def __init__( self._rel_tol = rel_tol self._abs_tol = abs_tol + def clone(self, node_map: dict[NodePattern, NodePattern]) -> Constant: + del node_map + return Constant(self._value, self._rel_tol, self._abs_tol) + @property def value(self) -> int | float: return self._value @@ -629,9 +641,6 @@ def matches(self, value: ir.Value, match: MatchResult) -> MatchResult: # used elsewhere. return match - def commute(self) -> list[ValuePattern]: - return [self] - def __str__(self) -> str: return str(self._value) @@ -657,13 +666,16 @@ class GraphPattern: """Represents a pattern that can be matched against a subgraph.""" def __init__( - self, inputs: Sequence[ValuePattern], outputs: Sequence[ValuePattern] + self, + inputs: Sequence[ValuePattern], + outputs: Sequence[ValuePattern], + nodes: Sequence[NodePattern], ) -> None: self._inputs = inputs self._outputs = outputs if len(outputs) == 0: raise ValueError("GraphPattern must have at least one output") - self._nodes = _nodes_in_pattern(outputs) + self._nodes = nodes # _nodes_in_pattern(outputs) # Check if all outputs are produced by the same node. output_nodes: set[NodePattern] = set() @@ -718,17 +730,33 @@ def num_outputs(self) -> int: return len(self._outputs) def commute(self) -> Sequence[GraphPattern]: - if not self.has_single_output_node: - raise NotImplementedError( - "Cannot commute a graph pattern with multiple output nodes." - ) - nodes = self.output_node.commute() - return [ - GraphPattern( - self._inputs, [NodeOutputPattern(n, i) for i in range(self.num_outputs)] - ) - for n in nodes - ] + def commute_node(node: NodePattern) -> Iterable[bool]: + if node.op_identifier() == ("", "Add", "") or node.op_identifier() == ( + "", + "Mul", + "", + ): + # Try with and without swapping inputs. + return [False, True] + # No swapping of inputs + return [False] + + iteration_space = [commute_node(node) for node in self._nodes] + + def copy_graph(swap_list: Iterable[bool]) -> GraphPattern: + if not any(swap_list): + # No need to swap inputs of any node + return self + # Create a copy of the graph, with swapped inputs for the nodes that need it. + node_map: dict[NodePattern, NodePattern] = {} + new_inputs = [v.clone(node_map) for v in self._inputs] + new_nodes = [ + node.clone(node_map, swap) for node, swap in zip(self._nodes, swap_list) + ] + new_outputs = [v.clone(node_map) for v in self._outputs] + return GraphPattern(new_inputs, new_outputs, new_nodes) + + return [copy_graph(swap_list) for swap_list in itertools.product(*iteration_space)] def __str__(self) -> str: inputs = ", ".join(str(v) for v in self._inputs) @@ -758,13 +786,15 @@ def pattern(op, x: Var, shape1: Var, shape2: Var): """ _pattern_vars = inspect.signature(pattern_constructor).parameters pattern_inputs = [Var(v) for v in _pattern_vars][1:] # Skip the first parameter - pattern_outputs = pattern_constructor(onnxop, *pattern_inputs) + builder = OpsetPatternBuilder("", record=True) + with pattern_builder(builder): + pattern_outputs = pattern_constructor(builder, *pattern_inputs) # TODO(rama): classify inputs as value/attribute vars # Returned value could be a single ValuePattern or a list of ValuePatterns. # Normalize representation to a list of ValuePatterns. if isinstance(pattern_outputs, ValuePattern): pattern_outputs = [pattern_outputs] - return GraphPattern(pattern_inputs, pattern_outputs) + return GraphPattern(pattern_inputs, pattern_outputs, builder.nodes()) def _valid_to_replace( diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 31985db5a..5385a5233 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -421,5 +421,18 @@ def concat(op, x, y, result: ir.Value): self.assertNotIn("axis", model.graph[0].attributes) +class PatternBuilderTest(unittest.TestCase): + def test_pattern_builder_context(self): + builder = pattern.OpsetPatternBuilder("", True) + with pattern.pattern_builder(builder): + x = builder.Op1() + y = builder.Op2(x) + z = x + y + w = builder.Op3(z) + _ = z * w + ops = [x.op_type for x in builder.nodes()] + self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"]) + + if __name__ == "__main__": unittest.main()