From 54b4d7444df35a7a349bba68c5d44cfdbac1830e Mon Sep 17 00:00:00 2001 From: Doug Davis Date: Wed, 29 Nov 2023 16:23:35 -0600 Subject: [PATCH] addressing some comments Angus --- src/dask_awkward/layers/layers.py | 15 +++-- src/dask_awkward/lib/io/io.py | 106 +++++++++++++++++++++--------- tests/test_io.py | 4 +- 3 files changed, 86 insertions(+), 39 deletions(-) diff --git a/src/dask_awkward/layers/layers.py b/src/dask_awkward/layers/layers.py index 5ab32095a..ce1773559 100644 --- a/src/dask_awkward/layers/layers.py +++ b/src/dask_awkward/layers/layers.py @@ -4,6 +4,7 @@ from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union, cast +import awkward as ak from dask.blockwise import Blockwise, BlockwiseDepDict, blockwise_token from dask.highlevelgraph import MaterializedLayer from dask.layers import DataFrameTreeReduction @@ -107,13 +108,13 @@ def mock_empty(self, backend: BackendT = "cpu") -> AwkwardArray: ) -def io_func_rer_wrapped(func: ImplementsIOFunction) -> bool: - return hasattr(func, "__rer_wrapped__") +def io_func_empty_on_error_wrapped(func: ImplementsIOFunction) -> bool: + return hasattr(func, "_empty_on_error_wrapped") def maybe_unwrap(func: Callable) -> Callable: - if io_func_rer_wrapped(func): - return func.__rer_wrapped__ # type: ignore + if io_func_empty_on_error_wrapped(func): + return func._empty_on_error_wrapped # type: ignore return func @@ -242,8 +243,8 @@ def prepare_for_projection(self) -> tuple[AwkwardInputLayer, TypeTracerReport, T ImplementsProjection, fn ).prepare_for_projection() - if io_func_rer_wrapped(self.io_func): - new_return = (new_meta_array, type(new_meta_array)([])) + if io_func_empty_on_error_wrapped(self.io_func): + new_return = (new_meta_array, ak.from_iter([])) else: new_return = new_meta_array @@ -267,7 +268,7 @@ def project( fn = maybe_unwrap(self.io_func) io_func = cast(ImplementsProjection, fn).project(report=report, state=state) - if io_func_rer_wrapped(self.io_func): + if io_func_empty_on_error_wrapped(self.io_func): io_func = self.io_func.recreate(io_func) # type: ignore return AwkwardInputLayer( diff --git a/src/dask_awkward/lib/io/io.py b/src/dask_awkward/lib/io/io.py index 543605e72..53f8fc852 100644 --- a/src/dask_awkward/lib/io/io.py +++ b/src/dask_awkward/lib/io/io.py @@ -8,6 +8,9 @@ import awkward as ak import numpy as np +from awkward.forms.listoffsetform import ListOffsetForm +from awkward.forms.numpyform import NumpyForm +from awkward.forms.recordform import RecordForm from awkward.types.numpytype import primitive_to_dtype from awkward.typetracer import length_zero_if_typetracer from dask.base import flatten, tokenize @@ -496,20 +499,47 @@ def __call__(self, packed_arg): ) -def default_report_success(*args: Any, **kwargs: Any) -> ak.Array: - return ak.Array( - [ - { - "args": [], - "kwargs": [], - "exception": "", - "message": "", - }, - ], - ) +_default_failure_array_form = RecordForm( + [ + ListOffsetForm( + "i64", + ListOffsetForm( + "i64", + NumpyForm("uint8", parameters={"__array__": "char"}), + parameters={"__array__": "string"}, + ), + ), + ListOffsetForm( + "i64", + ListOffsetForm( + "i64", + ListOffsetForm( + "i64", + NumpyForm("uint8", parameters={"__array__": "char"}), + parameters={"__array__": "string"}, + ), + ), + ), + ListOffsetForm( + "i64", + NumpyForm("uint8", parameters={"__array__": "char"}), + parameters={"__array__": "string"}, + ), + ListOffsetForm( + "i64", + NumpyForm("uint8", parameters={"__array__": "char"}), + parameters={"__array__": "string"}, + ), + ], + ["args", "kwargs", "exception", "message"], +) + +def on_success_default(*args: Any, **kwargs: Any) -> ak.Array: + return ak.Array(_default_failure_array_form.length_one_array(highlevel=False)) -def default_report_failure( + +def on_failure_default( exception: type[BaseException], *args: Any, **kwargs: Any, @@ -526,38 +556,54 @@ def default_report_failure( ) -class return_empty_on_raise: +class ReturnEmptyOnRaise: def __init__( self, fn: Callable[..., ak.Array], allowed_exceptions: tuple[type[BaseException], ...], backend: BackendT, - success_callback: Callable[..., ak.Array], - failure_callback: Callable[..., ak.Array], + on_success: Callable[..., ak.Array], + on_failure: Callable[..., ak.Array], ): - self.__rer_wrapped__ = fn + self._empty_on_error_wrapped = fn self.fn = fn self.allowed_exceptions = allowed_exceptions self.backend = backend - self.success_callback = success_callback - self.failure_callback = failure_callback + self.on_success = on_success + self.on_failure = on_failure def recreate(self, fn): return return_empty_on_raise( fn, self.allowed_exceptions, self.backend, - self.success_callback, - self.failure_callback, + self.on_success, + self.on_failure, ) def __call__(self, *args, **kwargs): try: result = self.fn(*args, **kwargs) - return result, self.success_callback(*args, **kwargs) + return result, self.on_success(*args, **kwargs) except self.allowed_exceptions as err: result = self.fn.mock_empty(self.backend) - return result, self.failure_callback(err, *args, **kwargs) + return result, self.on_failure(err, *args, **kwargs) + + +def return_empty_on_raise( + fn: Callable[..., ak.Array], + allowed_exceptions: tuple[type[BaseException], ...], + backend: BackendT, + on_success: Callable[..., ak.Array], + on_failure: Callable[..., ak.Array], +) -> ReturnEmptyOnRaise: + return ReturnEmptyOnRaise( + fn, + allowed_exceptions, + backend, + on_success, + on_failure, + ) @overload @@ -571,8 +617,8 @@ def from_map( meta: ak.Array | None = None, empty_on_raise: None = None, empty_backend: None = None, - empty_success_callback: Callable[..., ak.Array] = default_report_success, - empty_failure_callback: Callable[..., ak.Array] = default_report_failure, + on_success: Callable[..., ak.Array] = on_success_default, + on_failure: Callable[..., ak.Array] = on_failure_default, **kwargs: Any, ) -> Array: ... @@ -589,8 +635,8 @@ def from_map( token: str | None = None, divisions: tuple[int, ...] | tuple[None, ...] | None = None, meta: ak.Array | None = None, - empty_success_callback: Callable[..., ak.Array] = default_report_success, - empty_failure_callback: Callable[..., ak.Array] = default_report_failure, + on_success: Callable[..., ak.Array] = on_success_default, + on_failure: Callable[..., ak.Array] = on_failure_default, **kwargs: Any, ) -> tuple[Array, Array]: ... @@ -606,8 +652,8 @@ def from_map( meta: ak.Array | None = None, empty_on_raise: tuple[type[BaseException], ...] | None = None, empty_backend: BackendT | None = None, - empty_success_callback: Callable[..., ak.Array] = default_report_success, - empty_failure_callback: Callable[..., ak.Array] = default_report_failure, + on_success: Callable[..., ak.Array] = on_success_default, + on_failure: Callable[..., ak.Array] = on_failure_default, **kwargs: Any, ) -> Array | tuple[Array, Array]: """Create an Array collection from a custom mapping. @@ -731,8 +777,8 @@ def from_map( io_func, allowed_exceptions=empty_on_raise, backend=empty_backend, - success_callback=empty_success_callback, - failure_callback=empty_failure_callback, + on_success=on_success, + on_failure=on_failure, ) dsk = AwkwardInputLayer(name=name, inputs=inputs, io_func=io_func) diff --git a/tests/test_io.py b/tests/test_io.py index 11950dd17..d7c69d5f4 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -460,8 +460,8 @@ def fail(excep, *args, **kwargs): label="from-lists", empty_on_raise=(OSError,), empty_backend="cpu", - empty_failure_callback=fail, - empty_success_callback=succ, + on_failure=fail, + on_success=succ, ) _, rep = dask.compute(array, report)