Skip to content

Commit

Permalink
RewriteRule constructor parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed May 1, 2024
1 parent dfaf685 commit 98df760
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def _update_opset_imports(
class RewriteRule:
def __init__(
self,
target_pattern: Callable | None = None,
target_pattern: GraphPattern | Callable | None = None,
replacement_pattern: ReplacementPatternFunction | Callable | None = None,
condition_function: Callable | None = None,
) -> None:
Expand All @@ -731,15 +731,15 @@ def __init__(
"replacement_pattern must be provided if target_pattern is provided"
)

if callable(replacement_pattern):
replacement_pattern = ReplacementPatternFunction(replacement_pattern)
if not isinstance(target_pattern, GraphPattern):
target_pattern = _to_graph_pattern(target_pattern)
self._target_pattern = target_pattern

if not isinstance(replacement_pattern, ReplacementPatternFunction):
replacement_pattern = ReplacementPatternFunction(replacement_pattern)
self._replacement_pattern = replacement_pattern
self._condition_function = condition_function

# Get the last node pattern and number of outputs from the pattern function
self._target_pattern = _to_graph_pattern(target_pattern)

def matches(self, node: ir.Node, model: ir.Model) -> MatchResult:
"""Check if the node from IR matches the pattern."""
if len(node.outputs) != self._target_pattern.num_outputs:
Expand Down

0 comments on commit 98df760

Please sign in to comment.