-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add utility to remove unused values/nodes in IR (#1617)
Mostly a re-implementation of the existing proto-based optimization to remove unused-values/nodes to use the IR.
- Loading branch information
1 parent
4a9b04e
commit dc31a6e
Showing
5 changed files
with
297 additions
and
156 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import logging | ||
|
||
import onnx | ||
|
||
from onnxscript import ir | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def remove_unused_optional_outputs( | ||
node: ir.Node, graph_outputs: frozenset[ir.Value], onnx_opset_version: int | ||
) -> None: | ||
try: | ||
if node.domain not in {"", "onnx.ai"}: | ||
return | ||
op_schema = onnx.defs.get_schema(node.op_type, onnx_opset_version, domain=node.domain) | ||
except Exception: | ||
return | ||
|
||
if node.op_type == "BatchNormalization": | ||
# BatchNormalization op has 3 outputs: Y, running_mean, running_var | ||
# If running_mean and running_var are not used, remove them, and the training_mode attribute | ||
def is_used_output(i: int) -> bool: | ||
if i < len(node.outputs): | ||
val = node.outputs[i] | ||
return val in graph_outputs or bool(val.uses()) | ||
return False | ||
|
||
if is_used_output(1) or is_used_output(2): | ||
return | ||
node.outputs[1].name = "" | ||
node.outputs[2].name = "" | ||
node.attributes.pop("training_mode", None) | ||
return | ||
|
||
optional_info = [] | ||
for o in op_schema.outputs: | ||
# Current ops do not have optional outputs if they have variable number of outputs | ||
if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic: | ||
return | ||
optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional) | ||
# If no optional outputs in spec, skip delete operations | ||
if len([o == 1 for o in optional_info]) == 0: | ||
return | ||
|
||
for i, out in enumerate(node.outputs): | ||
if out not in graph_outputs and (not out.uses()) and optional_info[i] is True: | ||
out.name = "" | ||
|
||
|
||
def process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int: | ||
graph_outputs = frozenset(function_or_graph.outputs) | ||
onnx_opset_version = function_or_graph.opset_imports.get("", None) | ||
count = 0 | ||
for node in reversed(function_or_graph): | ||
removable = True | ||
for output in node.outputs: | ||
if output in graph_outputs or output.uses(): | ||
removable = False | ||
break | ||
if removable: | ||
function_or_graph.remove(node, safe=True) | ||
count += 1 | ||
else: | ||
if onnx_opset_version is not None: | ||
remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version) | ||
for attr in node.attributes.values(): | ||
if isinstance(attr, ir.AttrGraph): | ||
count += process_function_or_graph(attr.value) | ||
elif isinstance(attr, ir.AttrGraphs): | ||
for graph in attr.value: | ||
count += process_function_or_graph(graph) | ||
return count | ||
|
||
|
||
def remove_unused_nodes(model: ir.Model) -> None: | ||
"""Removes unused nodes from the model.""" | ||
count = process_function_or_graph(model.graph) | ||
graph_outputs = frozenset(model.graph.outputs) | ||
initializers = model.graph.initializers | ||
for init in list(initializers.values()): | ||
if not (init in graph_outputs or init.uses()): | ||
del initializers[init.name] # type: ignore[arg-type] | ||
count += 1 | ||
|
||
for function in model.functions.values(): | ||
count += process_function_or_graph(function) | ||
|
||
logger.info("Removed %s unused nodes", count) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import Sequence | ||
|
||
import onnx | ||
from google.protobuf.internal.containers import ( # type: ignore | ||
RepeatedCompositeFieldContainer, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def remove_unused_optional_outputs( | ||
n: onnx.NodeProto, used: set, opset_import: Sequence[onnx.OperatorSetIdProto] | ||
) -> None: | ||
try: | ||
if n.domain not in {"", "onnx.ai"}: | ||
return | ||
onnx_opset_version = 1 | ||
for opset in opset_import: | ||
if opset.domain == n.domain: | ||
onnx_opset_version = opset.version | ||
op_schema = onnx.defs.get_schema(n.op_type, onnx_opset_version, domain=n.domain) | ||
except Exception: | ||
return | ||
|
||
if n.op_type == "BatchNormalization": | ||
# BatchNormalization op has 3 outputs: Y, running_mean, running_var | ||
# If running_mean and running_var are not used, remove them, and the training_mode attribute | ||
def is_used_output(i: int) -> bool: | ||
if i < len(n.output): | ||
return n.output[i] in used | ||
return False | ||
|
||
if is_used_output(1) or is_used_output(2): | ||
return | ||
del n.output[1:] | ||
for j, attr in enumerate(n.attribute): | ||
if attr.name == "training_mode": | ||
del n.attribute[j] | ||
break | ||
|
||
optional_info = [] | ||
for o in op_schema.outputs: | ||
# Current ops do not have optional outputs if they have variable number of outputs | ||
if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic: | ||
return | ||
optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional) | ||
# If no optional outputs in spec, skip delete operations | ||
if len([o == 1 for o in optional_info]) == 0: | ||
return | ||
|
||
for i, out in enumerate(n.output): | ||
if out not in used and optional_info[i] is True: | ||
n.output[i] = "" | ||
# Only delete trailing unused optional outputs | ||
for o in n.output[::-1]: # type: ignore[assignment] | ||
if o == "": | ||
n.output.pop() | ||
else: | ||
return | ||
|
||
|
||
def compute_used_in_node(n: onnx.NodeProto) -> set[str]: | ||
used = {n for n in n.input if n != ""} | ||
for attr in n.attribute: | ||
if attr.HasField("g"): | ||
used |= compute_used_in_graph(attr.g) | ||
elif len(attr.graphs) > 0: | ||
for graph in attr.graphs: | ||
used |= compute_used_in_graph(graph) | ||
return used | ||
|
||
|
||
def compute_used_in_graph(g: onnx.GraphProto) -> set[str]: | ||
used = set() | ||
for n in g.node: | ||
used |= compute_used_in_node(n) | ||
return used | ||
|
||
|
||
def process_nodes( | ||
nodes: RepeatedCompositeFieldContainer[onnx.NodeProto], | ||
used: set, | ||
opset_import: Sequence[onnx.OperatorSetIdProto], | ||
) -> int: | ||
count = 0 | ||
i = len(nodes) - 1 | ||
while i >= 0: | ||
node = nodes[i] | ||
remove_unused_optional_outputs(node, used, opset_import) | ||
used_outputs = [x for x in node.output if x in used] | ||
if not used_outputs: | ||
del nodes[i] | ||
count += 1 | ||
i -= 1 | ||
continue | ||
for attr in node.attribute: | ||
if attr.HasField("g"): | ||
process_graph(attr.g, opset_import) | ||
elif len(attr.graphs) > 0: | ||
for graph in attr.graphs: | ||
process_graph(graph, opset_import) | ||
used |= compute_used_in_node(node) | ||
i -= 1 | ||
return count | ||
|
||
|
||
def process_graph( | ||
graph: onnx.GraphProto, opset_import: Sequence[onnx.OperatorSetIdProto] | ||
) -> int: | ||
used = {output.name for output in graph.output} | ||
|
||
count = process_nodes(graph.node, used, opset_import) | ||
|
||
for i in range(len(graph.initializer) - 1, -1, -1): | ||
if graph.initializer[i].name not in used: | ||
del graph.initializer[i] | ||
count += 1 | ||
|
||
return count | ||
|
||
|
||
def process_function( | ||
function: onnx.FunctionProto, opset_import: Sequence[onnx.OperatorSetIdProto] | ||
) -> int: | ||
used = set(function.output) | ||
|
||
return process_nodes(function.node, used, opset_import) | ||
|
||
|
||
def remove_unused_nodes(model: onnx.ModelProto) -> None: | ||
"""Removes unused nodes from the model.""" | ||
count = process_graph(model.graph, model.opset_import) | ||
for function in model.functions: | ||
count += process_function(function, model.opset_import) | ||
|
||
logger.info("Removed %s unused nodes", count) |
Oops, something went wrong.