From 816bfc08c1ce0ec61827d1163daab7ff7965ee3f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Nov 2023 16:31:16 -0800 Subject: [PATCH] Undo isomorphism test change Signed-off-by: Ganesan Ramalingam --- onnxscript/testing.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/onnxscript/testing.py b/onnxscript/testing.py index afcc08042..eb7203ec0 100644 --- a/onnxscript/testing.py +++ b/onnxscript/testing.py @@ -117,12 +117,9 @@ def _same_value_info(vi1, vi2): ) -def _same_attr(attr1, attr2, graph_equality, ref_equality): +def _same_attr(attr1, attr2, graph_equality): # no name check; names used to match attributes already. - if not _same_optional("ref_attr_name", attr1, attr2, ref_equality): - return False - - for field in ["type", "f", "i", "s"]: + for field in ["type", "ref_attr_name", "f", "i", "s"]: if not _same_optional(field, attr1, attr2): return False @@ -149,7 +146,7 @@ def _same_attr(attr1, attr2, graph_equality, ref_equality): return True -def _same_attrs(attrs1, attrs2, graph_equality, ref_equality): +def _same_attrs(attrs1, attrs2, graph_equality): if len(attrs1) != len(attrs2): return False attrs1map = {a.name: a for a in attrs1} @@ -157,7 +154,7 @@ def _same_attrs(attrs1, attrs2, graph_equality, ref_equality): if attr2.name not in attrs1map: return False attr1 = attrs1map[attr2.name] - if not _same_attr(attr1, attr2, graph_equality, ref_equality): + if not _same_attr(attr1, attr2, graph_equality): return False return True @@ -191,15 +188,6 @@ def defmap(f): self.node_mapping: dict[onnx.NodeProto, onnx.NodeProto] = {} self.outer_scope = outer_scope - def same_ref_attr(self, ref_attr_name1, ref_attr_name2) -> bool: - def find(fun, name): - for i, attr_name in enumerate(fun.attribute): - if attr_name == name: - return i - # TODO: handle attribute_protos - return None - return find(self.fg1, ref_attr_name1) == find(self.fg2, ref_attr_name2) - def same_value(self, var1, var2): """Match two variables (strings).""" if var1 not in self.defmap1 or var2 not in self.defmap2: @@ -227,7 +215,7 @@ def same_node(self, n1, n2): if node1.domain != node2.domain: return False # check attrs - if not _same_attrs(node1.attribute, node2.attribute, self.same_sub_graph, self.same_ref_attr): + if not _same_attrs(node1.attribute, node2.attribute, self.same_sub_graph): return False if not self.same_value_list(node1.input, node2.input): return False @@ -273,7 +261,7 @@ def same_function(self): if len(self.fg1.input) != len(self.fg2.input): return False - if len(self.fg1.attribute) != len(self.fg2.attribute): + if set(self.fg1.attribute) != set(self.fg2.attribute): return False # Opset imports must be same (but possibly in different order):