Skip to content

Commit

Permalink
Fix name conflict resolution
Browse files Browse the repository at this point in the history
Signed-off-by: Ganesan Ramalingam <[email protected]>
  • Loading branch information
gramalingam committed Nov 18, 2023
1 parent 4fffceb commit a127a41
Showing 1 changed file with 27 additions and 19 deletions.
46 changes: 27 additions & 19 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,28 +223,36 @@ class Exporter:

def __init__(self, rename_function, use_operators=False, inline_const=False) -> None:
self.use_operators = use_operators
self._rename_variable = rename_function
self._rename_variable = self._handle_attrname_conflict(rename_function)
self.inline_const = inline_const
self.constants: dict[str, str] = {}
self._attr_renaming: dict[str, str] = {} # For current function.
self._names_used: set[str] = set() # For current function.

def _handle_attrname_conflict(self, renamer):
"""Add ref-attr-name-conflict handling logic to renaming function."""
def new_renamer(name):
new_name = renamer(name)
if new_name not in self._attr_renaming:
return new_name
# Name conflicts with attribute parameter name.
alternate = self._attr_renaming[new_name]
if alternate is not None:
return alternate
counter = 0
candidate = new_name
while candidate in self._names_used:
candidate = f"{new_name}_{counter}"
counter += 1
self._attr_renaming[new_name] = candidate
self._names_used.add(candidate)
return candidate
return new_renamer

def _rename_variable_s(self, name):
"""Renames all names equal to a python keyword."""
return str(self._rename_variable(name))

def _rename_attr_parameter(self, name):
"""Renames an attribute parameter."""
if name in self._attr_renaming:
return self._attr_renaming[name]
counter = 0
candidate = name
while candidate in self._names_used:
candidate = f"{name}_attr_{counter}"
self._attr_renaming[name] = candidate
self._names_used.add(candidate)
return candidate

def _rename_domain(self, domain: str) -> str:
if domain == "":
return "opset"
Expand All @@ -254,7 +262,7 @@ def make_opset_name(self, domain, version):
return f"{self._rename_domain(domain)}{version}"

def _python_make_node_name(self, domain, version, name, node=False):
name = _rename_variable(name)
name = _rename_variable(name) # TODO: Is this a typo? Is it supposed to be self._rename_variable(name)?
if node:
if version is None:
version = 1
Expand Down Expand Up @@ -298,8 +306,7 @@ def _python_make_node_make_attribute_str(self, node):
attributes = []
for at in node.attribute:
if _is_attribute_ref(at):
ref_attr_name = self._attr_renaming[at.ref_attr_name]
attributes.append((at.name, ref_attr_name))
attributes.append((at.name, at.ref_attr_name))
continue
value = _attribute_value(at)
if isinstance(value, str):
Expand Down Expand Up @@ -494,11 +501,12 @@ def translate_function_signature(self, funproto: onnx.FunctionProto) -> str:
type_map = _attribute_param_types(funproto)

def attr_sig(attr_name: str) -> str:
name = self._rename_attr_parameter(attr_name)
self._attr_renaming[attr_name] = None
self._names_used.add(attr_name)
# A default type of INT is used for attribute parameters that are never used.
type = type_map.get(attr_name, onnx.AttributeProto.INT)
typerep = onnxscript.type_annotation.onnx_attr_type_to_onnxscript_repr(type)
return f"{name}: {typerep}"
return f"{attr_name}: {typerep}"

inputs = [self._rename_variable(x) for x in funproto.input]
attrs = [attr_sig(x) for x in funproto.attribute]
Expand Down Expand Up @@ -589,7 +597,7 @@ def rename_variable(name):
if var_name in variable_names:
return variable_names[var_name]
new_name = f"v{len(variable_names) + 1}"
assert var_name is not None
assert var_name is not None # TODO(rama): This looks suspect.
variable_names[var_name] = new_name
return new_name

Expand Down

0 comments on commit a127a41

Please sign in to comment.