-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add utility to inline all functions in a model. Still TODO: * Some edge cases to be considered for renaming and avoiding conflicts, especially with subgraphs. * Must ensure no variable capture happens (part of above renaming). * Test renaming of node names. Fixes #1769 --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
- Loading branch information
1 parent
14e538e
commit d7a6411
Showing
2 changed files
with
509 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,298 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""Implementation of an inliner for onnxscript.ir""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
from typing import Iterable, Sequence, Tuple | ||
|
||
import onnxscript.ir as ir | ||
import onnxscript.ir.convenience as ir_convenience | ||
|
||
# A replacement for a node specifies a list of nodes that replaces the original node, | ||
# and a list of values that replaces the original node's outputs. | ||
|
||
NodeReplacement = Tuple[Sequence[ir.Node], Sequence[ir.Value]] | ||
|
||
# A call stack is a list of identifiers of call sites, where the first element is the | ||
# outermost call site, and the last element is the innermost call site. This is used | ||
# primarily for generating unique names for values in the inlined functions. | ||
CallSiteId = str | ||
CallStack = list[CallSiteId] | ||
|
||
|
||
def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: | ||
"""Generate a unique name from a name, calling-context, and set of used names. | ||
When a value X in a function is inlined into a graph, we rename X by adding a prefix | ||
representing the call-stack of the function. This should typically avoid name clashes. | ||
If there is a name clash, even after this, we add a numeric suffix to the name to make | ||
it unique. We use the same strategy to make node names unique. | ||
""" | ||
prefix = "_".join(callstack) | ||
name = prefix + "_" + name | ||
candidate = name | ||
i = 1 | ||
while candidate in used_names: | ||
i += 1 | ||
candidate = f"{name}_{i}" | ||
used_names.add(candidate) | ||
return candidate | ||
|
||
|
||
class _CopyReplace: | ||
"""Utilities for creating a copy of IR objects with substitutions for attributes/input values.""" | ||
|
||
def __init__( | ||
self, | ||
inliner: _Inliner, | ||
attr_map: dict[str, ir.Attr | ir.RefAttr], | ||
value_map: dict[ir.Value, ir.Value | None], | ||
metadata_props: dict[str, str], | ||
call_stack: CallStack, | ||
) -> None: | ||
self._inliner = inliner | ||
self._value_map = value_map | ||
self._attr_map = attr_map | ||
self._metadata_props = metadata_props | ||
self._call_stack = call_stack | ||
|
||
def clone_value(self, value: ir.Value) -> ir.Value | None: | ||
if value in self._value_map: | ||
return self._value_map[value] | ||
# If the value is not in the value map, it must be a graph input. | ||
assert value.producer() is not None, f"Value {value} has no entry in the value map" | ||
new_value = ir.Value( | ||
name=value.name, | ||
type=value.type, | ||
shape=value.shape, | ||
doc_string=value.doc_string, | ||
const_value=value.const_value, | ||
) | ||
self._value_map[value] = new_value | ||
return new_value | ||
|
||
def clone_optional_value(self, value: ir.Value | None) -> ir.Value | None: | ||
if value is None: | ||
return None | ||
return self.clone_value(value) | ||
|
||
def clone_attr(self, key: str, attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAttr | None: | ||
if isinstance(attr, ir.Attr): | ||
if attr.type == ir.AttributeType.GRAPH: | ||
graph = self.clone_graph(attr.value) | ||
return ir.Attr(key, ir.AttributeType.GRAPH, graph, doc_string=attr.doc_string) | ||
elif attr.type == ir.AttributeType.GRAPHS: | ||
graphs = [self.clone_graph(graph) for graph in attr.value] | ||
return ir.Attr( | ||
key, ir.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string | ||
) | ||
return attr | ||
assert isinstance(attr, ir.RefAttr) | ||
if key in self._attr_map: | ||
return self._attr_map[key] | ||
# Note that if a function has an attribute-parameter X, and a call (node) to the function | ||
# has no attribute X, all references to X in nodes inside the function body will be | ||
# removed. This is just the ONNX representation of optional-attributes. | ||
return None | ||
|
||
def clone_node(self, node: ir.Node) -> ir.Node: | ||
new_inputs = [self.clone_optional_value(input) for input in node.inputs] | ||
new_attributes = [ | ||
new_value | ||
for key, value in node.attributes.items() | ||
if (new_value := self.clone_attr(key, value)) is not None | ||
] | ||
new_name = node.name | ||
if new_name is not None: | ||
new_name = _make_unique_name( | ||
new_name, self._call_stack, self._inliner.used_node_names | ||
) | ||
|
||
new_metadata = {**self._metadata_props, **node.metadata_props} | ||
# TODO: For now, node metadata overrides callnode metadata if there is a conflict. | ||
# Do we need to preserve both? | ||
|
||
new_node = ir.Node( | ||
node.domain, | ||
node.op_type, | ||
new_inputs, | ||
new_attributes, | ||
overload=node.overload, | ||
num_outputs=len(node.outputs), | ||
graph=None, | ||
name=new_name, | ||
doc_string=node.doc_string, | ||
metadata_props=new_metadata, | ||
) | ||
new_outputs = new_node.outputs | ||
for i, output in enumerate(node.outputs): | ||
self._value_map[output] = new_outputs[i] | ||
old_name = output.name if output.name is not None else f"output_{i}" | ||
new_outputs[i].name = _make_unique_name( | ||
old_name, self._call_stack, self._inliner.used_value_names | ||
) | ||
|
||
self._inliner.node_context[new_node] = self._call_stack | ||
|
||
return new_node | ||
|
||
def clone_graph(self, graph: ir.Graph) -> ir.Graph: | ||
input_values = [self.clone_value(v) for v in graph.inputs] | ||
nodes = [self.clone_node(node) for node in graph] | ||
initializers = [self.clone_value(init) for init in graph.initializers.values()] | ||
|
||
return ir.Graph( | ||
input_values, # type: ignore | ||
graph.outputs, | ||
nodes=nodes, | ||
initializers=initializers, # type: ignore | ||
doc_string=graph.doc_string, | ||
opset_imports=graph.opset_imports, | ||
name=graph.name, | ||
metadata_props=graph.metadata_props, | ||
) | ||
|
||
|
||
def _abbreviate( | ||
function_ids: Iterable[ir.OperatorIdentifier], | ||
) -> dict[ir.OperatorIdentifier, str]: | ||
"""Create a short unambiguous abbreviation for all function ids.""" | ||
|
||
def id_abbreviation(id: ir.OperatorIdentifier) -> str: | ||
"""Create a short unambiguous abbreviation for a function id.""" | ||
domain, name, overload = id | ||
# Omit the domain, if it remains unambiguous after omitting it. | ||
if any(x[0] != domain and x[1] == name and x[2] == overload for x in function_ids): | ||
short_domain = domain + "_" | ||
else: | ||
short_domain = "" | ||
if overload != "": | ||
return short_domain + name + "_" + overload | ||
return short_domain + name | ||
|
||
return {id: id_abbreviation(id) for id in function_ids} | ||
|
||
|
||
class _Inliner: | ||
def __init__(self, model: ir.Model) -> None: | ||
self._functions = model.functions | ||
self._function_id_abbreviations = _abbreviate(self._functions.keys()) | ||
self._opset_imports = model.opset_imports | ||
self.used_value_names: set[str] = set() | ||
self.used_node_names: set[str] = set() | ||
self.node_context: dict[ir.Node, CallStack] = {} | ||
|
||
def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement: | ||
id = node.op_identifier() | ||
function = self._functions[id] | ||
|
||
# check opset compatibility and update the opset imports | ||
for key, value in function.opset_imports.items(): | ||
if key not in self._opset_imports: | ||
self._opset_imports[key] = value | ||
elif self._opset_imports[key] != value: | ||
raise ValueError( | ||
f"Opset mismatch: {key} {self._opset_imports[key]} != {value}" | ||
) | ||
|
||
# Identify substitutions for both inputs and attributes of the function: | ||
attributes: dict[str, ir.Attr | ir.RefAttr] = node.attributes | ||
default_attr_values = { | ||
attr.name: attr | ||
for attr in function.attributes.values() | ||
if attr.name not in attributes and attr.value is not None | ||
} | ||
if default_attr_values: | ||
attributes = {**attributes, **default_attr_values} | ||
if any( | ||
attr.type == ir.AttributeType.GRAPH or attr.type == ir.AttributeType.GRAPHS | ||
for attr in attributes.values() | ||
): | ||
raise ValueError( | ||
"Inliner does not support graph attribute parameters to functions" | ||
) | ||
|
||
if len(node.inputs) > len(function.inputs): | ||
raise ValueError(f"Input mismatch: {len(node.inputs)} > {len(function.inputs)}") | ||
value_map = {} | ||
for i, input in enumerate(node.inputs): | ||
value_map[function.inputs[i]] = input | ||
for i in range(len(node.inputs), len(function.inputs)): | ||
value_map[function.inputs[i]] = None | ||
|
||
# Identify call-stack for node, used to generate unique names. | ||
call_stack = self.node_context.get(node, []) | ||
call_stack.append(call_site_id) | ||
|
||
cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, call_stack) | ||
|
||
# iterate over the nodes in the function, creating a copy of each node | ||
# and replacing inputs with the corresponding values in the value map. | ||
# Update the value map with the new values. | ||
|
||
nodes = [cloner.clone_node(node) for node in function] | ||
output_values = [value_map[output] for output in function.outputs] | ||
return nodes, output_values # type: ignore | ||
|
||
def inline_calls_in(self, graph: ir.Graph) -> None: | ||
for input in graph.inputs: | ||
if input.name is not None: | ||
self.used_value_names.add(input.name) | ||
for initializer in graph.initializers: | ||
self.used_value_names.add(initializer) | ||
|
||
# Pre-processing: | ||
# * Count the number of times each function is called in the graph. | ||
# This is used for disambiguating names of values in the inlined functions. | ||
# * And identify names of values that are used in the graph. | ||
id_count: dict[ir.OperatorIdentifier, int] = defaultdict(int) | ||
for node in graph: | ||
if node.name: | ||
self.used_node_names.add(node.name) | ||
id = node.op_identifier() | ||
if id in self._functions: | ||
id_count[id] += 1 | ||
for output in node.outputs: | ||
if output.name is not None: | ||
self.used_value_names.add(output.name) | ||
next_id: dict[ir.OperatorIdentifier, int] = defaultdict(int) | ||
for node in graph: | ||
id = node.op_identifier() | ||
if id in self._functions: | ||
# If there are multiple calls to same function, we use a prefix to disambiguate | ||
# the different call-sites: | ||
if id_count[id] > 1: | ||
call_site_prefix = f"_{next_id[id]}" | ||
next_id[id] += 1 | ||
else: | ||
call_site_prefix = "" | ||
call_site = node.name or ( | ||
self._function_id_abbreviations[id] + call_site_prefix | ||
) | ||
nodes, values = self._instantiate_call(node, call_site) | ||
ir_convenience.replace_nodes_and_values( | ||
graph, | ||
insertion_point=node, | ||
old_nodes=[node], | ||
new_nodes=nodes, | ||
old_values=node.outputs, | ||
new_values=values, | ||
) | ||
else: | ||
for attr in node.attributes.values(): | ||
if not isinstance(attr, ir.Attr): | ||
continue | ||
if attr.type == ir.AttributeType.GRAPH: | ||
self.inline_calls_in(attr.value) | ||
elif attr.type == ir.AttributeType.GRAPHS: | ||
for graph in attr.value: | ||
self.inline_calls_in(graph) | ||
|
||
|
||
def inline(model: ir.Model) -> None: | ||
"""Inline all function calls (recursively) in the model.""" | ||
inliner = _Inliner(model) | ||
inliner.inline_calls_in(model.graph) | ||
model.functions.clear() |
Oops, something went wrong.