Skip to content

Commit

Permalink
Enable additional ruff rules (#1311)
Browse files Browse the repository at this point in the history
Enable additional ruff rules including pylint rules and others to guard
code quality.
  • Loading branch information
justinchuby authored Mar 26, 2024
1 parent 1bfb4b1 commit 46c8751
Show file tree
Hide file tree
Showing 22 changed files with 61 additions and 52 deletions.
2 changes: 1 addition & 1 deletion docs/examples/01_plot_selu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def Selu(X, alpha: float, gamma: float):
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
return op.Where(zero >= X, neg, pos)


# %%
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from ._internal.utils import external_tensor
from .values import OnnxFunction, TracedOnnxFunction

try:
try: # noqa: SIM105
__version__ = importlib.metadata.version("onnxscript")
except importlib.metadata.PackageNotFoundError:
# package is not installed
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/_internal/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str:
if not isinstance(for_stmt.target, ast.Name):
raise ValueError(formatter(for_stmt, "For loop target must be a single variable."))
raise TypeError(formatter(for_stmt, "For loop target must be a single variable."))
return for_stmt.target.id


Expand Down
2 changes: 1 addition & 1 deletion onnxscript/backend/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _load(folder, names):
elif isinstance(new_tensor, onnx.TensorProto):
t = onnx.numpy_helper.to_array(new_tensor)
else:
raise RuntimeError( # pragma: no cover
raise RuntimeError( # noqa: TRY004
f"Unexpected type {type(new_tensor)!r} for {full!r}."
)
res.append(t)
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def test_graph_attr_loop_error(self):
from onnxscript.tests.models.graph_attr import sum_to_error

input = np.array(6, dtype=np.int64)
with self.assertRaisesRegex(ValueError, "@graph"):
with self.assertRaisesRegex(TypeError, "@graph"):
sum_to_error(input)

def test_loop_outer_scope(self):
Expand Down
21 changes: 10 additions & 11 deletions onnxscript/diagnostics/infra/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import dataclasses
import gzip
import logging
from typing import Callable, Generator, List, Literal, Mapping, Optional, TypeVar
import typing
from typing import Callable, Generator, List, Literal, Mapping, Optional

from onnxscript.diagnostics import infra
from onnxscript.diagnostics.infra import formatter, sarif, utils
from onnxscript.diagnostics.infra.sarif import version as sarif_version

# This is a workaround for mypy not supporting Self from typing_extensions.
_Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")
if typing.TYPE_CHECKING:
from typing_extensions import Self


@dataclasses.dataclass
Expand Down Expand Up @@ -70,37 +71,35 @@ def sarif(self) -> sarif.Result:
sarif_result.properties = sarif.PropertyBag(tags=[tag.value for tag in self.tags])
return sarif_result

def with_location(self: _Diagnostic, location: infra.Location) -> _Diagnostic:
def with_location(self: Self, location: infra.Location) -> Self:
"""Adds a location to the diagnostic."""
self.locations.append(location)
return self

def with_thread_flow_location(
self: _Diagnostic, location: infra.ThreadFlowLocation
) -> _Diagnostic:
def with_thread_flow_location(self: Self, location: infra.ThreadFlowLocation) -> Self:
"""Adds a thread flow location to the diagnostic."""
self.thread_flow_locations.append(location)
return self

def with_stack(self: _Diagnostic, stack: infra.Stack) -> _Diagnostic:
def with_stack(self: Self, stack: infra.Stack) -> Self:
"""Adds a stack to the diagnostic."""
self.stacks.append(stack)
return self

def with_graph(self: _Diagnostic, graph: infra.Graph) -> _Diagnostic:
def with_graph(self: Self, graph: infra.Graph) -> Self:
"""Adds a graph to the diagnostic."""
self.graphs.append(graph)
return self

def with_additional_message(self: _Diagnostic, message: str) -> _Diagnostic:
def with_additional_message(self: Self, message: str) -> Self:
"""Adds an additional message to the diagnostic."""
if self.additional_message is None:
self.additional_message = message
else:
self.additional_message = f"{self.additional_message}\n{message}"
return self

def with_source_exception(self: _Diagnostic, exception: Exception) -> _Diagnostic:
def with_source_exception(self: Self, exception: Exception) -> Self:
"""Adds the source exception to the diagnostic."""
self.source_exception = exception
return self
Expand Down
7 changes: 4 additions & 3 deletions onnxscript/diagnostics/infra/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def decorator(fn):
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
common_error_message = "diagnose_call can only be applied to callables"
if not callable(fn):
raise AssertionError(
raise AssertionError( # noqa: TRY004
f"{common_error_message}. Got {type(fn)} instead of callable."
)
arg0 = args[0] if len(args) > 0 else None
Expand All @@ -88,7 +88,7 @@ def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
else:
# NOTE: At decorate time, it can't tell if a callable is function or method.
# Technically both are regarded as function at that time.
raise AssertionError(
raise AssertionError( # noqa: TRY004
f"{common_error_message}. For {fn}, "
f"If it is a function, a DiagnosticContext instance must be present as "
f"the first argument. "
Expand Down Expand Up @@ -129,7 +129,6 @@ def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
additional_messages.append(
format_return_values_in_markdown(return_values, format_argument)
)
return return_values
except Exception as e: # pylint: disable=broad-exception-caught
# Record exception.
diag.level = infra.levels.ERROR
Expand All @@ -138,6 +137,8 @@ def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
diag.message += f"Raised from:\n {type(e).__name__}: {e}"
diag.with_source_exception(e)
additional_messages.append(format_exception_in_markdown(e))
else:
return return_values
finally:
diag.with_additional_message("\n".join(additional_messages).strip())
ctx.log_and_raise_if_error(diag)
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def adapt_attributes(
else:
adapted_attributes[k] = v.function
elif callable(v):
raise ValueError(
raise TypeError(
f"Error: function-valued attribute {v.__name__} has no graph_proto"
"attribute. Did you forget to decorate it with @graph?"
)
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_function(x, y: float = 1.0):
return op.Add(x, y)

x = np.array(0.0, dtype=np.float32)
with evaluator.default_as(evaluator.ORTEvaluator()):
with evaluator.default_as(evaluator.ORTEvaluator()): # noqa: SIM117
with self.assertRaises(TypeError):
_ = test_function(x, unknown=42) # pylint: disable=unexpected-keyword-arg

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __repr__(self):
)
repr_strs += [
f" {type_constraint.name}: {type_constraint.type_strs}"
for type_constraint in ordered_unique_type_constraints.keys()
for type_constraint in ordered_unique_type_constraints
if type_constraint is not None
]

Expand All @@ -180,7 +180,7 @@ def __repr__(self):
}
repr_strs += [
f" {type_constraint.name}: {type_constraint.type_strs}"
for type_constraint in ordered_unique_type_constraints.keys()
for type_constraint in ordered_unique_type_constraints
if type_constraint is not None
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,17 @@ def parse_default_value(arg: torchgen.model.Argument) -> Any:

try:
value = ast.literal_eval(default)
except ValueError:
# Treat it as a string.
return default.lower()
else:
if isinstance(value, int):
# Expand the value to a tuple if the type is a list.
if isinstance(arg.type, torchgen.model.ListType):
if arg.type.size is not None:
return (value,) * arg.type.size
return (value,)
return value
except ValueError:
# Treat it as a string.
return default.lower()


def create_return_type(returns: Sequence[torchgen.model.Return]) -> cg.TypeRef:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,17 @@ def parse_default_value(arg: torchgen.model.Argument) -> Any:

try:
value = ast.literal_eval(default)
except ValueError:
# Treat it as a string.
return default.lower()
else:
if isinstance(value, int):
# Expand the value to a tuple if the type is a list.
if isinstance(arg.type, torchgen.model.ListType):
if arg.type.size is not None:
return (value,) * arg.type.size
return (value,)
return value
except ValueError:
# Treat it as a string.
return default.lower()


def create_return_type(returns: Sequence[torchgen.model.Return]) -> cg.TypeRef:
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def eval_function( # type: ignore[override]
else:
# Fall to call add_function_call
pass
elif isinstance(args[0], Sequence):
elif isinstance(args[0], Sequence): # noqa: SIM103
return False
else:
# Python constants are scalars
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4025,7 +4025,7 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy
*range(
advanced_indexing_rank, starting_position_of_none_in_back
), # None_front_1...x_None_back_1
*range(0, advanced_indexing_rank), # 0...len(broadcasted_shape)
*range(advanced_indexing_rank), # 0...len(broadcasted_shape)
*range(
starting_position_of_none_in_back,
result_rank,
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class IRVar:

def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -> None:
if not isinstance(varname, str):
raise ValueError(f"varname must be a string not {type(varname)!r}.")
raise TypeError(f"varname must be a string not {type(varname)!r}.")
self.name = varname
self.info = sourceinfo
self.typeinfo = typeinfo
Expand Down
10 changes: 2 additions & 8 deletions onnxscript/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@ def _same_optional(field, obj1, obj2, equals=_default_equality_op):
def _same_repeated(values1, values2, equals=_default_equality_op):
if len(values1) != len(values2):
return False
for val1, val2 in zip(values1, values2):
if not equals(val1, val2):
return False
return True
return all(equals(val1, val2) for val1, val2 in zip(values1, values2))


def _same_string_string_map(proto1, proto2):
Expand Down Expand Up @@ -232,10 +229,7 @@ def same_value_list(self, list1, list2):
"""Match two lists of variables (either a string or ValueInfoProto)"""
if len(list1) != len(list2):
return False
for x, y in zip(list1, list2):
if not self.same_value(_ioname(x), _ioname(y)):
return False
return True
return all(self.same_value(_ioname(x), _ioname(y)) for x, y in zip(list1, list2))

def same_sub_graph(self, g1, g2):
"""Match two sub-graphs."""
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_all_script_functions_are_onnx_functions(self):
with self.subTest(name=info.op_info_name):
func = info.op
if not isinstance(func, onnxscript.OnnxFunction):
raise AssertionError(
raise TypeError(
f"'{func}' is not an OnnxFunction. Was it decorated with '@torch_op'? "
"If the function is trace_only, please specify trace_only=True "
"in the TorchLibOpInfo entry."
Expand All @@ -116,7 +116,7 @@ def test_all_trace_only_functions_are_not_onnx_functions(self):
with self.subTest(name=info.op_info_name):
func = info.op
if not isinstance(func, onnxscript.TracedOnnxFunction):
raise AssertionError(
raise TypeError(
f"'{func.name}' is not a TracedOnnxFunction. "
"If the function is not trace_only, please remove trace_only=True "
"in the TorchLibOpInfo entry."
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,9 @@ def normal_xfail_skip_test_behaviors(
yield
# We could use `except (AssertionError, RuntimeError, ...) as e:`, but it needs
# to go over all test cases to find the right exception type.
except Exception as e: # pylint: disable=broad-exception-caught
except Exception: # pylint: disable=broad-exception-caught
if test_behavior is None:
raise e
raise
if test_behavior == "xfail":
pytest.xfail(reason=reason)
else:
Expand Down
10 changes: 5 additions & 5 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def _where_input_wrangler(
"all",
core_ops.aten_all,
).skip(
matcher=lambda sample: not (len(sample.kwargs) == 0),
matcher=lambda sample: len(sample.kwargs) != 0,
reason="this Aten overload only support one tensor as input by design",
),
TorchLibOpInfo("abs", core_ops.aten_abs),
Expand Down Expand Up @@ -582,7 +582,7 @@ def _where_input_wrangler(
"any",
core_ops.aten_any,
).skip(
matcher=lambda sample: not (len(sample.kwargs) == 0),
matcher=lambda sample: len(sample.kwargs) != 0,
reason="this Aten overload only support one tensor as input by design",
),
TorchLibOpInfo(
Expand Down Expand Up @@ -852,15 +852,15 @@ def _where_input_wrangler(
"index_put_bool",
core_ops.aten_index_put_bool,
).skip(
matcher=lambda sample: not (sample.args[0][0].dtype == torch.bool),
matcher=lambda sample: sample.args[0][0].dtype != torch.bool,
reason="this Aten overload only supports tensor(bool) as indices",
),
TorchLibOpInfo(
"index_put",
core_ops.aten_index_put,
)
.skip(
matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64),
matcher=lambda sample: sample.args[0][0].dtype != torch.int64,
reason="this Aten overload only supports tensor(int) as indices",
)
.xfail(
Expand Down Expand Up @@ -1549,7 +1549,7 @@ def _where_input_wrangler(
"squeeze",
core_ops.aten_squeeze,
).skip(
matcher=lambda sample: not (len(sample.args) == 0),
matcher=lambda sample: len(sample.args) != 0,
reason="this Aten overload only support one tensor as input by design",
),
TorchLibOpInfo("stack", core_ops.aten_stack),
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def __getitem__(self, opname):
def __contains__(self, opname):
try:
onnx.defs.get_schema(opname, self.version, self.domain)
return True
except Exception: # pylint: disable=broad-except # TODO: more specific exception
return False
else:
return True

def __str__(self) -> str:
return self.domain
Expand Down
2 changes: 1 addition & 1 deletion opgen/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)
args = argparser.parse_args()

try:
try: # noqa: SIM105
shutil.rmtree(opsets_path)
except FileNotFoundError:
pass # if base_path doesn't exist, that's great
Expand Down
Loading

0 comments on commit 46c8751

Please sign in to comment.