diff --git a/onnxscript/_internal/ast_utils.py b/onnxscript/_internal/ast_utils.py index 104e82670..c7250e126 100644 --- a/onnxscript/_internal/ast_utils.py +++ b/onnxscript/_internal/ast_utils.py @@ -18,8 +18,7 @@ def get_src_and_ast(func: Callable, /) -> tuple[str, ast.FunctionDef]: src = inspect.getsource(func) except OSError as e: raise RuntimeError( - f"Decorator script does not work on dynamically " - f"compiled function {func.__name__}." + f"Decorator script does not work on dynamically compiled function {func.__name__}." ) from e src = textwrap.dedent(src) top_level_ast = ast.parse(src) diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index c1a2afbfb..1d05428a2 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -129,9 +129,9 @@ def extract_functions(name: str, content: str, test_folder: pathlib.Path): filename = str(test_folder / f"{name}.py") with open(filename, "w", encoding="utf-8") as f: f.write(content + "\n") - assert os.path.exists( - filename - ), f"{filename!r} ({os.path.abspath(filename)!r} does not exist." + assert os.path.exists(filename), ( + f"{filename!r} ({os.path.abspath(filename)!r} does not exist." + ) import_name = f"tests.{test_folder.parts[-1]}.{name}" try: mod = importlib.import_module(import_name) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py index 1270c6376..3915027aa 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py @@ -475,9 +475,9 @@ def register_outputs( if isinstance(outputs, TorchScriptTensor): outputs = (outputs,) for output in outputs: - assert isinstance( - output, TorchScriptTensor - ), f"output must be a TorchScriptTensor, not {type(output)}" + assert isinstance(output, TorchScriptTensor), ( + f"output must be a TorchScriptTensor, not {type(output)}" + ) self._graph.outputs.append(output) def _add_constant_to_graph(self, constant) -> Sequence[ir.Value | None]: @@ -556,9 +556,9 @@ def _add_ir_graph_op_call( # TODO(justinchuby): What is this case? graph_inputs.append(input) for key, value in onnx_attributes.items(): - assert not isinstance( - value, TorchScriptTensor - ), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}." + assert not isinstance(value, TorchScriptTensor), ( + f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}." + ) tensors = _create_op_call_in_graph( self._graph, domain, @@ -586,9 +586,9 @@ def _fetch_function_dict( domain = sub_torch_script_graph.domain_name assert domain is not None name_domain = (sub_graph_name, domain, "") - assert ( - name_domain not in function_dict - ), f"Sub graph name already exists. {name_domain}" + assert name_domain not in function_dict, ( + f"Sub graph name already exists. {name_domain}" + ) function_dict[name_domain] = sub_torch_script_graph._to_function( # pylint: disable=protected-access opset_version, sub_graph_name ) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index f59505ccc..8d0aab509 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -689,9 +689,9 @@ def register_outputs( return assert isinstance(unwrapped_outputs, Sequence) for ts_output in unwrapped_outputs: - assert isinstance( - ts_output, torch.Value - ), f"ts_output must be a torch.Value, not {type(ts_output)}" + assert isinstance(ts_output, torch.Value), ( + f"ts_output must be a torch.Value, not {type(ts_output)}" + ) self._torch_graph.registerOutput(ts_output) return @@ -772,9 +772,9 @@ def _add_torchscript_op_call( ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: graph_inputs = self.preprocess_inputs(onnx_inputs) for key, value in onnx_attributes.items(): - assert not isinstance( - value, TorchScriptTensor - ), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}." + assert not isinstance(value, TorchScriptTensor), ( + f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}." + ) result = _create_op_call_in_torch_graph( self._torch_graph, name, @@ -816,9 +816,9 @@ def fetch_function_proto_dict( sub_graph_name, domain, ) - assert ( - name_domain not in function_proto_dict - ), f"Sub graph name already exists. {name_domain}" + assert name_domain not in function_proto_dict, ( + f"Sub graph name already exists. {name_domain}" + ) function_proto_dict[name_domain] = sub_torch_script_graph.to_function_proto( opset_version, sub_graph_name ) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1145e9b13..a1793858e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3048,9 +3048,9 @@ def aten_embedding_bag_padding_idx( We add default values for the attributes to accommodate _embedding_bag as well: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) """ - assert ( - padding_idx is not None - ), "padding_idx must not be None. This is likely a dispatcher error" + assert padding_idx is not None, ( + "padding_idx must not be None. This is likely a dispatcher error" + ) if per_sample_weights is None: per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices)) @@ -4417,9 +4417,9 @@ def aten_instance_norm( if use_input_stats: return op.InstanceNormalization(input, weight, bias, epsilon=eps) - assert ( - running_mean is not None and running_var is not None - ), "running_mean and running_var must be provided when use_input_stats is False" + assert running_mean is not None and running_var is not None, ( + "running_mean and running_var must be provided when use_input_stats is False" + ) batch_size = op.Shape(input, start=0, end=1) bn_input = op.Reshape( diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 0f0b5d891..35c89acd4 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1801,13 +1801,13 @@ def aten_scaled_dot_product_attention( L is the target sequence length, S is the source sequence length, and E is the embedding size. """ # Use trace_only to handle optional inputs - assert (not is_causal) or ( - is_causal and attn_mask is None - ), "is_causal and attn_mask cannot be set at the same time" + assert (not is_causal) or (is_causal and attn_mask is None), ( + "is_causal and attn_mask cannot be set at the same time" + ) - assert ( - not enable_gqa - ), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + assert not enable_gqa, ( + "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + ) # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html if scale is None: @@ -2018,13 +2018,13 @@ def aten_scaled_dot_product_attention_bool_mask( L is the target sequence length, S is the source sequence length, and E is the embedding size. """ # Use trace_only to handle optional inputs - assert (not is_causal) or ( - is_causal and attn_mask is None - ), "is_causal and attn_mask cannot be set at the same time" + assert (not is_causal) or (is_causal and attn_mask is None), ( + "is_causal and attn_mask cannot be set at the same time" + ) - assert ( - not enable_gqa - ), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + assert not enable_gqa, ( + "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + ) if scale is None: scale = _attention_scale(query) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 519221509..faffde748 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -388,9 +388,9 @@ def __init__( def __array__(self, dtype: Any = None) -> np.ndarray: if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw): return self._raw.__array__(dtype) - assert _compatible_with_dlpack( - self._raw - ), f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}" + assert _compatible_with_dlpack(self._raw), ( + f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}" + ) return np.from_dlpack(self._raw) def __dlpack__(self, *, stream: Any = None) -> Any: @@ -765,9 +765,9 @@ def __init__( def __array__(self, dtype: Any = None) -> np.ndarray: if isinstance(self._raw, np.ndarray): return self._raw - assert isinstance( - self._raw, Sequence - ), f"Bug: Expected a sequence, got {type(self._raw)}" + assert isinstance(self._raw, Sequence), ( + f"Bug: Expected a sequence, got {type(self._raw)}" + ) return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy()) def __dlpack__(self, *, stream: Any = None) -> Any: @@ -2228,11 +2228,11 @@ def _graph_str(graph: Graph | GraphView) -> str: ) signature = f"""\ graph( - name={graph.name or 'anonymous_graph:' + str(id(graph))}, - inputs=({textwrap.indent(inputs_text, ' ' * 8)} + name={graph.name or "anonymous_graph:" + str(id(graph))}, + inputs=({textwrap.indent(inputs_text, " " * 8)} ), - outputs=({textwrap.indent(outputs_text, ' ' * 8)} - ),{textwrap.indent(initializers_text, ' ' * 4)} + outputs=({textwrap.indent(outputs_text, " " * 8)} + ),{textwrap.indent(initializers_text, " " * 4)} )""" node_count = len(graph) number_width = len(str(node_count)) @@ -2266,11 +2266,11 @@ def _graph_repr(graph: Graph | GraphView) -> str: ) return f"""\ {graph.__class__.__name__}( - name={graph.name or 'anonymous_graph:' + str(id(graph))!r}, - inputs=({textwrap.indent(inputs_text, ' ' * 8)} + name={graph.name or "anonymous_graph:" + str(id(graph))!r}, + inputs=({textwrap.indent(inputs_text, " " * 8)} ), - outputs=({textwrap.indent(outputs_text, ' ' * 8)} - ),{textwrap.indent(initializers_text, ' ' * 4)} + outputs=({textwrap.indent(outputs_text, " " * 8)} + ),{textwrap.indent(initializers_text, " " * 4)} len()={len(graph)} )""" @@ -2484,7 +2484,7 @@ def __repr__(self) -> str: domain={self.domain!r}, model_version={self.model_version!r}, functions={self.functions!r}, - graph={textwrap.indent(repr(self.graph), ' ' * 4).strip()} + graph={textwrap.indent(repr(self.graph), " " * 4).strip()} )""" @@ -2684,10 +2684,10 @@ def __str__(self) -> str: > def {full_name}( inputs=( -{textwrap.indent(inputs_text, ' ' * 8)} - ),{textwrap.indent(attributes_text, ' ' * 4)} +{textwrap.indent(inputs_text, " " * 8)} + ),{textwrap.indent(attributes_text, " " * 4)} outputs=( -{textwrap.indent(outputs_text, ' ' * 8)} +{textwrap.indent(outputs_text, " " * 8)} ), )""" node_count = len(self) diff --git a/onnxscript/ir/_linked_list.py b/onnxscript/ir/_linked_list.py index 2c12ad856..0db770e20 100644 --- a/onnxscript/ir/_linked_list.py +++ b/onnxscript/ir/_linked_list.py @@ -131,9 +131,9 @@ def __reversed__(self) -> Iterator[T]: box = box.prev def __len__(self) -> int: - assert self._length == len( - self._value_ids_to_boxes - ), "Bug in the implementation: length mismatch" + assert self._length == len(self._value_ids_to_boxes), ( + "Bug in the implementation: length mismatch" + ) return self._length def __getitem__(self, index: int) -> T: diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index 3422a0c28..d4d88ab5b 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -301,9 +301,9 @@ def _get_allowed_types_from_type_annotation( allowed_types = set() subtypes = typing.get_args(type_) for subtype in subtypes: - assert subtype is not type( - None - ), "Union should not contain None type because it is handled by _is_optional." + assert subtype is not type(None), ( + "Union should not contain None type because it is handled by _is_optional." + ) allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) return allowed_types diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 432af8cf1..b333df823 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -320,8 +320,7 @@ def numpy(self) -> np.ndarray: raise ValueError("Cannot convert UNDEFINED tensor to numpy array.") if self._proto.data_location == onnx.TensorProto.EXTERNAL: raise ValueError( - "Cannot convert external tensor to numpy array. " - "Use ir.ExternalTensor instead." + "Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead." ) if self._proto.HasField("raw_data"): diff --git a/onnxscript/rewriter/broadcast_to_matmul_test.py b/onnxscript/rewriter/broadcast_to_matmul_test.py index 49c97d2c7..c2f3b31f9 100644 --- a/onnxscript/rewriter/broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/broadcast_to_matmul_test.py @@ -97,12 +97,12 @@ def test_reshape_matmul_reshape_does_not_replace_when_output_sizes_do_not_match( agraph (float{input_x_shape} input_x, float{input_y_shape} input_y) => (float{output_shape} output) {{ - shape_a = Constant() + shape_a = Constant() reshape_x = Reshape (input_x, shape_a) - shape_b = Constant() + shape_b = Constant() reshape_y = Reshape (input_y, shape_b) matmul = MatMul (reshape_x, reshape_y) - shape_c = Constant() + shape_c = Constant() output = Reshape (matmul, shape_c) }} """ diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index de06d7a22..563e88f2d 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -36,21 +36,21 @@ def __init__( self.matched_pattern_to_model_value: dict[orp.ValuePattern, ir.Value] = {} for graph_node, pattern_node in zip(model_nodes, pattern_nodes): - assert ( - graph_node.op_identifier() == pattern_node.op_identifier() - ), f"Unexpected type mismatch {graph_node.op_identifier()!r} != {pattern_node.op_identifier()!r}" - assert len(graph_node.inputs) == len( - pattern_node.inputs - ), f"Unexpected number of inputs for type {graph_node.op_identifier()}" + assert graph_node.op_identifier() == pattern_node.op_identifier(), ( + f"Unexpected type mismatch {graph_node.op_identifier()!r} != {pattern_node.op_identifier()!r}" + ) + assert len(graph_node.inputs) == len(pattern_node.inputs), ( + f"Unexpected number of inputs for type {graph_node.op_identifier()}" + ) for a, b in zip(graph_node.inputs, pattern_node.inputs): if b is None: # optional input or not an interesting input continue self._bind(b, a) - assert len(graph_node.outputs) == len( - pattern_node.outputs - ), f"Unexpected number of outputs for type {graph_node.op_identifier()}" + assert len(graph_node.outputs) == len(pattern_node.outputs), ( + f"Unexpected number of outputs for type {graph_node.op_identifier()}" + ) for a, b in zip(graph_node.outputs, pattern_node.outputs): self._bind(b, a) @@ -494,8 +494,7 @@ def _match_values_forward( # 1. make assumptions and continue # 2. mark the node as incomplete matching, we could end up stuck anyway. raise NotImplementedError( - f"There are more than one option, this will be implemented later, " - f"ec={ec}, gc={gc}" + f"There are more than one option, this will be implemented later, ec={ec}, gc={gc}" ) def _match_forward( @@ -620,9 +619,9 @@ def match( return result nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes - assert ( - not nodes_not_in_pattern - ), f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" + assert not nodes_not_in_pattern, ( + f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" + ) result = self._match_forward( node, matched, stack, next_graph_node, next_pattern_node @@ -633,9 +632,9 @@ def match( return result nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes - assert ( - not nodes_not_in_pattern - ), f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" + assert not nodes_not_in_pattern, ( + f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" + ) if self.verbose > 5: self._debug["iteration"] = iteration diff --git a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py index 7fff108f6..b6c6f0a96 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py +++ b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py @@ -104,14 +104,14 @@ def infer_attn_size_config(self, function: ir.Function) -> AttnSizeConfig: # Reference: # https://github.com/huggingface/diffusers/blob/ae05050db9d37d5af48a6cd0d6510a5ffb1c1cd4/src/diffusers/models/attention_processor.py#L1269 reshape_nodes = [node for node in function if node.op_type == "Reshape"] - assert ( - len(reshape_nodes) == 4 - ), "Expected 3 Reshape nodes for Q, K and V, and 1 reshape node for output of scaled_dot_product_attention." + assert len(reshape_nodes) == 4, ( + "Expected 3 Reshape nodes for Q, K and V, and 1 reshape node for output of scaled_dot_product_attention." + ) for reshape_node in reshape_nodes: constant_node = reshape_node.inputs[1].producer() - assert ( - constant_node.op_type == "Constant" - ), "Expected the second input to Reshape to be a Constant node." + assert constant_node.op_type == "Constant", ( + "Expected the second input to Reshape to be a Constant node." + ) value = reshape_node.inputs[1] constant_value = _ir_utils.get_const_value(value) if constant_value is None: diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 333cb489d..84ac42beb 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -575,9 +575,9 @@ def matches(self, node: ir.Node, match: MatchResult) -> MatchResult: def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePattern: inputs = [(v.clone(node_map) if v is not None else None) for v in self.inputs] if swap: - assert ( - len(inputs) == 2 - ), "Internal error: commutative swap applies only to binary ops." + assert len(inputs) == 2, ( + "Internal error: commutative swap applies only to binary ops." + ) inputs = [inputs[1], inputs[0]] outputs = [value.name for value in self.outputs] copied = NodePattern( diff --git a/onnxscript/tools/benchmark/benchmark_helpers.py b/onnxscript/tools/benchmark/benchmark_helpers.py index f9a46c8f5..b9101d5ec 100644 --- a/onnxscript/tools/benchmark/benchmark_helpers.py +++ b/onnxscript/tools/benchmark/benchmark_helpers.py @@ -224,16 +224,16 @@ def _flatten(outputs): rel_errs = [] for torch_outputs_mixed_types, onnx_outputs in zip(expected, outputs): torch_outputs = _flatten(torch_outputs_mixed_types) - assert len(torch_outputs) == len( - onnx_outputs - ), f"Length mismatch {len(torch_outputs)} != {len(onnx_outputs)}" + assert len(torch_outputs) == len(onnx_outputs), ( + f"Length mismatch {len(torch_outputs)} != {len(onnx_outputs)}" + ) for torch_tensor, onnx_tensor in zip(torch_outputs, onnx_outputs): - assert ( - torch_tensor.dtype == onnx_tensor.dtype - ), f"Type mismatch {torch_tensor.dtype} != {onnx_tensor.dtype}" - assert ( - torch_tensor.shape == onnx_tensor.shape - ), f"Type mismatch {torch_tensor.shape} != {onnx_tensor.shape}" + assert torch_tensor.dtype == onnx_tensor.dtype, ( + f"Type mismatch {torch_tensor.dtype} != {onnx_tensor.dtype}" + ) + assert torch_tensor.shape == onnx_tensor.shape, ( + f"Type mismatch {torch_tensor.shape} != {onnx_tensor.shape}" + ) diff = torch_tensor - onnx_tensor abs_err = float(diff.abs().max()) rel_err = float((diff.abs() / torch_tensor).max()) @@ -295,9 +295,9 @@ def common_export( dynamic_axes=dynamic_shapes, ) elif exporter == "dynamo": - assert ( - dynamic_shapes is None - ), f"dynamic_shapes={dynamic_shapes} is not implemented yet" + assert dynamic_shapes is None, ( + f"dynamic_shapes={dynamic_shapes} is not implemented yet" + ) with torch.no_grad(): prog = torch.onnx.dynamo_export(model, *inputs) onnx.save(prog.model_proto, filename) diff --git a/onnxscript/tools/benchmark/export_model_batch.py b/onnxscript/tools/benchmark/export_model_batch.py index ffef9cbd4..8dff49e0c 100644 --- a/onnxscript/tools/benchmark/export_model_batch.py +++ b/onnxscript/tools/benchmark/export_model_batch.py @@ -73,7 +73,7 @@ def main(args: list[str] | None = None): if kwargs["verbose"]: for i, cf in enumerate(configs): - print(f"[export_common_batch] config {i+1}: {cf}") + print(f"[export_common_batch] config {i + 1}: {cf}") ################################ # Running configuration. diff --git a/opgen/onnx_opset_builder.py b/opgen/onnx_opset_builder.py index 01c7f3bc2..fdf7f76bb 100644 --- a/opgen/onnx_opset_builder.py +++ b/opgen/onnx_opset_builder.py @@ -60,8 +60,7 @@ def __init__(self, domain: str, name: str, version: int): def __repr__(self) -> str: return ( - f"QualOpName(domain={self.domain!r}, " - f"version={self.version!r}, name={self.name!r})" + f"QualOpName(domain={self.domain!r}, version={self.version!r}, name={self.name!r})" ) def __str__(self) -> str: diff --git a/pyproject.toml b/pyproject.toml index 61128ac9e..ff873319f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -203,6 +203,7 @@ ignore = [ "TRY003", # Messages can be constructed in the exception "UP006", # keep-runtime-typing "UP007", # keep-runtime-typing + "UP045", # TODO: Support new style type annotations ] ignore-init-module-imports = true diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index 3ea357152..d045e2036 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.8.6 +ruff==0.9.1 # MYPY mypy==1.10.1 types-PyYAML==6.0.12.20241230 diff --git a/tests/function_libs/torch_lib/error_reproduction.py b/tests/function_libs/torch_lib/error_reproduction.py index 141946c56..1eac88c48 100644 --- a/tests/function_libs/torch_lib/error_reproduction.py +++ b/tests/function_libs/torch_lib/error_reproduction.py @@ -200,7 +200,7 @@ def create_reproduction_report( ) # Turn test name into a valid file name - markdown_file_name = f'{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md' + markdown_file_name = f"{short_test_name.replace('/', '-').replace(':', '-')}-{str(time.time()).replace('.', '_')}.md" markdown_file_path = save_error_report(markdown_file_name, markdown) print(f"Created reproduction report at {markdown_file_path}") @@ -247,7 +247,7 @@ def create_mismatch_report( error_stack=error_stack, ) - markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md' + markdown_file_name = f"mismatch-{short_test_name.replace('/', '-').replace(':', '-')}-{str(time.time()).replace('.', '_')}.md" markdown_file_path = save_error_report(markdown_file_name, markdown) print(f"Created reproduction report at {markdown_file_path}") diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index 3a9717cc3..e440a5b14 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -177,9 +177,9 @@ def add_decorate_info( # If the OpInfo doesn't exist and it is not enabled, we skip the OpInfo # because it could be an OpInfo that is in torch-nightly but not older versions. continue - assert ( - opinfo is not None - ), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" + assert opinfo is not None, ( + f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" + ) decorators = list(opinfo.decorators) new_decorator = opinfo_core.DecorateInfo( decorate_meta.decorator, @@ -370,12 +370,7 @@ def _safe_ort_session_run(serialized_model: bytes, ort_inputs: Mapping[str, Any] def _format_model_and_input_information(onnx_model, inputs): - return ( - f"Inputs:\n" - f"{pprint.pformat(inputs)}\n" - f"Model:\n" - f"{onnx.printer.to_text(onnx_model)}" - ) + return f"Inputs:\n{pprint.pformat(inputs)}\nModel:\n{onnx.printer.to_text(onnx_model)}" TORCH_DTYPE_TO_ONNX_STRING = { diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index bebd9a8ab..8422ab730 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2370,6 +2370,6 @@ def _where_input_wrangler( ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) # Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB" -assert NONDETERMINISTIC_OPS.issubset( - TESTED_OPS -), f"{NONDETERMINISTIC_OPS - TESTED_OPS} not in TESTED_OPS" +assert NONDETERMINISTIC_OPS.issubset(TESTED_OPS), ( + f"{NONDETERMINISTIC_OPS - TESTED_OPS} not in TESTED_OPS" +) diff --git a/tools/diagnostics/gen_diagnostics.py b/tools/diagnostics/gen_diagnostics.py index d54449df4..cf0f0f35b 100644 --- a/tools/diagnostics/gen_diagnostics.py +++ b/tools/diagnostics/gen_diagnostics.py @@ -101,9 +101,9 @@ def _format_rule_for_python_class(rule: _RuleType) -> str: if field_name is not None ] for field_name in field_names: - assert isinstance( - field_name, str - ), f"Unexpected field type {type(field_name)} from {field_name}. " + assert isinstance(field_name, str), ( + f"Unexpected field type {type(field_name)} from {field_name}. " + ) "Field name must be string.\nFull message template: {message_template}" # pylint: disable=pointless-string-statement assert not field_name.isnumeric(), f"Unexpected numeric field name {field_name}. " "Only keyword name formatting is supported.\nFull message template: {message_template}" # pylint: disable=pointless-string-statement diff --git a/tools/function_rewriter_testing/function_unittest_producer.py b/tools/function_rewriter_testing/function_unittest_producer.py index b2d484531..d8c51c694 100644 --- a/tools/function_rewriter_testing/function_unittest_producer.py +++ b/tools/function_rewriter_testing/function_unittest_producer.py @@ -336,9 +336,9 @@ def visit_model(self, model: onnx.ModelProto): tmp_model_path, providers=["CUDAExecutionProvider"] ) outputs = sess.run(fetch_outputs, inputs) - assert ( - len(outputs) == len(fetch_outputs) - ), f"Number of outputs mismatch. outputs: {len(outputs)}, fetch_outputs: {len(fetch_outputs)}" + assert len(outputs) == len(fetch_outputs), ( + f"Number of outputs mismatch. outputs: {len(outputs)}, fetch_outputs: {len(fetch_outputs)}" + ) self._named_values = dict(zip(fetch_outputs, outputs)) # type: ignore[arg-type] for inputs, outputs in target_function_meta.values():