Skip to content

Commit

Permalink
Refactors onnxscript.rewriter.onnxruntime.rewrite to call onnx.rewrit…
Browse files Browse the repository at this point in the history
…er.rewrite (#1628)

Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre authored Jun 17, 2024
1 parent dc31a6e commit d620466
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 23 deletions.
4 changes: 2 additions & 2 deletions docs/api/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
```

```{eval-rst}
.. autofunction:: onnxscript.tools.transformers_models.phi.get_phi_model_config
.. autofunction:: onnxscript.tools.transformers_models.phi.get_phi_model_from_config
```

```{eval-rst}
.. autofunction:: onnxscript.tools.transformers_models.llama.get_llama_model_config
.. autofunction:: onnxscript.tools.transformers_models.llama.get_llama_model_from_config
```
2 changes: 2 additions & 0 deletions onnxscript/_legacy_ir/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,8 @@ def get_constant_value(i: int) -> onnx.TensorProto | None:
)

for output in node.output:
if output == "":
continue
info = self.lookup_or_create(output)
if output in output_types:
if info.type is not None:
Expand Down
6 changes: 4 additions & 2 deletions onnxscript/optimizer/remove_unused_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def is_used_output(i: int) -> bool:

if is_used_output(1) or is_used_output(2):
return
node.outputs[1].name = ""
node.outputs[2].name = ""
if len(node.outputs) > 1:
node.outputs[1].name = ""
if len(node.outputs) > 2:
node.outputs[2].name = ""
node.attributes.pop("training_mode", None)
return

Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def rewrite(
pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules)
count = pattern_rewrite_rules.apply_to_model(model_ir)
print(f"Applied {count} of general pattern rewrite rules.")
remove_unused.remove_unused_nodes(model_ir)
model_ir = remove_unused_function.remove_unused_functions(model_ir)
model = ir.serde.serialize_model(model_ir)
remove_unused.remove_unused_nodes(model)
model = remove_unused_function.remove_unused_functions(model)
return model
21 changes: 4 additions & 17 deletions onnxscript/rewriter/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

import onnx

from onnxscript import ir
from onnxscript.optimizer import remove_unused, remove_unused_function
from onnxscript.rewriter import function_rule, pattern
from onnxscript.rewriter import rewrite as _rewrite
from onnxscript.rewriter.onnxruntime import (
group_normalization_merge_silu,
instance_to_group_normalization,
Expand Down Expand Up @@ -44,18 +43,6 @@ def rewrite(
"""
function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES
pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES
model = ir.serde.deserialize_model(model_proto)
# TODO(bowenbao): Function rules first, or pattern rules first?
if function_rules:
for rule_cls in function_rules:
count, model = rule_cls().apply_to_model(model)
if count > 0:
print(f"Applied {count} of onnxruntime specific function rewrite rules.")
if pattern_rules:
count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model)
print(f"Applied {count} of onnxruntime specific pattern rewrite rules.")

model_proto = ir.serde.serialize_model(model)
remove_unused.remove_unused_nodes(model_proto)
model_proto = remove_unused_function.remove_unused_functions(model_proto)
return model_proto
return _rewrite(
model_proto, function_rewrite_rules=function_rules, pattern_rewrite_rules=pattern_rules
)

0 comments on commit d620466

Please sign in to comment.