Skip to content

Commit

Permalink
Cleanup new-IR based constant propagation (#1739)
Browse files Browse the repository at this point in the history
Factor out some common logic between rewriter and constant-propagation
into a utility function, and other minor cleanup.
  • Loading branch information
gramalingam authored Jul 24, 2024
1 parent c37e98b commit 712aa87
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 84 deletions.
42 changes: 42 additions & 0 deletions onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,45 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
continue
values[value.name] = value
return values


def replace_nodes_and_values(
graph_or_function: _core.Graph | _core.Function,
/,
insertion_point: _core.Node,
old_nodes: Sequence[_core.Node],
new_nodes: Sequence[_core.Node],
old_values: Sequence[_core.Value],
new_values: Sequence[_core.Value],
) -> None:
"""Replaces nodes and values in the graph or function.
Args:
graph_or_function: The graph or function to replace nodes and values in.
insertion_point: The node to insert the new nodes after.
old_nodes: The nodes to replace.
new_nodes: The nodes to replace with.
old_values: The values to replace.
new_values: The values to replace with.
"""

for old_value, new_value in zip(old_values, new_values):
# Propagate relevant info from old value to new value
# TODO(Rama): Perhaps this should be a separate utility function. Also, consider
# merging old and new type/shape info.
new_value.type = old_value.type
new_value.shape = old_value.shape
new_value.const_value = old_value.const_value
new_value.name = old_value.name

# Reconnect the users of the deleted values to use the new values
replace_all_uses_with(old_values, new_values)
# Update graph/function outputs if the node generates output
replacement_mapping = dict(zip(old_values, new_values))
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
if graph_or_function_output in replacement_mapping:
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]

# insert new nodes after the index node
graph_or_function.insert_after(insertion_point, new_nodes)
graph_or_function.remove(old_nodes, safe=True)
2 changes: 2 additions & 0 deletions onnxscript/ir/convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
"convert_attribute",
"convert_attributes",
"replace_all_uses_with",
"replace_nodes_and_values",
]

from onnxscript.ir._convenience import (
convert_attribute,
convert_attributes,
replace_all_uses_with,
replace_nodes_and_values,
)

# NOTE: Do not implement any other functions in this module.
Expand Down
46 changes: 15 additions & 31 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import onnxscript.ir._convenience as _convenience
import onnxscript.optimizer.constant_folding as constant_folding
import onnxscript.rewriter.pattern as orp
import onnxscript.utils.utils as utils


def is_control_flow_op(node: ir.Node) -> bool:
Expand All @@ -27,14 +28,13 @@ def is_control_flow_op(node: ir.Node) -> bool:


def is_non_deterministic_op(node: ir.Node) -> bool:
return (
node.op_type in constant_folding.non_deterministic_ops
and constant_folding.is_onnx_domain(node.domain)
return node.op_type in constant_folding.non_deterministic_ops and utils.is_onnx_domain(
node.domain
)


def is_constant_op(node: ir.Node) -> bool:
return node.op_type in {"Constant", "ConstantOfShape"} and constant_folding.is_onnx_domain(
return node.op_type in {"Constant", "ConstantOfShape"} and utils.is_onnx_domain(
node.domain
)

Expand Down Expand Up @@ -648,32 +648,11 @@ def convert(av):
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function):
logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)

# TODO: what about new opset_imports?
old_values = node.outputs
new_values = replacement.new_outputs
for old_value, new_value in zip(old_values, new_values):
# Propagate relevant info from old value to new value
# TODO(Rama): Perhaps we should merge old and new types. As of now, new
# values don't have type information. Note that this could be a problem
# for semantics-altering rewrite-rules: we should allow users to override
# this for such rules.
new_value.type = old_value.type
new_value.shape = old_value.shape
new_value.const_value = old_value.const_value
new_value.name = old_value.name

# Reconnect the users of the deleted node to use the new outputs
_convenience.replace_all_uses_with(old_values, new_values)
# Update graph/function outputs if the node generates output
replacement_mapping = dict(zip(old_values, new_values))
for idx, graph_or_function_output in enumerate(root.outputs):
if graph_or_function_output in replacement_mapping:
root.outputs[idx] = replacement_mapping[graph_or_function_output]

# insert new nodes after the index node
root.insert_after(node, replacement.new_nodes)
root.remove(node, safe=True)
_convenience.replace_nodes_and_values(
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)

# TODO: what about new opset_imports?
# TODO: track statistics about replaced nodes and sizes of new constants

def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None:
Expand All @@ -698,12 +677,17 @@ def visit_graph(self, graph: ir.Graph) -> None:
for node in graph:
self.visit_node(node, graph)

def visit_function(self, function: ir.Function) -> None:
for node in function:
self.visit_node(node, function)

def visit_model(self, model: ir.Model) -> None:
self._init()
self.opset_imports = model.opset_imports
self.visit_graph(model.graph)
# TODO(rama): handle functions
# Pending decision on whether we want to specialize functions or not.
for function in model.functions.values():
# TODO(rama): Should we specialize functions?
self.visit_function(function)


def fold_constants(
Expand Down
66 changes: 13 additions & 53 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,58 +1238,6 @@ def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None):
)


def _apply_delta(
graph_or_function: ir.Graph | ir.Function,
node: ir.Node,
delta: ReplacementSubgraph,
):
"""Applies delta.
This code is valid is the considered pattern has only one output.
In case of multi output replacements, there is not need to rename
the outputs.
In case of multi-output design, the nodes may not be necessary inserted
all at the same position. To be convinced, you can take a pattern
producing two outputs, but the second one needs the first one and
another input appeared after the first outputs. What could be
the right place to inserted all of the node.
The current implementation insert all the nodes at the same position
but checks there is not inconsistency. In that case, it fails.
We could reorder (long) or do more clever changes.
The reordering would probably happen not very often.
"""

assert isinstance(delta, ReplacementSubgraph)
# Replace matched nodes with new nodes, matched values with new values
old_values = delta.match.outputs
new_values = delta.new_outputs

for old_value, new_value in zip(old_values, new_values):
# Propagate relevant info from old value to new value
# TODO(Rama): Perhaps we should merge old and new types. As of now, new
# values don't have type information. Note that this could be a problem
# for semantics-altering rewrite-rules: we should allow users to override
# this for such rules.
new_value.type = old_value.type
new_value.shape = old_value.shape
new_value.const_value = old_value.const_value
new_value.name = old_value.name

# Reconnect the users of the deleted node to use the new outputs
_convenience.replace_all_uses_with(old_values, new_values)
# Update graph/function outputs if the node generates output
replacement_mapping = dict(zip(old_values, new_values))
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
if graph_or_function_output in replacement_mapping:
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]

# insert new nodes after the index node
graph_or_function.insert_after(node, delta.new_nodes)
graph_or_function.remove(delta.match.nodes, safe=True)


class RewriteRuleSet:
def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None:
if commute:
Expand All @@ -1311,7 +1259,19 @@ def _apply_to_graph_or_function(
delta = rule.try_rewrite(model, graph_or_function, node, verbose=verbose)
if delta is None:
continue
_apply_delta(graph_or_function, node, delta)
assert isinstance(delta, ReplacementSubgraph)
# TODO: This does not yet handle the problem of determining the correct insertion point
# for inserted nodes in the case of patterns with multiple output-nodes. The following
# is sufficient for patterns with a single output-node "node", which can serve as the
# insertion-point.
_convenience.replace_nodes_and_values(
graph_or_function,
node,
delta.match.nodes,
delta.new_nodes,
delta.match.outputs,
delta.new_outputs,
)
count += 1

return count
Expand Down

0 comments on commit 712aa87

Please sign in to comment.