Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for expected termination of transform functions #2595

Merged
merged 6 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions src/awkward/operations/ak_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def transform(
numpy_to_regular=False,
regular_to_jagged=False,
return_value="simplified",
expect_return_value=False,
highlevel=True,
behavior=None,
):
Expand Down Expand Up @@ -73,6 +74,8 @@ def transform(
nodes are not nested inappropriately. Note that if `return_value` is `"none"`,
the only way to get information out of this function is through the
`lateral_context`.
expect_return_value (bool): If True, raise a `RuntimeError` if the transformer
does not terminate the recursion.
highlevel (bool): If True, return an #ak.Array; otherwise, return
a low-level #ak.contents.Content subclass.
behavior (None or dict): Custom #ak.behavior for the output array, if
Expand Down Expand Up @@ -428,6 +431,7 @@ def transform(
numpy_to_regular,
regular_to_jagged,
return_value,
expect_return_value,
behavior,
highlevel,
)
Expand All @@ -446,6 +450,7 @@ def _impl(
numpy_to_regular,
regular_to_jagged,
return_value,
expect_return_value,
behavior,
highlevel,
):
Expand All @@ -472,17 +477,22 @@ def _impl(
"return_array": return_value != "none",
"function_name": "ak.transform",
"broadcast_parameters_rule": broadcast_parameters_rule,
"expect_return_value": expect_return_value,
}

transformer_did_terminate = False

if len(more_layouts) == 0:

def action(layout, **kwargs):
nonlocal transformer_did_terminate
out = transformation(layout, **kwargs)

if out is None:
return out

elif isinstance(out, ak.contents.Content):
transformer_did_terminate = True
return out

else:
Expand All @@ -501,12 +511,20 @@ def action(layout, **kwargs):
options,
)

if return_value != "none":
if return_value == "none":
return None
elif expect_return_value and not transformer_did_terminate:
raise RuntimeError(
"the transformation function was expected to terminate by returning a Content, "
"but instead only returned None."
)
else:
return wrap_layout(out, behavior, highlevel)

else:

def action(inputs, **kwargs):
nonlocal transformer_did_terminate
out = transformation(tuple(inputs), **kwargs)

if out is None:
Expand All @@ -516,14 +534,17 @@ def action(inputs, **kwargs):
return None

elif isinstance(out, tuple):
transformer_did_terminate = True
for x in out:
if not isinstance(x, ak.contents.Content):
raise TypeError(
f"transformation must return a Content, tuple of Contents, or None, not a tuple containing {type(x)}\n\n{x!r}"
f"transformation must return a Content, tuple of Contents, or None, "
f"not a tuple containing {type(x)}\n\n{x!r}"
)
return out

elif isinstance(out, ak.contents.Content):
transformer_did_terminate = True
return (out,)

else:
Expand All @@ -546,8 +567,14 @@ def action(inputs, **kwargs):
assert isinstance(out, tuple)
out = [ak._broadcasting.broadcast_unpack(x, isscalar, backend) for x in out]

if return_value != "none":
if len(out) == 1:
return wrap_layout(out[0], behavior, highlevel)
else:
return tuple(wrap_layout(x, behavior, highlevel) for x in out)
if return_value == "none":
return
elif expect_return_value and not transformer_did_terminate:
raise RuntimeError(
"the transformation function was expected to terminate by returning a Content, "
"or tuple of Contents, but instead only returned None."
)
elif len(out) == 1:
return wrap_layout(out[0], behavior, highlevel)
else:
return tuple(wrap_layout(x, behavior, highlevel) for x in out)
73 changes: 73 additions & 0 deletions tests/test_2595_transform_termination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import pytest

import awkward as ak


def test_single_no_termination():
def transform(layout, **kwargs):
pass

with pytest.raises(RuntimeError, match=r"expected to terminate"):
ak.transform(transform, [{"x": 1}], expect_return_value=True)

result = ak.transform(transform, [{"x": 1}], expect_return_value=False)
assert result.to_list() == [{"x": 1}]

result = ak.transform(
transform, [{"x": 1}], expect_return_value=True, return_value="none"
)
assert result is None


def test_single_termination():
def transform(layout, **kwargs):
if layout.is_numpy:
return ak.contents.NumpyArray(layout.data * 2)

result = ak.transform(transform, [{"x": 1}], expect_return_value=True)
assert result.to_list() == [{"x": 2}]

result = ak.transform(transform, [{"x": 1}], expect_return_value=False)
assert result.to_list() == [{"x": 2}]

ak.transform(transform, [{"x": 1}], expect_return_value=True, return_value="none")
assert result.to_list() == [{"x": 2}]


def test_many_no_termination():
def transform(layout, **kwargs):
pass

with pytest.raises(RuntimeError, match=r"expected to terminate"):
ak.transform(transform, [{"x": 1}], [2], expect_return_value=True)

result = ak.transform(transform, [{"x": 1}], [2], expect_return_value=False)
assert result[0].to_list() == [{"x": 1}]
assert result[1].to_list() == [{"x": 2}]

result = ak.transform(
transform, [{"x": 1}], [2], expect_return_value=True, return_value="none"
)
assert result is None


def test_many_termination():
def transform(inputs, **kwargs):
if all(layout.is_numpy for layout in inputs):
return tuple([ak.contents.NumpyArray(layout.data * 2) for layout in inputs])

result = ak.transform(transform, [{"x": 1}], [2], expect_return_value=True)
assert result[0].to_list() == [{"x": 2}]
assert result[1].to_list() == [{"x": 4}]

result = ak.transform(transform, [{"x": 1}], [2], expect_return_value=False)
assert result[0].to_list() == [{"x": 2}]
assert result[1].to_list() == [{"x": 4}]

ak.transform(
transform, [{"x": 1}], [2], expect_return_value=True, return_value="none"
)
assert result[0].to_list() == [{"x": 2}]
assert result[1].to_list() == [{"x": 4}]