Skip to content

Commit

Permalink
Refactoring onnxscript.rewriter.onnxruntime.rewrite
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Jun 17, 2024
1 parent dc31a6e commit d30fde0
Showing 1 changed file with 4 additions and 17 deletions.
21 changes: 4 additions & 17 deletions onnxscript/rewriter/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

import onnx

from onnxscript import ir
from onnxscript.optimizer import remove_unused, remove_unused_function
from onnxscript.rewriter import function_rule, pattern
from onnxscript.rewriter import rewrite as _rewrite
from onnxscript.rewriter.onnxruntime import (
group_normalization_merge_silu,
instance_to_group_normalization,
Expand Down Expand Up @@ -44,18 +43,6 @@ def rewrite(
"""
function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES
pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES
model = ir.serde.deserialize_model(model_proto)
# TODO(bowenbao): Function rules first, or pattern rules first?
if function_rules:
for rule_cls in function_rules:
count, model = rule_cls().apply_to_model(model)
if count > 0:
print(f"Applied {count} of onnxruntime specific function rewrite rules.")
if pattern_rules:
count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model)
print(f"Applied {count} of onnxruntime specific pattern rewrite rules.")

model_proto = ir.serde.serialize_model(model)
remove_unused.remove_unused_nodes(model_proto)
model_proto = remove_unused_function.remove_unused_functions(model_proto)
return model_proto
return _rewrite(
model_proto, function_rewrite_rules=function_rules, pattern_rewrite_rules=pattern_rules
)

0 comments on commit d30fde0

Please sign in to comment.