From 1522a650c2d4e23c47038f7a0ec4bfb6415fcb98 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 31 Jul 2024 18:26:01 -0700 Subject: [PATCH 01/11] Partial implementation of pattern builder context --- onnxscript/rewriter/cast_constant_of_shape.py | 2 +- onnxscript/rewriter/gemm_to_matmul_add.py | 2 +- onnxscript/rewriter/llama_rule_sets.py | 2 +- onnxscript/rewriter/no_op.py | 2 +- .../onnxruntime/fused_matmul_rule_sets.py | 2 +- onnxscript/rewriter/onnxruntime/softmax.py | 2 +- onnxscript/rewriter/pattern.py | 77 ++++++++++++------- onnxscript/rewriter/pattern_test.py | 10 +++ 8 files changed, 64 insertions(+), 35 deletions(-) diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/cast_constant_of_shape.py index bd58af933..a8c6dba26 100644 --- a/onnxscript/rewriter/cast_constant_of_shape.py +++ b/onnxscript/rewriter/cast_constant_of_shape.py @@ -9,7 +9,7 @@ from onnxscript import ir from onnxscript.rewriter import pattern -op = pattern.onnxop +# op = pattern.onnxop logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/gemm_to_matmul_add.py index 0b9ee373b..b27f3c77d 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/gemm_to_matmul_add.py @@ -3,7 +3,7 @@ from onnxscript.rewriter import pattern from onnxscript.rewriter.broadcast_to_matmul import check_if_not_need_reshape -op = pattern.onnxop +# op = pattern.onnxop # Pattern to match against diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index 1adb03e16..9d96a64ed 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -11,7 +11,7 @@ import onnxscript.rewriter.no_op as no_op import onnxscript.rewriter.pattern as orp -op = orp.onnxop +# op = orp.onnxop class CastIdentity(orp.RewriteRuleAsClass): diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py index 95c3e2434..46426a9aa 100644 --- a/onnxscript/rewriter/no_op.py +++ b/onnxscript/rewriter/no_op.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from onnxscript.rewriter import pattern -op = pattern.onnxop +# op = pattern.onnxop # TODO: Support 1-D constant tensors # https://github.com/microsoft/onnx-rewriter/issues/186 diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py index 3a4444dbb..adb168713 100644 --- a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py @@ -6,7 +6,7 @@ import onnxscript.rewriter.pattern as orp -op = orp.onnxop +# op = orp.onnxop class FusedMatMulDiv1(orp.RewriteRuleAsClass): diff --git a/onnxscript/rewriter/onnxruntime/softmax.py b/onnxscript/rewriter/onnxruntime/softmax.py index 12ad97672..0b38f5347 100644 --- a/onnxscript/rewriter/onnxruntime/softmax.py +++ b/onnxscript/rewriter/onnxruntime/softmax.py @@ -9,7 +9,7 @@ from onnxscript import ir from onnxscript.rewriter import pattern -op = pattern.onnxop +# op = pattern.onnxop logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 6f3613e5f..525f821d5 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +import contextlib import dataclasses import inspect import itertools @@ -35,7 +36,18 @@ class Pattern(Protocol[T]): # type: ignore[misc] def matches(self, item: T) -> bool: ... -class StringConstantPattern(Pattern[str]): +class StringPattern(abc.ABC, Pattern[str]): + """Abstract base class for string patterns.""" + + @abc.abstractmethod + def matches(self, item: str) -> bool: + pass + + @abc.abstractmethod + def __str__(self) -> str: + pass + +class StringConstantPattern(StringPattern): """Matches strings with given value.""" def __init__(self, value: str): @@ -47,8 +59,11 @@ def matches(self, item: str) -> bool: def __str__(self) -> str: return self._value + def value(self) -> str: + return self._value + -class PrefixPattern(Pattern[str]): +class PrefixPattern(StringPattern): """Matches strings with a given prefix.""" def __init__(self, value: str) -> None: @@ -145,20 +160,14 @@ class OpsetPatternBuilder(Pattern[str]): input model. """ - def __init__(self, domain: Pattern[str] | str) -> None: + def __init__(self, domain: StringPattern | str) -> None: if isinstance(domain, str): - self._domain_name: str | None = domain self._domain_pattern: Pattern[str] = StringConstantPattern(domain) else: - self._domain_name = None self._domain_pattern = domain - @property - def domain_name(self) -> str | None: - return self._domain_name - - def matches(self, domain): - return self._domain_pattern.matches(domain) + def domain_pattern(self) -> StringPattern: + return self._domain_pattern def __getattr__(self, op_name: str) -> OpPatternBuilder: return OpPatternBuilder(self, op_name) @@ -173,9 +182,9 @@ def __str__(self) -> str: onnxop = OpsetPatternBuilder("") -msft_op = OpsetPatternBuilder("com.microsoft") +# msft_op = OpsetPatternBuilder("com.microsoft") -torch_module_op = OpsetPatternBuilder(PrefixPattern("pkg.torch")) +# torch_module_op = OpsetPatternBuilder(PrefixPattern("pkg.torch")) class OpPatternBuilder: @@ -194,10 +203,10 @@ class OpPatternBuilder: def __init__( self, - opset_pattern: OpsetPatternBuilder, + pattern_builder: OpsetPatternBuilder, op_name: str | Pattern[str], ) -> None: - self.opset_pattern = opset_pattern + self.pattern_builder = pattern_builder self.op_name = op_name def __call__( @@ -215,9 +224,9 @@ def __call__( "Version restrictions should be handled by rewrite rules." ) if _domain is None: - opset_pattern = self.opset_pattern + opset_pattern = self.pattern_builder.domain_pattern() elif isinstance(_domain, str): - opset_pattern = OpsetPatternBuilder(_domain) + opset_pattern = StringConstantPattern(_domain) else: # TODO(rama): allow OpsetPatternBuilder as _domain. raise TypeError("_domain must be a string.") @@ -353,6 +362,15 @@ def extend(self, other: MatchResult | bool): assert self._matched_nodes is not None, "_matched_nodes should not be None." self._matched_nodes.extend(other._matched_nodes) # type: ignore[attr-defined] +_pattern_builder : OpPatternBuilder | None = None + +@contextlib.contextmanager +def pattern_builder(rewriter_context: RewriterContext): + global _pattern_builder + prev_builder = _pattern_builder + _pattern_builder = rewriter_context + yield + _pattern_builder = prev_builder class ValuePattern: """Base class for all patterns that match against IR values. @@ -392,31 +410,31 @@ def commute(self) -> Sequence[ValuePattern]: return [self] def __add__(self, other): - return onnxop.Add(self, other) + return _pattern_builder.Add(self, other) def __radd__(self, other): - return onnxop.Add(other, self) + return _pattern_builder.Add(other, self) def __sub__(self, other): - return onnxop.Sub(self, other) + return _pattern_builder.Sub(self, other) def __rsub__(self, other): - return onnxop.Sub(other, self) + return _pattern_builder.Sub(other, self) def __mul__(self, other): - return onnxop.Mul(self, other) + return _pattern_builder.Mul(self, other) def __rmul__(self, other): - return onnxop.Mul(other, self) + return _pattern_builder.Mul(other, self) def __truediv__(self, other): - return onnxop.Div(self, other) + return _pattern_builder.Div(self, other) def __rtruediv__(self, other): - return onnxop.Div(other, self) + return _pattern_builder.Div(other, self) def __pow__(self, other): - return onnxop.Pow(self, other) + return _pattern_builder.Pow(self, other) def __str__(self) -> str: return self._name if self._name is not None else "anonymous:" + str(id(self)) @@ -441,7 +459,7 @@ class NodePattern: def __init__( self, - domain: OpsetPatternBuilder, + domain: StringPattern, op: str | Pattern[str], inputs: Sequence[int | float | ValuePattern | None], attributes: dict[str, AttrPattern], @@ -457,11 +475,11 @@ def __init__( self.attributes = attributes self.allow_other_attributes = allow_other_attributes # In the common case, domain and op are constants, which can be used to optimize matching. - if isinstance(op, str) and domain.domain_name is not None: + if isinstance(op, str) and isinstance(domain, StringConstantPattern): # TODO(rama): support overloaded operators. overload = "" self._op_identifier: tuple[str, str, str] | None = ( - domain.domain_name, + domain.value, op, overload, ) @@ -872,6 +890,7 @@ def __init__(self, function) -> None: def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None: context = RewriterContext() + # with pattern_builder(context): new_outputs = self._function(context, **match.bindings) if new_outputs is None: return None # Failed to create replacement subgraph diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 31985db5a..dbffa2ec9 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -420,6 +420,16 @@ def concat(op, x, y, result: ir.Value): self.assertEqual(model.graph[0].op_type, "Concat") self.assertNotIn("axis", model.graph[0].attributes) +class PatternBuilderTest(unittest.TestCase): + def test_pattern_builder_context(self): + builder = pattern.RewriterContext() + with pattern.pattern_builder(builder): + x = builder.Op1() + y = builder.Op2(x) + z = x + y + w = builder.Op3(z) + ops = [x.op_type for x in builder.nodes] + self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3"]) if __name__ == "__main__": unittest.main() From 97719302c0c81c292fce9c068e413681502bad4a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 1 Aug 2024 11:57:51 -0700 Subject: [PATCH 02/11] Some fixes to builder context --- onnxscript/rewriter/pattern.py | 100 +++++++++++++++++++++++++--- onnxscript/rewriter/pattern_test.py | 7 +- 2 files changed, 93 insertions(+), 14 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 525f821d5..3b73d5f88 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -144,8 +144,8 @@ def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> raise TypeError(f"Cannot convert {type(value)} to AttrPattern") -class OpsetPatternBuilder(Pattern[str]): - """Represents an opset pattern. +class OpsetPatternBuilder: + """Represents an opset pattern and a pattern builder. (i) It is used to create a NodePattern (via OpPatternBuilder). Example usage: @@ -156,15 +156,18 @@ class OpsetPatternBuilder(Pattern[str]): Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance of OpPatternBuilder, and `op.Matmul(x, y)` is an instance of NodePattern. - (ii) An opset pattern is also matched against the actual opset domain used in the + (ii) It contains a domain pattern matched against the actual opset domain used in the input model. """ - def __init__(self, domain: StringPattern | str) -> None: + def __init__(self, domain: StringPattern | str, record: bool = False) -> None: if isinstance(domain, str): - self._domain_pattern: Pattern[str] = StringConstantPattern(domain) + domain = StringConstantPattern(domain) + self._domain_pattern = domain + if record: + self._nodes: list[NodePattern] | None = [] else: - self._domain_pattern = domain + self._nodes = None def domain_pattern(self) -> StringPattern: return self._domain_pattern @@ -179,6 +182,15 @@ def submodule(self, name: str) -> OpPatternBuilder: def __str__(self) -> str: return str(self._domain_pattern) + def add_node(self, node: NodePattern) -> None: + if self._nodes is not None: + self._nodes.append(node) + + def nodes(self) -> Sequence[NodePattern]: + if self._nodes is None: + raise ValueError("Nodes were not recorded.") + return self._nodes + onnxop = OpsetPatternBuilder("") @@ -242,6 +254,7 @@ def __call__( node_pattern = NodePattern( opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes ) + self.pattern_builder.add_node(node_pattern) output_values = node_pattern.outputs # Unpack outputs if there is only one output, the common case. if len(output_values) == 1: @@ -362,7 +375,7 @@ def extend(self, other: MatchResult | bool): assert self._matched_nodes is not None, "_matched_nodes should not be None." self._matched_nodes.extend(other._matched_nodes) # type: ignore[attr-defined] -_pattern_builder : OpPatternBuilder | None = None +_pattern_builder : OpPatternBuilder = onnxop @contextlib.contextmanager def pattern_builder(rewriter_context: RewriterContext): @@ -572,6 +585,36 @@ def enumerate_inputs(inputs, index): for input in inputs ] + def commute2(self) -> Sequence[NodePattern]: + list_of_lists = [ + [None] if pattern is None else pattern.commute() for pattern in self.inputs + ] # type: ignore[attr-defined] + + def enumerate_inputs(inputs, index): + if index >= len(inputs): + yield [] + else: + for pattern in inputs[index]: + for rest in enumerate_inputs(inputs, index + 1): + yield [pattern, *rest] + + inputs = list(enumerate_inputs(list_of_lists, 0)) + if self.domain.matches("") and (self.op.matches("Add") or self.op.matches("Mul")): + # TODO: handle cases where number of inputs is not 2. + swapped = [[x[1], x[0]] for x in inputs] + inputs.extend(swapped) + outputs = [value.name for value in self.outputs] + return [ + NodePattern( + self.domain, + self.op, + input, + self.attributes, + outputs, + self.allow_other_attributes, + ) + for input in inputs + ] class NodeOutputPattern(ValuePattern): """Represents a pattern that matches against a specific output of a Node. @@ -675,13 +718,13 @@ class GraphPattern: """Represents a pattern that can be matched against a subgraph.""" def __init__( - self, inputs: Sequence[ValuePattern], outputs: Sequence[ValuePattern] + self, inputs: Sequence[ValuePattern], outputs: Sequence[ValuePattern], nodes: Sequence[NodePattern] ) -> None: self._inputs = inputs self._outputs = outputs if len(outputs) == 0: raise ValueError("GraphPattern must have at least one output") - self._nodes = _nodes_in_pattern(outputs) + self._nodes = nodes # _nodes_in_pattern(outputs) # Check if all outputs are produced by the same node. output_nodes: set[NodePattern] = set() @@ -736,6 +779,39 @@ def num_outputs(self) -> int: return len(self._outputs) def commute(self) -> Sequence[GraphPattern]: + def _commute_node(node: NodePattern) -> Sequence[NodePattern]: + return node.commute() + + list_of_lists = [ + [None] if pattern is None else pattern.commute() for pattern in self.inputs + ] # type: ignore[attr-defined] + + def enumerate_inputs(inputs, index): + if index >= len(inputs): + yield [] + else: + for pattern in inputs[index]: + for rest in enumerate_inputs(inputs, index + 1): + yield [pattern, *rest] + + inputs = list(enumerate_inputs(list_of_lists, 0)) + if self.domain.matches("") and (self.op.matches("Add") or self.op.matches("Mul")): + # TODO: handle cases where number of inputs is not 2. + swapped = [[x[1], x[0]] for x in inputs] + inputs.extend(swapped) + outputs = [value.name for value in self.outputs] + return [ + NodePattern( + self.domain, + self.op, + input, + self.attributes, + outputs, + self.allow_other_attributes, + ) + for input in inputs + ] + if not self.has_single_output_node: raise NotImplementedError( "Cannot commute a graph pattern with multiple output nodes." @@ -776,13 +852,15 @@ def pattern(op, x: Var, shape1: Var, shape2: Var): """ _pattern_vars = inspect.signature(pattern_constructor).parameters pattern_inputs = [Var(v) for v in _pattern_vars][1:] # Skip the first parameter - pattern_outputs = pattern_constructor(onnxop, *pattern_inputs) + builder = OpsetPatternBuilder("", record=True) + with pattern_builder(builder): + pattern_outputs = pattern_constructor(builder, *pattern_inputs) # TODO(rama): classify inputs as value/attribute vars # Returned value could be a single ValuePattern or a list of ValuePatterns. # Normalize representation to a list of ValuePatterns. if isinstance(pattern_outputs, ValuePattern): pattern_outputs = [pattern_outputs] - return GraphPattern(pattern_inputs, pattern_outputs) + return GraphPattern(pattern_inputs, pattern_outputs, builder.nodes()) def _valid_to_replace( diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index dbffa2ec9..1171ef8f7 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -422,14 +422,15 @@ def concat(op, x, y, result: ir.Value): class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): - builder = pattern.RewriterContext() + builder = pattern.OpsetPatternBuilder("", True) with pattern.pattern_builder(builder): x = builder.Op1() y = builder.Op2(x) z = x + y w = builder.Op3(z) - ops = [x.op_type for x in builder.nodes] - self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3"]) + t = z * w + ops = [x.op_type for x in builder.nodes()] + self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"]) if __name__ == "__main__": unittest.main() From 340aa3dc18c28d9535388946e7fdfeea6084e2e9 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 1 Aug 2024 15:13:30 -0700 Subject: [PATCH 03/11] Cleanup commute implementation --- onnxscript/rewriter/pattern.py | 138 +++++++++++++--------------- onnxscript/rewriter/pattern_test.py | 2 + 2 files changed, 64 insertions(+), 76 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 3b73d5f88..a755b565c 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -45,7 +45,8 @@ def matches(self, item: str) -> bool: @abc.abstractmethod def __str__(self) -> str: - pass + pass + class StringConstantPattern(StringPattern): """Matches strings with given value.""" @@ -375,7 +376,9 @@ def extend(self, other: MatchResult | bool): assert self._matched_nodes is not None, "_matched_nodes should not be None." self._matched_nodes.extend(other._matched_nodes) # type: ignore[attr-defined] -_pattern_builder : OpPatternBuilder = onnxop + +_pattern_builder: OpPatternBuilder = onnxop + @contextlib.contextmanager def pattern_builder(rewriter_context: RewriterContext): @@ -385,6 +388,7 @@ def pattern_builder(rewriter_context: RewriterContext): yield _pattern_builder = prev_builder + class ValuePattern: """Base class for all patterns that match against IR values. @@ -397,6 +401,10 @@ def __init__(self, name: str | None) -> None: # Note: uses will be computed only when the full graph-pattern is constructed. self._uses: list[tuple[NodePattern, int]] = [] + def clone(self, node_map: dict[NodePattern, NodePattern]) -> ValuePattern: + del node_map + return ValuePattern(self._name) + @property def name(self) -> str | None: return self._name @@ -585,36 +593,20 @@ def enumerate_inputs(inputs, index): for input in inputs ] - def commute2(self) -> Sequence[NodePattern]: - list_of_lists = [ - [None] if pattern is None else pattern.commute() for pattern in self.inputs - ] # type: ignore[attr-defined] - - def enumerate_inputs(inputs, index): - if index >= len(inputs): - yield [] - else: - for pattern in inputs[index]: - for rest in enumerate_inputs(inputs, index + 1): - yield [pattern, *rest] - - inputs = list(enumerate_inputs(list_of_lists, 0)) - if self.domain.matches("") and (self.op.matches("Add") or self.op.matches("Mul")): - # TODO: handle cases where number of inputs is not 2. - swapped = [[x[1], x[0]] for x in inputs] - inputs.extend(swapped) + def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePattern: + inputs = [v.clone(node_map) for v in self.inputs] + if swap: + assert ( + len(inputs) == 2 + ), "Internal error: commutative swap applies only to binary ops." + inputs = [inputs[1], inputs[0]] outputs = [value.name for value in self.outputs] - return [ - NodePattern( - self.domain, - self.op, - input, - self.attributes, - outputs, - self.allow_other_attributes, - ) - for input in inputs - ] + copy = NodePattern( + self.domain, self.op, inputs, self.attributes, outputs, self.allow_other_attributes + ) + node_map[self] = copy + return copy + class NodeOutputPattern(ValuePattern): """Represents a pattern that matches against a specific output of a Node. @@ -630,6 +622,10 @@ def __init__( self._producer = producer self._output_index = output_index + def clone(self, node_map: dict[NodePattern, NodePattern]) -> NodeOutputPattern: + return node_map[self._producer].outputs[self._output_index] + # return NodeOutputPattern(node_map[self._producer], self._output_index, self._name) + @property def output_index(self) -> int: return self._output_index @@ -659,6 +655,10 @@ def __init__( self._rel_tol = rel_tol self._abs_tol = abs_tol + def clone(self, node_map: dict[NodePattern, NodePattern]) -> Constant: + del node_map + return Constant(self._value, self._rel_tol, self._abs_tol) + @property def value(self) -> int | float: return self._value @@ -718,7 +718,10 @@ class GraphPattern: """Represents a pattern that can be matched against a subgraph.""" def __init__( - self, inputs: Sequence[ValuePattern], outputs: Sequence[ValuePattern], nodes: Sequence[NodePattern] + self, + inputs: Sequence[ValuePattern], + outputs: Sequence[ValuePattern], + nodes: Sequence[NodePattern], ) -> None: self._inputs = inputs self._outputs = outputs @@ -779,50 +782,33 @@ def num_outputs(self) -> int: return len(self._outputs) def commute(self) -> Sequence[GraphPattern]: - def _commute_node(node: NodePattern) -> Sequence[NodePattern]: - return node.commute() - - list_of_lists = [ - [None] if pattern is None else pattern.commute() for pattern in self.inputs - ] # type: ignore[attr-defined] - - def enumerate_inputs(inputs, index): - if index >= len(inputs): - yield [] - else: - for pattern in inputs[index]: - for rest in enumerate_inputs(inputs, index + 1): - yield [pattern, *rest] - - inputs = list(enumerate_inputs(list_of_lists, 0)) - if self.domain.matches("") and (self.op.matches("Add") or self.op.matches("Mul")): - # TODO: handle cases where number of inputs is not 2. - swapped = [[x[1], x[0]] for x in inputs] - inputs.extend(swapped) - outputs = [value.name for value in self.outputs] - return [ - NodePattern( - self.domain, - self.op, - input, - self.attributes, - outputs, - self.allow_other_attributes, - ) - for input in inputs - ] - - if not self.has_single_output_node: - raise NotImplementedError( - "Cannot commute a graph pattern with multiple output nodes." - ) - nodes = self.output_node.commute() - return [ - GraphPattern( - self._inputs, [NodeOutputPattern(n, i) for i in range(self.num_outputs)] - ) - for n in nodes - ] + def commute_node(node: NodePattern) -> Iterable[bool]: + if node.op_identifier() == ("", "Add", "") or node.op_identifier() == ( + "", + "Mul", + "", + ): + # Try with and without swapping inputs. + return [False, True] + # No swapping of inputs + return [False] + + iteration_space = [commute_node(node) for node in self._nodes] + + def copy_graph(swap_list: Iterable[bool]) -> GraphPattern: + if not any(swap_list): + # No need to swap inputs of any node + return self + # Create a copy of the graph, with swapped inputs for the nodes that need it. + node_map: dict[NodePattern, NodePattern] = {} + new_inputs = [v.clone(node_map) for v in self._inputs] + new_nodes = [ + node.clone(node_map, swap) for node, swap in zip(self._nodes, swap_list) + ] + new_outputs = [v.clone(node_map) for v in self._outputs] + return GraphPattern(new_inputs, new_outputs, new_nodes) + + return [copy_graph(swap_list) for swap_list in itertools.product(*iteration_space)] def __str__(self) -> str: inputs = ", ".join(str(v) for v in self._inputs) diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 1171ef8f7..0f375beed 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -420,6 +420,7 @@ def concat(op, x, y, result: ir.Value): self.assertEqual(model.graph[0].op_type, "Concat") self.assertNotIn("axis", model.graph[0].attributes) + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): builder = pattern.OpsetPatternBuilder("", True) @@ -432,5 +433,6 @@ def test_pattern_builder_context(self): ops = [x.op_type for x in builder.nodes()] self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"]) + if __name__ == "__main__": unittest.main() From df97a912e6405a79acfb3c5e5bb3afc6ba73b096 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 1 Aug 2024 15:27:47 -0700 Subject: [PATCH 04/11] Fix some issues --- onnxscript/rewriter/pattern.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index a755b565c..02b8087b4 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -197,7 +197,7 @@ def nodes(self) -> Sequence[NodePattern]: # msft_op = OpsetPatternBuilder("com.microsoft") -# torch_module_op = OpsetPatternBuilder(PrefixPattern("pkg.torch")) +torch_module_op = OpsetPatternBuilder(PrefixPattern("pkg.torch")) class OpPatternBuilder: @@ -500,7 +500,7 @@ def __init__( # TODO(rama): support overloaded operators. overload = "" self._op_identifier: tuple[str, str, str] | None = ( - domain.value, + domain.value(), op, overload, ) From 34dc9ff88373babdfb2076e8b57118c97e74443f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 1 Aug 2024 15:45:48 -0700 Subject: [PATCH 05/11] Remove unused commute implementation --- onnxscript/rewriter/pattern.py | 50 ---------------------------------- 1 file changed, 50 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 02b8087b4..f4dafc6a0 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -421,15 +421,6 @@ def append_use(self, node: NodePattern, index: int): def __repr__(self) -> str: return f"ValuePattern({self._name!r})" - def commute(self) -> Sequence[ValuePattern]: - """Return a list of commuted patterns. - - This is used to handle commutative operations like addition and multiplication. - A single pattern is converted into a list of equivalent patterns by swapping - the parameters of commutative operations. - """ - return [self] - def __add__(self, other): return _pattern_builder.Add(self, other) @@ -562,37 +553,6 @@ def matches(self, node: ir.Node, match: MatchResult) -> MatchResult: return match - def commute(self) -> Sequence[NodePattern]: - list_of_lists = [ - [None] if pattern is None else pattern.commute() for pattern in self.inputs - ] # type: ignore[attr-defined] - - def enumerate_inputs(inputs, index): - if index >= len(inputs): - yield [] - else: - for pattern in inputs[index]: - for rest in enumerate_inputs(inputs, index + 1): - yield [pattern, *rest] - - inputs = list(enumerate_inputs(list_of_lists, 0)) - if self.domain.matches("") and (self.op.matches("Add") or self.op.matches("Mul")): - # TODO: handle cases where number of inputs is not 2. - swapped = [[x[1], x[0]] for x in inputs] - inputs.extend(swapped) - outputs = [value.name for value in self.outputs] - return [ - NodePattern( - self.domain, - self.op, - input, - self.attributes, - outputs, - self.allow_other_attributes, - ) - for input in inputs - ] - def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePattern: inputs = [v.clone(node_map) for v in self.inputs] if swap: @@ -630,13 +590,6 @@ def clone(self, node_map: dict[NodePattern, NodePattern]) -> NodeOutputPattern: def output_index(self) -> int: return self._output_index - def commute(self) -> Sequence[ValuePattern]: - # TODO - return [ - NodeOutputPattern(pattern, self._output_index, self.name) - for pattern in self._producer.commute() - ] - def producer(self) -> NodePattern: return self._producer @@ -690,9 +643,6 @@ def matches(self, value: ir.Value, match: MatchResult) -> MatchResult: # used elsewhere. return match - def commute(self) -> list[ValuePattern]: - return [self] - def __str__(self) -> str: return str(self._value) From a78ed8bd382f77db183fc41a06f99cd87b98d830 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 1 Aug 2024 15:48:30 -0700 Subject: [PATCH 06/11] Remove unused onnxop import --- onnxscript/rewriter/onnxruntime/softmax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/rewriter/onnxruntime/softmax.py b/onnxscript/rewriter/onnxruntime/softmax.py index 0b38f5347..f1d6df7b6 100644 --- a/onnxscript/rewriter/onnxruntime/softmax.py +++ b/onnxscript/rewriter/onnxruntime/softmax.py @@ -9,7 +9,6 @@ from onnxscript import ir from onnxscript.rewriter import pattern -# op = pattern.onnxop logger = logging.getLogger(__name__) From 0526a14ccc7499995919c2b40617fd0c02585b48 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 1 Aug 2024 15:50:31 -0700 Subject: [PATCH 07/11] Remove unused onnxop import --- onnxscript/rewriter/cast_constant_of_shape.py | 1 - onnxscript/rewriter/gemm_to_matmul_add.py | 2 -- onnxscript/rewriter/llama_rule_sets.py | 2 -- onnxscript/rewriter/no_op.py | 2 -- onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py | 2 -- onnxscript/rewriter/pattern_test.py | 2 +- 6 files changed, 1 insertion(+), 10 deletions(-) diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/cast_constant_of_shape.py index a8c6dba26..34656ff19 100644 --- a/onnxscript/rewriter/cast_constant_of_shape.py +++ b/onnxscript/rewriter/cast_constant_of_shape.py @@ -9,7 +9,6 @@ from onnxscript import ir from onnxscript.rewriter import pattern -# op = pattern.onnxop logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/gemm_to_matmul_add.py index b27f3c77d..bff77839f 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/gemm_to_matmul_add.py @@ -3,8 +3,6 @@ from onnxscript.rewriter import pattern from onnxscript.rewriter.broadcast_to_matmul import check_if_not_need_reshape -# op = pattern.onnxop - # Pattern to match against def reshape_gemm_reshape_pattern(op, input_a, input_b, input_c, shape_a, shape_c): diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index 9d96a64ed..0d163d0a2 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -11,8 +11,6 @@ import onnxscript.rewriter.no_op as no_op import onnxscript.rewriter.pattern as orp -# op = orp.onnxop - class CastIdentity(orp.RewriteRuleAsClass): """Replaces ``Cast(., to=to)`` by ``Identity`` if possible.""" diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py index 46426a9aa..7a4b00798 100644 --- a/onnxscript/rewriter/no_op.py +++ b/onnxscript/rewriter/no_op.py @@ -2,8 +2,6 @@ # Licensed under the MIT License. from onnxscript.rewriter import pattern -# op = pattern.onnxop - # TODO: Support 1-D constant tensors # https://github.com/microsoft/onnx-rewriter/issues/186 diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py index adb168713..65496ec8b 100644 --- a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py @@ -6,8 +6,6 @@ import onnxscript.rewriter.pattern as orp -# op = orp.onnxop - class FusedMatMulDiv1(orp.RewriteRuleAsClass): """Replaces ``MatMul + Div`` by FusedMatMul.""" diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 0f375beed..5385a5233 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -429,7 +429,7 @@ def test_pattern_builder_context(self): y = builder.Op2(x) z = x + y w = builder.Op3(z) - t = z * w + _ = z * w ops = [x.op_type for x in builder.nodes()] self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"]) From 67527db0db195c4e88f13cffbe75e5b795f6769c Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 1 Aug 2024 16:13:36 -0700 Subject: [PATCH 08/11] Remove commented code --- onnxscript/rewriter/pattern.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index f4dafc6a0..e54e0f472 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -195,8 +195,6 @@ def nodes(self) -> Sequence[NodePattern]: onnxop = OpsetPatternBuilder("") -# msft_op = OpsetPatternBuilder("com.microsoft") - torch_module_op = OpsetPatternBuilder(PrefixPattern("pkg.torch")) @@ -904,7 +902,6 @@ def __init__(self, function) -> None: def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None: context = RewriterContext() - # with pattern_builder(context): new_outputs = self._function(context, **match.bindings) if new_outputs is None: return None # Failed to create replacement subgraph From 0599d7f03859bf064be9953090b45893a2c813e8 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 1 Aug 2024 16:55:12 -0700 Subject: [PATCH 09/11] Fix typo --- onnxscript/rewriter/pattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index e54e0f472..32f01f337 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -375,7 +375,7 @@ def extend(self, other: MatchResult | bool): self._matched_nodes.extend(other._matched_nodes) # type: ignore[attr-defined] -_pattern_builder: OpPatternBuilder = onnxop +_pattern_builder: OpsetPatternBuilder = onnxop @contextlib.contextmanager From efb99eaa4598505a81e366d8fe4c2d0d4b3567d0 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 2 Aug 2024 07:50:25 -0700 Subject: [PATCH 10/11] Fix lint warnings --- onnxscript/rewriter/pattern.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 32f01f337..d58cb7201 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -379,10 +379,10 @@ def extend(self, other: MatchResult | bool): @contextlib.contextmanager -def pattern_builder(rewriter_context: RewriterContext): +def pattern_builder(builder: OpsetPatternBuilder): global _pattern_builder prev_builder = _pattern_builder - _pattern_builder = rewriter_context + _pattern_builder = builder yield _pattern_builder = prev_builder @@ -552,7 +552,7 @@ def matches(self, node: ir.Node, match: MatchResult) -> MatchResult: return match def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePattern: - inputs = [v.clone(node_map) for v in self.inputs] + inputs = [(v.clone(node_map) if v is not None else None) for v in self.inputs] if swap: assert ( len(inputs) == 2 From 4a2f31171815d19eb32fe80c5c9c55627e0de112 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 2 Aug 2024 10:48:53 -0700 Subject: [PATCH 11/11] Update onnxscript/rewriter/pattern.py Co-authored-by: Justin Chu --- onnxscript/rewriter/pattern.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index d58cb7201..87544874d 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -559,11 +559,11 @@ def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePat ), "Internal error: commutative swap applies only to binary ops." inputs = [inputs[1], inputs[0]] outputs = [value.name for value in self.outputs] - copy = NodePattern( + copied = NodePattern( self.domain, self.op, inputs, self.attributes, outputs, self.allow_other_attributes ) - node_map[self] = copy - return copy + node_map[self] = copied + return copied class NodeOutputPattern(ValuePattern):