Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(deps): bump ruff from 0.8.6 to 0.9.1 in /requirements/lintrunner #2008

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions onnxscript/_internal/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ def get_src_and_ast(func: Callable, /) -> tuple[str, ast.FunctionDef]:
src = inspect.getsource(func)
except OSError as e:
raise RuntimeError(
f"Decorator script does not work on dynamically "
f"compiled function {func.__name__}."
f"Decorator script does not work on dynamically compiled function {func.__name__}."
) from e
src = textwrap.dedent(src)
top_level_ast = ast.parse(src)
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def extract_functions(name: str, content: str, test_folder: pathlib.Path):
filename = str(test_folder / f"{name}.py")
with open(filename, "w", encoding="utf-8") as f:
f.write(content + "\n")
assert os.path.exists(
filename
), f"{filename!r} ({os.path.abspath(filename)!r} does not exist."
assert os.path.exists(filename), (
f"{filename!r} ({os.path.abspath(filename)!r} does not exist."
)
import_name = f"tests.{test_folder.parts[-1]}.{name}"
try:
mod = importlib.import_module(import_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,9 @@ def register_outputs(
if isinstance(outputs, TorchScriptTensor):
outputs = (outputs,)
for output in outputs:
assert isinstance(
output, TorchScriptTensor
), f"output must be a TorchScriptTensor, not {type(output)}"
assert isinstance(output, TorchScriptTensor), (
f"output must be a TorchScriptTensor, not {type(output)}"
)
self._graph.outputs.append(output)

def _add_constant_to_graph(self, constant) -> Sequence[ir.Value | None]:
Expand Down Expand Up @@ -556,9 +556,9 @@ def _add_ir_graph_op_call(
# TODO(justinchuby): What is this case?
graph_inputs.append(input)
for key, value in onnx_attributes.items():
assert not isinstance(
value, TorchScriptTensor
), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
assert not isinstance(value, TorchScriptTensor), (
f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
)
tensors = _create_op_call_in_graph(
self._graph,
domain,
Expand Down Expand Up @@ -586,9 +586,9 @@ def _fetch_function_dict(
domain = sub_torch_script_graph.domain_name
assert domain is not None
name_domain = (sub_graph_name, domain, "")
assert (
name_domain not in function_dict
), f"Sub graph name already exists. {name_domain}"
assert name_domain not in function_dict, (
f"Sub graph name already exists. {name_domain}"
)
function_dict[name_domain] = sub_torch_script_graph._to_function( # pylint: disable=protected-access
opset_version, sub_graph_name
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,9 +689,9 @@ def register_outputs(
return
assert isinstance(unwrapped_outputs, Sequence)
for ts_output in unwrapped_outputs:
assert isinstance(
ts_output, torch.Value
), f"ts_output must be a torch.Value, not {type(ts_output)}"
assert isinstance(ts_output, torch.Value), (
f"ts_output must be a torch.Value, not {type(ts_output)}"
)
self._torch_graph.registerOutput(ts_output)
return

Expand Down Expand Up @@ -772,9 +772,9 @@ def _add_torchscript_op_call(
) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
graph_inputs = self.preprocess_inputs(onnx_inputs)
for key, value in onnx_attributes.items():
assert not isinstance(
value, TorchScriptTensor
), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
assert not isinstance(value, TorchScriptTensor), (
f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
)
result = _create_op_call_in_torch_graph(
self._torch_graph,
name,
Expand Down Expand Up @@ -816,9 +816,9 @@ def fetch_function_proto_dict(
sub_graph_name,
domain,
)
assert (
name_domain not in function_proto_dict
), f"Sub graph name already exists. {name_domain}"
assert name_domain not in function_proto_dict, (
f"Sub graph name already exists. {name_domain}"
)
function_proto_dict[name_domain] = sub_torch_script_graph.to_function_proto(
opset_version, sub_graph_name
)
Expand Down
12 changes: 6 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3048,9 +3048,9 @@ def aten_embedding_bag_padding_idx(
We add default values for the attributes to accommodate _embedding_bag as well:
_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
"""
assert (
padding_idx is not None
), "padding_idx must not be None. This is likely a dispatcher error"
assert padding_idx is not None, (
"padding_idx must not be None. This is likely a dispatcher error"
)

if per_sample_weights is None:
per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices))
Expand Down Expand Up @@ -4417,9 +4417,9 @@ def aten_instance_norm(
if use_input_stats:
return op.InstanceNormalization(input, weight, bias, epsilon=eps)

assert (
running_mean is not None and running_var is not None
), "running_mean and running_var must be provided when use_input_stats is False"
assert running_mean is not None and running_var is not None, (
"running_mean and running_var must be provided when use_input_stats is False"
)

batch_size = op.Shape(input, start=0, end=1)
bn_input = op.Reshape(
Expand Down
24 changes: 12 additions & 12 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,13 +1801,13 @@ def aten_scaled_dot_product_attention(
L is the target sequence length, S is the source sequence length, and E is the embedding size.
"""
# Use trace_only to handle optional inputs
assert (not is_causal) or (
is_causal and attn_mask is None
), "is_causal and attn_mask cannot be set at the same time"
assert (not is_causal) or (is_causal and attn_mask is None), (
"is_causal and attn_mask cannot be set at the same time"
)

assert (
not enable_gqa
), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
assert not enable_gqa, (
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
)

# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if scale is None:
Expand Down Expand Up @@ -2018,13 +2018,13 @@ def aten_scaled_dot_product_attention_bool_mask(
L is the target sequence length, S is the source sequence length, and E is the embedding size.
"""
# Use trace_only to handle optional inputs
assert (not is_causal) or (
is_causal and attn_mask is None
), "is_causal and attn_mask cannot be set at the same time"
assert (not is_causal) or (is_causal and attn_mask is None), (
"is_causal and attn_mask cannot be set at the same time"
)

assert (
not enable_gqa
), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
assert not enable_gqa, (
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
)

if scale is None:
scale = _attention_scale(query)
Expand Down
36 changes: 18 additions & 18 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,9 @@ def __init__(
def __array__(self, dtype: Any = None) -> np.ndarray:
if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
return self._raw.__array__(dtype)
assert _compatible_with_dlpack(
self._raw
), f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
assert _compatible_with_dlpack(self._raw), (
f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
)
return np.from_dlpack(self._raw)

def __dlpack__(self, *, stream: Any = None) -> Any:
Expand Down Expand Up @@ -765,9 +765,9 @@ def __init__(
def __array__(self, dtype: Any = None) -> np.ndarray:
if isinstance(self._raw, np.ndarray):
return self._raw
assert isinstance(
self._raw, Sequence
), f"Bug: Expected a sequence, got {type(self._raw)}"
assert isinstance(self._raw, Sequence), (
f"Bug: Expected a sequence, got {type(self._raw)}"
)
return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())

def __dlpack__(self, *, stream: Any = None) -> Any:
Expand Down Expand Up @@ -2228,11 +2228,11 @@ def _graph_str(graph: Graph | GraphView) -> str:
)
signature = f"""\
graph(
name={graph.name or 'anonymous_graph:' + str(id(graph))},
inputs=({textwrap.indent(inputs_text, ' ' * 8)}
name={graph.name or "anonymous_graph:" + str(id(graph))},
inputs=({textwrap.indent(inputs_text, " " * 8)}
),
outputs=({textwrap.indent(outputs_text, ' ' * 8)}
),{textwrap.indent(initializers_text, ' ' * 4)}
outputs=({textwrap.indent(outputs_text, " " * 8)}
),{textwrap.indent(initializers_text, " " * 4)}
)"""
node_count = len(graph)
number_width = len(str(node_count))
Expand Down Expand Up @@ -2266,11 +2266,11 @@ def _graph_repr(graph: Graph | GraphView) -> str:
)
return f"""\
{graph.__class__.__name__}(
name={graph.name or 'anonymous_graph:' + str(id(graph))!r},
inputs=({textwrap.indent(inputs_text, ' ' * 8)}
name={graph.name or "anonymous_graph:" + str(id(graph))!r},
inputs=({textwrap.indent(inputs_text, " " * 8)}
),
outputs=({textwrap.indent(outputs_text, ' ' * 8)}
),{textwrap.indent(initializers_text, ' ' * 4)}
outputs=({textwrap.indent(outputs_text, " " * 8)}
),{textwrap.indent(initializers_text, " " * 4)}
len()={len(graph)}
)"""

Expand Down Expand Up @@ -2484,7 +2484,7 @@ def __repr__(self) -> str:
domain={self.domain!r},
model_version={self.model_version!r},
functions={self.functions!r},
graph={textwrap.indent(repr(self.graph), ' ' * 4).strip()}
graph={textwrap.indent(repr(self.graph), " " * 4).strip()}
)"""


Expand Down Expand Up @@ -2684,10 +2684,10 @@ def __str__(self) -> str:
>
def {full_name}(
inputs=(
{textwrap.indent(inputs_text, ' ' * 8)}
),{textwrap.indent(attributes_text, ' ' * 4)}
{textwrap.indent(inputs_text, " " * 8)}
),{textwrap.indent(attributes_text, " " * 4)}
outputs=(
{textwrap.indent(outputs_text, ' ' * 8)}
{textwrap.indent(outputs_text, " " * 8)}
),
)"""
node_count = len(self)
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/ir/_linked_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def __reversed__(self) -> Iterator[T]:
box = box.prev

def __len__(self) -> int:
assert self._length == len(
self._value_ids_to_boxes
), "Bug in the implementation: length mismatch"
assert self._length == len(self._value_ids_to_boxes), (
"Bug in the implementation: length mismatch"
)
return self._length

def __getitem__(self, index: int) -> T:
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/ir/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,9 @@ def _get_allowed_types_from_type_annotation(
allowed_types = set()
subtypes = typing.get_args(type_)
for subtype in subtypes:
assert subtype is not type(
None
), "Union should not contain None type because it is handled by _is_optional."
assert subtype is not type(None), (
"Union should not contain None type because it is handled by _is_optional."
)
allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
return allowed_types

Expand Down
3 changes: 1 addition & 2 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,7 @@ def numpy(self) -> np.ndarray:
raise ValueError("Cannot convert UNDEFINED tensor to numpy array.")
if self._proto.data_location == onnx.TensorProto.EXTERNAL:
raise ValueError(
"Cannot convert external tensor to numpy array. "
"Use ir.ExternalTensor instead."
"Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead."
)

if self._proto.HasField("raw_data"):
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/rewriter/broadcast_to_matmul_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ def test_reshape_matmul_reshape_does_not_replace_when_output_sizes_do_not_match(
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float{input_x_shape} input_x, float{input_y_shape} input_y) => (float{output_shape} output)
{{
shape_a = Constant<value: tensor = int64[{len(shape_a)}] {{ {', '.join(str(i) for i in shape_a)} }}>()
shape_a = Constant<value: tensor = int64[{len(shape_a)}] {{ {", ".join(str(i) for i in shape_a)} }}>()
reshape_x = Reshape (input_x, shape_a)
shape_b = Constant<value: tensor = int64[{len(shape_b)}] {{ {', '.join(str(i) for i in shape_b)} }}>()
shape_b = Constant<value: tensor = int64[{len(shape_b)}] {{ {", ".join(str(i) for i in shape_b)} }}>()
reshape_y = Reshape (input_y, shape_b)
matmul = MatMul (reshape_x, reshape_y)
shape_c = Constant<value: tensor = int64[{len(shape_c)}] {{ {', '.join(str(i) for i in shape_c)} }}>()
shape_c = Constant<value: tensor = int64[{len(shape_c)}] {{ {", ".join(str(i) for i in shape_c)} }}>()
output = Reshape (matmul, shape_c)
}}
"""
Expand Down
33 changes: 16 additions & 17 deletions onnxscript/rewriter/generic_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,21 @@ def __init__(
self.matched_pattern_to_model_value: dict[orp.ValuePattern, ir.Value] = {}

for graph_node, pattern_node in zip(model_nodes, pattern_nodes):
assert (
graph_node.op_identifier() == pattern_node.op_identifier()
), f"Unexpected type mismatch {graph_node.op_identifier()!r} != {pattern_node.op_identifier()!r}"
assert len(graph_node.inputs) == len(
pattern_node.inputs
), f"Unexpected number of inputs for type {graph_node.op_identifier()}"
assert graph_node.op_identifier() == pattern_node.op_identifier(), (
f"Unexpected type mismatch {graph_node.op_identifier()!r} != {pattern_node.op_identifier()!r}"
)
assert len(graph_node.inputs) == len(pattern_node.inputs), (
f"Unexpected number of inputs for type {graph_node.op_identifier()}"
)
for a, b in zip(graph_node.inputs, pattern_node.inputs):
if b is None:
# optional input or not an interesting input
continue
self._bind(b, a)

assert len(graph_node.outputs) == len(
pattern_node.outputs
), f"Unexpected number of outputs for type {graph_node.op_identifier()}"
assert len(graph_node.outputs) == len(pattern_node.outputs), (
f"Unexpected number of outputs for type {graph_node.op_identifier()}"
)
for a, b in zip(graph_node.outputs, pattern_node.outputs):
self._bind(b, a)

Expand Down Expand Up @@ -494,8 +494,7 @@ def _match_values_forward(
# 1. make assumptions and continue
# 2. mark the node as incomplete matching, we could end up stuck anyway.
raise NotImplementedError(
f"There are more than one option, this will be implemented later, "
f"ec={ec}, gc={gc}"
f"There are more than one option, this will be implemented later, ec={ec}, gc={gc}"
)

def _match_forward(
Expand Down Expand Up @@ -620,9 +619,9 @@ def match(
return result

nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes
assert (
not nodes_not_in_pattern
), f"Some nodes are not part of the pattern: {nodes_not_in_pattern}"
assert not nodes_not_in_pattern, (
f"Some nodes are not part of the pattern: {nodes_not_in_pattern}"
)

result = self._match_forward(
node, matched, stack, next_graph_node, next_pattern_node
Expand All @@ -633,9 +632,9 @@ def match(
return result

nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes
assert (
not nodes_not_in_pattern
), f"Some nodes are not part of the pattern: {nodes_not_in_pattern}"
assert not nodes_not_in_pattern, (
f"Some nodes are not part of the pattern: {nodes_not_in_pattern}"
)

if self.verbose > 5:
self._debug["iteration"] = iteration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def infer_attn_size_config(self, function: ir.Function) -> AttnSizeConfig:
# Reference:
# https://github.com/huggingface/diffusers/blob/ae05050db9d37d5af48a6cd0d6510a5ffb1c1cd4/src/diffusers/models/attention_processor.py#L1269
reshape_nodes = [node for node in function if node.op_type == "Reshape"]
assert (
len(reshape_nodes) == 4
), "Expected 3 Reshape nodes for Q, K and V, and 1 reshape node for output of scaled_dot_product_attention."
assert len(reshape_nodes) == 4, (
"Expected 3 Reshape nodes for Q, K and V, and 1 reshape node for output of scaled_dot_product_attention."
)
for reshape_node in reshape_nodes:
constant_node = reshape_node.inputs[1].producer()
assert (
constant_node.op_type == "Constant"
), "Expected the second input to Reshape to be a Constant node."
assert constant_node.op_type == "Constant", (
"Expected the second input to Reshape to be a Constant node."
)
value = reshape_node.inputs[1]
constant_value = _ir_utils.get_const_value(value)
if constant_value is None:
Expand Down
Loading
Loading