diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 180ac1717..d144502ed 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -4,7 +4,7 @@ import inspect import itertools import math -from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Any, Callable, List, MutableSequence, Optional, Sequence, Tuple import numpy as np import onnx @@ -320,7 +320,7 @@ class MatchResult: """Represents the result of a match operation. A match can either succeed or fail. - If it succeeds, it returns a list of IR values that matched the pattern + If it succeeds, it returns a list of nodes that matched the pattern and a set of bindings for the variables in the pattern. Example: @@ -330,46 +330,40 @@ def pattern(x, shape1, shape2): t2 = op.Reshape(t1, shape2) return t2 The above pattern matches a sequence of two Reshape ops. - The matched_values will contain the values representing the (output of) - the two Reshape ops, and the bindings will contain the values that - are bound to the variables `x`, `shape1`, and `shape2`. + The matched_nodes will contain the two Reshape ops, and the bindings will + contain the values that are bound to the variables `x`, `shape1`, and `shape2`. """ - def __init__( - self, matched_values=None, bindings: dict[str, ir.Value | Any] | None = None - ) -> None: - assert matched_values is None or isinstance(matched_values, list) - self.success: bool = matched_values is not None - # For a successful match, matched_values is a list of values that matched the pattern. + def __init__(self, success: bool) -> None: + self.success: bool = success + # For a successful match, matched_nodes is a list of values that matched the pattern. # These include the internal nodes of the pattern that were matched, but not # the leaves (sub-trees) that match against the variables in the pattern. # These represent the values that will be replaced by the replacement pattern. - self.matched_values: Sequence[Any] | None = matched_values + self.matched_nodes: MutableSequence[ir.Node] = [] # For a successful match, bindings is a dictionary of mapping pattern-variable-names # to values. - self.bindings: dict[str, Any] = bindings if bindings is not None else {} + self.bindings: dict[str, Any] = {} def __bool__(self): return self.success @classmethod def FAIL(cls): - return cls(None) + return cls(False) @property - def values(self) -> Sequence[Any] | None: - return self.matched_values + def nodes(self) -> MutableSequence[ir.Node]: + return self.matched_nodes - def fail(self): - self.success = False - self.matched_values = None - self.bindings = {} + def bind(self, var: str, value: Any): + self.bindings[var] = value def extend(self, other: MatchResult | bool): if not self.success: return if not other: - self.fail() + self.success = False return if isinstance(other, bool): return @@ -377,12 +371,12 @@ def extend(self, other: MatchResult | bool): if var in self.bindings: # TODO: handle attribute var bindings if self.bindings[var] != val: - self.fail() + self.success = False return else: self.bindings[var] = val - assert self.matched_values is not None, "matched_values should not be None." - self.matched_values.extend(other.matched_values) # type: ignore[attr-defined] + assert self.matched_nodes is not None, "matched_nodes should not be None." + self.matched_nodes.extend(other.matched_nodes) # type: ignore[attr-defined] class ValuePattern: @@ -399,9 +393,10 @@ def __repr__(self) -> str: return f"ValuePattern({self.name!r})" def matches(self, value: ir.Value, model: ir.Model): - if self.name is None: - return MatchResult([], {}) - return MatchResult([], {self.name: value}) + result = MatchResult(success=True) + if self.name is not None: + result.bind(self.name, value) + return result def commute(self) -> Sequence[ValuePattern]: """Return a list of commuted patterns. @@ -467,7 +462,7 @@ def matches_node(self, node: ir.Node, model: ir.Model) -> MatchResult: return MatchResult.FAIL() if not self.op.matches(node.op_type): return MatchResult.FAIL() - match = MatchResult([]) + match = MatchResult(success=True) # TODO: We should add filtered logging starting from here to emit why # matching failed. This should cut a lot of noises compared to logging everything, # because at least the starting node op_type is already matched. @@ -491,11 +486,11 @@ def matches_node(self, node: ir.Node, model: ir.Model) -> MatchResult: return MatchResult.FAIL() match.extend(sub_match) for name in node.attributes: - # TODO: Support matching default values for attributes. + # TODO: Support matching default nodes for attributes. if name not in self.attributes: return MatchResult.FAIL() - assert match.values is not None, "Matched values should not be None." - match.values.append(node) # type: ignore[attr-defined] + assert match.nodes is not None, "Matched nodes should not be None." + match.nodes.append(node) return match def commute(self) -> Sequence[NodePattern]: @@ -563,10 +558,15 @@ def __init__( self.rel_tol = rel_tol self.abs_tol = abs_tol - def match_scalar(self, scalar_value, return_value: Sequence[ir.Node]): - if math.isclose(scalar_value, self.value, rel_tol=self.rel_tol, abs_tol=self.abs_tol): - return MatchResult(return_value) - return MatchResult.FAIL() + def match_scalar(self, scalar_value): + status = math.isclose( + scalar_value, self.value, rel_tol=self.rel_tol, abs_tol=self.abs_tol + ) + # Note: If the value is produced by a Constant node, we could include + # the Constant node in the return_value list. However, we don't do that. + # Instead, we will rely on DCE to remove the constant node if it is not + # used elsewhere. + return MatchResult(success=status) def matches(self, value: ir.Value, model: ir.Model): value = _ir_utils.propagate_const_value(value) @@ -578,13 +578,7 @@ def matches(self, value: ir.Value, model: ir.Model): if constant_value.size != 1: return MatchResult.FAIL() - return_value: list[ir.Node] = [] - # Note: If the value is produced by a Constant node, we could include - # the Constant node in the return_value list. However, we don't do that. - # Instead, we will rely on DCE to remove the constant node if it is not - # used elsewhere. - - return self.match_scalar(constant_value.item(), return_value) + return self.match_scalar(constant_value.item()) def commute(self) -> list[ValuePattern]: return [self] @@ -635,7 +629,7 @@ def pattern(x, shape1, shape2): return node_pattern, num_outputs -def _valid_to_replace(matched_nodes: Sequence[Any]) -> bool: +def _valid_to_replace(matched_nodes: Sequence[ir.Node]) -> bool: """Check that values computed by the matched_nodes, except for the last one, are used only by the matched_nodes.""" # * Must check that all values matched by pattern are used only by pattern, # except for the value that is replaced. @@ -721,7 +715,8 @@ def used_opsets(self) -> UsedOpsets: class ReplacementSubgraph: """A subgraph that will replace the matched pattern.""" - new_values: Sequence[ir.Value] + match: MatchResult + new_outputs: Sequence[ir.Value] new_nodes: Sequence[ir.Node] used_opsets: UsedOpsets @@ -736,15 +731,14 @@ class ReplacementPatternFunction: def __init__(self, function) -> None: self._function = function - def get_replacement( - self, - match_bindings: dict[str, ir.Value | Any] | None = None, - ) -> ReplacementSubgraph: + def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None: context = RewriterContext() - new_values = self._function(context, **match_bindings) - if not isinstance(new_values, Sequence): - new_values = [new_values] - return ReplacementSubgraph(new_values, context.nodes, context.used_opsets) + new_outputs = self._function(context, **match.bindings) + if new_outputs is None: + return None # Failed to create replacement subgraph + if not isinstance(new_outputs, Sequence): + new_outputs = [new_outputs] + return ReplacementSubgraph(match, new_outputs, context.nodes, context.used_opsets) def _update_opset_imports( @@ -823,29 +817,26 @@ def matches(self, node: ir.Node, model: ir.Model) -> MatchResult: def try_rewrite( self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node - ): # TODO(rama) -> ReplacementSubgraph | None: + ) -> ReplacementSubgraph | None: """If the node matches the pattern, then replace the node with the replacement pattern.""" match = self.matches(node, model) if match: - assert match.values is not None, "Matched values should not be None." - if _valid_to_replace(match.values): - # bindings will be consumed by the replacement function - delta = self._replacement_pattern.get_replacement(match.bindings) - if len(delta.new_values) != self._target_num_outputs: + assert match.nodes is not None, "Matched values should not be None." + if _valid_to_replace(match.nodes): + replacement_subgraph = self._replacement_pattern.get_replacement(match) + if replacement_subgraph is None: + return None + if len(replacement_subgraph.new_outputs) != self._target_num_outputs: raise ValueError( f"Number of outputs from replacement function does not match the number of outputs from the target pattern. " - f"Expected {self._target_num_outputs}, but got {len(delta.new_values)}." + f"Expected {self._target_num_outputs}, but got {len(replacement_subgraph.new_outputs)}." ) # TODO(rama): Check/update opset-imports - # (i) Integrate following with the multi-output matcher and code elsewhere: + # (i) Following is required by multi-output matcher too; move this. # (ii) Remove the opset imports from deleted nodes? - # (iii) Code in the caller (below) checks if match overlaps previous match, which - # appears incorrect for single-pattern matcher. Best to alter iteration to apply - # each rewrite immediately, instead of accumulating them. - # (iv) return delta here - _update_opset_imports(graph_or_function, delta) - _update_opset_imports(model.graph, delta) - return match.values, delta.new_nodes + _update_opset_imports(graph_or_function, replacement_subgraph) + _update_opset_imports(model.graph, replacement_subgraph) + return replacement_subgraph return None def apply_to_model(self, model: ir.Model, *, commute: bool = False): @@ -869,12 +860,13 @@ def replace_pattern(new_pattern): return [replace_pattern(p) for p in self._target_node_pattern.commute()] -def _apply_deltas( +def _apply_delta( graph_or_function: ir.Graph | ir.Function, + node: ir.Node, # TODO(jutinchuby): Use a more descriptive data structure to store deltas - deltas: Sequence[tuple[int, tuple[Sequence[ir.Node], Sequence[ir.Node]]]], + delta, ): - """Applies deltas. + """Applies delta. This code is valid is the considered pattern has only one output. In case of multi output replacements, there is not need to rename @@ -891,56 +883,16 @@ def _apply_deltas( We could reorder (long) or do more clever changes. The reordering would probably happen not very often. """ - existing_ids = {id(n): (i, n) for i, n in enumerate(graph_or_function)} - to_delete: set[ir.Node] = set() - to_insert: list[tuple[ir.Node, list[ir.Node]]] = [] - - for i, delta in reversed(deltas): - if len(delta) == 3: - # multi-outut strategy - n_matches, deleted_nodes, inserted_nodes = delta - for d in deleted_nodes: - assert id(d) in existing_ids - to_delete.add(d) - - # the position to insert must be chosen. - # we'll try position i - assert i not in to_insert # conflicts should avoid that case - to_insert.append((graph_or_function[i], inserted_nodes)) - else: - deleted_nodes, inserted_nodes = delta - # Replace deleted nodes with inserted nodes. - # TODO: simplify this - last_deleted = deleted_nodes[-1] - last_inserted = inserted_nodes[-1] - - for old_value, new_value in zip(last_deleted.outputs, last_inserted.outputs): - # Propagate relevant info from old value to new value - # TODO(Rama): Perhaps we should merge old and new types. As of now, new - # values don't have type information. Note that this could be a problem - # for semantics-altering rewrite-rules: we should allow users to override - # this for such rules. - new_value.type = old_value.type - new_value.shape = old_value.shape - new_value.const_value = old_value.const_value - new_value.name = old_value.name - - # Reconnect the users of the deleted node to use the new outputs - _convenience.replace_all_uses_with(last_deleted.outputs, last_inserted.outputs) - # Update graph/function outputs if the node generates output - replacement_mapping = dict(zip(last_deleted.outputs, last_inserted.outputs)) - for idx, graph_or_function_output in enumerate(graph_or_function.outputs): - if graph_or_function_output in replacement_mapping: - graph_or_function.outputs[idx] = replacement_mapping[ - graph_or_function_output - ] - - # insert new nodes after the index node - graph_or_function.insert_after(last_deleted, inserted_nodes) - graph_or_function.remove(deleted_nodes, safe=True) - - for replaced_node, inserted_nodes in to_insert: - graph_or_function.insert_after(replaced_node, inserted_nodes) + + if isinstance(delta, tuple): + # multi-output strategy + n_matches, matched_nodes, inserted_nodes = delta + + # TODO(rama): Was "assert i not in to_insert"; seems wrong. + # What is this trying to check? Best effort correction below. + assert node not in inserted_nodes # conflicts should avoid that case + + graph_or_function.insert_after(node, inserted_nodes) # TODO: improve this # This is updating the graph/function outputs to use the new outputs for inserted_node in inserted_nodes: @@ -948,7 +900,36 @@ def _apply_deltas( if (index := new_output.meta.get(_ir_utils.GRAPH_OUTPUT_META_KEY)) is not None: # type: ignore[assignment] graph_or_function.outputs[index] = new_output - graph_or_function.remove(to_delete, safe=True) + for d in matched_nodes: + assert d in graph_or_function + graph_or_function.remove(matched_nodes, safe=True) + else: + assert isinstance(delta, ReplacementSubgraph) + # Replace matched nodes with new nodes. + last_inserted = delta.new_nodes[-1] + + for old_value, new_value in zip(node.outputs, last_inserted.outputs): + # Propagate relevant info from old value to new value + # TODO(Rama): Perhaps we should merge old and new types. As of now, new + # values don't have type information. Note that this could be a problem + # for semantics-altering rewrite-rules: we should allow users to override + # this for such rules. + new_value.type = old_value.type + new_value.shape = old_value.shape + new_value.const_value = old_value.const_value + new_value.name = old_value.name + + # Reconnect the users of the deleted node to use the new outputs + _convenience.replace_all_uses_with(node.outputs, last_inserted.outputs) + # Update graph/function outputs if the node generates output + replacement_mapping = dict(zip(node.outputs, last_inserted.outputs)) + for idx, graph_or_function_output in enumerate(graph_or_function.outputs): + if graph_or_function_output in replacement_mapping: + graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] + + # insert new nodes after the index node + graph_or_function.insert_after(node, delta.new_nodes) + graph_or_function.remove(delta.match.nodes, safe=True) class RewriteRuleSet: @@ -963,36 +944,17 @@ def _apply_to_graph_or_function( graph_or_function: ir.Graph | ir.Function, ) -> int: count = 0 - marked = set() + # NOTE: Rules should be prioritized in the order they are added to the RewriteRuleSet. # And the graph is applied in order. for rule in self.rules: - deltas = [] - for i, node in enumerate(graph_or_function): + for node in graph_or_function: delta = rule.try_rewrite(model, graph_or_function, node) - if delta is None: continue - - matched_nodes, _ = delta[-2:] - - conflict = False - for n in matched_nodes: - if id(n) in marked: - # The same node cannot be matched twice with different patterns. - conflict = True - break - - if conflict: - # Some nodes are already marked as rewritten. - continue - - marked |= set(map(id, matched_nodes)) - - deltas.append((i, delta)) + _apply_delta(graph_or_function, node, delta) count += 1 - _apply_deltas(graph_or_function, deltas) return count def apply_to_model(self, model: ir.Model) -> int: