From 4fffcebadf984d9bcd06461905e45dc149a27fec Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Nov 2023 13:21:00 -0800 Subject: [PATCH 1/6] Handle attr-ref name conflict Signed-off-by: Ganesan Ramalingam --- onnxscript/backend/onnx_export.py | 80 ++++++++++++++++++++++---- onnxscript/backend/onnx_export_test.py | 21 ++++++- onnxscript/testing.py | 24 ++++++-- 3 files changed, 107 insertions(+), 18 deletions(-) diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index 063ded7f0..fb102f745 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -25,15 +25,7 @@ {%- endif %} {{translate_opset_imports_of(main_model)}} {% for domain, name, fct in functions: %} -@script({{make_opset_name(domain, 1)}}) -def {{ python_make_node_name(fct['proto'].domain, 1, fct['proto'].name) }}{{ - translate_function_signature(fct['proto'])}} - {% if fct['proto'].doc_string %}""" - {{ fct['proto'].doc_string }} - """{%- endif %} - {%- for node in fct['proto'].node: %} -{{ python_make_node(node, opsets, indent=1) }}{% endfor %} - return {{ ", ".join(map(rename, fct['proto'].output)) }} +{{translate_function(fct["proto"])}} {% endfor %} {% if graph %} @script() @@ -198,6 +190,33 @@ def _attribute_value(attr: onnx.AttributeProto): # - onnx.AttributeProto.TYPE_PROTOS raise NotImplementedError(f"Unable to return a value for attribute {attr!r}.") +def _update_names_used_in_graph(names: set[str], graph: GraphProto) -> None: + """Returns the names used in a graph.""" + names.update(x.name for x in graph.input) + names.update(x.name for x in graph.output) + names.update(x.name for x in graph.initializer) + for node in graph.node: + _update_names_used_in_node(names, node) + +def _update_names_used_in_node(names: set[str], node: onnx.NodeProto) -> None: + names.update(node.input) + names.update(node.output) + for attr in node.attribute: + if attr.HasField("g"): + _update_names_used_in_graph(names, attr.g) + for g in attr.graphs: + _update_names_used_in_graph(names, g) + +def _update_names_used_in_function(names: set[str], fun: FunctionProto) -> None: + names.update(fun.input) + names.update(fun.output) + for node in fun.node: + _update_names_used_in_node(names, node) + +def _names_used_in_function(fun: FunctionProto) -> set[str]: + names = set() + _update_names_used_in_function(names, fun) + return names class Exporter: """Class used for recursive traversal of Proto structures.""" @@ -207,11 +226,25 @@ def __init__(self, rename_function, use_operators=False, inline_const=False) -> self._rename_variable = rename_function self.inline_const = inline_const self.constants: dict[str, str] = {} + self._attr_renaming: dict[str, str] = {} # For current function. + self._names_used: set[str] = set() # For current function. def _rename_variable_s(self, name): """Renames all names equal to a python keyword.""" return str(self._rename_variable(name)) + def _rename_attr_parameter(self, name): + """Renames an attribute parameter.""" + if name in self._attr_renaming: + return self._attr_renaming[name] + counter = 0 + candidate = name + while candidate in self._names_used: + candidate = f"{name}_attr_{counter}" + self._attr_renaming[name] = candidate + self._names_used.add(candidate) + return candidate + def _rename_domain(self, domain: str) -> str: if domain == "": return "opset" @@ -265,7 +298,8 @@ def _python_make_node_make_attribute_str(self, node): attributes = [] for at in node.attribute: if _is_attribute_ref(at): - attributes.append((at.name, at.ref_attr_name)) + ref_attr_name = self._attr_renaming[at.ref_attr_name] + attributes.append((at.name, ref_attr_name)) continue value = _attribute_value(at) if isinstance(value, str): @@ -460,7 +494,7 @@ def translate_function_signature(self, funproto: onnx.FunctionProto) -> str: type_map = _attribute_param_types(funproto) def attr_sig(attr_name: str) -> str: - name = self._rename_variable(attr_name) + name = self._rename_attr_parameter(attr_name) # A default type of INT is used for attribute parameters that are never used. type = type_map.get(attr_name, onnx.AttributeProto.INT) typerep = onnxscript.type_annotation.onnx_attr_type_to_onnxscript_repr(type) @@ -475,6 +509,28 @@ def attr_sig(attr_name: str) -> str: message = "" return f"({input_and_attrs}):{message}" + def translate_function(self, funproto: onnx.FunctionProto) -> str: + """Generate python code for FunctionProto.""" + self._attr_renaming = {} + used_proto_names = _names_used_in_function(funproto) + renamed_names_used = [self._rename_variable(x) for x in used_proto_names] + self._names_used = set(renamed_names_used) + result = [] + def add_line(line: str) -> None: + result.append(line) + opset_name = self.make_opset_name(funproto.domain, 1) + add_line(f"@script({opset_name})") + fun_name = self._python_make_node_name(funproto.domain, 1, funproto.name) + fun_sig = self.translate_function_signature(funproto) + add_line(f"def {fun_name}{fun_sig}") + if funproto.doc_string: + add_line(f' """{funproto.doc_string}"""') + for node in funproto.node: + add_line(self._python_make_node(node, self.opsets, indent=1)) + return_values = ", ".join(self._rename_variable(x) for x in funproto.output) + add_line(f" return {return_values}") + return "\n".join(result) + def _attribute_param_types( funproto: onnx.FunctionProto, @@ -553,6 +609,7 @@ def rename_variable(name): "rename": rename_variable, "translate_sig": _translate_signature, "translate_function_signature": exporter.translate_function_signature, + "translate_function": exporter.translate_function, "translate_opset_imports_of": exporter.translate_opset_imports_of, "hasattr": hasattr, "make_opset_name": exporter.make_opset_name, @@ -564,6 +621,7 @@ def rename_variable(name): for oimp in model_onnx.opset_import: opsets[oimp.domain] = oimp.version context["opsets"] = opsets + exporter.opsets = opsets graph = model_onnx.graph if hasattr(model_onnx, "graph") else model_onnx diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index efaa6c62d..fcba283d8 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -103,7 +103,7 @@ def run_function(obj, *inputs): return got -def extract_functions(name: str, content: str, test_folder: pathlib.Path): +def write_functions(name: str, content: str, test_folder: pathlib.Path): """Write the content into a file and import all OnnxFunctions from it.""" if not test_folder.exists(): test_folder.mkdir(exist_ok=True, parents=True) @@ -112,6 +112,8 @@ def extract_functions(name: str, content: str, test_folder: pathlib.Path): file = test_folder / f"{name}.py" file.write_text(content, encoding="utf-8") +def extract_functions(name: str, content: str, test_folder: pathlib.Path): + write_functions(name, content, test_folder) import_name = f"onnxscript.tests.{test_folder.parts[-1]}.{name}" try: mod = importlib.import_module(import_name) @@ -137,6 +139,12 @@ class TestOnnxBackEnd(unittest.TestCase): test_folder = root_folder / "tests" / "onnx_backend_test_code" temp_folder = root_folder / "tests" / "export" + def _proto_to_os_and_back(self, proto: onnxscript.FunctionProto, **export_options): + """Convert a proto to onnxscript code and convert it back to a proto.""" + code = onnx_export.export2python(proto, **export_options) + map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder) + return map[proto.name] + def _round_trip_check(self, script_function, **export_options): proto = script_function.to_function_proto() code = onnx_export.export2python(proto, **export_options) @@ -154,6 +162,17 @@ def fun_with_attr_param(X, dtype: int): self._round_trip_check(fun_with_attr_param) + def test_double_attr_val_promotion(self): + op = onnxscript.opset17 + + @onnxscript.script() + def fun_with_double_attr_promotion(X, dtype: int): + Y = op.Add(X, dtype) + Z = op.Add(Y, dtype) + return Z + + self._round_trip_check(fun_with_double_attr_promotion) + def test_qualified_domain(self): """Test use of qualified domain name.""" op = onnxscript.opset17 diff --git a/onnxscript/testing.py b/onnxscript/testing.py index eb7203ec0..afcc08042 100644 --- a/onnxscript/testing.py +++ b/onnxscript/testing.py @@ -117,9 +117,12 @@ def _same_value_info(vi1, vi2): ) -def _same_attr(attr1, attr2, graph_equality): +def _same_attr(attr1, attr2, graph_equality, ref_equality): # no name check; names used to match attributes already. - for field in ["type", "ref_attr_name", "f", "i", "s"]: + if not _same_optional("ref_attr_name", attr1, attr2, ref_equality): + return False + + for field in ["type", "f", "i", "s"]: if not _same_optional(field, attr1, attr2): return False @@ -146,7 +149,7 @@ def _same_attr(attr1, attr2, graph_equality): return True -def _same_attrs(attrs1, attrs2, graph_equality): +def _same_attrs(attrs1, attrs2, graph_equality, ref_equality): if len(attrs1) != len(attrs2): return False attrs1map = {a.name: a for a in attrs1} @@ -154,7 +157,7 @@ def _same_attrs(attrs1, attrs2, graph_equality): if attr2.name not in attrs1map: return False attr1 = attrs1map[attr2.name] - if not _same_attr(attr1, attr2, graph_equality): + if not _same_attr(attr1, attr2, graph_equality, ref_equality): return False return True @@ -188,6 +191,15 @@ def defmap(f): self.node_mapping: dict[onnx.NodeProto, onnx.NodeProto] = {} self.outer_scope = outer_scope + def same_ref_attr(self, ref_attr_name1, ref_attr_name2) -> bool: + def find(fun, name): + for i, attr_name in enumerate(fun.attribute): + if attr_name == name: + return i + # TODO: handle attribute_protos + return None + return find(self.fg1, ref_attr_name1) == find(self.fg2, ref_attr_name2) + def same_value(self, var1, var2): """Match two variables (strings).""" if var1 not in self.defmap1 or var2 not in self.defmap2: @@ -215,7 +227,7 @@ def same_node(self, n1, n2): if node1.domain != node2.domain: return False # check attrs - if not _same_attrs(node1.attribute, node2.attribute, self.same_sub_graph): + if not _same_attrs(node1.attribute, node2.attribute, self.same_sub_graph, self.same_ref_attr): return False if not self.same_value_list(node1.input, node2.input): return False @@ -261,7 +273,7 @@ def same_function(self): if len(self.fg1.input) != len(self.fg2.input): return False - if set(self.fg1.attribute) != set(self.fg2.attribute): + if len(self.fg1.attribute) != len(self.fg2.attribute): return False # Opset imports must be same (but possibly in different order): From a127a4136249c79e14bc200c41dd48b2731f937e Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Nov 2023 16:28:36 -0800 Subject: [PATCH 2/6] Fix name conflict resolution Signed-off-by: Ganesan Ramalingam --- onnxscript/backend/onnx_export.py | 46 ++++++++++++++++++------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index fb102f745..32ae524d5 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -223,28 +223,36 @@ class Exporter: def __init__(self, rename_function, use_operators=False, inline_const=False) -> None: self.use_operators = use_operators - self._rename_variable = rename_function + self._rename_variable = self._handle_attrname_conflict(rename_function) self.inline_const = inline_const self.constants: dict[str, str] = {} self._attr_renaming: dict[str, str] = {} # For current function. self._names_used: set[str] = set() # For current function. + def _handle_attrname_conflict(self, renamer): + """Add ref-attr-name-conflict handling logic to renaming function.""" + def new_renamer(name): + new_name = renamer(name) + if new_name not in self._attr_renaming: + return new_name + # Name conflicts with attribute parameter name. + alternate = self._attr_renaming[new_name] + if alternate is not None: + return alternate + counter = 0 + candidate = new_name + while candidate in self._names_used: + candidate = f"{new_name}_{counter}" + counter += 1 + self._attr_renaming[new_name] = candidate + self._names_used.add(candidate) + return candidate + return new_renamer + def _rename_variable_s(self, name): """Renames all names equal to a python keyword.""" return str(self._rename_variable(name)) - def _rename_attr_parameter(self, name): - """Renames an attribute parameter.""" - if name in self._attr_renaming: - return self._attr_renaming[name] - counter = 0 - candidate = name - while candidate in self._names_used: - candidate = f"{name}_attr_{counter}" - self._attr_renaming[name] = candidate - self._names_used.add(candidate) - return candidate - def _rename_domain(self, domain: str) -> str: if domain == "": return "opset" @@ -254,7 +262,7 @@ def make_opset_name(self, domain, version): return f"{self._rename_domain(domain)}{version}" def _python_make_node_name(self, domain, version, name, node=False): - name = _rename_variable(name) + name = _rename_variable(name) # TODO: Is this a typo? Is it supposed to be self._rename_variable(name)? if node: if version is None: version = 1 @@ -298,8 +306,7 @@ def _python_make_node_make_attribute_str(self, node): attributes = [] for at in node.attribute: if _is_attribute_ref(at): - ref_attr_name = self._attr_renaming[at.ref_attr_name] - attributes.append((at.name, ref_attr_name)) + attributes.append((at.name, at.ref_attr_name)) continue value = _attribute_value(at) if isinstance(value, str): @@ -494,11 +501,12 @@ def translate_function_signature(self, funproto: onnx.FunctionProto) -> str: type_map = _attribute_param_types(funproto) def attr_sig(attr_name: str) -> str: - name = self._rename_attr_parameter(attr_name) + self._attr_renaming[attr_name] = None + self._names_used.add(attr_name) # A default type of INT is used for attribute parameters that are never used. type = type_map.get(attr_name, onnx.AttributeProto.INT) typerep = onnxscript.type_annotation.onnx_attr_type_to_onnxscript_repr(type) - return f"{name}: {typerep}" + return f"{attr_name}: {typerep}" inputs = [self._rename_variable(x) for x in funproto.input] attrs = [attr_sig(x) for x in funproto.attribute] @@ -589,7 +597,7 @@ def rename_variable(name): if var_name in variable_names: return variable_names[var_name] new_name = f"v{len(variable_names) + 1}" - assert var_name is not None + assert var_name is not None # TODO(rama): This looks suspect. variable_names[var_name] = new_name return new_name From 816bfc08c1ce0ec61827d1163daab7ff7965ee3f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Nov 2023 16:31:16 -0800 Subject: [PATCH 3/6] Undo isomorphism test change Signed-off-by: Ganesan Ramalingam --- onnxscript/testing.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/onnxscript/testing.py b/onnxscript/testing.py index afcc08042..eb7203ec0 100644 --- a/onnxscript/testing.py +++ b/onnxscript/testing.py @@ -117,12 +117,9 @@ def _same_value_info(vi1, vi2): ) -def _same_attr(attr1, attr2, graph_equality, ref_equality): +def _same_attr(attr1, attr2, graph_equality): # no name check; names used to match attributes already. - if not _same_optional("ref_attr_name", attr1, attr2, ref_equality): - return False - - for field in ["type", "f", "i", "s"]: + for field in ["type", "ref_attr_name", "f", "i", "s"]: if not _same_optional(field, attr1, attr2): return False @@ -149,7 +146,7 @@ def _same_attr(attr1, attr2, graph_equality, ref_equality): return True -def _same_attrs(attrs1, attrs2, graph_equality, ref_equality): +def _same_attrs(attrs1, attrs2, graph_equality): if len(attrs1) != len(attrs2): return False attrs1map = {a.name: a for a in attrs1} @@ -157,7 +154,7 @@ def _same_attrs(attrs1, attrs2, graph_equality, ref_equality): if attr2.name not in attrs1map: return False attr1 = attrs1map[attr2.name] - if not _same_attr(attr1, attr2, graph_equality, ref_equality): + if not _same_attr(attr1, attr2, graph_equality): return False return True @@ -191,15 +188,6 @@ def defmap(f): self.node_mapping: dict[onnx.NodeProto, onnx.NodeProto] = {} self.outer_scope = outer_scope - def same_ref_attr(self, ref_attr_name1, ref_attr_name2) -> bool: - def find(fun, name): - for i, attr_name in enumerate(fun.attribute): - if attr_name == name: - return i - # TODO: handle attribute_protos - return None - return find(self.fg1, ref_attr_name1) == find(self.fg2, ref_attr_name2) - def same_value(self, var1, var2): """Match two variables (strings).""" if var1 not in self.defmap1 or var2 not in self.defmap2: @@ -227,7 +215,7 @@ def same_node(self, n1, n2): if node1.domain != node2.domain: return False # check attrs - if not _same_attrs(node1.attribute, node2.attribute, self.same_sub_graph, self.same_ref_attr): + if not _same_attrs(node1.attribute, node2.attribute, self.same_sub_graph): return False if not self.same_value_list(node1.input, node2.input): return False @@ -273,7 +261,7 @@ def same_function(self): if len(self.fg1.input) != len(self.fg2.input): return False - if len(self.fg1.attribute) != len(self.fg2.attribute): + if set(self.fg1.attribute) != set(self.fg2.attribute): return False # Opset imports must be same (but possibly in different order): From b673b480a529b8589c6f23ff1aec10c694376c2f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Nov 2023 17:18:47 -0800 Subject: [PATCH 4/6] address lint --- onnxscript/backend/onnx_export_test.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index fcba283d8..79f66bad2 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -103,17 +103,13 @@ def run_function(obj, *inputs): return got -def write_functions(name: str, content: str, test_folder: pathlib.Path): - """Write the content into a file and import all OnnxFunctions from it.""" +def extract_functions(name: str, content: str, test_folder: pathlib.Path): if not test_folder.exists(): test_folder.mkdir(exist_ok=True, parents=True) init = test_folder / "__init__.py" init.touch(exist_ok=True) file = test_folder / f"{name}.py" file.write_text(content, encoding="utf-8") - -def extract_functions(name: str, content: str, test_folder: pathlib.Path): - write_functions(name, content, test_folder) import_name = f"onnxscript.tests.{test_folder.parts[-1]}.{name}" try: mod = importlib.import_module(import_name) From adccc4751746aa67be0a7d21d27c9799d3e04c3f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Nov 2023 20:05:20 -0800 Subject: [PATCH 5/6] run lint --- onnxscript/backend/onnx_export.py | 15 +++++++++++++-- onnxscript/backend/onnx_export_test.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index 32ae524d5..315352daf 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -190,6 +190,7 @@ def _attribute_value(attr: onnx.AttributeProto): # - onnx.AttributeProto.TYPE_PROTOS raise NotImplementedError(f"Unable to return a value for attribute {attr!r}.") + def _update_names_used_in_graph(names: set[str], graph: GraphProto) -> None: """Returns the names used in a graph.""" names.update(x.name for x in graph.input) @@ -198,6 +199,7 @@ def _update_names_used_in_graph(names: set[str], graph: GraphProto) -> None: for node in graph.node: _update_names_used_in_node(names, node) + def _update_names_used_in_node(names: set[str], node: onnx.NodeProto) -> None: names.update(node.input) names.update(node.output) @@ -207,17 +209,20 @@ def _update_names_used_in_node(names: set[str], node: onnx.NodeProto) -> None: for g in attr.graphs: _update_names_used_in_graph(names, g) + def _update_names_used_in_function(names: set[str], fun: FunctionProto) -> None: names.update(fun.input) names.update(fun.output) for node in fun.node: _update_names_used_in_node(names, node) + def _names_used_in_function(fun: FunctionProto) -> set[str]: names = set() _update_names_used_in_function(names, fun) return names + class Exporter: """Class used for recursive traversal of Proto structures.""" @@ -231,6 +236,7 @@ def __init__(self, rename_function, use_operators=False, inline_const=False) -> def _handle_attrname_conflict(self, renamer): """Add ref-attr-name-conflict handling logic to renaming function.""" + def new_renamer(name): new_name = renamer(name) if new_name not in self._attr_renaming: @@ -246,7 +252,8 @@ def new_renamer(name): counter += 1 self._attr_renaming[new_name] = candidate self._names_used.add(candidate) - return candidate + return candidate + return new_renamer def _rename_variable_s(self, name): @@ -262,7 +269,9 @@ def make_opset_name(self, domain, version): return f"{self._rename_domain(domain)}{version}" def _python_make_node_name(self, domain, version, name, node=False): - name = _rename_variable(name) # TODO: Is this a typo? Is it supposed to be self._rename_variable(name)? + name = _rename_variable( + name + ) # TODO: Is this a typo? Is it supposed to be self._rename_variable(name)? if node: if version is None: version = 1 @@ -524,8 +533,10 @@ def translate_function(self, funproto: onnx.FunctionProto) -> str: renamed_names_used = [self._rename_variable(x) for x in used_proto_names] self._names_used = set(renamed_names_used) result = [] + def add_line(line: str) -> None: result.append(line) + opset_name = self.make_opset_name(funproto.domain, 1) add_line(f"@script({opset_name})") fun_name = self._python_make_node_name(funproto.domain, 1, funproto.name) diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index 79f66bad2..0197e0edc 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -168,7 +168,7 @@ def fun_with_double_attr_promotion(X, dtype: int): return Z self._round_trip_check(fun_with_double_attr_promotion) - + def test_qualified_domain(self): """Test use of qualified domain name.""" op = onnxscript.opset17 From 632d44ba0eb32eb056e53b18227fc2c8d10dd344 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sat, 18 Nov 2023 16:29:40 -0800 Subject: [PATCH 6/6] address mypy issues --- onnxscript/backend/onnx_export.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index 315352daf..47deb1a83 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -218,7 +218,7 @@ def _update_names_used_in_function(names: set[str], fun: FunctionProto) -> None: def _names_used_in_function(fun: FunctionProto) -> set[str]: - names = set() + names: set[str] = set() _update_names_used_in_function(names, fun) return names @@ -231,8 +231,9 @@ def __init__(self, rename_function, use_operators=False, inline_const=False) -> self._rename_variable = self._handle_attrname_conflict(rename_function) self.inline_const = inline_const self.constants: dict[str, str] = {} - self._attr_renaming: dict[str, str] = {} # For current function. + self._attr_renaming: dict[str, str | None] = {} # For current function. self._names_used: set[str] = set() # For current function. + self.opsets: dict[str, int] = {} def _handle_attrname_conflict(self, renamer): """Add ref-attr-name-conflict handling logic to renaming function."""