diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 7e38651db..aee57615e 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -4,254 +4,156 @@ import inspect import itertools import math -from typing import Any, Callable, List, MutableSequence, Optional, Sequence, Tuple +from typing import ( + Any, + Callable, + List, + MutableSequence, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + Union, +) -import numpy as np import onnx -import onnx.numpy_helper -import onnx.printer from onnxscript import ir from onnxscript.ir import _convenience from onnxscript.rewriter import _ir_utils, _tape -# Overview of the pattern module: The classes below are used to define both -# patterns (that we search for) and replacements for rewrite rules. -# The matches() method of a pattern is used to check if an IR component -# matches the pattern. -# TODO: Ensure that all matches() methods have same type signature (where -# appropriate). +T = TypeVar("T") -class PythonPattern: - def __init__(self, value: int | str | Sequence, name: str | None = None) -> None: - self._value = value - self._name = name - - @property - def value(self) -> int | str | Sequence: - return self._value - - @property - def name(self) -> str | None: - return self._name - - def matches(self, value: int | str | Sequence) -> bool: - return value == self.value - - -class StringConstantPattern: - def __init__(self, value: str, name: str) -> None: - self._value = value - self._name = name - - @property - def value(self) -> str: - return self._value - - @property - def name(self) -> str: - return self._name - - def matches(self, attr: ir.AttrString) -> bool: - return attr.value == self.value - - -class IntConstantPattern: - def __init__(self, value: int, name: str) -> None: - self._value = value - self._name = name - - @property - def value(self) -> int: - return self._value +class Pattern(Protocol[T]): # type: ignore[misc] + """This is essentially a Predicate[T], that is, a Callable[[T], bool] bound to the name "matches".""" - @property - def name(self) -> str: - return self._name + def matches(self, item: T) -> bool: ... - def matches(self, attr: ir.AttrInt64) -> bool: - return attr.value == self.value +class StringConstantPattern(Pattern[str]): + """Matches strings with given value.""" -class ListConstantPattern: - def __init__(self, value: Sequence[int | str | float], name: str) -> None: + def __init__(self, value: str): self._value = value - self._name = name - - @property - def value(self) -> Sequence[int | str | float]: - return self._value - - @property - def name(self) -> str: - return self._name - def matches(self, attr: ir.AttrFloat32s | ir.AttrInt64s | ir.AttrStrings) -> bool: - # TODO: Need more data points to determine if this is the right way to compare lists. - return attr.value == self.value + def matches(self, item: str) -> bool: + return item == self._value -class PrefixPattern: - """This pattern is used to simplify submodule opset pattern matching.""" +class PrefixPattern(Pattern[str]): + """Matches strings with a given prefix.""" def __init__(self, value: str) -> None: self._value = value - @property - def value(self) -> str: - return self._value - def matches(self, value: str) -> bool: - return value.startswith(self.value) + return value.startswith(self._value) -class FloatConstantPattern: - def __init__( - self, value: float, name: str, rel_tol: float = 1e-5, abs_tol: float = 1e-8 - ) -> None: - self._value = value - self._name = name - self._rel_tol = rel_tol - self._abs_tol = abs_tol +class AttrPattern(Pattern[Union[ir.Attr, ir.RefAttr]]): + """Base class for an attribute pattern. Matches any attribute value by default.""" - @property - def value(self): - return self._value + def __init__(self, name: str | None): + self.name = name - @property - def name(self): - return self._name + def matches(self, attr: ir.Attr | ir.RefAttr) -> bool: + return True - def matches(self, attr: ir.AttrFloat32): - return math.isclose( - attr.value, self.value, rel_tol=self._rel_tol, abs_tol=self._abs_tol - ) +# TODO: Support tensors. Align with usage elsewhere. +SupportedAttrTypes = Union[ + int, + float, + str, + Sequence[int], + Sequence[float], + Sequence[str], +] -class TensorConstantPattern: - def __init__( - self, value: ir.TensorProtocol, name, rel_tol: float = 1e-3, abs_tol: float = 1e-3 - ) -> None: - self._value = value - self._name = name - self._rel_tol = rel_tol - self._abs_tol = abs_tol - @property - def value(self): - return self._value +class AttrConstantPattern(AttrPattern): + """Matches attributes with given value. - @property - def name(self): - return self._name - - def matches(self, attr: ir.AttrTensor): - return ( - attr.value.dtype == self._value.dtype - and attr.value.shape == self._value.shape - and np.allclose( - attr.value, - self._value, - rtol=self._rel_tol, - atol=self._abs_tol, - ) - ) + Uses standard equality for matching. For list-valued attributes, the order of elements matters. + If order is immaterial, we need to define a separate pattern for that. + """ + def __init__(self, value: SupportedAttrTypes): + super().__init__(None) + self._value = value -def _make_constant_pattern( - value: float | int | Sequence | ir.TensorProtocol, name: str -) -> ( - IntConstantPattern - | FloatConstantPattern - | TensorConstantPattern - | StringConstantPattern - | ListConstantPattern -): - """Convert an attrbute value to a ConstantPattern.""" - if isinstance(value, float): - return FloatConstantPattern(value, name) - if isinstance(value, int): - return IntConstantPattern(value, name) - if isinstance(value, str): - return StringConstantPattern(value, name) + def matches(self, attr: ir.Attr | ir.RefAttr) -> bool: + return isinstance(attr, ir.Attr) and attr.value == self._value + + +def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> AttrPattern: + """Represents promotion of values allowed as keyword-arguments in a pattern-builder call to an AttrPattern.""" + if isinstance(value, AttrPattern): + return value + if type(value) == ValuePattern: + # This is a hack. Currently, when we create pattern-variables, we create them as ValuePattern, + # and change them to AttrPattern if/when used in an attribute context. We could use type + # annotations to distinguish between ValuePattern and AttrPattern, but forces users to + # use these type annotations. + # TODO: check for misuse at rule-creation time. (Currently will be caught by matcher at match-time.) + return AttrPattern(value.name) + if isinstance(value, (int, float, str)): + return AttrConstantPattern(value) if isinstance(value, Sequence): - return ListConstantPattern(value, name) - if isinstance(value, ir.TensorProtocol): - return TensorConstantPattern(value, name) - raise TypeError(f"Cannot convert {type(value)} to ConstantPattern") - - -class AnyPattern: - def matches(self, value) -> bool: - return True - - -class AttrPattern: - def __init__( - self, value: Var | int | float | Sequence | ir.TensorProtocol, name: str - ) -> None: - if isinstance(value, Var): - self.value_pattern = value - elif isinstance(value, (int, float, Sequence, ir.TensorProtocol)): - self.value_pattern = _make_constant_pattern(value, name) # type: ignore[assignment] - else: - raise TypeError(f"Cannot convert {type(value)} to AttrPattern") - - def matches( - self, - attr_val: int | float | Sequence | Var | ir.TensorProtocol | ir.Value, - model: ir.Model, - ) -> MatchResult: - if isinstance(self.value_pattern, Var): - return self.value_pattern.matches(attr_val, model) # type: ignore[arg-type] - return self.value_pattern.matches(attr_val) + if all(isinstance(i, (int, float)) for i in value): + return AttrConstantPattern(value) + if all(isinstance(i, str) for i in value): + return AttrConstantPattern(value) + raise ValueError("Only lists of int/float/str can be used as an AttrPattern") + raise TypeError(f"Cannot convert {type(value)} to AttrPattern") -class OpsetPattern: +class OpsetPatternBuilder(Pattern[str]): """Represents an opset pattern. - It is used primarily to create a NodePattern (via OpPattern). + (i) It is used to create a NodePattern (via OpPatternBuilder). Example usage: :: z = op.Matmul(x, y) - Here, `op` is an instance of OpsetPattern and `op.Matmul` is an instance - of OpPattern, and `op.Matmul(x, y)` is an instance of NodePattern. + Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance + of OpPatternBuilder, and `op.Matmul(x, y)` is an instance of NodePattern. - An opset pattern is also matched against the actual opset used in the + (ii) An opset pattern is also matched against the actual opset domain used in the input model. """ - def __init__(self, domain_pattern: PythonPattern | PrefixPattern | str) -> None: + def __init__(self, domain_pattern: Pattern[str] | str) -> None: if isinstance(domain_pattern, str): - domain_pattern = PythonPattern(domain_pattern) + domain_pattern = StringConstantPattern(domain_pattern) self.domain_pattern = domain_pattern @classmethod - def domain_prefix(cls, domain: str) -> OpsetPattern: + def domain_prefix(cls, domain: str) -> OpsetPatternBuilder: return cls(PrefixPattern(domain)) def matches(self, domain): return self.domain_pattern.matches(domain) - def __getattr__(self, name: str) -> Any: - return OpPattern(self, PythonPattern(name)) + def __getattr__(self, name: str) -> OpPatternBuilder: + return OpPatternBuilder(self, StringConstantPattern(name)) - def submodule(self, name: str) -> Any: + def submodule(self, name: str) -> OpPatternBuilder: """This method is used to match against submodule ops with prefix.""" - return OpPattern(self, PrefixPattern(name)) + return OpPatternBuilder(self, PrefixPattern(name)) -onnxop = OpsetPattern("") +onnxop = OpsetPatternBuilder("") -msft_op = OpsetPattern("com.microsoft") +msft_op = OpsetPatternBuilder("com.microsoft") -torch_module_op = OpsetPattern.domain_prefix("pkg.torch") +torch_module_op = OpsetPatternBuilder.domain_prefix("pkg.torch") -class OpPattern: +class OpPatternBuilder: """A utility class to build a NodePattern. It is used primarily to create a NodePattern. @@ -260,15 +162,15 @@ class OpPattern: z = op.Matmul(x, y) - Here, `op` is an instance of OpsetPattern and `op.Matmul` is an instance - of OpPattern, and `op.Matmul(x, y)` is an instance of NodePattern. + Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance + of OpPatternBuilder, and `op.Matmul(x, y)` is an instance of NodePattern. """ def __init__( self, - opset_pattern: OpsetPattern, - op_name_pattern: PythonPattern | PrefixPattern, + opset_pattern: Pattern[str], + op_name_pattern: Pattern[str], ) -> None: self.opset_pattern = opset_pattern self.op_name_pattern = op_name_pattern @@ -280,10 +182,11 @@ def __call__(self, *args, **kwargs): del kwargs["_num_outputs"] else: num_outputs = 1 - attributes = { - name: AttrPattern(value=value, name=name) for (name, value) in kwargs.items() - } - node_pattern = NodePattern(self.opset_pattern, self.op_name_pattern, args, attributes) + inputs = [_to_value_pattern(x) for x in args] + attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()} + node_pattern = NodePattern( + self.opset_pattern, self.op_name_pattern, inputs, attributes + ) if num_outputs == 1: return NodeOutputPattern(node_pattern, 0) else: @@ -292,7 +195,7 @@ def __call__(self, *args, **kwargs): def _to_value_pattern( x: ValuePattern | int | float | None, -) -> NodeOutputPattern | Constant | ValuePattern | None: +) -> ValuePattern | None: """Promotes an input-value used to construct a NodePattern to a ValuePattern. Example usage: @@ -310,8 +213,13 @@ def _to_value_pattern( """ if x is None or isinstance(x, ValuePattern): return x - if isinstance(x, (int, float, Sequence)): + if isinstance(x, (int, float)): return Constant(x) + # TODO(rama): support lists of int/float + # if isinstance(x, list): + # if all(isinstance(i, (int, float)) for i in x): + # return Constant(x) + # raise ValueError("Only lists of int/float can be used as a ValuePattern") # TODO(titaiwang): Could this be wrapped Constant? raise TypeError(f"Cannot convert {type(x)} to ValuePattern") @@ -356,8 +264,19 @@ def FAIL(cls): def nodes(self) -> MutableSequence[ir.Node]: return self.matched_nodes - def bind(self, var: str, value: Any): + def bind(self, var: str, value: Any) -> bool: + """Binds a pattern variable name to a value from the matched IR. + + Returns True if the binding is successful, False otherwise (when the binding is inconsistent). + """ + if var in self.bindings: + # TODO(rama): Use appropriate equality-check here. + if self.bindings[var] == value: + return True + self.success = False + return False self.bindings[var] = value + return True def extend(self, other: MatchResult | bool): if not self.success: @@ -392,7 +311,7 @@ def __init__(self, name: str | None) -> None: def __repr__(self) -> str: return f"ValuePattern({self.name!r})" - def matches(self, value: ir.Value, model: ir.Model): + def matches(self, value: ir.Value): result = MatchResult(success=True) if self.name is not None: result.bind(self.name, value) @@ -445,8 +364,8 @@ class NodePattern: def __init__( self, - domain: OpsetPattern, - op: PythonPattern | PrefixPattern, + domain: Pattern[str], + op: Pattern[str], inputs: Sequence[int | float | ValuePattern | None], attributes: dict[str, AttrPattern], ): @@ -454,9 +373,8 @@ def __init__( self.op = op self.inputs = [_to_value_pattern(x) for x in inputs] self.attributes = attributes - self.bound_value = None - def matches_node(self, node: ir.Node, model: ir.Model) -> MatchResult: + def matches_node(self, node: ir.Node) -> MatchResult: """Examine if the IR node matches the self pattern.""" if not self.domain.matches(node.domain): return MatchResult.FAIL() @@ -472,7 +390,7 @@ def matches_node(self, node: ir.Node, model: ir.Model) -> MatchResult: continue if arg_value is None or previous_node_output_pattern is None: return MatchResult.FAIL() - sub_match = previous_node_output_pattern.matches(arg_value, model) # type: ignore[attr-defined] + sub_match = previous_node_output_pattern.matches(arg_value) match.extend(sub_match) if not match: # If sub-match failed, return match @@ -481,15 +399,15 @@ def matches_node(self, node: ir.Node, model: ir.Model) -> MatchResult: attr_value = node.attributes.get(name) if attr_value is None: return MatchResult.FAIL() - sub_match = attr_pattern.matches(attr_value, model) # type: ignore[arg-type] - if not sub_match: + if not attr_pattern.matches(attr_value): return MatchResult.FAIL() - match.extend(sub_match) + if attr_pattern.name is not None: + if not match.bind(attr_pattern.name, attr_value): + return match for name in node.attributes: # TODO: Support matching default nodes for attributes. if name not in self.attributes: return MatchResult.FAIL() - assert match.nodes is not None, "Matched nodes should not be None." match.nodes.append(node) return match @@ -528,16 +446,17 @@ def __init__( self.node_pattern = node_pattern self.output_index = output_index - def matches(self, value: ir.Value, model: ir.Model): + def matches(self, value: ir.Value): """Match the StaticValueInfo from IR with the `matches_node()` in node pattern.""" node = value.producer() if node is None: return MatchResult.FAIL() if value.index() != self.output_index: return MatchResult.FAIL() - return self.node_pattern.matches_node(node, model) + return self.node_pattern.matches_node(node) def commute(self) -> Sequence[ValuePattern]: + # TODO return [ NodeOutputPattern(pattern, self.output_index, self.name) for pattern in self.node_pattern.commute() @@ -568,7 +487,7 @@ def match_scalar(self, scalar_value): # used elsewhere. return MatchResult(success=status) - def matches(self, value: ir.Value, model: ir.Model): + def matches(self, value: ir.Value): value = _ir_utils.propagate_const_value(value) constant_value = _ir_utils.get_numpy_from_ir_value(value) if constant_value is None: @@ -584,49 +503,78 @@ def commute(self) -> list[ValuePattern]: return [self] -def _handle_pattern_return_value( - node_output_pattern: NodeOutputPattern | list[NodeOutputPattern], -) -> tuple[NodePattern, int]: - """This checks and cleans up the return value of a pattern-construction function. +class GraphPattern: + """Represents a pattern that can be matched against a subgraph.""" + + def __init__(self, outputs: Sequence[ValuePattern]) -> None: + self.outputs = outputs + if len(outputs) == 0: + raise ValueError("GraphPattern must have at least one output") + # Check if all outputs are produced by the same node. + output_node = None + for i, value_pattern in enumerate(outputs): + if not isinstance(value_pattern, ValuePattern): + raise TypeError( + f"Invalid type {type(value_pattern)} for graph pattern output." + ) + if not isinstance(value_pattern, NodeOutputPattern) or ( + value_pattern.output_index != i + ): + output_node = None + elif i == 0: + output_node = value_pattern.node_pattern + elif value_pattern.node_pattern is not output_node: + output_node = None + self._output_node = output_node + + @property + def num_outputs(self) -> int: + return len(self.outputs) + + def matches_node(self, node: ir.Node) -> MatchResult: + if self._output_node is None: + return MatchResult.FAIL() + return self._output_node.matches_node(node) + + def commute(self) -> Sequence[GraphPattern]: + if self._output_node is None: + raise NotImplementedError( + "Cannot commute a graph pattern with multiple output nodes." + ) + nodes = self._output_node.commute() + return [ + GraphPattern([NodeOutputPattern(n, i) for i in range(self.num_outputs)]) + for n in nodes + ] + + +def _to_graph_pattern(pattern_constructor: Callable) -> GraphPattern: + """Convert a pattern-construction function to a GraphPattern. A pattern-construction function will return values as below: :: - def pattern(x, shape1, shape2): + def pattern(x: Var, shape1: Var, shape2: Var): ... - return op.SomeOp(...) - However, `SomeOp` may represent an ONNX op that produces multiple outputs. - This function validates that the return values represent the outputs of - a single NodePattern. It returns the node_pattern and the number of outputs. - - This follows an important restriction of the pattern-matcher algorithm: it - only matches against subgraphs that end in a single terminal node. If we - permit two terminal nodes, then we would have to match against all possible - pairs of nodes in the graph, which produces an extra quadratic factor in the - complexity of the pattern-matching algorithm. In general, the complexity becomes - exponential in the number of terminal nodes. + return outputs + + We create a pattern graph by creating pattern-variables for each parameter of the function, + and calling the function. The returned values are normalized to a list of ValuePatterns, + which represent the outputs of the pattern graph. Args: - node_output_pattern: NodeOutputPattern | Sequence[NodeOutputPattern] + pattern_constructor: Callable Returns: - tuple[NodePattern, int]: The last node_pattern, num_outputs + GraphPattern: A representation of the pattern that can be matched against a subgraph. """ - if isinstance(node_output_pattern, NodeOutputPattern): - node_pattern = node_output_pattern.node_pattern - num_outputs = 1 - elif isinstance(node_output_pattern, Sequence): - node_pattern = node_output_pattern[0].node_pattern - num_outputs = len(node_output_pattern) - for i, p in enumerate(node_output_pattern): - assert isinstance(p, NodeOutputPattern) - if (p.node_pattern is not node_pattern) or (p.output_index != i): - raise NotImplementedError( - "Multi-output pattern not handled by this API. " - "Use other APIs to handle multi-output patterns." - ) - else: - raise TypeError(f"Invalid type {type(node_output_pattern)} for pattern") - return node_pattern, num_outputs + _pattern_vars = inspect.signature(pattern_constructor).parameters + vars = [Var(v) for v in _pattern_vars] + pattern_outputs = pattern_constructor(*vars) + # Returned value could be a single ValuePattern or a list of ValuePatterns. + # Normalize representation to a list of ValuePatterns. + if isinstance(pattern_outputs, ValuePattern): + pattern_outputs = [pattern_outputs] + return GraphPattern(pattern_outputs) def _valid_to_replace(matched_nodes: Sequence[ir.Node]) -> bool: @@ -647,25 +595,6 @@ def _valid_to_replace(matched_nodes: Sequence[ir.Node]) -> bool: return True -class TargetPatternFunction: - """The targeted pattern that will be replaced by the replacement pattern. - - Attributes: - function (Callable): The pattern function that will be matched against the IR. - """ - - def __init__(self, function: Callable) -> None: - self._function = function - - @property - def function(self) -> Callable: - return self._function - - def get_pattern(self, variables: Sequence[Var]) -> tuple[NodePattern, int]: - node_output_pattern = self._function(*variables) - return _handle_pattern_return_value(node_output_pattern) - - # A type representing the domains/versions used in creating a replacement subgraph UsedOpsets = List[Tuple[str, Optional[int]]] @@ -759,7 +688,7 @@ def _update_opset_imports( class RewriteRule: def __init__( self, - target_pattern: TargetPatternFunction | Callable | None = None, + target_pattern: GraphPattern | Callable | None = None, replacement_pattern: ReplacementPatternFunction | Callable | None = None, condition_function: Callable | None = None, ) -> None: @@ -784,34 +713,21 @@ def __init__( raise ValueError( "replacement_pattern must be provided if target_pattern is provided" ) - # TODO: Do we want to tolerate Callable inputs? - if callable(target_pattern): - target_pattern = TargetPatternFunction(target_pattern) - if callable(replacement_pattern): - replacement_pattern = ReplacementPatternFunction(replacement_pattern) + if not isinstance(target_pattern, GraphPattern): + target_pattern = _to_graph_pattern(target_pattern) self._target_pattern = target_pattern + + if not isinstance(replacement_pattern, ReplacementPatternFunction): + replacement_pattern = ReplacementPatternFunction(replacement_pattern) self._replacement_pattern = replacement_pattern self._condition_function = condition_function - _pattern_vars = inspect.signature(self._target_pattern.function).parameters - - self._vars = [Var(v) for v in _pattern_vars] - # Get the last node pattern and number of outputs from the pattern function - self._target_node_pattern, self._target_num_outputs = self._target_pattern.get_pattern( - self._vars # type: ignore[arg-type] - ) - def matches(self, node: ir.Node, model: ir.Model) -> MatchResult: """Check if the node from IR matches the pattern.""" - if len(node.outputs) != self._target_num_outputs: + if len(node.outputs) != self._target_pattern.num_outputs: return MatchResult.FAIL() - match = self._target_node_pattern.matches_node(node, model) - # NOTE: migrating to a simpler interface for match_condition signature. - # Ideally, the caller should pass in match_bindings as **match_bindings. - # This makes it easier to define this as a function with inputs like - # (input_a, input_b, **_) and omit all references to match_bindings. - # **_ refers to all the unused parameters in the match_condition function. + match = self._target_pattern.matches_node(node) if ( self._condition_function is not None and match @@ -831,10 +747,10 @@ def try_rewrite( replacement_subgraph = self._replacement_pattern.get_replacement(match) if replacement_subgraph is None: return None - if len(replacement_subgraph.new_outputs) != self._target_num_outputs: + if len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs: raise ValueError( f"Number of outputs from replacement function does not match the number of outputs from the target pattern. " - f"Expected {self._target_num_outputs}, but got {len(replacement_subgraph.new_outputs)}." + f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}." ) # TODO(rama): Check/update opset-imports # (i) Following is required by multi-output matcher too; move this. @@ -856,13 +772,11 @@ def replace_pattern(new_pattern): """Return a shallow copy of self with node_pattern replaced by new_pattern.""" rule = RewriteRule() rule._condition_function = self._condition_function - rule._target_node_pattern = new_pattern - rule._target_num_outputs = self._target_num_outputs + rule._target_pattern = new_pattern rule._replacement_pattern = self._replacement_pattern - rule._vars = self._vars return rule - return [replace_pattern(p) for p in self._target_node_pattern.commute()] + return [replace_pattern(p) for p in self._target_pattern.commute()] def _apply_delta(