Skip to content

Commit

Permalink
[rewriter] Modify rewriter API to accept RewriteRuleSet (#1456)
Browse files Browse the repository at this point in the history
Modify rewrite API to accept RewriteRuleSet as well as a Sequence of pattern.RewriteRule
  • Loading branch information
shubhambhokare1 authored Apr 25, 2024
1 parent b883020 commit eb7c907
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Sequence
from typing import Sequence, Union

__all__ = [
# Modules
Expand All @@ -16,22 +16,26 @@
from onnxscript.optimizer import remove_unused, remove_unused_function
from onnxscript.rewriter import function_rule, pattern

RewriteRuleSet = pattern.RewriteRuleSet
PatternRewriteRule = pattern.RewriteRule
FunctionRewriteRule = function_rule.FunctionRewriteRule


def rewrite(
model: onnx.ModelProto,
function_rewrite_rules: Sequence[type[FunctionRewriteRule]] = (),
pattern_rewrite_rules: Sequence[PatternRewriteRule] = (),
pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], RewriteRuleSet] = (),
) -> onnx.ModelProto:
model_ir = ir.serde.deserialize_model(model)
if function_rewrite_rules:
for rule_cls in function_rewrite_rules:
count, model_ir = rule_cls().apply_to_model(model_ir)
print(f"Applied {count} of onnxruntime specific function rewrite rules.")
if pattern_rewrite_rules:
count = pattern.RewriteRuleSet(pattern_rewrite_rules).apply_to_model(model_ir)
if not isinstance(pattern_rewrite_rules, RewriteRuleSet):
# Create a pattern rule-set using provided rules
pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules)
count = pattern_rewrite_rules.apply_to_model(model_ir)
print(f"Applied {count} of general pattern rewrite rules.")
model = ir.serde.serialize_model(model_ir)
remove_unused.remove_unused_nodes(model)
Expand Down

0 comments on commit eb7c907

Please sign in to comment.