diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index be265963c..1f00840d4 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -959,6 +959,12 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: self._matched[pattern_node] = node + # TODO: Revisit this to handle optional trailing inputs better. + if len(node.inputs) != len(pattern_node.inputs): + return self.fail( + "Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}" + ) + for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs): # arg_pattern could be a Var, if it's the original arg. if arg_pattern is None: