diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index 1b61e29a8..f76dd680c 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -4,9 +4,8 @@ import onnx -from onnxscript import ir -from onnxscript.optimizer import remove_unused, remove_unused_function from onnxscript.rewriter import function_rule, pattern +from onnxscript.rewriter import rewrite as _rewrite from onnxscript.rewriter.onnxruntime import ( group_normalization_merge_silu, instance_to_group_normalization, @@ -44,18 +43,6 @@ def rewrite( """ function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES - model = ir.serde.deserialize_model(model_proto) - # TODO(bowenbao): Function rules first, or pattern rules first? - if function_rules: - for rule_cls in function_rules: - count, model = rule_cls().apply_to_model(model) - if count > 0: - print(f"Applied {count} of onnxruntime specific function rewrite rules.") - if pattern_rules: - count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model) - print(f"Applied {count} of onnxruntime specific pattern rewrite rules.") - - model_proto = ir.serde.serialize_model(model) - remove_unused.remove_unused_nodes(model_proto) - model_proto = remove_unused_function.remove_unused_functions(model_proto) - return model_proto + return _rewrite( + model_proto, function_rewrite_rules=function_rules, pattern_rewrite_rules=pattern_rules + )