From a127a4136249c79e14bc200c41dd48b2731f937e Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 17 Nov 2023 16:28:36 -0800 Subject: [PATCH] Fix name conflict resolution Signed-off-by: Ganesan Ramalingam --- onnxscript/backend/onnx_export.py | 46 ++++++++++++++++++------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index fb102f745..32ae524d5 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -223,28 +223,36 @@ class Exporter: def __init__(self, rename_function, use_operators=False, inline_const=False) -> None: self.use_operators = use_operators - self._rename_variable = rename_function + self._rename_variable = self._handle_attrname_conflict(rename_function) self.inline_const = inline_const self.constants: dict[str, str] = {} self._attr_renaming: dict[str, str] = {} # For current function. self._names_used: set[str] = set() # For current function. + def _handle_attrname_conflict(self, renamer): + """Add ref-attr-name-conflict handling logic to renaming function.""" + def new_renamer(name): + new_name = renamer(name) + if new_name not in self._attr_renaming: + return new_name + # Name conflicts with attribute parameter name. + alternate = self._attr_renaming[new_name] + if alternate is not None: + return alternate + counter = 0 + candidate = new_name + while candidate in self._names_used: + candidate = f"{new_name}_{counter}" + counter += 1 + self._attr_renaming[new_name] = candidate + self._names_used.add(candidate) + return candidate + return new_renamer + def _rename_variable_s(self, name): """Renames all names equal to a python keyword.""" return str(self._rename_variable(name)) - def _rename_attr_parameter(self, name): - """Renames an attribute parameter.""" - if name in self._attr_renaming: - return self._attr_renaming[name] - counter = 0 - candidate = name - while candidate in self._names_used: - candidate = f"{name}_attr_{counter}" - self._attr_renaming[name] = candidate - self._names_used.add(candidate) - return candidate - def _rename_domain(self, domain: str) -> str: if domain == "": return "opset" @@ -254,7 +262,7 @@ def make_opset_name(self, domain, version): return f"{self._rename_domain(domain)}{version}" def _python_make_node_name(self, domain, version, name, node=False): - name = _rename_variable(name) + name = _rename_variable(name) # TODO: Is this a typo? Is it supposed to be self._rename_variable(name)? if node: if version is None: version = 1 @@ -298,8 +306,7 @@ def _python_make_node_make_attribute_str(self, node): attributes = [] for at in node.attribute: if _is_attribute_ref(at): - ref_attr_name = self._attr_renaming[at.ref_attr_name] - attributes.append((at.name, ref_attr_name)) + attributes.append((at.name, at.ref_attr_name)) continue value = _attribute_value(at) if isinstance(value, str): @@ -494,11 +501,12 @@ def translate_function_signature(self, funproto: onnx.FunctionProto) -> str: type_map = _attribute_param_types(funproto) def attr_sig(attr_name: str) -> str: - name = self._rename_attr_parameter(attr_name) + self._attr_renaming[attr_name] = None + self._names_used.add(attr_name) # A default type of INT is used for attribute parameters that are never used. type = type_map.get(attr_name, onnx.AttributeProto.INT) typerep = onnxscript.type_annotation.onnx_attr_type_to_onnxscript_repr(type) - return f"{name}: {typerep}" + return f"{attr_name}: {typerep}" inputs = [self._rename_variable(x) for x in funproto.input] attrs = [attr_sig(x) for x in funproto.attribute] @@ -589,7 +597,7 @@ def rename_variable(name): if var_name in variable_names: return variable_names[var_name] new_name = f"v{len(variable_names) + 1}" - assert var_name is not None + assert var_name is not None # TODO(rama): This looks suspect. variable_names[var_name] = new_name return new_name