diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index ea5c823a9..7dc784650 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from typing import Sequence, Union __all__ = [ # Modules @@ -16,6 +16,7 @@ from onnxscript.optimizer import remove_unused, remove_unused_function from onnxscript.rewriter import function_rule, pattern +RewriteRuleSet = pattern.RewriteRuleSet PatternRewriteRule = pattern.RewriteRule FunctionRewriteRule = function_rule.FunctionRewriteRule @@ -23,7 +24,7 @@ def rewrite( model: onnx.ModelProto, function_rewrite_rules: Sequence[type[FunctionRewriteRule]] = (), - pattern_rewrite_rules: Sequence[PatternRewriteRule] = (), + pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], RewriteRuleSet] = (), ) -> onnx.ModelProto: model_ir = ir.serde.deserialize_model(model) if function_rewrite_rules: @@ -31,7 +32,10 @@ def rewrite( count, model_ir = rule_cls().apply_to_model(model_ir) print(f"Applied {count} of onnxruntime specific function rewrite rules.") if pattern_rewrite_rules: - count = pattern.RewriteRuleSet(pattern_rewrite_rules).apply_to_model(model_ir) + if not isinstance(pattern_rewrite_rules, RewriteRuleSet): + # Create a pattern rule-set using provided rules + pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules) + count = pattern_rewrite_rules.apply_to_model(model_ir) print(f"Applied {count} of general pattern rewrite rules.") model = ir.serde.serialize_model(model_ir) remove_unused.remove_unused_nodes(model)