Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shape inference for Loop@v19 and v21 #164

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Change log
**Bug fix**

- Value propagation of string tensors no longer raises an erroneous ``ValueError`` in some instances.
- Apply custom shape inference logic in :func:`spox.opsets.ai.onnx.v19.loop` and :func:`spox.opsets.ai.onnx.v21.loop`.


0.12.1 (2024-06-18)
Expand Down
14 changes: 14 additions & 0 deletions src/spox/opset/ai/onnx/v19.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,20 @@ class Inputs(BaseInputs):
class Outputs(BaseOutputs):
v_final_and_scan_outputs: Sequence[Var]

def infer_output_types(self) -> dict[str, Type]:
output_types = super().infer_output_types()

body = self.attrs.body.value
n = len(body.requested_arguments) - 2

carried_names = list(self.outputs.get_vars())[:n]
carried_types = [v.type for v in list(body.requested_results.values())[1:][:n]]

for name, typ in zip(carried_names, carried_types):
output_types[name] = typ

return output_types

op_type = OpType("Loop", "", 19)

attrs: Attributes
Expand Down
14 changes: 14 additions & 0 deletions src/spox/opset/ai/onnx/v21.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,20 @@ class Inputs(BaseInputs):
class Outputs(BaseOutputs):
v_final_and_scan_outputs: Sequence[Var]

def infer_output_types(self) -> dict[str, Type]:
output_types = super().infer_output_types()

body = self.attrs.body.value
n = len(body.requested_arguments) - 2

carried_names = list(self.outputs.get_vars())[:n]
carried_types = [v.type for v in list(body.requested_results.values())[1:][:n]]

for name, typ in zip(carried_names, carried_types):
output_types[name] = typ

Copy link
Member

@adityagoel4512 adityagoel4512 Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you not want to check the output type of the body is compatible with the initial value as per #163 (comment)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now this PR just continues with the existing behavior of earlier opsets, which is arguably not good enough. @MatejUrbanQC How did you end up solving your issue? Would we break your solution if we were to impose stricter checks here?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not using loops in the end, so that's fine.

But I think this only make sense if you can explicitly mark dimensions that you want to change in the loop as None.

Let's say I want to use a function lambda x: concat([x, [0]], axis=0) in the loop.
If my input has shape (None,), then the shape inference passes, because the output is also (None,).
However, if my input has known shape, it will fail.

So it's not just very limiting, it's also inconsistent.

These issues should go away if you can say "This dimensions will be changed, treat it as None". If the number of dimensions will also change, you set the whole shape to None.

return output_types

op_type = OpType("Loop", "", 21)

attrs: Attributes
Expand Down
9 changes: 7 additions & 2 deletions tests/type_inference/test_loop.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

import spox.opset.ai.onnx.v17 as op
import pytest

import spox.opset.ai.onnx.v17 as op17
import spox.opset.ai.onnx.v19 as op19
import spox.opset.ai.onnx.v21 as op21
from spox import Tensor, argument


def test_loop_inference():
@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_inference(op):
x, y, zs = op.loop(
v_initial=[argument(Tensor(float, (None,))), argument(Tensor(int, ("N", 2)))],
body=lambda i, c, a, b: [op.const(True), a, op.add(i, b), i],
Expand Down
3 changes: 2 additions & 1 deletion tools/generate_opset.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def main(
"ai.onnx",
19,
extras=["const"],
type_inference={"Compress": "compress11"},
type_inference={"Loop": "loop16-fix"},
value_propagation={"Constant": "constant13"},
out_variadic_solutions=V18_OUT_VARIADIC_SOLUTIONS,
subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS,
Expand All @@ -711,6 +711,7 @@ def main(
"ai.onnx",
21,
extras=["const"],
type_inference={"Loop": "loop16-fix"},
value_propagation={"Constant": "constant13"},
out_variadic_solutions=V18_OUT_VARIADIC_SOLUTIONS,
subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS,
Expand Down