From d62046637ff2db1064dfd4df30465b1eb9e238ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 17 Jun 2024 22:02:08 +0200 Subject: [PATCH] Refactors onnxscript.rewriter.onnxruntime.rewrite to call onnx.rewriter.rewrite (#1628) Signed-off-by: Xavier Dupre --- docs/api/tools.md | 4 ++-- onnxscript/_legacy_ir/visitor.py | 2 ++ onnxscript/optimizer/remove_unused_ir.py | 6 ++++-- onnxscript/rewriter/__init__.py | 4 ++-- onnxscript/rewriter/onnxruntime/__init__.py | 21 ++++----------------- 5 files changed, 14 insertions(+), 23 deletions(-) diff --git a/docs/api/tools.md b/docs/api/tools.md index d67074664..459e6ac54 100644 --- a/docs/api/tools.md +++ b/docs/api/tools.md @@ -7,9 +7,9 @@ ``` ```{eval-rst} -.. autofunction:: onnxscript.tools.transformers_models.phi.get_phi_model_config +.. autofunction:: onnxscript.tools.transformers_models.phi.get_phi_model_from_config ``` ```{eval-rst} -.. autofunction:: onnxscript.tools.transformers_models.llama.get_llama_model_config +.. autofunction:: onnxscript.tools.transformers_models.llama.get_llama_model_from_config ``` diff --git a/onnxscript/_legacy_ir/visitor.py b/onnxscript/_legacy_ir/visitor.py index 2a7257451..8dcc3893a 100644 --- a/onnxscript/_legacy_ir/visitor.py +++ b/onnxscript/_legacy_ir/visitor.py @@ -590,6 +590,8 @@ def get_constant_value(i: int) -> onnx.TensorProto | None: ) for output in node.output: + if output == "": + continue info = self.lookup_or_create(output) if output in output_types: if info.type is not None: diff --git a/onnxscript/optimizer/remove_unused_ir.py b/onnxscript/optimizer/remove_unused_ir.py index 217206787..8a8b0b713 100644 --- a/onnxscript/optimizer/remove_unused_ir.py +++ b/onnxscript/optimizer/remove_unused_ir.py @@ -32,8 +32,10 @@ def is_used_output(i: int) -> bool: if is_used_output(1) or is_used_output(2): return - node.outputs[1].name = "" - node.outputs[2].name = "" + if len(node.outputs) > 1: + node.outputs[1].name = "" + if len(node.outputs) > 2: + node.outputs[2].name = "" node.attributes.pop("training_mode", None) return diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index f6eb0d793..3eac373d6 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -40,7 +40,7 @@ def rewrite( pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules) count = pattern_rewrite_rules.apply_to_model(model_ir) print(f"Applied {count} of general pattern rewrite rules.") + remove_unused.remove_unused_nodes(model_ir) + model_ir = remove_unused_function.remove_unused_functions(model_ir) model = ir.serde.serialize_model(model_ir) - remove_unused.remove_unused_nodes(model) - model = remove_unused_function.remove_unused_functions(model) return model diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index 1b61e29a8..f76dd680c 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -4,9 +4,8 @@ import onnx -from onnxscript import ir -from onnxscript.optimizer import remove_unused, remove_unused_function from onnxscript.rewriter import function_rule, pattern +from onnxscript.rewriter import rewrite as _rewrite from onnxscript.rewriter.onnxruntime import ( group_normalization_merge_silu, instance_to_group_normalization, @@ -44,18 +43,6 @@ def rewrite( """ function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES - model = ir.serde.deserialize_model(model_proto) - # TODO(bowenbao): Function rules first, or pattern rules first? - if function_rules: - for rule_cls in function_rules: - count, model = rule_cls().apply_to_model(model) - if count > 0: - print(f"Applied {count} of onnxruntime specific function rewrite rules.") - if pattern_rules: - count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model) - print(f"Applied {count} of onnxruntime specific pattern rewrite rules.") - - model_proto = ir.serde.serialize_model(model) - remove_unused.remove_unused_nodes(model_proto) - model_proto = remove_unused_function.remove_unused_functions(model_proto) - return model_proto + return _rewrite( + model_proto, function_rewrite_rules=function_rules, pattern_rewrite_rules=pattern_rules + )