From 7e9c9e61ed4ecee03f8f3e2021e7ba2796f80751 Mon Sep 17 00:00:00 2001 From: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com> Date: Thu, 2 May 2024 16:27:33 -0700 Subject: [PATCH] [optimizer][docs] Add Tutorial for Optimizer API (#1482) Add Tutorial for Optimizer API --- docs/conf.py | 1 + docs/index.md | 3 +- docs/optimizer/index.md | 5 ++ docs/optimizer/optimize.md | 53 ++++++++++++ docs/rewriter/rewrite_patterns.md | 20 ++--- onnxscript/optimizer/__init__.py | 7 -- onnxscript/optimizer/constant_folding.py | 5 +- onnxscript/optimizer/copy_propagation.py | 81 ------------------- onnxscript/optimizer/copy_propagation_test.py | 49 ----------- .../optimizer/simple_function_folding.py | 2 + 10 files changed, 77 insertions(+), 149 deletions(-) create mode 100644 docs/optimizer/index.md create mode 100644 docs/optimizer/optimize.md delete mode 100644 onnxscript/optimizer/copy_propagation.py delete mode 100644 onnxscript/optimizer/copy_propagation_test.py diff --git a/docs/conf.py b/docs/conf.py index e981146a6..319b4044e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -26,6 +26,7 @@ "sphinx.ext.ifconfig", "sphinx.ext.viewcode", "sphinx.ext.autodoc", + "sphinx.ext.autosummary", "sphinx.ext.githubpages", "sphinx_gallery.gen_gallery", "sphinx.ext.autodoc", diff --git a/docs/index.md b/docs/index.md index ed1bab77b..856b57e36 100644 --- a/docs/index.md +++ b/docs/index.md @@ -99,8 +99,9 @@ result = MatmulAdd(x, wt, bias) tutorial/index api/index intermediate_representation/index -auto_examples/index +optimizer/index rewriter/index +auto_examples/index articles/index ``` diff --git a/docs/optimizer/index.md b/docs/optimizer/index.md new file mode 100644 index 000000000..f5f66c38c --- /dev/null +++ b/docs/optimizer/index.md @@ -0,0 +1,5 @@ +# Optimizer Tutorials + +```{toctree} +optimize +``` diff --git a/docs/optimizer/optimize.md b/docs/optimizer/optimize.md new file mode 100644 index 000000000..5ceb7dfb8 --- /dev/null +++ b/docs/optimizer/optimize.md @@ -0,0 +1,53 @@ +# Optimizing a Model using the Optimizer + +## Introduction + +The ONNX Script `Optimizer` tool provides the user with the functionality to optimize an ONNX model by performing optimizations and clean-ups such as constant folding, dead code elimination, etc. + +## Usage + +In order to utilize the optimizer tool, + +```python +import onnxscript + +onnxscript.optimizer.optimize(model) +``` + +### optimize API +The `onnxscript.optimizer.optimize` call takes in several optional parameters that allows the caller to further fine-tune the process of optimization. + +```{eval-rst} +.. autofunction:: onnxscript.optimizer.optimize + :noindex: +``` + +## Description of optimizations applied by `onnxscript.optimizer.optimize` + +:::{table} +:widths: auto +:align: center + +| Optimization 'onnxscript.optimizer.` + .. | Description | +| - | - | +| **Constant folding**
`constant_folding.fold_constants` | Applies constant folding optimization to the model. | +| **Constant propagation**
`constant_folding.fold_constants` | Applies constant propagation optimization to the model. Applied as part of the constant folding optimization. | +| **Sequence simplification**
`constant_folding.fold_constants` | Simplifies Sequence based ops (SequenceConstruct, ConcatFromSequence) present in the model. Applied as part of the constant folding optimization. | +| **Remove unused nodes**
`remove_unused.remove_unused_nodes` | Removes unused nodes from the model. | +| **Remove unused functions**
`remove_unused_function.remove_unused_functions` | Removes unused function protos from the model. | +| **Inline functions with unused outputs**
`simple_function_folding.inline_functions_with_unused_outputs` | Inlines function nodes that have unused outputs. | +| **Inline simple functions**
`simple_function_folding.inline_simple_functions` | Inlines simple functions based on a node count threshold. | +::: + +## List of pattern rewrite rules applied by `onnxscript.optimizer.optimize` + +```{eval-rst} +.. autosummary:: + :nosignatures: + + onnxscript.rewriter.broadcast_to_matmul + onnxscript.rewriter.cast_constant_of_shape + onnxscript.rewriter.gemm_to_matmul_add + onnxscript.rewriter.no_op + +``` diff --git a/docs/rewriter/rewrite_patterns.md b/docs/rewriter/rewrite_patterns.md index 87a5e31af..ba58b2636 100644 --- a/docs/rewriter/rewrite_patterns.md +++ b/docs/rewriter/rewrite_patterns.md @@ -28,10 +28,10 @@ We will show how we can find a subgraph matching this computation and replace it Firstly, include all the rewriter relevant imports. ```python - from onnxscript.rewriter import pattern - from onnxscript import ir +from onnxscript.rewriter import pattern +from onnxscript import ir - _op = pattern.onnxop +_op = pattern.onnxop ``` Then create a target pattern that needs to be replaced using onnxscript operators. @@ -55,10 +55,10 @@ The inputs to the replacement pattern are of type `ir.Value`. For detailed usage For this example, we do not require a `match_condition` so that option is skipped for now. Then the rewrite rule is created using the `RewriteRule` function. ```python - rule = pattern.RewriteRule( - erf_gelu_pattern, # Target Pattern - gelu, # Replacement Pattern - ) +rule = pattern.RewriteRule( + erf_gelu_pattern, # Target Pattern + gelu, # Replacement Pattern +) ``` Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components: @@ -117,8 +117,8 @@ Only one of the patterns has been successfully matched and replaced by a `GELU` This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example: ```python - from onnxscript.rewriter import pattern - rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2]) +from onnxscript.rewriter import pattern +rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2]) ``` In order to apply this method to the example above, first create the two separate target patterns as follows: @@ -171,7 +171,7 @@ First, write a target pattern and replacement pattern in a similar way to the fi ``` ```{literalinclude} examples/broadcast_matmul.py -:pyobject: matmul +:pyobject: matmul_pattern ``` :::{note} diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index f70d4d35e..03c1e748e 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -6,10 +6,6 @@ from onnxscript import rewriter from onnxscript.optimizer.constant_folding import fold_constants -from onnxscript.optimizer.copy_propagation import ( - do_copy_propagation, - do_sequence_simplification, -) from onnxscript.optimizer.remove_unused import remove_unused_nodes from onnxscript.optimizer.remove_unused_function import remove_unused_functions from onnxscript.optimizer.simple_function_folding import ( @@ -108,7 +104,6 @@ def optimize( node.name, ) - # do_sequence_simplification(model) return model @@ -116,6 +111,4 @@ def optimize( "fold_constants", "remove_unused_nodes", "optimize", - "do_copy_propagation", - "do_sequence_simplification", ] diff --git a/onnxscript/optimizer/constant_folding.py b/onnxscript/optimizer/constant_folding.py index 9a51298c7..283a13fd1 100644 --- a/onnxscript/optimizer/constant_folding.py +++ b/onnxscript/optimizer/constant_folding.py @@ -263,7 +263,10 @@ def fold_constants( *, onnx_shape_inference: bool = False, ) -> bool: - """Returns true iff the model was modified.""" + """ + Applies constant folding optimization to the model. + Returns true iff the model was modified. + """ folder = ConstantFolder( evaluator.registry, external_data_folder, diff --git a/onnxscript/optimizer/copy_propagation.py b/onnxscript/optimizer/copy_propagation.py deleted file mode 100644 index 6a7d4143d..000000000 --- a/onnxscript/optimizer/copy_propagation.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import onnx - -import onnxscript.optimizer.remove_unused -from onnxscript._legacy_ir import visitor -from onnxscript.utils.utils import is_onnx_op - - -class CopyPropagator(visitor.ProtoVisitor): - def __init__(self): - super().__init__() - - def visit_node(self, node: onnx.NodeProto) -> None: - super().visit_node(node) - for i in range(len(node.input)): - input = self.get_input(node, i) - if input is not None and input.is_copy(): - node.input[i] = input.symbolic_value # type: ignore[assignment] - - if is_onnx_op(node, "Identity"): - input = self.get_input(node, 0) - output = self.get_output(node, 0) - if input is not None and output is not None: - output.symbolic_value = input.name - - -# TODO: "Z = Identity(x)" where Z is a graph-output cannot be handled by this optimization, -# and requires some extension. (Eg., we could rename graph-output to be Z or we can try to -# rename x to be Z.) - - -def get_node_attr_value(node: onnx.NodeProto, attr_name: str, default: Any) -> Any: - matching = [x for x in node.attribute if x.name == attr_name] - if len(matching) > 1: - raise ValueError(f"Node has multiple attributes with name {attr_name}") - if len(matching) < 1: - return default - return onnx.helper.get_attribute_value(matching[0]) - - -class SymbolicEvaluator(CopyPropagator): - def __init__(self): - super().__init__() - - def visit_node(self, node: onnx.NodeProto) -> None: - super().visit_node(node) - - if is_onnx_op(node, "SequenceConstruct"): - output = self.get_output(node, 0) - if output is not None: - output.symbolic_value = list(node.input) - - if is_onnx_op(node, "ConcatFromSequence"): - input = self.get_input(node, 0) - new_axis = get_node_attr_value(node, "new_axis", 0) - if input is not None and isinstance(input.symbolic_value, list) and new_axis == 0: - node.op_type = "Concat" - node.input[:] = input.symbolic_value - for i in range(len(node.attribute)): - if node.attribute[i].name == "new_axis": - del node.attribute[i] - break - - # TODO: handle SequenceEmpty, SequenceAt, etc. - - -def do_copy_propagation(model: onnx.ModelProto, *, remove_unused: bool = True) -> None: - transformer = CopyPropagator() - transformer.visit_model(model) - if remove_unused: - onnxscript.optimizer.remove_unused_nodes(model) - - -def do_sequence_simplification(model: onnx.ModelProto, *, remove_unused: bool = True) -> None: - transformer = SymbolicEvaluator() - transformer.visit_model(model) - if remove_unused: - onnxscript.optimizer.remove_unused_nodes(model) diff --git a/onnxscript/optimizer/copy_propagation_test.py b/onnxscript/optimizer/copy_propagation_test.py deleted file mode 100644 index 6b88b027a..000000000 --- a/onnxscript/optimizer/copy_propagation_test.py +++ /dev/null @@ -1,49 +0,0 @@ -import unittest - -import onnx - -from onnxscript import optimizer - - -class RemoveUnusedTest(unittest.TestCase): - def test_simple_identity_removal(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) { - t = Identity(x) - t2 = Identity(t) - z = Identity(t2) - } - """ - ) - optimizer.do_copy_propagation(model) - self.assertEqual(len(model.graph.node), 1) - - def test_subgraph_identity_removal(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x, bool cond) => (float[N] z) { - t = Identity(x) - t2 = Identity(t) - t3 = If (cond) < - then_branch = then_graph() => (t4) { - t5 = Identity(t2) - t4 = Identity(t5) - }, - else_branch = else__graph() => (t6) { - t7 = Identity(t) - t6 = Identity(t7) - } - > - z = Identity(t3) - } - """ - ) - optimizer.do_copy_propagation(model) - self.assertEqual(len(model.graph.node), 2) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/optimizer/simple_function_folding.py b/onnxscript/optimizer/simple_function_folding.py index b15a9c8a0..8b6f6662b 100644 --- a/onnxscript/optimizer/simple_function_folding.py +++ b/onnxscript/optimizer/simple_function_folding.py @@ -200,6 +200,7 @@ def function_with_unused_outputs(self) -> dict[ir.FunctionId, onnx.FunctionProto def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool: + """Inlines simple functions based on a node count threshold""" inliner = FunctionInliner(node_count) inliner.visit_model(model) logger.info( @@ -218,6 +219,7 @@ def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool def inline_functions_with_unused_outputs(model: onnx.ModelProto) -> bool: + """Inlines function nodes that have unused outputs.""" # TODO: Use onnx.inliner after 1.16. # This visitor based inliner is used to ensure the function inner value info remains consistent. visitor = FindFunctionWithUnusedOutputsVisitor()