-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init Fix inconsistencies Signed-off-by: neNasko1 <[email protected]> Add CHANGELOG Signed-off-by: neNasko1 <[email protected]> Fix docstring Signed-off-by: neNasko1 <[email protected]> Apply suggestions from code review Co-authored-by: Aditya Goel <[email protected]> Comments after code review Signed-off-by: neNasko1 <[email protected]> Comments after code review Signed-off-by: neNasko1 <[email protected]> Fix license Signed-off-by: neNasko1 <[email protected]> * Fix diff Signed-off-by: neNasko1 <[email protected]> * Update CHANGELOG.rst Co-authored-by: Aditya Goel <[email protected]> --------- Signed-off-by: neNasko1 <[email protected]> Co-authored-by: Aditya Goel <[email protected]>
- Loading branch information
1 parent
d7c1792
commit 0741557
Showing
10 changed files
with
208 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) QuantCo 2023-2025 | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from spox._standard import InferenceError | ||
from spox._type_system import Optional, Sequence, Tensor, Type | ||
|
||
|
||
def loop_erase_shape_info(typ: Type) -> Type: | ||
"""Erases the shape information for a type, that can exists as a state variable in a Loop""" | ||
if isinstance(typ, Tensor): | ||
return Tensor(typ.dtype, None) | ||
elif isinstance(typ, Sequence): | ||
if not isinstance(typ.elem_type, Tensor): | ||
raise InferenceError( | ||
f"Type {typ} not allowed for state variables in Loop operator, sequence element can only be a tensor" | ||
) | ||
return Sequence(loop_erase_shape_info(typ.elem_type)) | ||
elif isinstance(typ, Optional): | ||
if isinstance(typ.elem_type, Optional): | ||
raise InferenceError( | ||
f"Type {typ} not allowed for state variables in Loop operator, optionals of optionals are not allowed" | ||
) | ||
return Optional(loop_erase_shape_info(typ.elem_type)) | ||
raise InferenceError( | ||
f"Type {typ} not allowed for state variables in Loop operator." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,96 @@ | ||
# Copyright (c) QuantCo 2023-2024 | ||
# Copyright (c) QuantCo 2023-2025 | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import spox.opset.ai.onnx.v17 as op | ||
from spox import Tensor, argument | ||
import numpy as np | ||
import pytest | ||
|
||
import spox.opset.ai.onnx.v17 as op17 | ||
import spox.opset.ai.onnx.v19 as op19 | ||
import spox.opset.ai.onnx.v21 as op21 | ||
from spox import Optional, Sequence, Tensor, argument | ||
|
||
def test_loop_inference(): | ||
|
||
@pytest.mark.parametrize("op", [op17, op19, op21]) | ||
def test_loop_inference(op): | ||
x, y, zs = op.loop( | ||
v_initial=[argument(Tensor(float, (None,))), argument(Tensor(int, ("N", 2)))], | ||
v_initial=[ | ||
argument(Tensor(np.float64, (None,))), | ||
argument(Tensor(np.int64, ("N", 2))), | ||
], | ||
body=lambda i, c, a, b: [op.const(True), a, op.add(i, b), i], | ||
) | ||
assert x.type == Tensor(float, (None,)) | ||
assert y.type == Tensor(int, ("N", 2)) | ||
assert zs.type == Tensor(int, (None, 1)) | ||
assert x.type == Tensor(np.float64, (None,)) | ||
assert y.type == Tensor(np.int64, ("N", 2)) | ||
assert zs.type == Tensor(np.int64, (None, 1)) | ||
|
||
|
||
@pytest.mark.parametrize("op", [op17, op19, op21]) | ||
def test_loop_concat(op): | ||
num_iters = op.const(1) | ||
v = op.const([], dtype=np.int64) | ||
|
||
result = op.loop( | ||
num_iters, | ||
v_initial=[v], | ||
body=lambda i, c, x: (op.const(True), op.concat([x, op.const([1])], axis=0)), | ||
)[0] | ||
|
||
# type can change, so we cannot infer anything | ||
assert result.type == Tensor(np.int64, None) | ||
|
||
|
||
@pytest.mark.parametrize("op", [op17, op19, op21]) | ||
def test_loop_sequence(op): | ||
num_iters = op.const(1) | ||
v = op.sequence_empty(dtype=np.int64) | ||
|
||
result = op.loop( | ||
num_iters, | ||
v_initial=[v], | ||
body=lambda i, c, x: (op.const(True), op.sequence_insert(x, op.const([1]))), | ||
)[0] | ||
|
||
assert result.type == Sequence(Tensor(np.int64, None)) | ||
|
||
|
||
@pytest.mark.parametrize("op", [op17, op19, op21]) | ||
def test_loop_optional(op): | ||
num_iters = op.const(1) | ||
v = op.optional(type=Tensor(np.int64, (1, 2))) | ||
|
||
result = op.loop( | ||
num_iters, | ||
v_initial=[v], | ||
body=lambda i, c, x: ( | ||
op.const(True), | ||
op.if_( | ||
op.optional_has_element(x), | ||
then_branch=lambda: [op.optional(type=Tensor(np.int64, (1, 2)))], | ||
else_branch=lambda: [op.optional(op.const([[1, 1]]))], | ||
)[0], | ||
), | ||
)[0] | ||
|
||
assert result.type == Optional(Tensor(np.int64, (1, 2))) | ||
|
||
|
||
@pytest.mark.parametrize("op", [op17, op19, op21]) | ||
def test_loop_optional_no_shape(op): | ||
num_iters = op.const(1) | ||
v = op.optional(type=Tensor(np.int64, (1, 2))) | ||
|
||
result = op.loop( | ||
num_iters, | ||
v_initial=[v], | ||
body=lambda i, c, x: ( | ||
op.const(True), | ||
op.if_( | ||
op.optional_has_element(x), | ||
then_branch=lambda: [op.optional(type=Tensor(np.int64, (1, 2)))], | ||
else_branch=lambda: [op.optional(op.const([[1]]))], | ||
)[0], | ||
), | ||
)[0] | ||
|
||
# shape can change, we cannot infer type | ||
assert result.type == Optional(Tensor(np.int64, None)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,20 @@ | ||
output_types = super().infer_output_types({}) | ||
output_types = super().infer_output_types(input_prop_values) | ||
output_names = list(self.outputs.get_var_infos()) | ||
|
||
body = self.attrs.body.value | ||
n = len(body.requested_arguments) - 2 | ||
|
||
carried_names = list(self.outputs.get_var_infos())[:n] | ||
carried_types = [v.type for v in list(body.requested_results.values())[1:][:n]] | ||
# We skip the iteration_num and condition as they are correctly inferred | ||
initial_types = [v.type for v in list(body.requested_arguments)[2:]] | ||
# We skip the returned condition as it is correctly inferred | ||
carried_types = [v.type for v in list(body.requested_results.values())[1:]] | ||
|
||
for name, typ in zip(carried_names, carried_types): | ||
output_types[name] = typ | ||
shape_unchanged_between_iterations = all( | ||
i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types) | ||
) | ||
|
||
for name, _, c_typ in zip(output_names, initial_types, carried_types): | ||
output_types[name] = ( | ||
c_typ if shape_unchanged_between_iterations else loop_erase_shape_info(c_typ) | ||
) | ||
|
||
return output_types |