Skip to content

Commit

Permalink
feat: add array dispatcher (inline) (#2531)
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 authored Jun 22, 2023
1 parent efa1168 commit 9d27d04
Show file tree
Hide file tree
Showing 103 changed files with 733 additions and 340 deletions.
53 changes: 53 additions & 0 deletions src/awkward/_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from collections.abc import Callable, Collection, Generator
from functools import wraps
from inspect import isgenerator

from awkward._errors import OperationErrorContext
from awkward._typing import Any, TypeAlias, TypeVar

T = TypeVar("T")
DispatcherType: TypeAlias = "Callable[..., Generator[Collection[Any], None, T]]"
HighLevelType: TypeAlias = "Callable[..., T]"


def high_level_function(func: DispatcherType) -> HighLevelType:
"""Decorate a high-level function such that it may be overloaded by third-party array objects"""

@wraps(func)
def dispatch(*args, **kwargs):
# NOTE: this decorator assumes that the operation is exposed under `ak.`
with OperationErrorContext(f"ak.{func.__qualname__}", args, kwargs):
gen_or_result = func(*args, **kwargs)
if isgenerator(gen_or_result):
array_likes = next(gen_or_result)
assert isinstance(array_likes, Collection)

# Permit a third-party array object to intercept the invocation
for array_like in array_likes:
try:
custom_impl = array_like.__awkward_function__
except AttributeError:
continue
else:
result = custom_impl(dispatch, array_likes, args, kwargs)

# Future proof the implementation by permitting the `__awkward_function__` to return `NotImplemented`
# This may later be used to signal that another overload should be used.
if result is NotImplemented:
raise NotImplementedError
else:
return result

# Failed to find a custom overload, so resume the original function
try:
next(gen_or_result)
except StopIteration as err:
return err.value
else:
raise AssertionError(
"high-level functions should only implement a single yield statement"
)

return gen_or_result

return dispatch
8 changes: 6 additions & 2 deletions src/awkward/operations/ak_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import awkward as ak
from awkward._behavior import behavior_of
from awkward._connect.numpy import UNSUPPORTED
from awkward._errors import with_operation_context
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._regularize import regularize_axis

np = NumpyMetadata.instance()


@with_operation_context
@high_level_function
def all(
array,
axis=None,
Expand Down Expand Up @@ -53,6 +53,10 @@ def all(
See #ak.sum for a more complete description of nested list and missing
value (None) handling in reducers.
"""
# Dispatch
yield (array,)

# Implementation
return _impl(array, axis, keepdims, mask_identity, highlevel, behavior)


Expand Down
10 changes: 7 additions & 3 deletions src/awkward/operations/ak_almost_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
__all__ = ("almost_equal",)
from awkward._backends.dispatch import backend_of
from awkward._behavior import behavior_of, get_array_class, get_record_class
from awkward._errors import with_operation_context
from awkward._dispatch import high_level_function
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._parameters import parameters_are_equal
from awkward.operations.ak_to_layout import to_layout

np = NumpyMetadata.instance()


@with_operation_context
@high_level_function
def almost_equal(
left,
right,
Expand All @@ -22,7 +22,7 @@ def almost_equal(
dtype_exact: bool = True,
check_parameters: bool = True,
check_regular: bool = True,
) -> bool:
):
"""
Args:
left: Array-like data (anything #ak.to_layout recognizes).
Expand All @@ -45,6 +45,10 @@ def almost_equal(
TypeTracer arrays are not supported, as there is very little information to
be compared.
"""
# Dispatch
yield left, right

# Implementation
left_behavior = behavior_of(left)
right_behavior = behavior_of(right)

Expand Down
8 changes: 6 additions & 2 deletions src/awkward/operations/ak_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import awkward as ak
from awkward._behavior import behavior_of
from awkward._connect.numpy import UNSUPPORTED
from awkward._errors import with_operation_context
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._regularize import regularize_axis

np = NumpyMetadata.instance()


@with_operation_context
@high_level_function
def any(
array,
axis=None,
Expand Down Expand Up @@ -53,6 +53,10 @@ def any(
See #ak.sum for a more complete description of nested list and missing
value (None) handling in reducers.
"""
# Dispatch
yield (array,)

# Implementation
return _impl(array, axis, keepdims, mask_identity, highlevel, behavior)


Expand Down
13 changes: 11 additions & 2 deletions src/awkward/operations/ak_argcartesian.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
__all__ = ("argcartesian",)
from collections.abc import Mapping

import awkward as ak
from awkward._backends.dispatch import backend_of
from awkward._backends.numpy import NumpyBackend
from awkward._behavior import behavior_of
from awkward._errors import with_operation_context
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._regularize import regularize_axis
Expand All @@ -13,7 +15,7 @@
cpu = NumpyBackend.instance()


@with_operation_context
@high_level_function
def argcartesian(
arrays,
axis=1,
Expand Down Expand Up @@ -90,6 +92,13 @@ def argcartesian(
All of the parameters for #ak.cartesian apply equally to #ak.argcartesian,
so see the #ak.cartesian documentation for a more complete description.
"""
# Dispatch
if isinstance(arrays, Mapping):
yield arrays.values()
else:
yield arrays

# Implementation
return _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior)


Expand Down
8 changes: 6 additions & 2 deletions src/awkward/operations/ak_argcombinations.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
__all__ = ("argcombinations",)
import awkward as ak
from awkward._errors import with_operation_context
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._regularize import regularize_axis

np = NumpyMetadata.instance()


@with_operation_context
@high_level_function
def argcombinations(
array,
n,
Expand Down Expand Up @@ -56,6 +56,10 @@ def argcombinations(
#ak.argcartesian. See #ak.combinations and #ak.argcartesian for a more
complete description.
"""
# Dispatch
yield (array,)

# Implementation
return _impl(
array,
n,
Expand Down
14 changes: 11 additions & 3 deletions src/awkward/operations/ak_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import awkward as ak
from awkward._behavior import behavior_of
from awkward._connect.numpy import UNSUPPORTED
from awkward._errors import with_operation_context
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._regularize import regularize_axis

np = NumpyMetadata.instance()


@with_operation_context
@high_level_function
def argmax(
array,
axis=None,
Expand Down Expand Up @@ -60,10 +60,14 @@ def argmax(
See also #ak.nanargmax.
"""
# Dispatch
yield (array,)

# Implementation
return _impl(array, axis, keepdims, mask_identity, highlevel, behavior)


@with_operation_context
@high_level_function
def nanargmax(
array,
axis=None,
Expand Down Expand Up @@ -103,6 +107,10 @@ def nanargmax(
See also #ak.argmax.
"""
# Dispatch
yield (array,)

# Implementation
return _impl(
ak.operations.ak_nan_to_none._impl(array, False, None),
axis,
Expand Down
14 changes: 11 additions & 3 deletions src/awkward/operations/ak_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import awkward as ak
from awkward._behavior import behavior_of
from awkward._connect.numpy import UNSUPPORTED
from awkward._errors import with_operation_context
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._regularize import regularize_axis

np = NumpyMetadata.instance()


@with_operation_context
@high_level_function
def argmin(
array,
axis=None,
Expand Down Expand Up @@ -60,10 +60,14 @@ def argmin(
See also #ak.nanargmin.
"""
# Dispatch
yield (array,)

# Implementation
return _impl(array, axis, keepdims, mask_identity, highlevel, behavior)


@with_operation_context
@high_level_function
def nanargmin(
array,
axis=None,
Expand Down Expand Up @@ -102,6 +106,10 @@ def nanargmin(
See also #ak.argmin.
"""
# Dispatch
yield (array,)

# Implementation
return _impl(
ak.operations.ak_nan_to_none._impl(array, False, None),
axis,
Expand Down
8 changes: 6 additions & 2 deletions src/awkward/operations/ak_argsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
__all__ = ("argsort",)
import awkward as ak
from awkward._connect.numpy import UNSUPPORTED
from awkward._errors import with_operation_context
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._regularize import regularize_axis

np = NumpyMetadata.instance()


@with_operation_context
@high_level_function
def argsort(
array, axis=-1, *, ascending=True, stable=True, highlevel=True, behavior=None
):
Expand Down Expand Up @@ -49,6 +49,10 @@ def argsort(
>>> data[index]
<Array [[5, 7, 7], [], [2], [2, 8]] type='4 * var * int64'>
"""
# Dispatch
yield (array,)

# Implementation
return _impl(array, axis, ascending, stable, highlevel, behavior)


Expand Down
10 changes: 7 additions & 3 deletions src/awkward/operations/ak_backend.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
__all__ = ("backend",)
from awkward._backends.dispatch import backend_of
from awkward._errors import with_operation_context
from awkward._dispatch import high_level_function


@with_operation_context
def backend(*arrays) -> str:
@high_level_function
def backend(*arrays):
"""
Args:
arrays: Array-like data (anything #ak.to_layout recognizes).
Expand All @@ -21,6 +21,10 @@ def backend(*arrays) -> str:
See #ak.to_backend.
"""
# Dispatch
yield arrays

# Implementation
return _impl(arrays)


Expand Down
8 changes: 6 additions & 2 deletions src/awkward/operations/ak_broadcast_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from awkward._backends.numpy import NumpyBackend
from awkward._behavior import behavior_of
from awkward._connect.numpy import UNSUPPORTED
from awkward._errors import with_operation_context
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
from awkward._nplikes.numpylike import NumpyMetadata

np = NumpyMetadata.instance()
cpu = NumpyBackend.instance()


@with_operation_context
@high_level_function
def broadcast_arrays(
*arrays,
depth_limit=None,
Expand Down Expand Up @@ -178,6 +178,10 @@ def broadcast_arrays(
[],
[[6.6]]]
"""
# Dispatch
yield arrays

# Implementation
return _impl(
arrays,
depth_limit,
Expand Down
Loading

0 comments on commit 9d27d04

Please sign in to comment.