diff --git a/docs/optimizer/optimize.md b/docs/optimizer/optimize.md index 68b342abd6..ceb846b4a1 100644 --- a/docs/optimizer/optimize.md +++ b/docs/optimizer/optimize.md @@ -38,7 +38,7 @@ The `onnxscript.optimizer.optimize` call takes in several optional parameters th ``` -## Description of pattern rewrite rules applied by `onnxscript.optimizer.optimize` +## List of pattern rewrite rules applied by `onnxscript.optimizer.optimize` ```{eval-rst} .. autosummary:: diff --git a/onnxscript/optimizer/constant_folding.py b/onnxscript/optimizer/constant_folding.py index 9a51298c7f..283a13fd13 100644 --- a/onnxscript/optimizer/constant_folding.py +++ b/onnxscript/optimizer/constant_folding.py @@ -263,7 +263,10 @@ def fold_constants( *, onnx_shape_inference: bool = False, ) -> bool: - """Returns true iff the model was modified.""" + """ + Applies constant folding optimization to the model. + Returns true iff the model was modified. + """ folder = ConstantFolder( evaluator.registry, external_data_folder, diff --git a/onnxscript/optimizer/copy_propagation.py b/onnxscript/optimizer/copy_propagation.py index 6a7d4143d4..f7d006b045 100644 --- a/onnxscript/optimizer/copy_propagation.py +++ b/onnxscript/optimizer/copy_propagation.py @@ -68,6 +68,7 @@ def visit_node(self, node: onnx.NodeProto) -> None: def do_copy_propagation(model: onnx.ModelProto, *, remove_unused: bool = True) -> None: + """Applies copy propagation optimization to the model.""" transformer = CopyPropagator() transformer.visit_model(model) if remove_unused: @@ -75,6 +76,7 @@ def do_copy_propagation(model: onnx.ModelProto, *, remove_unused: bool = True) - def do_sequence_simplification(model: onnx.ModelProto, *, remove_unused: bool = True) -> None: + """Simplifies Sequence based ops (SequenceConstruct, ConcatFromSequence) present in the model.""" transformer = SymbolicEvaluator() transformer.visit_model(model) if remove_unused: diff --git a/onnxscript/optimizer/simple_function_folding.py b/onnxscript/optimizer/simple_function_folding.py index b15a9c8a0d..8b6f6662b0 100644 --- a/onnxscript/optimizer/simple_function_folding.py +++ b/onnxscript/optimizer/simple_function_folding.py @@ -200,6 +200,7 @@ def function_with_unused_outputs(self) -> dict[ir.FunctionId, onnx.FunctionProto def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool: + """Inlines simple functions based on a node count threshold""" inliner = FunctionInliner(node_count) inliner.visit_model(model) logger.info( @@ -218,6 +219,7 @@ def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool def inline_functions_with_unused_outputs(model: onnx.ModelProto) -> bool: + """Inlines function nodes that have unused outputs.""" # TODO: Use onnx.inliner after 1.16. # This visitor based inliner is used to ensure the function inner value info remains consistent. visitor = FindFunctionWithUnusedOutputsVisitor()