Skip to content

Commit

Permalink
Handle attr-ref name conflict
Browse files Browse the repository at this point in the history
Signed-off-by: Ganesan Ramalingam <[email protected]>
  • Loading branch information
gramalingam committed Nov 17, 2023
1 parent 10f9a1f commit 4fffceb
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 18 deletions.
80 changes: 69 additions & 11 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,7 @@
{%- endif %}
{{translate_opset_imports_of(main_model)}}
{% for domain, name, fct in functions: %}
@script({{make_opset_name(domain, 1)}})
def {{ python_make_node_name(fct['proto'].domain, 1, fct['proto'].name) }}{{
translate_function_signature(fct['proto'])}}
{% if fct['proto'].doc_string %}"""
{{ fct['proto'].doc_string }}
"""{%- endif %}
{%- for node in fct['proto'].node: %}
{{ python_make_node(node, opsets, indent=1) }}{% endfor %}
return {{ ", ".join(map(rename, fct['proto'].output)) }}
{{translate_function(fct["proto"])}}
{% endfor %}
{% if graph %}
@script()
Expand Down Expand Up @@ -198,6 +190,33 @@ def _attribute_value(attr: onnx.AttributeProto):
# - onnx.AttributeProto.TYPE_PROTOS
raise NotImplementedError(f"Unable to return a value for attribute {attr!r}.")

def _update_names_used_in_graph(names: set[str], graph: GraphProto) -> None:
"""Returns the names used in a graph."""
names.update(x.name for x in graph.input)
names.update(x.name for x in graph.output)
names.update(x.name for x in graph.initializer)
for node in graph.node:
_update_names_used_in_node(names, node)

def _update_names_used_in_node(names: set[str], node: onnx.NodeProto) -> None:
names.update(node.input)
names.update(node.output)
for attr in node.attribute:
if attr.HasField("g"):
_update_names_used_in_graph(names, attr.g)
for g in attr.graphs:
_update_names_used_in_graph(names, g)

def _update_names_used_in_function(names: set[str], fun: FunctionProto) -> None:
names.update(fun.input)
names.update(fun.output)
for node in fun.node:
_update_names_used_in_node(names, node)

def _names_used_in_function(fun: FunctionProto) -> set[str]:
names = set()
_update_names_used_in_function(names, fun)
return names

class Exporter:
"""Class used for recursive traversal of Proto structures."""
Expand All @@ -207,11 +226,25 @@ def __init__(self, rename_function, use_operators=False, inline_const=False) ->
self._rename_variable = rename_function
self.inline_const = inline_const
self.constants: dict[str, str] = {}
self._attr_renaming: dict[str, str] = {} # For current function.
self._names_used: set[str] = set() # For current function.

def _rename_variable_s(self, name):
"""Renames all names equal to a python keyword."""
return str(self._rename_variable(name))

def _rename_attr_parameter(self, name):
"""Renames an attribute parameter."""
if name in self._attr_renaming:
return self._attr_renaming[name]
counter = 0
candidate = name
while candidate in self._names_used:
candidate = f"{name}_attr_{counter}"
self._attr_renaming[name] = candidate
self._names_used.add(candidate)
return candidate

def _rename_domain(self, domain: str) -> str:
if domain == "":
return "opset"
Expand Down Expand Up @@ -265,7 +298,8 @@ def _python_make_node_make_attribute_str(self, node):
attributes = []
for at in node.attribute:
if _is_attribute_ref(at):
attributes.append((at.name, at.ref_attr_name))
ref_attr_name = self._attr_renaming[at.ref_attr_name]
attributes.append((at.name, ref_attr_name))
continue
value = _attribute_value(at)
if isinstance(value, str):
Expand Down Expand Up @@ -460,7 +494,7 @@ def translate_function_signature(self, funproto: onnx.FunctionProto) -> str:
type_map = _attribute_param_types(funproto)

def attr_sig(attr_name: str) -> str:
name = self._rename_variable(attr_name)
name = self._rename_attr_parameter(attr_name)
# A default type of INT is used for attribute parameters that are never used.
type = type_map.get(attr_name, onnx.AttributeProto.INT)
typerep = onnxscript.type_annotation.onnx_attr_type_to_onnxscript_repr(type)
Expand All @@ -475,6 +509,28 @@ def attr_sig(attr_name: str) -> str:
message = ""
return f"({input_and_attrs}):{message}"

def translate_function(self, funproto: onnx.FunctionProto) -> str:
"""Generate python code for FunctionProto."""
self._attr_renaming = {}
used_proto_names = _names_used_in_function(funproto)
renamed_names_used = [self._rename_variable(x) for x in used_proto_names]
self._names_used = set(renamed_names_used)
result = []
def add_line(line: str) -> None:
result.append(line)
opset_name = self.make_opset_name(funproto.domain, 1)
add_line(f"@script({opset_name})")
fun_name = self._python_make_node_name(funproto.domain, 1, funproto.name)
fun_sig = self.translate_function_signature(funproto)
add_line(f"def {fun_name}{fun_sig}")
if funproto.doc_string:
add_line(f' """{funproto.doc_string}"""')
for node in funproto.node:
add_line(self._python_make_node(node, self.opsets, indent=1))
return_values = ", ".join(self._rename_variable(x) for x in funproto.output)
add_line(f" return {return_values}")
return "\n".join(result)


def _attribute_param_types(
funproto: onnx.FunctionProto,
Expand Down Expand Up @@ -553,6 +609,7 @@ def rename_variable(name):
"rename": rename_variable,
"translate_sig": _translate_signature,
"translate_function_signature": exporter.translate_function_signature,
"translate_function": exporter.translate_function,
"translate_opset_imports_of": exporter.translate_opset_imports_of,
"hasattr": hasattr,
"make_opset_name": exporter.make_opset_name,
Expand All @@ -564,6 +621,7 @@ def rename_variable(name):
for oimp in model_onnx.opset_import:
opsets[oimp.domain] = oimp.version
context["opsets"] = opsets
exporter.opsets = opsets

graph = model_onnx.graph if hasattr(model_onnx, "graph") else model_onnx

Expand Down
21 changes: 20 additions & 1 deletion onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def run_function(obj, *inputs):
return got


def extract_functions(name: str, content: str, test_folder: pathlib.Path):
def write_functions(name: str, content: str, test_folder: pathlib.Path):
"""Write the content into a file and import all OnnxFunctions from it."""
if not test_folder.exists():
test_folder.mkdir(exist_ok=True, parents=True)
Expand All @@ -112,6 +112,8 @@ def extract_functions(name: str, content: str, test_folder: pathlib.Path):
file = test_folder / f"{name}.py"
file.write_text(content, encoding="utf-8")

def extract_functions(name: str, content: str, test_folder: pathlib.Path):
write_functions(name, content, test_folder)
import_name = f"onnxscript.tests.{test_folder.parts[-1]}.{name}"
try:
mod = importlib.import_module(import_name)
Expand All @@ -137,6 +139,12 @@ class TestOnnxBackEnd(unittest.TestCase):
test_folder = root_folder / "tests" / "onnx_backend_test_code"
temp_folder = root_folder / "tests" / "export"

def _proto_to_os_and_back(self, proto: onnxscript.FunctionProto, **export_options):
"""Convert a proto to onnxscript code and convert it back to a proto."""
code = onnx_export.export2python(proto, **export_options)
map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder)
return map[proto.name]

def _round_trip_check(self, script_function, **export_options):
proto = script_function.to_function_proto()
code = onnx_export.export2python(proto, **export_options)
Expand All @@ -154,6 +162,17 @@ def fun_with_attr_param(X, dtype: int):

self._round_trip_check(fun_with_attr_param)

def test_double_attr_val_promotion(self):
op = onnxscript.opset17

@onnxscript.script()
def fun_with_double_attr_promotion(X, dtype: int):
Y = op.Add(X, dtype)
Z = op.Add(Y, dtype)
return Z

self._round_trip_check(fun_with_double_attr_promotion)

def test_qualified_domain(self):
"""Test use of qualified domain name."""
op = onnxscript.opset17
Expand Down
24 changes: 18 additions & 6 deletions onnxscript/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,12 @@ def _same_value_info(vi1, vi2):
)


def _same_attr(attr1, attr2, graph_equality):
def _same_attr(attr1, attr2, graph_equality, ref_equality):
# no name check; names used to match attributes already.
for field in ["type", "ref_attr_name", "f", "i", "s"]:
if not _same_optional("ref_attr_name", attr1, attr2, ref_equality):
return False

for field in ["type", "f", "i", "s"]:
if not _same_optional(field, attr1, attr2):
return False

Expand All @@ -146,15 +149,15 @@ def _same_attr(attr1, attr2, graph_equality):
return True


def _same_attrs(attrs1, attrs2, graph_equality):
def _same_attrs(attrs1, attrs2, graph_equality, ref_equality):
if len(attrs1) != len(attrs2):
return False
attrs1map = {a.name: a for a in attrs1}
for attr2 in attrs2:
if attr2.name not in attrs1map:
return False
attr1 = attrs1map[attr2.name]
if not _same_attr(attr1, attr2, graph_equality):
if not _same_attr(attr1, attr2, graph_equality, ref_equality):
return False
return True

Expand Down Expand Up @@ -188,6 +191,15 @@ def defmap(f):
self.node_mapping: dict[onnx.NodeProto, onnx.NodeProto] = {}
self.outer_scope = outer_scope

def same_ref_attr(self, ref_attr_name1, ref_attr_name2) -> bool:
def find(fun, name):
for i, attr_name in enumerate(fun.attribute):
if attr_name == name:
return i
# TODO: handle attribute_protos
return None
return find(self.fg1, ref_attr_name1) == find(self.fg2, ref_attr_name2)

def same_value(self, var1, var2):
"""Match two variables (strings)."""
if var1 not in self.defmap1 or var2 not in self.defmap2:
Expand Down Expand Up @@ -215,7 +227,7 @@ def same_node(self, n1, n2):
if node1.domain != node2.domain:
return False
# check attrs
if not _same_attrs(node1.attribute, node2.attribute, self.same_sub_graph):
if not _same_attrs(node1.attribute, node2.attribute, self.same_sub_graph, self.same_ref_attr):
return False
if not self.same_value_list(node1.input, node2.input):
return False
Expand Down Expand Up @@ -261,7 +273,7 @@ def same_function(self):

if len(self.fg1.input) != len(self.fg2.input):
return False
if set(self.fg1.attribute) != set(self.fg2.attribute):
if len(self.fg1.attribute) != len(self.fg2.attribute):
return False

# Opset imports must be same (but possibly in different order):
Expand Down

0 comments on commit 4fffceb

Please sign in to comment.