From 7e9c9e61ed4ecee03f8f3e2021e7ba2796f80751 Mon Sep 17 00:00:00 2001
From: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com>
Date: Thu, 2 May 2024 16:27:33 -0700
Subject: [PATCH] [optimizer][docs] Add Tutorial for Optimizer API (#1482)
Add Tutorial for Optimizer API
---
docs/conf.py | 1 +
docs/index.md | 3 +-
docs/optimizer/index.md | 5 ++
docs/optimizer/optimize.md | 53 ++++++++++++
docs/rewriter/rewrite_patterns.md | 20 ++---
onnxscript/optimizer/__init__.py | 7 --
onnxscript/optimizer/constant_folding.py | 5 +-
onnxscript/optimizer/copy_propagation.py | 81 -------------------
onnxscript/optimizer/copy_propagation_test.py | 49 -----------
.../optimizer/simple_function_folding.py | 2 +
10 files changed, 77 insertions(+), 149 deletions(-)
create mode 100644 docs/optimizer/index.md
create mode 100644 docs/optimizer/optimize.md
delete mode 100644 onnxscript/optimizer/copy_propagation.py
delete mode 100644 onnxscript/optimizer/copy_propagation_test.py
diff --git a/docs/conf.py b/docs/conf.py
index e981146a6..319b4044e 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -26,6 +26,7 @@
"sphinx.ext.ifconfig",
"sphinx.ext.viewcode",
"sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
"sphinx.ext.githubpages",
"sphinx_gallery.gen_gallery",
"sphinx.ext.autodoc",
diff --git a/docs/index.md b/docs/index.md
index ed1bab77b..856b57e36 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -99,8 +99,9 @@ result = MatmulAdd(x, wt, bias)
tutorial/index
api/index
intermediate_representation/index
-auto_examples/index
+optimizer/index
rewriter/index
+auto_examples/index
articles/index
```
diff --git a/docs/optimizer/index.md b/docs/optimizer/index.md
new file mode 100644
index 000000000..f5f66c38c
--- /dev/null
+++ b/docs/optimizer/index.md
@@ -0,0 +1,5 @@
+# Optimizer Tutorials
+
+```{toctree}
+optimize
+```
diff --git a/docs/optimizer/optimize.md b/docs/optimizer/optimize.md
new file mode 100644
index 000000000..5ceb7dfb8
--- /dev/null
+++ b/docs/optimizer/optimize.md
@@ -0,0 +1,53 @@
+# Optimizing a Model using the Optimizer
+
+## Introduction
+
+The ONNX Script `Optimizer` tool provides the user with the functionality to optimize an ONNX model by performing optimizations and clean-ups such as constant folding, dead code elimination, etc.
+
+## Usage
+
+In order to utilize the optimizer tool,
+
+```python
+import onnxscript
+
+onnxscript.optimizer.optimize(model)
+```
+
+### optimize API
+The `onnxscript.optimizer.optimize` call takes in several optional parameters that allows the caller to further fine-tune the process of optimization.
+
+```{eval-rst}
+.. autofunction:: onnxscript.optimizer.optimize
+ :noindex:
+```
+
+## Description of optimizations applied by `onnxscript.optimizer.optimize`
+
+:::{table}
+:widths: auto
+:align: center
+
+| Optimization 'onnxscript.optimizer.` + .. | Description |
+| - | - |
+| **Constant folding**
`constant_folding.fold_constants` | Applies constant folding optimization to the model. |
+| **Constant propagation**
`constant_folding.fold_constants` | Applies constant propagation optimization to the model. Applied as part of the constant folding optimization. |
+| **Sequence simplification**
`constant_folding.fold_constants` | Simplifies Sequence based ops (SequenceConstruct, ConcatFromSequence) present in the model. Applied as part of the constant folding optimization. |
+| **Remove unused nodes**
`remove_unused.remove_unused_nodes` | Removes unused nodes from the model. |
+| **Remove unused functions**
`remove_unused_function.remove_unused_functions` | Removes unused function protos from the model. |
+| **Inline functions with unused outputs**
`simple_function_folding.inline_functions_with_unused_outputs` | Inlines function nodes that have unused outputs. |
+| **Inline simple functions**
`simple_function_folding.inline_simple_functions` | Inlines simple functions based on a node count threshold. |
+:::
+
+## List of pattern rewrite rules applied by `onnxscript.optimizer.optimize`
+
+```{eval-rst}
+.. autosummary::
+ :nosignatures:
+
+ onnxscript.rewriter.broadcast_to_matmul
+ onnxscript.rewriter.cast_constant_of_shape
+ onnxscript.rewriter.gemm_to_matmul_add
+ onnxscript.rewriter.no_op
+
+```
diff --git a/docs/rewriter/rewrite_patterns.md b/docs/rewriter/rewrite_patterns.md
index 87a5e31af..ba58b2636 100644
--- a/docs/rewriter/rewrite_patterns.md
+++ b/docs/rewriter/rewrite_patterns.md
@@ -28,10 +28,10 @@ We will show how we can find a subgraph matching this computation and replace it
Firstly, include all the rewriter relevant imports.
```python
- from onnxscript.rewriter import pattern
- from onnxscript import ir
+from onnxscript.rewriter import pattern
+from onnxscript import ir
- _op = pattern.onnxop
+_op = pattern.onnxop
```
Then create a target pattern that needs to be replaced using onnxscript operators.
@@ -55,10 +55,10 @@ The inputs to the replacement pattern are of type `ir.Value`. For detailed usage
For this example, we do not require a `match_condition` so that option is skipped for now. Then the rewrite rule is created using the `RewriteRule` function.
```python
- rule = pattern.RewriteRule(
- erf_gelu_pattern, # Target Pattern
- gelu, # Replacement Pattern
- )
+rule = pattern.RewriteRule(
+ erf_gelu_pattern, # Target Pattern
+ gelu, # Replacement Pattern
+)
```
Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components:
@@ -117,8 +117,8 @@ Only one of the patterns has been successfully matched and replaced by a `GELU`
This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example:
```python
- from onnxscript.rewriter import pattern
- rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])
+from onnxscript.rewriter import pattern
+rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])
```
In order to apply this method to the example above, first create the two separate target patterns as follows:
@@ -171,7 +171,7 @@ First, write a target pattern and replacement pattern in a similar way to the fi
```
```{literalinclude} examples/broadcast_matmul.py
-:pyobject: matmul
+:pyobject: matmul_pattern
```
:::{note}
diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py
index f70d4d35e..03c1e748e 100644
--- a/onnxscript/optimizer/__init__.py
+++ b/onnxscript/optimizer/__init__.py
@@ -6,10 +6,6 @@
from onnxscript import rewriter
from onnxscript.optimizer.constant_folding import fold_constants
-from onnxscript.optimizer.copy_propagation import (
- do_copy_propagation,
- do_sequence_simplification,
-)
from onnxscript.optimizer.remove_unused import remove_unused_nodes
from onnxscript.optimizer.remove_unused_function import remove_unused_functions
from onnxscript.optimizer.simple_function_folding import (
@@ -108,7 +104,6 @@ def optimize(
node.name,
)
- # do_sequence_simplification(model)
return model
@@ -116,6 +111,4 @@ def optimize(
"fold_constants",
"remove_unused_nodes",
"optimize",
- "do_copy_propagation",
- "do_sequence_simplification",
]
diff --git a/onnxscript/optimizer/constant_folding.py b/onnxscript/optimizer/constant_folding.py
index 9a51298c7..283a13fd1 100644
--- a/onnxscript/optimizer/constant_folding.py
+++ b/onnxscript/optimizer/constant_folding.py
@@ -263,7 +263,10 @@ def fold_constants(
*,
onnx_shape_inference: bool = False,
) -> bool:
- """Returns true iff the model was modified."""
+ """
+ Applies constant folding optimization to the model.
+ Returns true iff the model was modified.
+ """
folder = ConstantFolder(
evaluator.registry,
external_data_folder,
diff --git a/onnxscript/optimizer/copy_propagation.py b/onnxscript/optimizer/copy_propagation.py
deleted file mode 100644
index 6a7d4143d..000000000
--- a/onnxscript/optimizer/copy_propagation.py
+++ /dev/null
@@ -1,81 +0,0 @@
-from __future__ import annotations
-
-from typing import Any
-
-import onnx
-
-import onnxscript.optimizer.remove_unused
-from onnxscript._legacy_ir import visitor
-from onnxscript.utils.utils import is_onnx_op
-
-
-class CopyPropagator(visitor.ProtoVisitor):
- def __init__(self):
- super().__init__()
-
- def visit_node(self, node: onnx.NodeProto) -> None:
- super().visit_node(node)
- for i in range(len(node.input)):
- input = self.get_input(node, i)
- if input is not None and input.is_copy():
- node.input[i] = input.symbolic_value # type: ignore[assignment]
-
- if is_onnx_op(node, "Identity"):
- input = self.get_input(node, 0)
- output = self.get_output(node, 0)
- if input is not None and output is not None:
- output.symbolic_value = input.name
-
-
-# TODO: "Z = Identity(x)" where Z is a graph-output cannot be handled by this optimization,
-# and requires some extension. (Eg., we could rename graph-output to be Z or we can try to
-# rename x to be Z.)
-
-
-def get_node_attr_value(node: onnx.NodeProto, attr_name: str, default: Any) -> Any:
- matching = [x for x in node.attribute if x.name == attr_name]
- if len(matching) > 1:
- raise ValueError(f"Node has multiple attributes with name {attr_name}")
- if len(matching) < 1:
- return default
- return onnx.helper.get_attribute_value(matching[0])
-
-
-class SymbolicEvaluator(CopyPropagator):
- def __init__(self):
- super().__init__()
-
- def visit_node(self, node: onnx.NodeProto) -> None:
- super().visit_node(node)
-
- if is_onnx_op(node, "SequenceConstruct"):
- output = self.get_output(node, 0)
- if output is not None:
- output.symbolic_value = list(node.input)
-
- if is_onnx_op(node, "ConcatFromSequence"):
- input = self.get_input(node, 0)
- new_axis = get_node_attr_value(node, "new_axis", 0)
- if input is not None and isinstance(input.symbolic_value, list) and new_axis == 0:
- node.op_type = "Concat"
- node.input[:] = input.symbolic_value
- for i in range(len(node.attribute)):
- if node.attribute[i].name == "new_axis":
- del node.attribute[i]
- break
-
- # TODO: handle SequenceEmpty, SequenceAt, etc.
-
-
-def do_copy_propagation(model: onnx.ModelProto, *, remove_unused: bool = True) -> None:
- transformer = CopyPropagator()
- transformer.visit_model(model)
- if remove_unused:
- onnxscript.optimizer.remove_unused_nodes(model)
-
-
-def do_sequence_simplification(model: onnx.ModelProto, *, remove_unused: bool = True) -> None:
- transformer = SymbolicEvaluator()
- transformer.visit_model(model)
- if remove_unused:
- onnxscript.optimizer.remove_unused_nodes(model)
diff --git a/onnxscript/optimizer/copy_propagation_test.py b/onnxscript/optimizer/copy_propagation_test.py
deleted file mode 100644
index 6b88b027a..000000000
--- a/onnxscript/optimizer/copy_propagation_test.py
+++ /dev/null
@@ -1,49 +0,0 @@
-import unittest
-
-import onnx
-
-from onnxscript import optimizer
-
-
-class RemoveUnusedTest(unittest.TestCase):
- def test_simple_identity_removal(self):
- model = onnx.parser.parse_model(
- """
-
- agraph (float[N] x) => (float[N] z) {
- t = Identity(x)
- t2 = Identity(t)
- z = Identity(t2)
- }
- """
- )
- optimizer.do_copy_propagation(model)
- self.assertEqual(len(model.graph.node), 1)
-
- def test_subgraph_identity_removal(self):
- model = onnx.parser.parse_model(
- """
-
- agraph (float[N] x, bool cond) => (float[N] z) {
- t = Identity(x)
- t2 = Identity(t)
- t3 = If (cond) <
- then_branch = then_graph() => (t4) {
- t5 = Identity(t2)
- t4 = Identity(t5)
- },
- else_branch = else__graph() => (t6) {
- t7 = Identity(t)
- t6 = Identity(t7)
- }
- >
- z = Identity(t3)
- }
- """
- )
- optimizer.do_copy_propagation(model)
- self.assertEqual(len(model.graph.node), 2)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/onnxscript/optimizer/simple_function_folding.py b/onnxscript/optimizer/simple_function_folding.py
index b15a9c8a0..8b6f6662b 100644
--- a/onnxscript/optimizer/simple_function_folding.py
+++ b/onnxscript/optimizer/simple_function_folding.py
@@ -200,6 +200,7 @@ def function_with_unused_outputs(self) -> dict[ir.FunctionId, onnx.FunctionProto
def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool:
+ """Inlines simple functions based on a node count threshold"""
inliner = FunctionInliner(node_count)
inliner.visit_model(model)
logger.info(
@@ -218,6 +219,7 @@ def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool
def inline_functions_with_unused_outputs(model: onnx.ModelProto) -> bool:
+ """Inlines function nodes that have unused outputs."""
# TODO: Use onnx.inliner after 1.16.
# This visitor based inliner is used to ensure the function inner value info remains consistent.
visitor = FindFunctionWithUnusedOutputsVisitor()