diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6f6d6185..63b0595c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -10,9 +10,15 @@ Change log 0.14.0 (unreleased) ------------------- +**Bug fix** + +- Adds missing shape inference logic for :func:`spox.opsets.ai.v19.loop` and :func:`spox.opsets.ai.v21.loop`. + **Other changes** - Propagated values may now be garbage collected if their associated `Var` object goes out of scope. +- :func:`spox.opsets.ai.v17.loop`, :func:`spox.opsets.ai.v19.loop` and :func:`spox.opsets.ai.v21.loop` will only infer shapes for loop carried dependencies if their shapes are unchanged across iterations. + 0.13.0 (2024-12-06) ------------------- diff --git a/src/spox/_type_inference_utils.py b/src/spox/_type_inference_utils.py new file mode 100644 index 00000000..01dc7724 --- /dev/null +++ b/src/spox/_type_inference_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) QuantCo 2023-2025 +# SPDX-License-Identifier: BSD-3-Clause + +from spox._standard import InferenceError +from spox._type_system import Optional, Sequence, Tensor, Type + + +def loop_erase_shape_info(typ: Type) -> Type: + """Erases the shape information for a type, that can exists as a state variable in a Loop""" + if isinstance(typ, Tensor): + return Tensor(typ.dtype, None) + elif isinstance(typ, Sequence): + if not isinstance(typ.elem_type, Tensor): + raise InferenceError( + f"Type {typ} not allowed for state variables in Loop operator, sequence element can only be a tensor" + ) + return Sequence(loop_erase_shape_info(typ.elem_type)) + elif isinstance(typ, Optional): + if isinstance(typ.elem_type, Optional): + raise InferenceError( + f"Type {typ} not allowed for state variables in Loop operator, optionals of optionals are not allowed" + ) + return Optional(loop_erase_shape_info(typ.elem_type)) + raise InferenceError( + f"Type {typ} not allowed for state variables in Loop operator." + ) diff --git a/src/spox/_type_system.py b/src/spox/_type_system.py index 0d1887e6..ffa1dca3 100644 --- a/src/spox/_type_system.py +++ b/src/spox/_type_system.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations diff --git a/src/spox/opset/ai/onnx/v17.py b/src/spox/opset/ai/onnx/v17.py index e02a8fa8..b0ab6a8d 100644 --- a/src/spox/opset/ai/onnx/v17.py +++ b/src/spox/opset/ai/onnx/v17.py @@ -29,6 +29,7 @@ from spox._graph import Graph, subgraph from spox._node import OpType from spox._standard import InferenceError, StandardNode +from spox._type_inference_utils import loop_erase_shape_info from spox._type_system import Sequence as SpoxSequence from spox._type_system import Tensor, Type from spox._value_prop import PropDict, PropValueType @@ -1780,16 +1781,26 @@ class Outputs(BaseOutputs): v_final_and_scan_outputs: Sequence[_VarInfo] def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]: - output_types = super().infer_output_types({}) + output_types = super().infer_output_types(input_prop_values) + output_names = list(self.outputs.get_var_infos()) body = self.attrs.body.value - n = len(body.requested_arguments) - 2 - carried_names = list(self.outputs.get_var_infos())[:n] - carried_types = [v.type for v in list(body.requested_results.values())[1:][:n]] + # We skip the iteration_num and condition as they are correctly inferred + initial_types = [v.type for v in list(body.requested_arguments)[2:]] + # We skip the returned condition as it is correctly inferred + carried_types = [v.type for v in list(body.requested_results.values())[1:]] - for name, typ in zip(carried_names, carried_types): - output_types[name] = typ + shape_unchanged_between_iterations = all( + i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types) + ) + + for name, _, c_typ in zip(output_names, initial_types, carried_types): + output_types[name] = ( + c_typ + if shape_unchanged_between_iterations + else loop_erase_shape_info(c_typ) + ) return output_types diff --git a/src/spox/opset/ai/onnx/v19.py b/src/spox/opset/ai/onnx/v19.py index bbc6d9b9..386dc642 100644 --- a/src/spox/opset/ai/onnx/v19.py +++ b/src/spox/opset/ai/onnx/v19.py @@ -28,6 +28,7 @@ from spox._graph import Graph, subgraph from spox._node import OpType from spox._standard import StandardNode +from spox._type_inference_utils import loop_erase_shape_info from spox._type_system import Tensor, Type from spox._value_prop import PropDict, PropValueType from spox._var import Var, _VarInfo, create_prop_dict, unwrap_vars @@ -617,6 +618,30 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): v_final_and_scan_outputs: Sequence[_VarInfo] + def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]: + output_types = super().infer_output_types(input_prop_values) + output_names = list(self.outputs.get_var_infos()) + + body = self.attrs.body.value + + # We skip the iteration_num and condition as they are correctly inferred + initial_types = [v.type for v in list(body.requested_arguments)[2:]] + # We skip the returned condition as it is correctly inferred + carried_types = [v.type for v in list(body.requested_results.values())[1:]] + + shape_unchanged_between_iterations = all( + i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types) + ) + + for name, _, c_typ in zip(output_names, initial_types, carried_types): + output_types[name] = ( + c_typ + if shape_unchanged_between_iterations + else loop_erase_shape_info(c_typ) + ) + + return output_types + op_type = OpType("Loop", "", 19) attrs: Attributes diff --git a/src/spox/opset/ai/onnx/v21.py b/src/spox/opset/ai/onnx/v21.py index cc5cb1dd..359dfd7a 100644 --- a/src/spox/opset/ai/onnx/v21.py +++ b/src/spox/opset/ai/onnx/v21.py @@ -28,6 +28,7 @@ from spox._graph import Graph, subgraph from spox._node import OpType from spox._standard import StandardNode +from spox._type_inference_utils import loop_erase_shape_info from spox._type_system import Tensor, Type from spox._value_prop import PropDict, PropValueType from spox._var import Var, _VarInfo, create_prop_dict, unwrap_vars @@ -612,6 +613,30 @@ class Inputs(BaseInputs): class Outputs(BaseOutputs): v_final_and_scan_outputs: Sequence[_VarInfo] + def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]: + output_types = super().infer_output_types(input_prop_values) + output_names = list(self.outputs.get_var_infos()) + + body = self.attrs.body.value + + # We skip the iteration_num and condition as they are correctly inferred + initial_types = [v.type for v in list(body.requested_arguments)[2:]] + # We skip the returned condition as it is correctly inferred + carried_types = [v.type for v in list(body.requested_results.values())[1:]] + + shape_unchanged_between_iterations = all( + i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types) + ) + + for name, _, c_typ in zip(output_names, initial_types, carried_types): + output_types[name] = ( + c_typ + if shape_unchanged_between_iterations + else loop_erase_shape_info(c_typ) + ) + + return output_types + op_type = OpType("Loop", "", 21) attrs: Attributes diff --git a/tests/type_inference/test_loop.py b/tests/type_inference/test_loop.py index c193d519..12ea9e28 100644 --- a/tests/type_inference/test_loop.py +++ b/tests/type_inference/test_loop.py @@ -1,15 +1,96 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause -import spox.opset.ai.onnx.v17 as op -from spox import Tensor, argument +import numpy as np +import pytest +import spox.opset.ai.onnx.v17 as op17 +import spox.opset.ai.onnx.v19 as op19 +import spox.opset.ai.onnx.v21 as op21 +from spox import Optional, Sequence, Tensor, argument -def test_loop_inference(): + +@pytest.mark.parametrize("op", [op17, op19, op21]) +def test_loop_inference(op): x, y, zs = op.loop( - v_initial=[argument(Tensor(float, (None,))), argument(Tensor(int, ("N", 2)))], + v_initial=[ + argument(Tensor(np.float64, (None,))), + argument(Tensor(np.int64, ("N", 2))), + ], body=lambda i, c, a, b: [op.const(True), a, op.add(i, b), i], ) - assert x.type == Tensor(float, (None,)) - assert y.type == Tensor(int, ("N", 2)) - assert zs.type == Tensor(int, (None, 1)) + assert x.type == Tensor(np.float64, (None,)) + assert y.type == Tensor(np.int64, ("N", 2)) + assert zs.type == Tensor(np.int64, (None, 1)) + + +@pytest.mark.parametrize("op", [op17, op19, op21]) +def test_loop_concat(op): + num_iters = op.const(1) + v = op.const([], dtype=np.int64) + + result = op.loop( + num_iters, + v_initial=[v], + body=lambda i, c, x: (op.const(True), op.concat([x, op.const([1])], axis=0)), + )[0] + + # type can change, so we cannot infer anything + assert result.type == Tensor(np.int64, None) + + +@pytest.mark.parametrize("op", [op17, op19, op21]) +def test_loop_sequence(op): + num_iters = op.const(1) + v = op.sequence_empty(dtype=np.int64) + + result = op.loop( + num_iters, + v_initial=[v], + body=lambda i, c, x: (op.const(True), op.sequence_insert(x, op.const([1]))), + )[0] + + assert result.type == Sequence(Tensor(np.int64, None)) + + +@pytest.mark.parametrize("op", [op17, op19, op21]) +def test_loop_optional(op): + num_iters = op.const(1) + v = op.optional(type=Tensor(np.int64, (1, 2))) + + result = op.loop( + num_iters, + v_initial=[v], + body=lambda i, c, x: ( + op.const(True), + op.if_( + op.optional_has_element(x), + then_branch=lambda: [op.optional(type=Tensor(np.int64, (1, 2)))], + else_branch=lambda: [op.optional(op.const([[1, 1]]))], + )[0], + ), + )[0] + + assert result.type == Optional(Tensor(np.int64, (1, 2))) + + +@pytest.mark.parametrize("op", [op17, op19, op21]) +def test_loop_optional_no_shape(op): + num_iters = op.const(1) + v = op.optional(type=Tensor(np.int64, (1, 2))) + + result = op.loop( + num_iters, + v_initial=[v], + body=lambda i, c, x: ( + op.const(True), + op.if_( + op.optional_has_element(x), + then_branch=lambda: [op.optional(type=Tensor(np.int64, (1, 2)))], + else_branch=lambda: [op.optional(op.const([[1]]))], + )[0], + ), + )[0] + + # shape can change, we cannot infer type + assert result.type == Optional(Tensor(np.int64, None)) diff --git a/tools/generate_opset.py b/tools/generate_opset.py index 1866fb9f..9c542995 100644 --- a/tools/generate_opset.py +++ b/tools/generate_opset.py @@ -692,7 +692,7 @@ def main( "ai.onnx", 19, extras=["const"], - type_inference={"Compress": "compress11"}, + type_inference={"Compress": "compress11", "Loop": "loop16-fix"}, value_propagation={"Constant": "constant13"}, out_variadic_solutions=V18_OUT_VARIADIC_SOLUTIONS, subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS, @@ -715,6 +715,7 @@ def main( "ai.onnx", 21, extras=["const"], + type_inference={"Compress": "compress11", "Loop": "loop16-fix"}, value_propagation={"Constant": "constant13"}, out_variadic_solutions=V18_OUT_VARIADIC_SOLUTIONS, subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS, diff --git a/tools/templates/preamble.jinja2 b/tools/templates/preamble.jinja2 index b0dd5d63..8a05638b 100644 --- a/tools/templates/preamble.jinja2 +++ b/tools/templates/preamble.jinja2 @@ -33,4 +33,7 @@ from spox._internal_op import intro from spox._node import OpType from spox._standard import InferenceError, StandardNode from spox._type_system import Tensor, Type, Sequence as SpoxSequence +from spox._type_inference_utils import loop_erase_shape_info from spox._value_prop import PropValueType, PropDict +from spox._shape import Shape + diff --git a/tools/templates/type_inference/loop16-fix.jinja2 b/tools/templates/type_inference/loop16-fix.jinja2 index b797693c..ffc9f490 100644 --- a/tools/templates/type_inference/loop16-fix.jinja2 +++ b/tools/templates/type_inference/loop16-fix.jinja2 @@ -1,12 +1,20 @@ -output_types = super().infer_output_types({}) +output_types = super().infer_output_types(input_prop_values) +output_names = list(self.outputs.get_var_infos()) body = self.attrs.body.value -n = len(body.requested_arguments) - 2 -carried_names = list(self.outputs.get_var_infos())[:n] -carried_types = [v.type for v in list(body.requested_results.values())[1:][:n]] +# We skip the iteration_num and condition as they are correctly inferred +initial_types = [v.type for v in list(body.requested_arguments)[2:]] +# We skip the returned condition as it is correctly inferred +carried_types = [v.type for v in list(body.requested_results.values())[1:]] -for name, typ in zip(carried_names, carried_types): - output_types[name] = typ +shape_unchanged_between_iterations = all( + i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types) +) + +for name, _, c_typ in zip(output_names, initial_types, carried_types): + output_types[name] = ( + c_typ if shape_unchanged_between_iterations else loop_erase_shape_info(c_typ) + ) return output_types