Skip to content

Commit

Permalink
[optimizer][docs] Add Tutorial for Optimizer API (#1482)
Browse files Browse the repository at this point in the history
Add Tutorial for Optimizer API
  • Loading branch information
shubhambhokare1 authored May 2, 2024
1 parent a722e60 commit 7e9c9e6
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 149 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"sphinx.ext.ifconfig",
"sphinx.ext.viewcode",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.githubpages",
"sphinx_gallery.gen_gallery",
"sphinx.ext.autodoc",
Expand Down
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ result = MatmulAdd(x, wt, bias)
tutorial/index
api/index
intermediate_representation/index
auto_examples/index
optimizer/index
rewriter/index
auto_examples/index
articles/index
```

Expand Down
5 changes: 5 additions & 0 deletions docs/optimizer/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Optimizer Tutorials

```{toctree}
optimize
```
53 changes: 53 additions & 0 deletions docs/optimizer/optimize.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Optimizing a Model using the Optimizer

## Introduction

The ONNX Script `Optimizer` tool provides the user with the functionality to optimize an ONNX model by performing optimizations and clean-ups such as constant folding, dead code elimination, etc.

## Usage

In order to utilize the optimizer tool,

```python
import onnxscript

onnxscript.optimizer.optimize(model)
```

### optimize API
The `onnxscript.optimizer.optimize` call takes in several optional parameters that allows the caller to further fine-tune the process of optimization.

```{eval-rst}
.. autofunction:: onnxscript.optimizer.optimize
:noindex:
```

## Description of optimizations applied by `onnxscript.optimizer.optimize`

:::{table}
:widths: auto
:align: center

| Optimization 'onnxscript.optimizer.` + .. | Description |
| - | - |
| **Constant folding** <br>`constant_folding.fold_constants` | Applies constant folding optimization to the model. |
| **Constant propagation** <br>`constant_folding.fold_constants` | Applies constant propagation optimization to the model. Applied as part of the constant folding optimization. |
| **Sequence simplification** <br>`constant_folding.fold_constants` | Simplifies Sequence based ops (SequenceConstruct, ConcatFromSequence) present in the model. Applied as part of the constant folding optimization. |
| **Remove unused nodes** <br>`remove_unused.remove_unused_nodes` | Removes unused nodes from the model. |
| **Remove unused functions** <br>`remove_unused_function.remove_unused_functions` | Removes unused function protos from the model. |
| **Inline functions with unused outputs** <br>`simple_function_folding.inline_functions_with_unused_outputs` | Inlines function nodes that have unused outputs. |
| **Inline simple functions** <br>`simple_function_folding.inline_simple_functions` | Inlines simple functions based on a node count threshold. |
:::

## List of pattern rewrite rules applied by `onnxscript.optimizer.optimize`

```{eval-rst}
.. autosummary::
:nosignatures:
onnxscript.rewriter.broadcast_to_matmul
onnxscript.rewriter.cast_constant_of_shape
onnxscript.rewriter.gemm_to_matmul_add
onnxscript.rewriter.no_op
```
20 changes: 10 additions & 10 deletions docs/rewriter/rewrite_patterns.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ We will show how we can find a subgraph matching this computation and replace it
Firstly, include all the rewriter relevant imports.

```python
from onnxscript.rewriter import pattern
from onnxscript import ir
from onnxscript.rewriter import pattern
from onnxscript import ir

_op = pattern.onnxop
_op = pattern.onnxop
```

Then create a target pattern that needs to be replaced using onnxscript operators.
Expand All @@ -55,10 +55,10 @@ The inputs to the replacement pattern are of type `ir.Value`. For detailed usage
For this example, we do not require a `match_condition` so that option is skipped for now. Then the rewrite rule is created using the `RewriteRule` function.

```python
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
```

Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components:
Expand Down Expand Up @@ -117,8 +117,8 @@ Only one of the patterns has been successfully matched and replaced by a `GELU`
This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example:

```python
from onnxscript.rewriter import pattern
rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])
from onnxscript.rewriter import pattern
rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])
```

In order to apply this method to the example above, first create the two separate target patterns as follows:
Expand Down Expand Up @@ -171,7 +171,7 @@ First, write a target pattern and replacement pattern in a similar way to the fi
```

```{literalinclude} examples/broadcast_matmul.py
:pyobject: matmul
:pyobject: matmul_pattern
```

:::{note}
Expand Down
7 changes: 0 additions & 7 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@

from onnxscript import rewriter
from onnxscript.optimizer.constant_folding import fold_constants
from onnxscript.optimizer.copy_propagation import (
do_copy_propagation,
do_sequence_simplification,
)
from onnxscript.optimizer.remove_unused import remove_unused_nodes
from onnxscript.optimizer.remove_unused_function import remove_unused_functions
from onnxscript.optimizer.simple_function_folding import (
Expand Down Expand Up @@ -108,14 +104,11 @@ def optimize(
node.name,
)

# do_sequence_simplification(model)
return model


__all__ = [
"fold_constants",
"remove_unused_nodes",
"optimize",
"do_copy_propagation",
"do_sequence_simplification",
]
5 changes: 4 additions & 1 deletion onnxscript/optimizer/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,10 @@ def fold_constants(
*,
onnx_shape_inference: bool = False,
) -> bool:
"""Returns true iff the model was modified."""
"""
Applies constant folding optimization to the model.
Returns true iff the model was modified.
"""
folder = ConstantFolder(
evaluator.registry,
external_data_folder,
Expand Down
81 changes: 0 additions & 81 deletions onnxscript/optimizer/copy_propagation.py

This file was deleted.

49 changes: 0 additions & 49 deletions onnxscript/optimizer/copy_propagation_test.py

This file was deleted.

2 changes: 2 additions & 0 deletions onnxscript/optimizer/simple_function_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def function_with_unused_outputs(self) -> dict[ir.FunctionId, onnx.FunctionProto


def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool:
"""Inlines simple functions based on a node count threshold"""
inliner = FunctionInliner(node_count)
inliner.visit_model(model)
logger.info(
Expand All @@ -218,6 +219,7 @@ def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool


def inline_functions_with_unused_outputs(model: onnx.ModelProto) -> bool:
"""Inlines function nodes that have unused outputs."""
# TODO: Use onnx.inliner after 1.16.
# This visitor based inliner is used to ensure the function inner value info remains consistent.
visitor = FindFunctionWithUnusedOutputsVisitor()
Expand Down

0 comments on commit 7e9c9e6

Please sign in to comment.