Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shape inference for Loop #198

Merged
merged 5 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@ Change log
0.14.0 (unreleased)
-------------------

**Bug fix**

- Adds missing shape inference logic for :func:`spox.opsets.ai.v19.loop` and :func:`spox.opsets.ai.v21.loop`.

**Other changes**

- Propagated values may now be garbage collected if their associated `Var` object goes out of scope.
- :func:`spox.opsets.ai.v17.loop`, :func:`spox.opsets.ai.v19.loop` and :func:`spox.opsets.ai.v21.loop` will only infer shapes for loop carried dependencies if their shapes are unchanged across iterations.

neNasko1 marked this conversation as resolved.
Show resolved Hide resolved

0.13.0 (2024-12-06)
-------------------
Expand Down
26 changes: 26 additions & 0 deletions src/spox/_type_inference_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from spox._standard import InferenceError
from spox._type_system import Optional, Sequence, Tensor, Type


def loop_erase_shape_info(typ: Type) -> Type:
"""Erases the shape information for a type, that can exists as a state variable in a Loop"""
if isinstance(typ, Tensor):
return Tensor(typ.dtype, None)
elif isinstance(typ, Sequence):
if not isinstance(typ.elem_type, Tensor):
raise InferenceError(
f"Type {typ} not allowed for state variables in Loop operator, sequence element can only be a tensor"
)
return Sequence(loop_erase_shape_info(typ.elem_type))
elif isinstance(typ, Optional):
if isinstance(typ.elem_type, Optional):
raise InferenceError(
f"Type {typ} not allowed for state variables in Loop operator, optionals of optionals are not allowed"
)
return Optional(loop_erase_shape_info(typ.elem_type))
raise InferenceError(
f"Type {typ} not allowed for state variables in Loop operator."
)
2 changes: 1 addition & 1 deletion src/spox/_type_system.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down
23 changes: 17 additions & 6 deletions src/spox/opset/ai/onnx/v17.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from spox._graph import Graph, subgraph
from spox._node import OpType
from spox._standard import InferenceError, StandardNode
from spox._type_inference_utils import loop_erase_shape_info
from spox._type_system import Sequence as SpoxSequence
from spox._type_system import Tensor, Type
from spox._value_prop import PropDict, PropValueType
Expand Down Expand Up @@ -1780,16 +1781,26 @@ class Outputs(BaseOutputs):
v_final_and_scan_outputs: Sequence[_VarInfo]

def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
output_types = super().infer_output_types({})
output_types = super().infer_output_types(input_prop_values)
output_names = list(self.outputs.get_var_infos())

body = self.attrs.body.value
n = len(body.requested_arguments) - 2

carried_names = list(self.outputs.get_var_infos())[:n]
carried_types = [v.type for v in list(body.requested_results.values())[1:][:n]]
# We skip the iteration_num and condition as they are correctly inferred
initial_types = [v.type for v in list(body.requested_arguments)[2:]]
# We skip the returned condition as it is correctly inferred
carried_types = [v.type for v in list(body.requested_results.values())[1:]]
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved

for name, typ in zip(carried_names, carried_types):
output_types[name] = typ
shape_unchanged_between_iterations = all(
i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types)
)

for name, _, c_typ in zip(output_names, initial_types, carried_types):
output_types[name] = (
c_typ
if shape_unchanged_between_iterations
else loop_erase_shape_info(c_typ)
)

return output_types

Expand Down
25 changes: 25 additions & 0 deletions src/spox/opset/ai/onnx/v19.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from spox._graph import Graph, subgraph
from spox._node import OpType
from spox._standard import StandardNode
from spox._type_inference_utils import loop_erase_shape_info
from spox._type_system import Tensor, Type
from spox._value_prop import PropDict, PropValueType
from spox._var import Var, _VarInfo, create_prop_dict, unwrap_vars
Expand Down Expand Up @@ -617,6 +618,30 @@ class Inputs(BaseInputs):
class Outputs(BaseOutputs):
v_final_and_scan_outputs: Sequence[_VarInfo]

def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
output_types = super().infer_output_types(input_prop_values)
output_names = list(self.outputs.get_var_infos())

body = self.attrs.body.value

# We skip the iteration_num and condition as they are correctly inferred
initial_types = [v.type for v in list(body.requested_arguments)[2:]]
# We skip the returned condition as it is correctly inferred
carried_types = [v.type for v in list(body.requested_results.values())[1:]]

shape_unchanged_between_iterations = all(
i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types)
)

for name, _, c_typ in zip(output_names, initial_types, carried_types):
output_types[name] = (
c_typ
if shape_unchanged_between_iterations
else loop_erase_shape_info(c_typ)
)

return output_types

op_type = OpType("Loop", "", 19)

attrs: Attributes
Expand Down
25 changes: 25 additions & 0 deletions src/spox/opset/ai/onnx/v21.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from spox._graph import Graph, subgraph
from spox._node import OpType
from spox._standard import StandardNode
from spox._type_inference_utils import loop_erase_shape_info
from spox._type_system import Tensor, Type
from spox._value_prop import PropDict, PropValueType
from spox._var import Var, _VarInfo, create_prop_dict, unwrap_vars
Expand Down Expand Up @@ -612,6 +613,30 @@ class Inputs(BaseInputs):
class Outputs(BaseOutputs):
v_final_and_scan_outputs: Sequence[_VarInfo]

def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
output_types = super().infer_output_types(input_prop_values)
output_names = list(self.outputs.get_var_infos())

body = self.attrs.body.value

# We skip the iteration_num and condition as they are correctly inferred
initial_types = [v.type for v in list(body.requested_arguments)[2:]]
# We skip the returned condition as it is correctly inferred
carried_types = [v.type for v in list(body.requested_results.values())[1:]]

shape_unchanged_between_iterations = all(
i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types)
)

for name, _, c_typ in zip(output_names, initial_types, carried_types):
output_types[name] = (
c_typ
if shape_unchanged_between_iterations
else loop_erase_shape_info(c_typ)
)

return output_types

op_type = OpType("Loop", "", 21)

attrs: Attributes
Expand Down
97 changes: 89 additions & 8 deletions tests/type_inference/test_loop.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,96 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

import spox.opset.ai.onnx.v17 as op
from spox import Tensor, argument
import numpy as np
import pytest

import spox.opset.ai.onnx.v17 as op17
import spox.opset.ai.onnx.v19 as op19
import spox.opset.ai.onnx.v21 as op21
from spox import Optional, Sequence, Tensor, argument

def test_loop_inference():

@pytest.mark.parametrize("op", [op17, op19, op21])
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
def test_loop_inference(op):
x, y, zs = op.loop(
v_initial=[argument(Tensor(float, (None,))), argument(Tensor(int, ("N", 2)))],
v_initial=[
argument(Tensor(np.float64, (None,))),
argument(Tensor(np.int64, ("N", 2))),
],
body=lambda i, c, a, b: [op.const(True), a, op.add(i, b), i],
)
assert x.type == Tensor(float, (None,))
assert y.type == Tensor(int, ("N", 2))
assert zs.type == Tensor(int, (None, 1))
assert x.type == Tensor(np.float64, (None,))
assert y.type == Tensor(np.int64, ("N", 2))
assert zs.type == Tensor(np.int64, (None, 1))


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_concat(op):
num_iters = op.const(1)
v = op.const([], dtype=np.int64)

result = op.loop(
num_iters,
v_initial=[v],
body=lambda i, c, x: (op.const(True), op.concat([x, op.const([1])], axis=0)),
)[0]

# type can change, so we cannot infer anything
assert result.type == Tensor(np.int64, None)


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_sequence(op):
num_iters = op.const(1)
v = op.sequence_empty(dtype=np.int64)

result = op.loop(
num_iters,
v_initial=[v],
body=lambda i, c, x: (op.const(True), op.sequence_insert(x, op.const([1]))),
)[0]

assert result.type == Sequence(Tensor(np.int64, None))


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_optional(op):
num_iters = op.const(1)
v = op.optional(type=Tensor(np.int64, (1, 2)))

result = op.loop(
num_iters,
v_initial=[v],
body=lambda i, c, x: (
op.const(True),
op.if_(
op.optional_has_element(x),
then_branch=lambda: [op.optional(type=Tensor(np.int64, (1, 2)))],
else_branch=lambda: [op.optional(op.const([[1, 1]]))],
)[0],
),
)[0]

assert result.type == Optional(Tensor(np.int64, (1, 2)))


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_optional_no_shape(op):
num_iters = op.const(1)
v = op.optional(type=Tensor(np.int64, (1, 2)))

result = op.loop(
num_iters,
v_initial=[v],
body=lambda i, c, x: (
op.const(True),
op.if_(
op.optional_has_element(x),
then_branch=lambda: [op.optional(type=Tensor(np.int64, (1, 2)))],
else_branch=lambda: [op.optional(op.const([[1]]))],
)[0],
),
)[0]

# shape can change, we cannot infer type
assert result.type == Optional(Tensor(np.int64, None))
3 changes: 2 additions & 1 deletion tools/generate_opset.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def main(
"ai.onnx",
19,
extras=["const"],
type_inference={"Compress": "compress11"},
type_inference={"Compress": "compress11", "Loop": "loop16-fix"},
value_propagation={"Constant": "constant13"},
out_variadic_solutions=V18_OUT_VARIADIC_SOLUTIONS,
subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS,
Expand All @@ -715,6 +715,7 @@ def main(
"ai.onnx",
21,
extras=["const"],
type_inference={"Compress": "compress11", "Loop": "loop16-fix"},
value_propagation={"Constant": "constant13"},
out_variadic_solutions=V18_OUT_VARIADIC_SOLUTIONS,
subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS,
Expand Down
3 changes: 3 additions & 0 deletions tools/templates/preamble.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@ from spox._internal_op import intro
from spox._node import OpType
from spox._standard import InferenceError, StandardNode
from spox._type_system import Tensor, Type, Sequence as SpoxSequence
from spox._type_inference_utils import loop_erase_shape_info
from spox._value_prop import PropValueType, PropDict
from spox._shape import Shape

20 changes: 14 additions & 6 deletions tools/templates/type_inference/loop16-fix.jinja2
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
output_types = super().infer_output_types({})
output_types = super().infer_output_types(input_prop_values)
output_names = list(self.outputs.get_var_infos())

body = self.attrs.body.value
n = len(body.requested_arguments) - 2

carried_names = list(self.outputs.get_var_infos())[:n]
carried_types = [v.type for v in list(body.requested_results.values())[1:][:n]]
# We skip the iteration_num and condition as they are correctly inferred
initial_types = [v.type for v in list(body.requested_arguments)[2:]]
# We skip the returned condition as it is correctly inferred
carried_types = [v.type for v in list(body.requested_results.values())[1:]]

for name, typ in zip(carried_names, carried_types):
output_types[name] = typ
shape_unchanged_between_iterations = all(
i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types)
)

for name, _, c_typ in zip(output_names, initial_types, carried_types):
output_types[name] = (
c_typ if shape_unchanged_between_iterations else loop_erase_shape_info(c_typ)
)

return output_types
Loading