Skip to content

Commit

Permalink
Undo isomorphism test change
Browse files Browse the repository at this point in the history
Signed-off-by: Ganesan Ramalingam <[email protected]>
  • Loading branch information
gramalingam committed Nov 18, 2023
1 parent a127a41 commit 816bfc0
Showing 1 changed file with 6 additions and 18 deletions.
24 changes: 6 additions & 18 deletions onnxscript/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,9 @@ def _same_value_info(vi1, vi2):
)


def _same_attr(attr1, attr2, graph_equality, ref_equality):
def _same_attr(attr1, attr2, graph_equality):
# no name check; names used to match attributes already.
if not _same_optional("ref_attr_name", attr1, attr2, ref_equality):
return False

for field in ["type", "f", "i", "s"]:
for field in ["type", "ref_attr_name", "f", "i", "s"]:
if not _same_optional(field, attr1, attr2):
return False

Expand All @@ -149,15 +146,15 @@ def _same_attr(attr1, attr2, graph_equality, ref_equality):
return True


def _same_attrs(attrs1, attrs2, graph_equality, ref_equality):
def _same_attrs(attrs1, attrs2, graph_equality):
if len(attrs1) != len(attrs2):
return False
attrs1map = {a.name: a for a in attrs1}
for attr2 in attrs2:
if attr2.name not in attrs1map:
return False
attr1 = attrs1map[attr2.name]
if not _same_attr(attr1, attr2, graph_equality, ref_equality):
if not _same_attr(attr1, attr2, graph_equality):
return False
return True

Expand Down Expand Up @@ -191,15 +188,6 @@ def defmap(f):
self.node_mapping: dict[onnx.NodeProto, onnx.NodeProto] = {}
self.outer_scope = outer_scope

def same_ref_attr(self, ref_attr_name1, ref_attr_name2) -> bool:
def find(fun, name):
for i, attr_name in enumerate(fun.attribute):
if attr_name == name:
return i
# TODO: handle attribute_protos
return None
return find(self.fg1, ref_attr_name1) == find(self.fg2, ref_attr_name2)

def same_value(self, var1, var2):
"""Match two variables (strings)."""
if var1 not in self.defmap1 or var2 not in self.defmap2:
Expand Down Expand Up @@ -227,7 +215,7 @@ def same_node(self, n1, n2):
if node1.domain != node2.domain:
return False
# check attrs
if not _same_attrs(node1.attribute, node2.attribute, self.same_sub_graph, self.same_ref_attr):
if not _same_attrs(node1.attribute, node2.attribute, self.same_sub_graph):
return False
if not self.same_value_list(node1.input, node2.input):
return False
Expand Down Expand Up @@ -273,7 +261,7 @@ def same_function(self):

if len(self.fg1.input) != len(self.fg2.input):
return False
if len(self.fg1.attribute) != len(self.fg2.attribute):
if set(self.fg1.attribute) != set(self.fg2.attribute):
return False

# Opset imports must be same (but possibly in different order):
Expand Down

0 comments on commit 816bfc0

Please sign in to comment.