Skip to content

Commit

Permalink
run lint
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Nov 18, 2023
1 parent b673b48 commit adccc47
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
15 changes: 13 additions & 2 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def _attribute_value(attr: onnx.AttributeProto):
# - onnx.AttributeProto.TYPE_PROTOS
raise NotImplementedError(f"Unable to return a value for attribute {attr!r}.")


def _update_names_used_in_graph(names: set[str], graph: GraphProto) -> None:
"""Returns the names used in a graph."""
names.update(x.name for x in graph.input)
Expand All @@ -198,6 +199,7 @@ def _update_names_used_in_graph(names: set[str], graph: GraphProto) -> None:
for node in graph.node:
_update_names_used_in_node(names, node)

Check warning on line 200 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L200

Added line #L200 was not covered by tests


def _update_names_used_in_node(names: set[str], node: onnx.NodeProto) -> None:
names.update(node.input)
names.update(node.output)
Expand All @@ -207,17 +209,20 @@ def _update_names_used_in_node(names: set[str], node: onnx.NodeProto) -> None:
for g in attr.graphs:
_update_names_used_in_graph(names, g)

Check warning on line 210 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L210

Added line #L210 was not covered by tests


def _update_names_used_in_function(names: set[str], fun: FunctionProto) -> None:
names.update(fun.input)
names.update(fun.output)
for node in fun.node:
_update_names_used_in_node(names, node)


def _names_used_in_function(fun: FunctionProto) -> set[str]:
names = set()

Check failure

Code scanning / lintrunner

MYPY/var-annotated Error

Need type annotation for "names" (hint: "names: Set[] = ...") To disable, use # type: ignore[var-annotated]
_update_names_used_in_function(names, fun)
return names


class Exporter:
"""Class used for recursive traversal of Proto structures."""

Expand All @@ -231,6 +236,7 @@ def __init__(self, rename_function, use_operators=False, inline_const=False) ->

def _handle_attrname_conflict(self, renamer):
"""Add ref-attr-name-conflict handling logic to renaming function."""

def new_renamer(name):
new_name = renamer(name)
if new_name not in self._attr_renaming:
Expand All @@ -246,7 +252,8 @@ def new_renamer(name):
counter += 1
self._attr_renaming[new_name] = candidate
self._names_used.add(candidate)
return candidate
return candidate

return new_renamer

def _rename_variable_s(self, name):
Expand All @@ -262,7 +269,9 @@ def make_opset_name(self, domain, version):
return f"{self._rename_domain(domain)}{version}"

def _python_make_node_name(self, domain, version, name, node=False):
name = _rename_variable(name) # TODO: Is this a typo? Is it supposed to be self._rename_variable(name)?
name = _rename_variable(
name
) # TODO: Is this a typo? Is it supposed to be self._rename_variable(name)?
if node:
if version is None:
version = 1
Expand Down Expand Up @@ -524,8 +533,10 @@ def translate_function(self, funproto: onnx.FunctionProto) -> str:
renamed_names_used = [self._rename_variable(x) for x in used_proto_names]
self._names_used = set(renamed_names_used)
result = []

def add_line(line: str) -> None:
result.append(line)

opset_name = self.make_opset_name(funproto.domain, 1)
add_line(f"@script({opset_name})")
fun_name = self._python_make_node_name(funproto.domain, 1, funproto.name)
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def fun_with_double_attr_promotion(X, dtype: int):
return Z

Check warning on line 168 in onnxscript/backend/onnx_export_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export_test.py#L166-L168

Added lines #L166 - L168 were not covered by tests

self._round_trip_check(fun_with_double_attr_promotion)

def test_qualified_domain(self):
"""Test use of qualified domain name."""
op = onnxscript.opset17
Expand Down

0 comments on commit adccc47

Please sign in to comment.