diff --git a/src/awkward/operations/ak_transform.py b/src/awkward/operations/ak_transform.py index 011739f017..f7efc6ed58 100644 --- a/src/awkward/operations/ak_transform.py +++ b/src/awkward/operations/ak_transform.py @@ -26,6 +26,7 @@ def transform( numpy_to_regular=False, regular_to_jagged=False, return_value="simplified", + expect_return_value=False, highlevel=True, behavior=None, ): @@ -73,6 +74,8 @@ def transform( nodes are not nested inappropriately. Note that if `return_value` is `"none"`, the only way to get information out of this function is through the `lateral_context`. + expect_return_value (bool): If True, raise a `RuntimeError` if the transformer + does not terminate the recursion. highlevel (bool): If True, return an #ak.Array; otherwise, return a low-level #ak.contents.Content subclass. behavior (None or dict): Custom #ak.behavior for the output array, if @@ -428,6 +431,7 @@ def transform( numpy_to_regular, regular_to_jagged, return_value, + expect_return_value, behavior, highlevel, ) @@ -446,6 +450,7 @@ def _impl( numpy_to_regular, regular_to_jagged, return_value, + expect_return_value, behavior, highlevel, ): @@ -472,17 +477,22 @@ def _impl( "return_array": return_value != "none", "function_name": "ak.transform", "broadcast_parameters_rule": broadcast_parameters_rule, + "expect_return_value": expect_return_value, } + transformer_did_terminate = False + if len(more_layouts) == 0: def action(layout, **kwargs): + nonlocal transformer_did_terminate out = transformation(layout, **kwargs) if out is None: return out elif isinstance(out, ak.contents.Content): + transformer_did_terminate = True return out else: @@ -501,12 +511,20 @@ def action(layout, **kwargs): options, ) - if return_value != "none": + if return_value == "none": + return None + elif expect_return_value and not transformer_did_terminate: + raise RuntimeError( + "the transformation function was expected to terminate by returning a Content, " + "but instead only returned None." + ) + else: return wrap_layout(out, behavior, highlevel) else: def action(inputs, **kwargs): + nonlocal transformer_did_terminate out = transformation(tuple(inputs), **kwargs) if out is None: @@ -516,14 +534,17 @@ def action(inputs, **kwargs): return None elif isinstance(out, tuple): + transformer_did_terminate = True for x in out: if not isinstance(x, ak.contents.Content): raise TypeError( - f"transformation must return a Content, tuple of Contents, or None, not a tuple containing {type(x)}\n\n{x!r}" + f"transformation must return a Content, tuple of Contents, or None, " + f"not a tuple containing {type(x)}\n\n{x!r}" ) return out elif isinstance(out, ak.contents.Content): + transformer_did_terminate = True return (out,) else: @@ -546,8 +567,14 @@ def action(inputs, **kwargs): assert isinstance(out, tuple) out = [ak._broadcasting.broadcast_unpack(x, isscalar, backend) for x in out] - if return_value != "none": - if len(out) == 1: - return wrap_layout(out[0], behavior, highlevel) - else: - return tuple(wrap_layout(x, behavior, highlevel) for x in out) + if return_value == "none": + return + elif expect_return_value and not transformer_did_terminate: + raise RuntimeError( + "the transformation function was expected to terminate by returning a Content, " + "or tuple of Contents, but instead only returned None." + ) + elif len(out) == 1: + return wrap_layout(out[0], behavior, highlevel) + else: + return tuple(wrap_layout(x, behavior, highlevel) for x in out) diff --git a/tests/test_2595_transform_termination.py b/tests/test_2595_transform_termination.py new file mode 100644 index 0000000000..5e2b7e9dc4 --- /dev/null +++ b/tests/test_2595_transform_termination.py @@ -0,0 +1,73 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE + +import pytest + +import awkward as ak + + +def test_single_no_termination(): + def transform(layout, **kwargs): + pass + + with pytest.raises(RuntimeError, match=r"expected to terminate"): + ak.transform(transform, [{"x": 1}], expect_return_value=True) + + result = ak.transform(transform, [{"x": 1}], expect_return_value=False) + assert result.to_list() == [{"x": 1}] + + result = ak.transform( + transform, [{"x": 1}], expect_return_value=True, return_value="none" + ) + assert result is None + + +def test_single_termination(): + def transform(layout, **kwargs): + if layout.is_numpy: + return ak.contents.NumpyArray(layout.data * 2) + + result = ak.transform(transform, [{"x": 1}], expect_return_value=True) + assert result.to_list() == [{"x": 2}] + + result = ak.transform(transform, [{"x": 1}], expect_return_value=False) + assert result.to_list() == [{"x": 2}] + + ak.transform(transform, [{"x": 1}], expect_return_value=True, return_value="none") + assert result.to_list() == [{"x": 2}] + + +def test_many_no_termination(): + def transform(layout, **kwargs): + pass + + with pytest.raises(RuntimeError, match=r"expected to terminate"): + ak.transform(transform, [{"x": 1}], [2], expect_return_value=True) + + result = ak.transform(transform, [{"x": 1}], [2], expect_return_value=False) + assert result[0].to_list() == [{"x": 1}] + assert result[1].to_list() == [{"x": 2}] + + result = ak.transform( + transform, [{"x": 1}], [2], expect_return_value=True, return_value="none" + ) + assert result is None + + +def test_many_termination(): + def transform(inputs, **kwargs): + if all(layout.is_numpy for layout in inputs): + return tuple([ak.contents.NumpyArray(layout.data * 2) for layout in inputs]) + + result = ak.transform(transform, [{"x": 1}], [2], expect_return_value=True) + assert result[0].to_list() == [{"x": 2}] + assert result[1].to_list() == [{"x": 4}] + + result = ak.transform(transform, [{"x": 1}], [2], expect_return_value=False) + assert result[0].to_list() == [{"x": 2}] + assert result[1].to_list() == [{"x": 4}] + + ak.transform( + transform, [{"x": 1}], [2], expect_return_value=True, return_value="none" + ) + assert result[0].to_list() == [{"x": 2}] + assert result[1].to_list() == [{"x": 4}]