diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index 063ded7f0..47deb1a83 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -25,15 +25,7 @@ {%- endif %} {{translate_opset_imports_of(main_model)}} {% for domain, name, fct in functions: %} -@script({{make_opset_name(domain, 1)}}) -def {{ python_make_node_name(fct['proto'].domain, 1, fct['proto'].name) }}{{ - translate_function_signature(fct['proto'])}} - {% if fct['proto'].doc_string %}""" - {{ fct['proto'].doc_string }} - """{%- endif %} - {%- for node in fct['proto'].node: %} -{{ python_make_node(node, opsets, indent=1) }}{% endfor %} - return {{ ", ".join(map(rename, fct['proto'].output)) }} +{{translate_function(fct["proto"])}} {% endfor %} {% if graph %} @script() @@ -199,14 +191,71 @@ def _attribute_value(attr: onnx.AttributeProto): raise NotImplementedError(f"Unable to return a value for attribute {attr!r}.") +def _update_names_used_in_graph(names: set[str], graph: GraphProto) -> None: + """Returns the names used in a graph.""" + names.update(x.name for x in graph.input) + names.update(x.name for x in graph.output) + names.update(x.name for x in graph.initializer) + for node in graph.node: + _update_names_used_in_node(names, node) + + +def _update_names_used_in_node(names: set[str], node: onnx.NodeProto) -> None: + names.update(node.input) + names.update(node.output) + for attr in node.attribute: + if attr.HasField("g"): + _update_names_used_in_graph(names, attr.g) + for g in attr.graphs: + _update_names_used_in_graph(names, g) + + +def _update_names_used_in_function(names: set[str], fun: FunctionProto) -> None: + names.update(fun.input) + names.update(fun.output) + for node in fun.node: + _update_names_used_in_node(names, node) + + +def _names_used_in_function(fun: FunctionProto) -> set[str]: + names: set[str] = set() + _update_names_used_in_function(names, fun) + return names + + class Exporter: """Class used for recursive traversal of Proto structures.""" def __init__(self, rename_function, use_operators=False, inline_const=False) -> None: self.use_operators = use_operators - self._rename_variable = rename_function + self._rename_variable = self._handle_attrname_conflict(rename_function) self.inline_const = inline_const self.constants: dict[str, str] = {} + self._attr_renaming: dict[str, str | None] = {} # For current function. + self._names_used: set[str] = set() # For current function. + self.opsets: dict[str, int] = {} + + def _handle_attrname_conflict(self, renamer): + """Add ref-attr-name-conflict handling logic to renaming function.""" + + def new_renamer(name): + new_name = renamer(name) + if new_name not in self._attr_renaming: + return new_name + # Name conflicts with attribute parameter name. + alternate = self._attr_renaming[new_name] + if alternate is not None: + return alternate + counter = 0 + candidate = new_name + while candidate in self._names_used: + candidate = f"{new_name}_{counter}" + counter += 1 + self._attr_renaming[new_name] = candidate + self._names_used.add(candidate) + return candidate + + return new_renamer def _rename_variable_s(self, name): """Renames all names equal to a python keyword.""" @@ -221,7 +270,9 @@ def make_opset_name(self, domain, version): return f"{self._rename_domain(domain)}{version}" def _python_make_node_name(self, domain, version, name, node=False): - name = _rename_variable(name) + name = _rename_variable( + name + ) # TODO: Is this a typo? Is it supposed to be self._rename_variable(name)? if node: if version is None: version = 1 @@ -460,11 +511,12 @@ def translate_function_signature(self, funproto: onnx.FunctionProto) -> str: type_map = _attribute_param_types(funproto) def attr_sig(attr_name: str) -> str: - name = self._rename_variable(attr_name) + self._attr_renaming[attr_name] = None + self._names_used.add(attr_name) # A default type of INT is used for attribute parameters that are never used. type = type_map.get(attr_name, onnx.AttributeProto.INT) typerep = onnxscript.type_annotation.onnx_attr_type_to_onnxscript_repr(type) - return f"{name}: {typerep}" + return f"{attr_name}: {typerep}" inputs = [self._rename_variable(x) for x in funproto.input] attrs = [attr_sig(x) for x in funproto.attribute] @@ -475,6 +527,30 @@ def attr_sig(attr_name: str) -> str: message = "" return f"({input_and_attrs}):{message}" + def translate_function(self, funproto: onnx.FunctionProto) -> str: + """Generate python code for FunctionProto.""" + self._attr_renaming = {} + used_proto_names = _names_used_in_function(funproto) + renamed_names_used = [self._rename_variable(x) for x in used_proto_names] + self._names_used = set(renamed_names_used) + result = [] + + def add_line(line: str) -> None: + result.append(line) + + opset_name = self.make_opset_name(funproto.domain, 1) + add_line(f"@script({opset_name})") + fun_name = self._python_make_node_name(funproto.domain, 1, funproto.name) + fun_sig = self.translate_function_signature(funproto) + add_line(f"def {fun_name}{fun_sig}") + if funproto.doc_string: + add_line(f' """{funproto.doc_string}"""') + for node in funproto.node: + add_line(self._python_make_node(node, self.opsets, indent=1)) + return_values = ", ".join(self._rename_variable(x) for x in funproto.output) + add_line(f" return {return_values}") + return "\n".join(result) + def _attribute_param_types( funproto: onnx.FunctionProto, @@ -533,7 +609,7 @@ def rename_variable(name): if var_name in variable_names: return variable_names[var_name] new_name = f"v{len(variable_names) + 1}" - assert var_name is not None + assert var_name is not None # TODO(rama): This looks suspect. variable_names[var_name] = new_name return new_name @@ -553,6 +629,7 @@ def rename_variable(name): "rename": rename_variable, "translate_sig": _translate_signature, "translate_function_signature": exporter.translate_function_signature, + "translate_function": exporter.translate_function, "translate_opset_imports_of": exporter.translate_opset_imports_of, "hasattr": hasattr, "make_opset_name": exporter.make_opset_name, @@ -564,6 +641,7 @@ def rename_variable(name): for oimp in model_onnx.opset_import: opsets[oimp.domain] = oimp.version context["opsets"] = opsets + exporter.opsets = opsets graph = model_onnx.graph if hasattr(model_onnx, "graph") else model_onnx diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index efaa6c62d..0197e0edc 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -104,14 +104,12 @@ def run_function(obj, *inputs): def extract_functions(name: str, content: str, test_folder: pathlib.Path): - """Write the content into a file and import all OnnxFunctions from it.""" if not test_folder.exists(): test_folder.mkdir(exist_ok=True, parents=True) init = test_folder / "__init__.py" init.touch(exist_ok=True) file = test_folder / f"{name}.py" file.write_text(content, encoding="utf-8") - import_name = f"onnxscript.tests.{test_folder.parts[-1]}.{name}" try: mod = importlib.import_module(import_name) @@ -137,6 +135,12 @@ class TestOnnxBackEnd(unittest.TestCase): test_folder = root_folder / "tests" / "onnx_backend_test_code" temp_folder = root_folder / "tests" / "export" + def _proto_to_os_and_back(self, proto: onnxscript.FunctionProto, **export_options): + """Convert a proto to onnxscript code and convert it back to a proto.""" + code = onnx_export.export2python(proto, **export_options) + map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder) + return map[proto.name] + def _round_trip_check(self, script_function, **export_options): proto = script_function.to_function_proto() code = onnx_export.export2python(proto, **export_options) @@ -154,6 +158,17 @@ def fun_with_attr_param(X, dtype: int): self._round_trip_check(fun_with_attr_param) + def test_double_attr_val_promotion(self): + op = onnxscript.opset17 + + @onnxscript.script() + def fun_with_double_attr_promotion(X, dtype: int): + Y = op.Add(X, dtype) + Z = op.Add(Y, dtype) + return Z + + self._round_trip_check(fun_with_double_attr_promotion) + def test_qualified_domain(self): """Test use of qualified domain name.""" op = onnxscript.opset17