diff --git a/src/dask_awkward/layers/layers.py b/src/dask_awkward/layers/layers.py index c33a54a6..ce177355 100644 --- a/src/dask_awkward/layers/layers.py +++ b/src/dask_awkward/layers/layers.py @@ -4,6 +4,7 @@ from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union, cast +import awkward as ak from dask.blockwise import Blockwise, BlockwiseDepDict, blockwise_token from dask.highlevelgraph import MaterializedLayer from dask.layers import DataFrameTreeReduction @@ -107,20 +108,30 @@ def mock_empty(self, backend: BackendT = "cpu") -> AwkwardArray: ) +def io_func_empty_on_error_wrapped(func: ImplementsIOFunction) -> bool: + return hasattr(func, "_empty_on_error_wrapped") + + +def maybe_unwrap(func: Callable) -> Callable: + if io_func_empty_on_error_wrapped(func): + return func._empty_on_error_wrapped # type: ignore + return func + + def io_func_implements_projection(func: ImplementsIOFunction) -> bool: - return hasattr(func, "prepare_for_projection") + return hasattr(maybe_unwrap(func), "prepare_for_projection") def io_func_implements_mocking(func: ImplementsIOFunction) -> bool: - return hasattr(func, "mock") + return hasattr(maybe_unwrap(func), "mock") def io_func_implements_mock_empty(func: ImplementsIOFunction) -> bool: - return hasattr(func, "mock_empty") + return hasattr(maybe_unwrap(func), "mock_empty") def io_func_implements_columnar(func: ImplementsIOFunction) -> bool: - return hasattr(func, "necessary_columns") + return hasattr(maybe_unwrap(func), "necessary_columns") class AwkwardInputLayer(AwkwardBlockwiseLayer): @@ -184,10 +195,12 @@ def is_columnar(self) -> bool: def mock(self) -> AwkwardInputLayer: assert self.is_mockable + fn = maybe_unwrap(self.io_func) + return AwkwardInputLayer( name=self.name, inputs=[None][: int(list(self.numblocks.values())[0][0])], - io_func=lambda *_, **__: cast(ImplementsMocking, self.io_func).mock(), + io_func=lambda *_, **__: cast(ImplementsMocking, fn).mock(), label=self.label, produces_tasks=self.produces_tasks, creation_info=self.creation_info, @@ -225,14 +238,20 @@ def prepare_for_projection(self) -> tuple[AwkwardInputLayer, TypeTracerReport, T The black-box state object returned by the IO function. """ assert self.is_projectable + fn = maybe_unwrap(self.io_func) new_meta_array, report, state = cast( - ImplementsProjection, self.io_func + ImplementsProjection, fn ).prepare_for_projection() + if io_func_empty_on_error_wrapped(self.io_func): + new_return = (new_meta_array, ak.from_iter([])) + else: + new_return = new_meta_array + new_input_layer = AwkwardInputLayer( name=self.name, inputs=[None][: int(list(self.numblocks.values())[0][0])], - io_func=lambda *_, **__: new_meta_array, + io_func=lambda *_, **__: new_return, label=self.label, produces_tasks=self.produces_tasks, creation_info=self.creation_info, @@ -246,12 +265,16 @@ def project( state: T, ) -> AwkwardInputLayer: assert self.is_projectable + fn = maybe_unwrap(self.io_func) + io_func = cast(ImplementsProjection, fn).project(report=report, state=state) + + if io_func_empty_on_error_wrapped(self.io_func): + io_func = self.io_func.recreate(io_func) # type: ignore + return AwkwardInputLayer( name=self.name, inputs=self.inputs, - io_func=cast(ImplementsProjection, self.io_func).project( - report=report, state=state - ), + io_func=io_func, label=self.label, produces_tasks=self.produces_tasks, creation_info=self.creation_info, @@ -260,7 +283,8 @@ def project( def necessary_columns(self, report: TypeTracerReport, state: T) -> frozenset[str]: assert self.is_columnar - return cast(ImplementsNecessaryColumns, self.io_func).necessary_columns( + fn = maybe_unwrap(self.io_func) + return cast(ImplementsNecessaryColumns, fn).necessary_columns( report=report, state=state ) diff --git a/src/dask_awkward/lib/io/io.py b/src/dask_awkward/lib/io/io.py index 246de43e..46d09fca 100644 --- a/src/dask_awkward/lib/io/io.py +++ b/src/dask_awkward/lib/io/io.py @@ -1,14 +1,16 @@ from __future__ import annotations -import functools import logging import math from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, cast, overload import awkward as ak import numpy as np +from awkward.forms.listoffsetform import ListOffsetForm +from awkward.forms.numpyform import NumpyForm +from awkward.forms.recordform import RecordForm from awkward.types.numpytype import primitive_to_dtype from awkward.typetracer import length_zero_if_typetracer from dask.base import flatten, tokenize @@ -32,6 +34,7 @@ new_array_object, typetracer_array, ) +from dask_awkward.utils import first, second if TYPE_CHECKING: from dask.array.core import Array as DaskArray @@ -496,29 +499,147 @@ def __call__(self, packed_arg): ) +_default_failure_array_form = RecordForm( + [ + ListOffsetForm( + "i64", + ListOffsetForm( + "i64", + NumpyForm("uint8", parameters={"__array__": "char"}), + parameters={"__array__": "string"}, + ), + ), + ListOffsetForm( + "i64", + ListOffsetForm( + "i64", + ListOffsetForm( + "i64", + NumpyForm("uint8", parameters={"__array__": "char"}), + parameters={"__array__": "string"}, + ), + ), + ), + ListOffsetForm( + "i64", + NumpyForm("uint8", parameters={"__array__": "char"}), + parameters={"__array__": "string"}, + ), + ListOffsetForm( + "i64", + NumpyForm("uint8", parameters={"__array__": "char"}), + parameters={"__array__": "string"}, + ), + ], + ["args", "kwargs", "exception", "message"], +) + + +def on_success_default(*args: Any, **kwargs: Any) -> ak.Array: + return ak.Array(_default_failure_array_form.length_one_array(highlevel=False)) + + +def on_failure_default( + exception: type[BaseException], + *args: Any, + **kwargs: Any, +) -> ak.Array: + return ak.Array( + [ + { + "args": [repr(a) for a in args], + "kwargs": [[k, repr(v)] for k, v in kwargs.items()], + "exception": type(exception).__name__, + "message": str(exception), + }, + ], + ) + + +class ReturnEmptyOnRaise: + def __init__( + self, + fn: Callable[..., ak.Array], + allowed_exceptions: tuple[type[BaseException], ...], + backend: BackendT, + on_success: Callable[..., ak.Array], + on_failure: Callable[..., ak.Array], + ): + self._empty_on_error_wrapped = fn + self.fn = fn + self.allowed_exceptions = allowed_exceptions + self.backend = backend + self.on_success = on_success + self.on_failure = on_failure + + def recreate(self, fn): + return return_empty_on_raise( + fn, + self.allowed_exceptions, + self.backend, + self.on_success, + self.on_failure, + ) + + def __call__(self, *args, **kwargs): + try: + result = self.fn(*args, **kwargs) + return result, self.on_success(*args, **kwargs) + except self.allowed_exceptions as err: + result = self.fn.mock_empty(self.backend) + return result, self.on_failure(err, *args, **kwargs) + + def return_empty_on_raise( - fn: Callable, + fn: Callable[..., ak.Array], allowed_exceptions: tuple[type[BaseException], ...], backend: BackendT, -) -> Callable: - @functools.wraps(fn) - def wrapped(*args, **kwargs): - try: - return fn(*args, **kwargs) - except allowed_exceptions as err: - logmsg = ( - "%s call failed with args %s and kwargs %s; empty array returned. %s" - % ( - str(fn), - str(args), - str(kwargs), - str(err), - ) - ) - logger.info(logmsg) - return fn.mock_empty(backend) + on_success: Callable[..., ak.Array], + on_failure: Callable[..., ak.Array], +) -> ReturnEmptyOnRaise: + return ReturnEmptyOnRaise( + fn, + allowed_exceptions, + backend, + on_success, + on_failure, + ) + + +@overload +def from_map( + func: Callable, + *iterables: Iterable, + args: tuple[Any, ...] | None = None, + label: str | None = None, + token: str | None = None, + divisions: tuple[int, ...] | tuple[None, ...] | None = None, + meta: ak.Array | None = None, + empty_on_raise: None = None, + empty_backend: None = None, + on_success: Callable[..., ak.Array] = on_success_default, + on_failure: Callable[..., ak.Array] = on_failure_default, + **kwargs: Any, +) -> Array: + ... - return wrapped + +@overload +def from_map( + func: Callable, + *iterables: Iterable, + empty_on_raise: tuple[type[BaseException], ...], + empty_backend: BackendT, + args: tuple[Any, ...] | None = None, + label: str | None = None, + token: str | None = None, + divisions: tuple[int, ...] | tuple[None, ...] | None = None, + meta: ak.Array | None = None, + on_success: Callable[..., ak.Array] = on_success_default, + on_failure: Callable[..., ak.Array] = on_failure_default, + **kwargs: Any, +) -> tuple[Array, Array]: + ... def from_map( @@ -531,8 +652,10 @@ def from_map( meta: ak.Array | None = None, empty_on_raise: tuple[type[BaseException], ...] | None = None, empty_backend: BackendT | None = None, + on_success: Callable[..., ak.Array] = on_success_default, + on_failure: Callable[..., ak.Array] = on_failure_default, **kwargs: Any, -) -> Array: +) -> Array | tuple[Array, Array]: """Create an Array collection from a custom mapping. Parameters @@ -654,6 +777,8 @@ def from_map( io_func, allowed_exceptions=empty_on_raise, backend=empty_backend, + on_success=on_success, + on_failure=on_failure, ) dsk = AwkwardInputLayer(name=name, inputs=inputs, io_func=io_func) @@ -664,6 +789,11 @@ def from_map( else: result = new_array_object(hlg, name, meta=array_meta, npartitions=len(inputs)) + if empty_on_raise and empty_backend: + res = result.map_partitions(first, meta=array_meta, output_divisions=1) + rep = result.map_partitions(second, meta=empty_typetracer()) + return res, rep + return result diff --git a/src/dask_awkward/utils.py b/src/dask_awkward/utils.py index 19bb69ce..750ada95 100644 --- a/src/dask_awkward/utils.py +++ b/src/dask_awkward/utils.py @@ -143,3 +143,9 @@ def first(seq: Iterable[T]) -> T: """ return next(iter(seq)) + + +def second(seq: Iterable[T]) -> T: + the_iter = iter(seq) + next(the_iter) + return next(the_iter) diff --git a/tests/test_io.py b/tests/test_io.py index 580ad5c1..d7c69d5f 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -3,6 +3,7 @@ from pathlib import Path import awkward as ak +import dask import numpy as np import pytest from dask.array.utils import assert_eq as da_assert_eq @@ -344,7 +345,7 @@ def test_bytes_with_sample( assert len(sample_bytes) == 127 -def test_random_fail_from_lists(): +def test_from_map_random_fail_from_lists(): from dask_awkward.lib.testutils import RandomFailFromListsFn single = [[1, 2, 3], [4, 5], [6], [], [1, 2, 3]] @@ -352,7 +353,7 @@ def test_random_fail_from_lists(): divs = (0, *np.cumsum(list(map(len, many)))) form = ak.Array(many[0]).layout.form - array = from_map( + array, report = from_map( RandomFailFromListsFn(form), many, meta=typetracer_array(ak.Array(many[0])), @@ -363,8 +364,18 @@ def test_random_fail_from_lists(): ) assert len(array.compute()) < (len(single) * len(many)) + computed_report = report.compute() + + # we expect the 'args' field in the report to be empty if the + # from_map node succeded; so we use ak.num(..., axis=1) to filter + # those out. + succ = ak.num(computed_report["args"], axis=1) == 0 + fail = np.invert(succ) + assert len(computed_report[succ]) < len(computed_report) + assert ak.all(computed_report[fail].exception == "OSError") + with pytest.raises(OSError, match="BAD"): - array = from_map( + array, report = from_map( RandomFailFromListsFn(form), many, meta=typetracer_array(ak.Array(many[0])), @@ -386,7 +397,7 @@ def test_random_fail_from_lists(): array.compute() with pytest.raises(ValueError, match="must be used together"): - array = from_map( + from_map( RandomFailFromListsFn(form), many, meta=typetracer_array(ak.Array(many[0])), @@ -396,7 +407,7 @@ def test_random_fail_from_lists(): ) with pytest.raises(ValueError, match="must be used together"): - array = from_map( + from_map( RandomFailFromListsFn(form), many, meta=typetracer_array(ak.Array(many[0])), @@ -416,7 +427,7 @@ def __call__(self, *args): return self.x * args[0] with pytest.raises(ValueError, match="must implement"): - array = from_map( + from_map( NoMockEmpty(5), many, meta=typetracer_array(ak.Array(many[0])), @@ -425,3 +436,40 @@ def __call__(self, *args): empty_on_raise=(RuntimeError,), empty_backend="cpu", ) + + +def test_from_map_fail_with_callbacks(): + from dask_awkward.lib.testutils import RandomFailFromListsFn + + single = [[1, 2, 3], [4, 5], [6], [], [1, 2, 3]] + many = [single] * 30 + divs = (0, *np.cumsum(list(map(len, many)))) + form = ak.Array(many[0]).layout.form + + def succ(*args, **kwargs): + return ak.Array(["1"]) + + def fail(excep, *args, **kwargs): + return ak.Array([type(excep).__name__ * 2]) + + array, report = from_map( + RandomFailFromListsFn(form), + many, + meta=typetracer_array(ak.Array(many[0])), + divisions=divs, + label="from-lists", + empty_on_raise=(OSError,), + empty_backend="cpu", + on_failure=fail, + on_success=succ, + ) + + _, rep = dask.compute(array, report) + + assert "OSErrorOSError" in rep.tolist() + + nfail = len(rep[rep == "OSErrorOSError"]) + total = len(many) + # total number of successes should be total number of instances of + # "1" in the array + assert (total - nfail) == len(rep[rep == "1"])