Skip to content

Commit

Permalink
Fix shape inference for Loop (#198)
Browse files Browse the repository at this point in the history
* init

Fix inconsistencies

Signed-off-by: neNasko1 <[email protected]>

Add CHANGELOG

Signed-off-by: neNasko1 <[email protected]>

Fix docstring

Signed-off-by: neNasko1 <[email protected]>

Apply suggestions from code review

Co-authored-by: Aditya Goel <[email protected]>

Comments after code review

Signed-off-by: neNasko1 <[email protected]>

Comments after code review

Signed-off-by: neNasko1 <[email protected]>

Fix license

Signed-off-by: neNasko1 <[email protected]>

* Fix diff

Signed-off-by: neNasko1 <[email protected]>

* Update CHANGELOG.rst

Co-authored-by: Aditya Goel <[email protected]>

---------

Signed-off-by: neNasko1 <[email protected]>
Co-authored-by: Aditya Goel <[email protected]>
  • Loading branch information
neNasko1 and adityagoel4512 authored Jan 3, 2025
1 parent d7c1792 commit 0741557
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 22 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@ Change log
0.14.0 (unreleased)
-------------------

**Bug fix**

- Adds missing shape inference logic for :func:`spox.opsets.ai.v19.loop` and :func:`spox.opsets.ai.v21.loop`.

**Other changes**

- Propagated values may now be garbage collected if their associated `Var` object goes out of scope.
- :func:`spox.opsets.ai.v17.loop`, :func:`spox.opsets.ai.v19.loop` and :func:`spox.opsets.ai.v21.loop` will only infer shapes for loop carried dependencies if their shapes are unchanged across iterations.


0.13.0 (2024-12-06)
-------------------
Expand Down
26 changes: 26 additions & 0 deletions src/spox/_type_inference_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from spox._standard import InferenceError
from spox._type_system import Optional, Sequence, Tensor, Type


def loop_erase_shape_info(typ: Type) -> Type:
"""Erases the shape information for a type, that can exists as a state variable in a Loop"""
if isinstance(typ, Tensor):
return Tensor(typ.dtype, None)
elif isinstance(typ, Sequence):
if not isinstance(typ.elem_type, Tensor):
raise InferenceError(
f"Type {typ} not allowed for state variables in Loop operator, sequence element can only be a tensor"
)
return Sequence(loop_erase_shape_info(typ.elem_type))
elif isinstance(typ, Optional):
if isinstance(typ.elem_type, Optional):
raise InferenceError(
f"Type {typ} not allowed for state variables in Loop operator, optionals of optionals are not allowed"
)
return Optional(loop_erase_shape_info(typ.elem_type))
raise InferenceError(
f"Type {typ} not allowed for state variables in Loop operator."
)
2 changes: 1 addition & 1 deletion src/spox/_type_system.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down
23 changes: 17 additions & 6 deletions src/spox/opset/ai/onnx/v17.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from spox._graph import Graph, subgraph
from spox._node import OpType
from spox._standard import InferenceError, StandardNode
from spox._type_inference_utils import loop_erase_shape_info
from spox._type_system import Sequence as SpoxSequence
from spox._type_system import Tensor, Type
from spox._value_prop import PropDict, PropValueType
Expand Down Expand Up @@ -1780,16 +1781,26 @@ class Outputs(BaseOutputs):
v_final_and_scan_outputs: Sequence[_VarInfo]

def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
output_types = super().infer_output_types({})
output_types = super().infer_output_types(input_prop_values)
output_names = list(self.outputs.get_var_infos())

body = self.attrs.body.value
n = len(body.requested_arguments) - 2

carried_names = list(self.outputs.get_var_infos())[:n]
carried_types = [v.type for v in list(body.requested_results.values())[1:][:n]]
# We skip the iteration_num and condition as they are correctly inferred
initial_types = [v.type for v in list(body.requested_arguments)[2:]]
# We skip the returned condition as it is correctly inferred
carried_types = [v.type for v in list(body.requested_results.values())[1:]]

for name, typ in zip(carried_names, carried_types):
output_types[name] = typ
shape_unchanged_between_iterations = all(
i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types)
)

for name, _, c_typ in zip(output_names, initial_types, carried_types):
output_types[name] = (
c_typ
if shape_unchanged_between_iterations
else loop_erase_shape_info(c_typ)
)

return output_types

Expand Down
25 changes: 25 additions & 0 deletions src/spox/opset/ai/onnx/v19.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from spox._graph import Graph, subgraph
from spox._node import OpType
from spox._standard import StandardNode
from spox._type_inference_utils import loop_erase_shape_info
from spox._type_system import Tensor, Type
from spox._value_prop import PropDict, PropValueType
from spox._var import Var, _VarInfo, create_prop_dict, unwrap_vars
Expand Down Expand Up @@ -617,6 +618,30 @@ class Inputs(BaseInputs):
class Outputs(BaseOutputs):
v_final_and_scan_outputs: Sequence[_VarInfo]

def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
output_types = super().infer_output_types(input_prop_values)
output_names = list(self.outputs.get_var_infos())

body = self.attrs.body.value

# We skip the iteration_num and condition as they are correctly inferred
initial_types = [v.type for v in list(body.requested_arguments)[2:]]
# We skip the returned condition as it is correctly inferred
carried_types = [v.type for v in list(body.requested_results.values())[1:]]

shape_unchanged_between_iterations = all(
i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types)
)

for name, _, c_typ in zip(output_names, initial_types, carried_types):
output_types[name] = (
c_typ
if shape_unchanged_between_iterations
else loop_erase_shape_info(c_typ)
)

return output_types

op_type = OpType("Loop", "", 19)

attrs: Attributes
Expand Down
25 changes: 25 additions & 0 deletions src/spox/opset/ai/onnx/v21.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from spox._graph import Graph, subgraph
from spox._node import OpType
from spox._standard import StandardNode
from spox._type_inference_utils import loop_erase_shape_info
from spox._type_system import Tensor, Type
from spox._value_prop import PropDict, PropValueType
from spox._var import Var, _VarInfo, create_prop_dict, unwrap_vars
Expand Down Expand Up @@ -612,6 +613,30 @@ class Inputs(BaseInputs):
class Outputs(BaseOutputs):
v_final_and_scan_outputs: Sequence[_VarInfo]

def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
output_types = super().infer_output_types(input_prop_values)
output_names = list(self.outputs.get_var_infos())

body = self.attrs.body.value

# We skip the iteration_num and condition as they are correctly inferred
initial_types = [v.type for v in list(body.requested_arguments)[2:]]
# We skip the returned condition as it is correctly inferred
carried_types = [v.type for v in list(body.requested_results.values())[1:]]

shape_unchanged_between_iterations = all(
i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types)
)

for name, _, c_typ in zip(output_names, initial_types, carried_types):
output_types[name] = (
c_typ
if shape_unchanged_between_iterations
else loop_erase_shape_info(c_typ)
)

return output_types

op_type = OpType("Loop", "", 21)

attrs: Attributes
Expand Down
97 changes: 89 additions & 8 deletions tests/type_inference/test_loop.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,96 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

import spox.opset.ai.onnx.v17 as op
from spox import Tensor, argument
import numpy as np
import pytest

import spox.opset.ai.onnx.v17 as op17
import spox.opset.ai.onnx.v19 as op19
import spox.opset.ai.onnx.v21 as op21
from spox import Optional, Sequence, Tensor, argument

def test_loop_inference():

@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_inference(op):
x, y, zs = op.loop(
v_initial=[argument(Tensor(float, (None,))), argument(Tensor(int, ("N", 2)))],
v_initial=[
argument(Tensor(np.float64, (None,))),
argument(Tensor(np.int64, ("N", 2))),
],
body=lambda i, c, a, b: [op.const(True), a, op.add(i, b), i],
)
assert x.type == Tensor(float, (None,))
assert y.type == Tensor(int, ("N", 2))
assert zs.type == Tensor(int, (None, 1))
assert x.type == Tensor(np.float64, (None,))
assert y.type == Tensor(np.int64, ("N", 2))
assert zs.type == Tensor(np.int64, (None, 1))


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_concat(op):
num_iters = op.const(1)
v = op.const([], dtype=np.int64)

result = op.loop(
num_iters,
v_initial=[v],
body=lambda i, c, x: (op.const(True), op.concat([x, op.const([1])], axis=0)),
)[0]

# type can change, so we cannot infer anything
assert result.type == Tensor(np.int64, None)


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_sequence(op):
num_iters = op.const(1)
v = op.sequence_empty(dtype=np.int64)

result = op.loop(
num_iters,
v_initial=[v],
body=lambda i, c, x: (op.const(True), op.sequence_insert(x, op.const([1]))),
)[0]

assert result.type == Sequence(Tensor(np.int64, None))


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_optional(op):
num_iters = op.const(1)
v = op.optional(type=Tensor(np.int64, (1, 2)))

result = op.loop(
num_iters,
v_initial=[v],
body=lambda i, c, x: (
op.const(True),
op.if_(
op.optional_has_element(x),
then_branch=lambda: [op.optional(type=Tensor(np.int64, (1, 2)))],
else_branch=lambda: [op.optional(op.const([[1, 1]]))],
)[0],
),
)[0]

assert result.type == Optional(Tensor(np.int64, (1, 2)))


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_optional_no_shape(op):
num_iters = op.const(1)
v = op.optional(type=Tensor(np.int64, (1, 2)))

result = op.loop(
num_iters,
v_initial=[v],
body=lambda i, c, x: (
op.const(True),
op.if_(
op.optional_has_element(x),
then_branch=lambda: [op.optional(type=Tensor(np.int64, (1, 2)))],
else_branch=lambda: [op.optional(op.const([[1]]))],
)[0],
),
)[0]

# shape can change, we cannot infer type
assert result.type == Optional(Tensor(np.int64, None))
3 changes: 2 additions & 1 deletion tools/generate_opset.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def main(
"ai.onnx",
19,
extras=["const"],
type_inference={"Compress": "compress11"},
type_inference={"Compress": "compress11", "Loop": "loop16-fix"},
value_propagation={"Constant": "constant13"},
out_variadic_solutions=V18_OUT_VARIADIC_SOLUTIONS,
subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS,
Expand All @@ -715,6 +715,7 @@ def main(
"ai.onnx",
21,
extras=["const"],
type_inference={"Compress": "compress11", "Loop": "loop16-fix"},
value_propagation={"Constant": "constant13"},
out_variadic_solutions=V18_OUT_VARIADIC_SOLUTIONS,
subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS,
Expand Down
3 changes: 3 additions & 0 deletions tools/templates/preamble.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@ from spox._internal_op import intro
from spox._node import OpType
from spox._standard import InferenceError, StandardNode
from spox._type_system import Tensor, Type, Sequence as SpoxSequence
from spox._type_inference_utils import loop_erase_shape_info
from spox._value_prop import PropValueType, PropDict
from spox._shape import Shape

20 changes: 14 additions & 6 deletions tools/templates/type_inference/loop16-fix.jinja2
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
output_types = super().infer_output_types({})
output_types = super().infer_output_types(input_prop_values)
output_names = list(self.outputs.get_var_infos())

body = self.attrs.body.value
n = len(body.requested_arguments) - 2

carried_names = list(self.outputs.get_var_infos())[:n]
carried_types = [v.type for v in list(body.requested_results.values())[1:][:n]]
# We skip the iteration_num and condition as they are correctly inferred
initial_types = [v.type for v in list(body.requested_arguments)[2:]]
# We skip the returned condition as it is correctly inferred
carried_types = [v.type for v in list(body.requested_results.values())[1:]]

for name, typ in zip(carried_names, carried_types):
output_types[name] = typ
shape_unchanged_between_iterations = all(
i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types)
)

for name, _, c_typ in zip(output_names, initial_types, carried_types):
output_types[name] = (
c_typ if shape_unchanged_between_iterations else loop_erase_shape_info(c_typ)
)

return output_types

0 comments on commit 0741557

Please sign in to comment.