Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jan 14, 2025
1 parent 57783f6 commit 4cac988
Show file tree
Hide file tree
Showing 23 changed files with 123 additions and 131 deletions.
3 changes: 1 addition & 2 deletions onnxscript/_internal/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ def get_src_and_ast(func: Callable, /) -> tuple[str, ast.FunctionDef]:
src = inspect.getsource(func)
except OSError as e:
raise RuntimeError(
f"Decorator script does not work on dynamically "
f"compiled function {func.__name__}."
f"Decorator script does not work on dynamically compiled function {func.__name__}."
) from e
src = textwrap.dedent(src)
top_level_ast = ast.parse(src)
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def extract_functions(name: str, content: str, test_folder: pathlib.Path):
filename = str(test_folder / f"{name}.py")
with open(filename, "w", encoding="utf-8") as f:
f.write(content + "\n")
assert os.path.exists(
filename
), f"{filename!r} ({os.path.abspath(filename)!r} does not exist."
assert os.path.exists(filename), (
f"{filename!r} ({os.path.abspath(filename)!r} does not exist."
)
import_name = f"tests.{test_folder.parts[-1]}.{name}"
try:
mod = importlib.import_module(import_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,9 @@ def register_outputs(
if isinstance(outputs, TorchScriptTensor):
outputs = (outputs,)
for output in outputs:
assert isinstance(
output, TorchScriptTensor
), f"output must be a TorchScriptTensor, not {type(output)}"
assert isinstance(output, TorchScriptTensor), (
f"output must be a TorchScriptTensor, not {type(output)}"
)
self._graph.outputs.append(output)

def _add_constant_to_graph(self, constant) -> Sequence[ir.Value | None]:
Expand Down Expand Up @@ -556,9 +556,9 @@ def _add_ir_graph_op_call(
# TODO(justinchuby): What is this case?
graph_inputs.append(input)
for key, value in onnx_attributes.items():
assert not isinstance(
value, TorchScriptTensor
), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
assert not isinstance(value, TorchScriptTensor), (
f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
)
tensors = _create_op_call_in_graph(
self._graph,
domain,
Expand Down Expand Up @@ -586,9 +586,9 @@ def _fetch_function_dict(
domain = sub_torch_script_graph.domain_name
assert domain is not None
name_domain = (sub_graph_name, domain, "")
assert (
name_domain not in function_dict
), f"Sub graph name already exists. {name_domain}"
assert name_domain not in function_dict, (
f"Sub graph name already exists. {name_domain}"
)
function_dict[name_domain] = sub_torch_script_graph._to_function( # pylint: disable=protected-access
opset_version, sub_graph_name
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,9 +689,9 @@ def register_outputs(
return
assert isinstance(unwrapped_outputs, Sequence)
for ts_output in unwrapped_outputs:
assert isinstance(
ts_output, torch.Value
), f"ts_output must be a torch.Value, not {type(ts_output)}"
assert isinstance(ts_output, torch.Value), (
f"ts_output must be a torch.Value, not {type(ts_output)}"
)
self._torch_graph.registerOutput(ts_output)
return

Expand Down Expand Up @@ -772,9 +772,9 @@ def _add_torchscript_op_call(
) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
graph_inputs = self.preprocess_inputs(onnx_inputs)
for key, value in onnx_attributes.items():
assert not isinstance(
value, TorchScriptTensor
), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
assert not isinstance(value, TorchScriptTensor), (
f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
)
result = _create_op_call_in_torch_graph(
self._torch_graph,
name,
Expand Down Expand Up @@ -816,9 +816,9 @@ def fetch_function_proto_dict(
sub_graph_name,
domain,
)
assert (
name_domain not in function_proto_dict
), f"Sub graph name already exists. {name_domain}"
assert name_domain not in function_proto_dict, (
f"Sub graph name already exists. {name_domain}"
)
function_proto_dict[name_domain] = sub_torch_script_graph.to_function_proto(
opset_version, sub_graph_name
)
Expand Down
12 changes: 6 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3048,9 +3048,9 @@ def aten_embedding_bag_padding_idx(
We add default values for the attributes to accommodate _embedding_bag as well:
_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
"""
assert (
padding_idx is not None
), "padding_idx must not be None. This is likely a dispatcher error"
assert padding_idx is not None, (
"padding_idx must not be None. This is likely a dispatcher error"
)

if per_sample_weights is None:
per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices))
Expand Down Expand Up @@ -4417,9 +4417,9 @@ def aten_instance_norm(
if use_input_stats:
return op.InstanceNormalization(input, weight, bias, epsilon=eps)

assert (
running_mean is not None and running_var is not None
), "running_mean and running_var must be provided when use_input_stats is False"
assert running_mean is not None and running_var is not None, (
"running_mean and running_var must be provided when use_input_stats is False"
)

batch_size = op.Shape(input, start=0, end=1)
bn_input = op.Reshape(
Expand Down
24 changes: 12 additions & 12 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,13 +1801,13 @@ def aten_scaled_dot_product_attention(
L is the target sequence length, S is the source sequence length, and E is the embedding size.
"""
# Use trace_only to handle optional inputs
assert (not is_causal) or (
is_causal and attn_mask is None
), "is_causal and attn_mask cannot be set at the same time"
assert (not is_causal) or (is_causal and attn_mask is None), (
"is_causal and attn_mask cannot be set at the same time"
)

assert (
not enable_gqa
), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
assert not enable_gqa, (
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
)

# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if scale is None:
Expand Down Expand Up @@ -2018,13 +2018,13 @@ def aten_scaled_dot_product_attention_bool_mask(
L is the target sequence length, S is the source sequence length, and E is the embedding size.
"""
# Use trace_only to handle optional inputs
assert (not is_causal) or (
is_causal and attn_mask is None
), "is_causal and attn_mask cannot be set at the same time"
assert (not is_causal) or (is_causal and attn_mask is None), (
"is_causal and attn_mask cannot be set at the same time"
)

assert (
not enable_gqa
), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
assert not enable_gqa, (
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
)

if scale is None:
scale = _attention_scale(query)
Expand Down
36 changes: 18 additions & 18 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,9 @@ def __init__(
def __array__(self, dtype: Any = None) -> np.ndarray:
if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
return self._raw.__array__(dtype)
assert _compatible_with_dlpack(
self._raw
), f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
assert _compatible_with_dlpack(self._raw), (
f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
)
return np.from_dlpack(self._raw)

def __dlpack__(self, *, stream: Any = None) -> Any:
Expand Down Expand Up @@ -765,9 +765,9 @@ def __init__(
def __array__(self, dtype: Any = None) -> np.ndarray:
if isinstance(self._raw, np.ndarray):
return self._raw
assert isinstance(
self._raw, Sequence
), f"Bug: Expected a sequence, got {type(self._raw)}"
assert isinstance(self._raw, Sequence), (
f"Bug: Expected a sequence, got {type(self._raw)}"
)
return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())

def __dlpack__(self, *, stream: Any = None) -> Any:
Expand Down Expand Up @@ -2228,11 +2228,11 @@ def _graph_str(graph: Graph | GraphView) -> str:
)
signature = f"""\
graph(
name={graph.name or 'anonymous_graph:' + str(id(graph))},
inputs=({textwrap.indent(inputs_text, ' ' * 8)}
name={graph.name or "anonymous_graph:" + str(id(graph))},
inputs=({textwrap.indent(inputs_text, " " * 8)}
),
outputs=({textwrap.indent(outputs_text, ' ' * 8)}
),{textwrap.indent(initializers_text, ' ' * 4)}
outputs=({textwrap.indent(outputs_text, " " * 8)}
),{textwrap.indent(initializers_text, " " * 4)}
)"""
node_count = len(graph)
number_width = len(str(node_count))
Expand Down Expand Up @@ -2266,11 +2266,11 @@ def _graph_repr(graph: Graph | GraphView) -> str:
)
return f"""\
{graph.__class__.__name__}(
name={graph.name or 'anonymous_graph:' + str(id(graph))!r},
inputs=({textwrap.indent(inputs_text, ' ' * 8)}
name={graph.name or "anonymous_graph:" + str(id(graph))!r},
inputs=({textwrap.indent(inputs_text, " " * 8)}
),
outputs=({textwrap.indent(outputs_text, ' ' * 8)}
),{textwrap.indent(initializers_text, ' ' * 4)}
outputs=({textwrap.indent(outputs_text, " " * 8)}
),{textwrap.indent(initializers_text, " " * 4)}
len()={len(graph)}
)"""

Expand Down Expand Up @@ -2484,7 +2484,7 @@ def __repr__(self) -> str:
domain={self.domain!r},
model_version={self.model_version!r},
functions={self.functions!r},
graph={textwrap.indent(repr(self.graph), ' ' * 4).strip()}
graph={textwrap.indent(repr(self.graph), " " * 4).strip()}
)"""


Expand Down Expand Up @@ -2684,10 +2684,10 @@ def __str__(self) -> str:
>
def {full_name}(
inputs=(
{textwrap.indent(inputs_text, ' ' * 8)}
),{textwrap.indent(attributes_text, ' ' * 4)}
{textwrap.indent(inputs_text, " " * 8)}
),{textwrap.indent(attributes_text, " " * 4)}
outputs=(
{textwrap.indent(outputs_text, ' ' * 8)}
{textwrap.indent(outputs_text, " " * 8)}
),
)"""
node_count = len(self)
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/ir/_linked_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def __reversed__(self) -> Iterator[T]:
box = box.prev

def __len__(self) -> int:
assert self._length == len(
self._value_ids_to_boxes
), "Bug in the implementation: length mismatch"
assert self._length == len(self._value_ids_to_boxes), (
"Bug in the implementation: length mismatch"
)
return self._length

def __getitem__(self, index: int) -> T:
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/ir/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,9 @@ def _get_allowed_types_from_type_annotation(
allowed_types = set()
subtypes = typing.get_args(type_)
for subtype in subtypes:
assert subtype is not type(
None
), "Union should not contain None type because it is handled by _is_optional."
assert subtype is not type(None), (
"Union should not contain None type because it is handled by _is_optional."
)
allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
return allowed_types

Expand Down
3 changes: 1 addition & 2 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,7 @@ def numpy(self) -> np.ndarray:
raise ValueError("Cannot convert UNDEFINED tensor to numpy array.")
if self._proto.data_location == onnx.TensorProto.EXTERNAL:
raise ValueError(
"Cannot convert external tensor to numpy array. "
"Use ir.ExternalTensor instead."
"Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead."
)

if self._proto.HasField("raw_data"):
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/rewriter/broadcast_to_matmul_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ def test_reshape_matmul_reshape_does_not_replace_when_output_sizes_do_not_match(
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float{input_x_shape} input_x, float{input_y_shape} input_y) => (float{output_shape} output)
{{
shape_a = Constant<value: tensor = int64[{len(shape_a)}] {{ {', '.join(str(i) for i in shape_a)} }}>()
shape_a = Constant<value: tensor = int64[{len(shape_a)}] {{ {", ".join(str(i) for i in shape_a)} }}>()
reshape_x = Reshape (input_x, shape_a)
shape_b = Constant<value: tensor = int64[{len(shape_b)}] {{ {', '.join(str(i) for i in shape_b)} }}>()
shape_b = Constant<value: tensor = int64[{len(shape_b)}] {{ {", ".join(str(i) for i in shape_b)} }}>()
reshape_y = Reshape (input_y, shape_b)
matmul = MatMul (reshape_x, reshape_y)
shape_c = Constant<value: tensor = int64[{len(shape_c)}] {{ {', '.join(str(i) for i in shape_c)} }}>()
shape_c = Constant<value: tensor = int64[{len(shape_c)}] {{ {", ".join(str(i) for i in shape_c)} }}>()
output = Reshape (matmul, shape_c)
}}
"""
Expand Down
33 changes: 16 additions & 17 deletions onnxscript/rewriter/generic_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,21 @@ def __init__(
self.matched_pattern_to_model_value: dict[orp.ValuePattern, ir.Value] = {}

for graph_node, pattern_node in zip(model_nodes, pattern_nodes):
assert (
graph_node.op_identifier() == pattern_node.op_identifier()
), f"Unexpected type mismatch {graph_node.op_identifier()!r} != {pattern_node.op_identifier()!r}"
assert len(graph_node.inputs) == len(
pattern_node.inputs
), f"Unexpected number of inputs for type {graph_node.op_identifier()}"
assert graph_node.op_identifier() == pattern_node.op_identifier(), (
f"Unexpected type mismatch {graph_node.op_identifier()!r} != {pattern_node.op_identifier()!r}"
)
assert len(graph_node.inputs) == len(pattern_node.inputs), (
f"Unexpected number of inputs for type {graph_node.op_identifier()}"
)
for a, b in zip(graph_node.inputs, pattern_node.inputs):
if b is None:
# optional input or not an interesting input
continue
self._bind(b, a)

assert len(graph_node.outputs) == len(
pattern_node.outputs
), f"Unexpected number of outputs for type {graph_node.op_identifier()}"
assert len(graph_node.outputs) == len(pattern_node.outputs), (
f"Unexpected number of outputs for type {graph_node.op_identifier()}"
)
for a, b in zip(graph_node.outputs, pattern_node.outputs):
self._bind(b, a)

Expand Down Expand Up @@ -494,8 +494,7 @@ def _match_values_forward(
# 1. make assumptions and continue
# 2. mark the node as incomplete matching, we could end up stuck anyway.
raise NotImplementedError(
f"There are more than one option, this will be implemented later, "
f"ec={ec}, gc={gc}"
f"There are more than one option, this will be implemented later, ec={ec}, gc={gc}"
)

def _match_forward(
Expand Down Expand Up @@ -620,9 +619,9 @@ def match(
return result

nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes
assert (
not nodes_not_in_pattern
), f"Some nodes are not part of the pattern: {nodes_not_in_pattern}"
assert not nodes_not_in_pattern, (
f"Some nodes are not part of the pattern: {nodes_not_in_pattern}"
)

result = self._match_forward(
node, matched, stack, next_graph_node, next_pattern_node
Expand All @@ -633,9 +632,9 @@ def match(
return result

nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes
assert (
not nodes_not_in_pattern
), f"Some nodes are not part of the pattern: {nodes_not_in_pattern}"
assert not nodes_not_in_pattern, (
f"Some nodes are not part of the pattern: {nodes_not_in_pattern}"
)

if self.verbose > 5:
self._debug["iteration"] = iteration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def infer_attn_size_config(self, function: ir.Function) -> AttnSizeConfig:
# Reference:
# https://github.com/huggingface/diffusers/blob/ae05050db9d37d5af48a6cd0d6510a5ffb1c1cd4/src/diffusers/models/attention_processor.py#L1269
reshape_nodes = [node for node in function if node.op_type == "Reshape"]
assert (
len(reshape_nodes) == 4
), "Expected 3 Reshape nodes for Q, K and V, and 1 reshape node for output of scaled_dot_product_attention."
assert len(reshape_nodes) == 4, (
"Expected 3 Reshape nodes for Q, K and V, and 1 reshape node for output of scaled_dot_product_attention."
)
for reshape_node in reshape_nodes:
constant_node = reshape_node.inputs[1].producer()
assert (
constant_node.op_type == "Constant"
), "Expected the second input to Reshape to be a Constant node."
assert constant_node.op_type == "Constant", (
"Expected the second input to Reshape to be a Constant node."
)
value = reshape_node.inputs[1]
constant_value = _ir_utils.get_const_value(value)
if constant_value is None:
Expand Down
Loading

0 comments on commit 4cac988

Please sign in to comment.