From 98df760fa1ac8effa0c4c4aae83f923be5497dbc Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 30 Apr 2024 22:18:22 -0700 Subject: [PATCH] RewriteRule constructor parameter --- onnxscript/rewriter/pattern.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index d82af2f10..1be2b4a60 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -705,7 +705,7 @@ def _update_opset_imports( class RewriteRule: def __init__( self, - target_pattern: Callable | None = None, + target_pattern: GraphPattern | Callable | None = None, replacement_pattern: ReplacementPatternFunction | Callable | None = None, condition_function: Callable | None = None, ) -> None: @@ -731,15 +731,15 @@ def __init__( "replacement_pattern must be provided if target_pattern is provided" ) - if callable(replacement_pattern): - replacement_pattern = ReplacementPatternFunction(replacement_pattern) + if not isinstance(target_pattern, GraphPattern): + target_pattern = _to_graph_pattern(target_pattern) + self._target_pattern = target_pattern + if not isinstance(replacement_pattern, ReplacementPatternFunction): + replacement_pattern = ReplacementPatternFunction(replacement_pattern) self._replacement_pattern = replacement_pattern self._condition_function = condition_function - # Get the last node pattern and number of outputs from the pattern function - self._target_pattern = _to_graph_pattern(target_pattern) - def matches(self, node: ir.Node, model: ir.Model) -> MatchResult: """Check if the node from IR matches the pattern.""" if len(node.outputs) != self._target_pattern.num_outputs: