From 766791dbf117d3538eab3ba1583473f8811a26f6 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 23 Dec 2024 11:10:11 -0800 Subject: [PATCH] Fix lint issues --- onnxscript/rewriter/_ir_utils.py | 2 +- onnxscript/rewriter/pattern.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index cefbec823..1d657a5ab 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -95,7 +95,7 @@ def is_singleton_value( return expected == scalar # rtol must be specified for float comparison assert rtol is not None - return math.isclose(scalar, expected, rtol=rtol) + return math.isclose(scalar, expected, rel_tol=rtol) def has_rank(value: ir.Value | None, rank: int) -> bool: diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index eec45ab02..fa43c19d5 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1178,6 +1178,8 @@ def _multi_match(self, candidate: Iterable[ir.Node], check_removable: bool) -> M Args: candidate: An iterable of nodes that will be matched against the pattern output nodes. + check_removable: If True, check that the matched nodes can be removed (that is, that + they are not used elsewhere in the graph). """ match = self._match for pattern_node, node in zip(self.pattern.output_nodes, candidate):