diff --git a/onnxscript/ir/_convenience.py b/onnxscript/ir/_convenience.py index 86d2f88c3..166e7581b 100644 --- a/onnxscript/ir/_convenience.py +++ b/onnxscript/ir/_convenience.py @@ -395,3 +395,45 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]: continue values[value.name] = value return values + + +def replace_nodes_and_values( + graph_or_function: _core.Graph | _core.Function, + /, + insertion_point: _core.Node, + old_nodes: Sequence[_core.Node], + new_nodes: Sequence[_core.Node], + old_values: Sequence[_core.Value], + new_values: Sequence[_core.Value], +) -> None: + """Replaces nodes and values in the graph or function. + + Args: + graph_or_function: The graph or function to replace nodes and values in. + insertion_point: The node to insert the new nodes after. + old_nodes: The nodes to replace. + new_nodes: The nodes to replace with. + old_values: The values to replace. + new_values: The values to replace with. + """ + + for old_value, new_value in zip(old_values, new_values): + # Propagate relevant info from old value to new value + # TODO(Rama): Perhaps this should be a separate utility function. Also, consider + # merging old and new type/shape info. + new_value.type = old_value.type + new_value.shape = old_value.shape + new_value.const_value = old_value.const_value + new_value.name = old_value.name + + # Reconnect the users of the deleted values to use the new values + replace_all_uses_with(old_values, new_values) + # Update graph/function outputs if the node generates output + replacement_mapping = dict(zip(old_values, new_values)) + for idx, graph_or_function_output in enumerate(graph_or_function.outputs): + if graph_or_function_output in replacement_mapping: + graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] + + # insert new nodes after the index node + graph_or_function.insert_after(insertion_point, new_nodes) + graph_or_function.remove(old_nodes, safe=True) diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py index 03140f16a..fc8416cc1 100644 --- a/onnxscript/ir/convenience.py +++ b/onnxscript/ir/convenience.py @@ -8,12 +8,14 @@ "convert_attribute", "convert_attributes", "replace_all_uses_with", + "replace_nodes_and_values", ] from onnxscript.ir._convenience import ( convert_attribute, convert_attributes, replace_all_uses_with, + replace_nodes_and_values, ) # NOTE: Do not implement any other functions in this module. diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 6140b06f7..9f4899e0e 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -18,6 +18,7 @@ import onnxscript.ir._convenience as _convenience import onnxscript.optimizer.constant_folding as constant_folding import onnxscript.rewriter.pattern as orp +import onnxscript.utils.utils as utils def is_control_flow_op(node: ir.Node) -> bool: @@ -27,14 +28,13 @@ def is_control_flow_op(node: ir.Node) -> bool: def is_non_deterministic_op(node: ir.Node) -> bool: - return ( - node.op_type in constant_folding.non_deterministic_ops - and constant_folding.is_onnx_domain(node.domain) + return node.op_type in constant_folding.non_deterministic_ops and utils.is_onnx_domain( + node.domain ) def is_constant_op(node: ir.Node) -> bool: - return node.op_type in {"Constant", "ConstantOfShape"} and constant_folding.is_onnx_domain( + return node.op_type in {"Constant", "ConstantOfShape"} and utils.is_onnx_domain( node.domain ) @@ -648,32 +648,11 @@ def convert(av): def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name) - # TODO: what about new opset_imports? - old_values = node.outputs - new_values = replacement.new_outputs - for old_value, new_value in zip(old_values, new_values): - # Propagate relevant info from old value to new value - # TODO(Rama): Perhaps we should merge old and new types. As of now, new - # values don't have type information. Note that this could be a problem - # for semantics-altering rewrite-rules: we should allow users to override - # this for such rules. - new_value.type = old_value.type - new_value.shape = old_value.shape - new_value.const_value = old_value.const_value - new_value.name = old_value.name - - # Reconnect the users of the deleted node to use the new outputs - _convenience.replace_all_uses_with(old_values, new_values) - # Update graph/function outputs if the node generates output - replacement_mapping = dict(zip(old_values, new_values)) - for idx, graph_or_function_output in enumerate(root.outputs): - if graph_or_function_output in replacement_mapping: - root.outputs[idx] = replacement_mapping[graph_or_function_output] - - # insert new nodes after the index node - root.insert_after(node, replacement.new_nodes) - root.remove(node, safe=True) + _convenience.replace_nodes_and_values( + root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs + ) + # TODO: what about new opset_imports? # TODO: track statistics about replaced nodes and sizes of new constants def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None: @@ -698,12 +677,17 @@ def visit_graph(self, graph: ir.Graph) -> None: for node in graph: self.visit_node(node, graph) + def visit_function(self, function: ir.Function) -> None: + for node in function: + self.visit_node(node, function) + def visit_model(self, model: ir.Model) -> None: self._init() self.opset_imports = model.opset_imports self.visit_graph(model.graph) - # TODO(rama): handle functions - # Pending decision on whether we want to specialize functions or not. + for function in model.functions.values(): + # TODO(rama): Should we specialize functions? + self.visit_function(function) def fold_constants( diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 164b92f1e..04c1ffd13 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1238,58 +1238,6 @@ def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None): ) -def _apply_delta( - graph_or_function: ir.Graph | ir.Function, - node: ir.Node, - delta: ReplacementSubgraph, -): - """Applies delta. - - This code is valid is the considered pattern has only one output. - In case of multi output replacements, there is not need to rename - the outputs. - - In case of multi-output design, the nodes may not be necessary inserted - all at the same position. To be convinced, you can take a pattern - producing two outputs, but the second one needs the first one and - another input appeared after the first outputs. What could be - the right place to inserted all of the node. - - The current implementation insert all the nodes at the same position - but checks there is not inconsistency. In that case, it fails. - We could reorder (long) or do more clever changes. - The reordering would probably happen not very often. - """ - - assert isinstance(delta, ReplacementSubgraph) - # Replace matched nodes with new nodes, matched values with new values - old_values = delta.match.outputs - new_values = delta.new_outputs - - for old_value, new_value in zip(old_values, new_values): - # Propagate relevant info from old value to new value - # TODO(Rama): Perhaps we should merge old and new types. As of now, new - # values don't have type information. Note that this could be a problem - # for semantics-altering rewrite-rules: we should allow users to override - # this for such rules. - new_value.type = old_value.type - new_value.shape = old_value.shape - new_value.const_value = old_value.const_value - new_value.name = old_value.name - - # Reconnect the users of the deleted node to use the new outputs - _convenience.replace_all_uses_with(old_values, new_values) - # Update graph/function outputs if the node generates output - replacement_mapping = dict(zip(old_values, new_values)) - for idx, graph_or_function_output in enumerate(graph_or_function.outputs): - if graph_or_function_output in replacement_mapping: - graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] - - # insert new nodes after the index node - graph_or_function.insert_after(node, delta.new_nodes) - graph_or_function.remove(delta.match.nodes, safe=True) - - class RewriteRuleSet: def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: if commute: @@ -1311,7 +1259,19 @@ def _apply_to_graph_or_function( delta = rule.try_rewrite(model, graph_or_function, node, verbose=verbose) if delta is None: continue - _apply_delta(graph_or_function, node, delta) + assert isinstance(delta, ReplacementSubgraph) + # TODO: This does not yet handle the problem of determining the correct insertion point + # for inserted nodes in the case of patterns with multiple output-nodes. The following + # is sufficient for patterns with a single output-node "node", which can serve as the + # insertion-point. + _convenience.replace_nodes_and_values( + graph_or_function, + node, + delta.match.nodes, + delta.new_nodes, + delta.match.outputs, + delta.new_outputs, + ) count += 1 return count