diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2f2b3be4..663bf0fb 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,7 @@ Change log **Bug fix** - Value propagation of string tensors no longer raises an erroneous ``ValueError`` in some instances. +- Apply custom shape inference logic in :func:`spox.opsets.ai.onnx.v19.loop` and :func:`spox.opsets.ai.onnx.v21.loop`. 0.12.1 (2024-06-18) diff --git a/src/spox/opset/ai/onnx/v19.py b/src/spox/opset/ai/onnx/v19.py index 6c14823d..4be48159 100644 --- a/src/spox/opset/ai/onnx/v19.py +++ b/src/spox/opset/ai/onnx/v19.py @@ -617,6 +617,20 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): v_final_and_scan_outputs: Sequence[Var] + def infer_output_types(self) -> dict[str, Type]: + output_types = super().infer_output_types() + + body = self.attrs.body.value + n = len(body.requested_arguments) - 2 + + carried_names = list(self.outputs.get_vars())[:n] + carried_types = [v.type for v in list(body.requested_results.values())[1:][:n]] + + for name, typ in zip(carried_names, carried_types): + output_types[name] = typ + + return output_types + op_type = OpType("Loop", "", 19) attrs: Attributes diff --git a/src/spox/opset/ai/onnx/v21.py b/src/spox/opset/ai/onnx/v21.py index f4f027cc..0c93face 100644 --- a/src/spox/opset/ai/onnx/v21.py +++ b/src/spox/opset/ai/onnx/v21.py @@ -612,6 +612,20 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): v_final_and_scan_outputs: Sequence[Var] + def infer_output_types(self) -> dict[str, Type]: + output_types = super().infer_output_types() + + body = self.attrs.body.value + n = len(body.requested_arguments) - 2 + + carried_names = list(self.outputs.get_vars())[:n] + carried_types = [v.type for v in list(body.requested_results.values())[1:][:n]] + + for name, typ in zip(carried_names, carried_types): + output_types[name] = typ + + return output_types + op_type = OpType("Loop", "", 21) attrs: Attributes diff --git a/tests/type_inference/test_loop.py b/tests/type_inference/test_loop.py index c193d519..e99c92e7 100644 --- a/tests/type_inference/test_loop.py +++ b/tests/type_inference/test_loop.py @@ -1,11 +1,16 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -import spox.opset.ai.onnx.v17 as op +import pytest + +import spox.opset.ai.onnx.v17 as op17 +import spox.opset.ai.onnx.v19 as op19 +import spox.opset.ai.onnx.v21 as op21 from spox import Tensor, argument -def test_loop_inference(): +@pytest.mark.parametrize("op", [op17, op19, op21]) +def test_loop_inference(op): x, y, zs = op.loop( v_initial=[argument(Tensor(float, (None,))), argument(Tensor(int, ("N", 2)))], body=lambda i, c, a, b: [op.const(True), a, op.add(i, b), i], diff --git a/tools/generate_opset.py b/tools/generate_opset.py index 399a311f..efc9d09b 100644 --- a/tools/generate_opset.py +++ b/tools/generate_opset.py @@ -688,7 +688,7 @@ def main( "ai.onnx", 19, extras=["const"], - type_inference={"Compress": "compress11"}, + type_inference={"Loop": "loop16-fix"}, value_propagation={"Constant": "constant13"}, out_variadic_solutions=V18_OUT_VARIADIC_SOLUTIONS, subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS, @@ -711,6 +711,7 @@ def main( "ai.onnx", 21, extras=["const"], + type_inference={"Loop": "loop16-fix"}, value_propagation={"Constant": "constant13"}, out_variadic_solutions=V18_OUT_VARIADIC_SOLUTIONS, subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS,