Skip to content

Commit

Permalink
Handle name conflict (attribute parameter) when converting back to py…
Browse files Browse the repository at this point in the history
…thon (#1166)

Attribute parameters and normal values (like tensors) use the same
namespace in onnxscript/python. This can cause a conflict when
converting ONNX proto back to onnxscript (since these are different
namespaces in ONNX). Examples of where this happens shown in test-case
below: eg., when an attribute-parameter "yyy" is used as a value, the
onnxscript translator introduces a value called "yyy" which is bound to
the attribute-value "yyy".

Fix this in the exporter.

---------

Signed-off-by: Ganesan Ramalingam <[email protected]>
  • Loading branch information
gramalingam authored Nov 20, 2023
1 parent 10f9a1f commit e75da82
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 16 deletions.
106 changes: 92 additions & 14 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,7 @@
{%- endif %}
{{translate_opset_imports_of(main_model)}}
{% for domain, name, fct in functions: %}
@script({{make_opset_name(domain, 1)}})
def {{ python_make_node_name(fct['proto'].domain, 1, fct['proto'].name) }}{{
translate_function_signature(fct['proto'])}}
{% if fct['proto'].doc_string %}"""
{{ fct['proto'].doc_string }}
"""{%- endif %}
{%- for node in fct['proto'].node: %}
{{ python_make_node(node, opsets, indent=1) }}{% endfor %}
return {{ ", ".join(map(rename, fct['proto'].output)) }}
{{translate_function(fct["proto"])}}
{% endfor %}
{% if graph %}
@script()
Expand Down Expand Up @@ -199,14 +191,71 @@ def _attribute_value(attr: onnx.AttributeProto):
raise NotImplementedError(f"Unable to return a value for attribute {attr!r}.")


def _update_names_used_in_graph(names: set[str], graph: GraphProto) -> None:
"""Returns the names used in a graph."""
names.update(x.name for x in graph.input)
names.update(x.name for x in graph.output)
names.update(x.name for x in graph.initializer)
for node in graph.node:
_update_names_used_in_node(names, node)


def _update_names_used_in_node(names: set[str], node: onnx.NodeProto) -> None:
names.update(node.input)
names.update(node.output)
for attr in node.attribute:
if attr.HasField("g"):
_update_names_used_in_graph(names, attr.g)
for g in attr.graphs:
_update_names_used_in_graph(names, g)


def _update_names_used_in_function(names: set[str], fun: FunctionProto) -> None:
names.update(fun.input)
names.update(fun.output)
for node in fun.node:
_update_names_used_in_node(names, node)


def _names_used_in_function(fun: FunctionProto) -> set[str]:
names: set[str] = set()
_update_names_used_in_function(names, fun)
return names


class Exporter:
"""Class used for recursive traversal of Proto structures."""

def __init__(self, rename_function, use_operators=False, inline_const=False) -> None:
self.use_operators = use_operators
self._rename_variable = rename_function
self._rename_variable = self._handle_attrname_conflict(rename_function)
self.inline_const = inline_const
self.constants: dict[str, str] = {}
self._attr_renaming: dict[str, str | None] = {} # For current function.
self._names_used: set[str] = set() # For current function.
self.opsets: dict[str, int] = {}

def _handle_attrname_conflict(self, renamer):
"""Add ref-attr-name-conflict handling logic to renaming function."""

def new_renamer(name):
new_name = renamer(name)
if new_name not in self._attr_renaming:
return new_name
# Name conflicts with attribute parameter name.
alternate = self._attr_renaming[new_name]
if alternate is not None:
return alternate
counter = 0
candidate = new_name
while candidate in self._names_used:
candidate = f"{new_name}_{counter}"
counter += 1
self._attr_renaming[new_name] = candidate
self._names_used.add(candidate)
return candidate

return new_renamer

def _rename_variable_s(self, name):
"""Renames all names equal to a python keyword."""
Expand All @@ -221,7 +270,9 @@ def make_opset_name(self, domain, version):
return f"{self._rename_domain(domain)}{version}"

def _python_make_node_name(self, domain, version, name, node=False):
name = _rename_variable(name)
name = _rename_variable(
name
) # TODO: Is this a typo? Is it supposed to be self._rename_variable(name)?
if node:
if version is None:
version = 1
Expand Down Expand Up @@ -460,11 +511,12 @@ def translate_function_signature(self, funproto: onnx.FunctionProto) -> str:
type_map = _attribute_param_types(funproto)

def attr_sig(attr_name: str) -> str:
name = self._rename_variable(attr_name)
self._attr_renaming[attr_name] = None
self._names_used.add(attr_name)
# A default type of INT is used for attribute parameters that are never used.
type = type_map.get(attr_name, onnx.AttributeProto.INT)
typerep = onnxscript.type_annotation.onnx_attr_type_to_onnxscript_repr(type)
return f"{name}: {typerep}"
return f"{attr_name}: {typerep}"

inputs = [self._rename_variable(x) for x in funproto.input]
attrs = [attr_sig(x) for x in funproto.attribute]
Expand All @@ -475,6 +527,30 @@ def attr_sig(attr_name: str) -> str:
message = ""
return f"({input_and_attrs}):{message}"

def translate_function(self, funproto: onnx.FunctionProto) -> str:
"""Generate python code for FunctionProto."""
self._attr_renaming = {}
used_proto_names = _names_used_in_function(funproto)
renamed_names_used = [self._rename_variable(x) for x in used_proto_names]
self._names_used = set(renamed_names_used)
result = []

def add_line(line: str) -> None:
result.append(line)

opset_name = self.make_opset_name(funproto.domain, 1)
add_line(f"@script({opset_name})")
fun_name = self._python_make_node_name(funproto.domain, 1, funproto.name)
fun_sig = self.translate_function_signature(funproto)
add_line(f"def {fun_name}{fun_sig}")
if funproto.doc_string:
add_line(f' """{funproto.doc_string}"""')
for node in funproto.node:
add_line(self._python_make_node(node, self.opsets, indent=1))
return_values = ", ".join(self._rename_variable(x) for x in funproto.output)
add_line(f" return {return_values}")
return "\n".join(result)


def _attribute_param_types(
funproto: onnx.FunctionProto,
Expand Down Expand Up @@ -533,7 +609,7 @@ def rename_variable(name):
if var_name in variable_names:
return variable_names[var_name]
new_name = f"v{len(variable_names) + 1}"
assert var_name is not None
assert var_name is not None # TODO(rama): This looks suspect.
variable_names[var_name] = new_name
return new_name

Expand All @@ -553,6 +629,7 @@ def rename_variable(name):
"rename": rename_variable,
"translate_sig": _translate_signature,
"translate_function_signature": exporter.translate_function_signature,
"translate_function": exporter.translate_function,
"translate_opset_imports_of": exporter.translate_opset_imports_of,
"hasattr": hasattr,
"make_opset_name": exporter.make_opset_name,
Expand All @@ -564,6 +641,7 @@ def rename_variable(name):
for oimp in model_onnx.opset_import:
opsets[oimp.domain] = oimp.version
context["opsets"] = opsets
exporter.opsets = opsets

graph = model_onnx.graph if hasattr(model_onnx, "graph") else model_onnx

Expand Down
19 changes: 17 additions & 2 deletions onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,12 @@ def run_function(obj, *inputs):


def extract_functions(name: str, content: str, test_folder: pathlib.Path):
"""Write the content into a file and import all OnnxFunctions from it."""
if not test_folder.exists():
test_folder.mkdir(exist_ok=True, parents=True)
init = test_folder / "__init__.py"
init.touch(exist_ok=True)
file = test_folder / f"{name}.py"
file.write_text(content, encoding="utf-8")

import_name = f"onnxscript.tests.{test_folder.parts[-1]}.{name}"
try:
mod = importlib.import_module(import_name)
Expand All @@ -137,6 +135,12 @@ class TestOnnxBackEnd(unittest.TestCase):
test_folder = root_folder / "tests" / "onnx_backend_test_code"
temp_folder = root_folder / "tests" / "export"

def _proto_to_os_and_back(self, proto: onnxscript.FunctionProto, **export_options):
"""Convert a proto to onnxscript code and convert it back to a proto."""
code = onnx_export.export2python(proto, **export_options)
map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder)
return map[proto.name]

def _round_trip_check(self, script_function, **export_options):
proto = script_function.to_function_proto()
code = onnx_export.export2python(proto, **export_options)
Expand All @@ -154,6 +158,17 @@ def fun_with_attr_param(X, dtype: int):

self._round_trip_check(fun_with_attr_param)

def test_double_attr_val_promotion(self):
op = onnxscript.opset17

@onnxscript.script()
def fun_with_double_attr_promotion(X, dtype: int):
Y = op.Add(X, dtype)
Z = op.Add(Y, dtype)
return Z

self._round_trip_check(fun_with_double_attr_promotion)

def test_qualified_domain(self):
"""Test use of qualified domain name."""
op = onnxscript.opset17
Expand Down

0 comments on commit e75da82

Please sign in to comment.