Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle name conflict (attribute parameter) when converting back to python #1166

Merged
merged 6 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 91 additions & 14 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,7 @@
{%- endif %}
{{translate_opset_imports_of(main_model)}}
{% for domain, name, fct in functions: %}
@script({{make_opset_name(domain, 1)}})
def {{ python_make_node_name(fct['proto'].domain, 1, fct['proto'].name) }}{{
translate_function_signature(fct['proto'])}}
{% if fct['proto'].doc_string %}"""
{{ fct['proto'].doc_string }}
"""{%- endif %}
{%- for node in fct['proto'].node: %}
{{ python_make_node(node, opsets, indent=1) }}{% endfor %}
return {{ ", ".join(map(rename, fct['proto'].output)) }}
{{translate_function(fct["proto"])}}
{% endfor %}
{% if graph %}
@script()
Expand Down Expand Up @@ -199,14 +191,70 @@
raise NotImplementedError(f"Unable to return a value for attribute {attr!r}.")


def _update_names_used_in_graph(names: set[str], graph: GraphProto) -> None:
"""Returns the names used in a graph."""
names.update(x.name for x in graph.input)
names.update(x.name for x in graph.output)
names.update(x.name for x in graph.initializer)
for node in graph.node:
_update_names_used_in_node(names, node)

Check warning on line 200 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L200

Added line #L200 was not covered by tests


def _update_names_used_in_node(names: set[str], node: onnx.NodeProto) -> None:
names.update(node.input)
names.update(node.output)
for attr in node.attribute:
if attr.HasField("g"):
_update_names_used_in_graph(names, attr.g)

Check warning on line 208 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L208

Added line #L208 was not covered by tests
for g in attr.graphs:
_update_names_used_in_graph(names, g)

Check warning on line 210 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L210

Added line #L210 was not covered by tests


def _update_names_used_in_function(names: set[str], fun: FunctionProto) -> None:
names.update(fun.input)
names.update(fun.output)
for node in fun.node:
_update_names_used_in_node(names, node)


def _names_used_in_function(fun: FunctionProto) -> set[str]:
names = set()
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
_update_names_used_in_function(names, fun)
return names


class Exporter:
"""Class used for recursive traversal of Proto structures."""

def __init__(self, rename_function, use_operators=False, inline_const=False) -> None:
self.use_operators = use_operators
self._rename_variable = rename_function
self._rename_variable = self._handle_attrname_conflict(rename_function)
self.inline_const = inline_const
self.constants: dict[str, str] = {}
self._attr_renaming: dict[str, str] = {} # For current function.
self._names_used: set[str] = set() # For current function.

def _handle_attrname_conflict(self, renamer):
"""Add ref-attr-name-conflict handling logic to renaming function."""

def new_renamer(name):
new_name = renamer(name)
if new_name not in self._attr_renaming:
return new_name
# Name conflicts with attribute parameter name.
alternate = self._attr_renaming[new_name]
if alternate is not None:
return alternate
counter = 0
candidate = new_name
while candidate in self._names_used:
candidate = f"{new_name}_{counter}"
counter += 1
self._attr_renaming[new_name] = candidate
self._names_used.add(candidate)
return candidate

return new_renamer

def _rename_variable_s(self, name):
"""Renames all names equal to a python keyword."""
Expand All @@ -221,7 +269,9 @@
return f"{self._rename_domain(domain)}{version}"

def _python_make_node_name(self, domain, version, name, node=False):
name = _rename_variable(name)
name = _rename_variable(
name
) # TODO: Is this a typo? Is it supposed to be self._rename_variable(name)?
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
if node:
if version is None:
version = 1
Expand Down Expand Up @@ -460,11 +510,12 @@
type_map = _attribute_param_types(funproto)

def attr_sig(attr_name: str) -> str:
name = self._rename_variable(attr_name)
self._attr_renaming[attr_name] = None
Fixed Show fixed Hide fixed
self._names_used.add(attr_name)
# A default type of INT is used for attribute parameters that are never used.
type = type_map.get(attr_name, onnx.AttributeProto.INT)
typerep = onnxscript.type_annotation.onnx_attr_type_to_onnxscript_repr(type)
return f"{name}: {typerep}"
return f"{attr_name}: {typerep}"

inputs = [self._rename_variable(x) for x in funproto.input]
attrs = [attr_sig(x) for x in funproto.attribute]
Expand All @@ -475,6 +526,30 @@
message = ""
return f"({input_and_attrs}):{message}"

def translate_function(self, funproto: onnx.FunctionProto) -> str:
"""Generate python code for FunctionProto."""
self._attr_renaming = {}
used_proto_names = _names_used_in_function(funproto)
renamed_names_used = [self._rename_variable(x) for x in used_proto_names]
self._names_used = set(renamed_names_used)
result = []

def add_line(line: str) -> None:
result.append(line)

opset_name = self.make_opset_name(funproto.domain, 1)
add_line(f"@script({opset_name})")
fun_name = self._python_make_node_name(funproto.domain, 1, funproto.name)
fun_sig = self.translate_function_signature(funproto)
add_line(f"def {fun_name}{fun_sig}")
if funproto.doc_string:
add_line(f' """{funproto.doc_string}"""')

Check warning on line 546 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L546

Added line #L546 was not covered by tests
for node in funproto.node:
add_line(self._python_make_node(node, self.opsets, indent=1))
Fixed Show fixed Hide fixed
return_values = ", ".join(self._rename_variable(x) for x in funproto.output)
add_line(f" return {return_values}")
return "\n".join(result)


def _attribute_param_types(
funproto: onnx.FunctionProto,
Expand Down Expand Up @@ -533,7 +608,7 @@
if var_name in variable_names:
return variable_names[var_name]
new_name = f"v{len(variable_names) + 1}"
assert var_name is not None
assert var_name is not None # TODO(rama): This looks suspect.
variable_names[var_name] = new_name
return new_name

Expand All @@ -553,6 +628,7 @@
"rename": rename_variable,
"translate_sig": _translate_signature,
"translate_function_signature": exporter.translate_function_signature,
"translate_function": exporter.translate_function,
"translate_opset_imports_of": exporter.translate_opset_imports_of,
"hasattr": hasattr,
"make_opset_name": exporter.make_opset_name,
Expand All @@ -564,6 +640,7 @@
for oimp in model_onnx.opset_import:
opsets[oimp.domain] = oimp.version
context["opsets"] = opsets
exporter.opsets = opsets
Fixed Show fixed Hide fixed

graph = model_onnx.graph if hasattr(model_onnx, "graph") else model_onnx

Expand Down
19 changes: 17 additions & 2 deletions onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,12 @@


def extract_functions(name: str, content: str, test_folder: pathlib.Path):
"""Write the content into a file and import all OnnxFunctions from it."""
if not test_folder.exists():
test_folder.mkdir(exist_ok=True, parents=True)
init = test_folder / "__init__.py"
init.touch(exist_ok=True)
file = test_folder / f"{name}.py"
file.write_text(content, encoding="utf-8")

import_name = f"onnxscript.tests.{test_folder.parts[-1]}.{name}"
try:
mod = importlib.import_module(import_name)
Expand All @@ -137,6 +135,12 @@
test_folder = root_folder / "tests" / "onnx_backend_test_code"
temp_folder = root_folder / "tests" / "export"

def _proto_to_os_and_back(self, proto: onnxscript.FunctionProto, **export_options):
"""Convert a proto to onnxscript code and convert it back to a proto."""
code = onnx_export.export2python(proto, **export_options)
map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder)
return map[proto.name]

Check warning on line 142 in onnxscript/backend/onnx_export_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export_test.py#L140-L142

Added lines #L140 - L142 were not covered by tests

def _round_trip_check(self, script_function, **export_options):
proto = script_function.to_function_proto()
code = onnx_export.export2python(proto, **export_options)
Expand All @@ -154,6 +158,17 @@

self._round_trip_check(fun_with_attr_param)

def test_double_attr_val_promotion(self):
op = onnxscript.opset17

@onnxscript.script()
def fun_with_double_attr_promotion(X, dtype: int):
Y = op.Add(X, dtype)
Z = op.Add(Y, dtype)
return Z

Check warning on line 168 in onnxscript/backend/onnx_export_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export_test.py#L166-L168

Added lines #L166 - L168 were not covered by tests

self._round_trip_check(fun_with_double_attr_promotion)

def test_qualified_domain(self):
"""Test use of qualified domain name."""
op = onnxscript.opset17
Expand Down
Loading