Skip to content

Commit

Permalink
add ak.(arg)combinations and ak.(arg)cartesian; make named axis compa…
Browse files Browse the repository at this point in the history
…tible with branched structures;fix regularize_axis in all highlevel ops
  • Loading branch information
pfackeldey committed Sep 23, 2024
1 parent bb97999 commit 8720a05
Show file tree
Hide file tree
Showing 40 changed files with 419 additions and 301 deletions.
13 changes: 8 additions & 5 deletions src/awkward/_regularize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Iterable, Sequence, Sized

from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._typing import AxisMaybeNone, SupportsInt
from awkward._typing import Any

np = NumpyMetadata.instance()

Expand Down Expand Up @@ -51,8 +51,11 @@ def is_non_string_like_sequence(obj) -> bool:
return not isinstance(obj, (str, bytes)) and isinstance(obj, Sequence)


def regularize_axis(axis: SupportsInt | None) -> AxisMaybeNone:
if axis is None:
return None
else:
def regularize_axis(axis: Any) -> int | Any:
"""
This function's purpose is to convert "0" to 0, "1" to 1, etc., but leave any other value as it is.
"""
try:
return int(axis)
except (TypeError, ValueError):
return axis
8 changes: 4 additions & 4 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ def behavior(self, behavior):

@property
def positional_axis(self) -> tuple[int, ...]:
return _make_positional_axis_tuple(self.ndim)
(_, ndim) = self._layout.minmax_depth
return _make_positional_axis_tuple(ndim)

@property
def named_axis(self) -> AxisMapping:
Expand Down Expand Up @@ -1079,10 +1080,9 @@ def __getitem__(self, where):
"""
with ak._errors.SlicingErrorContext(self, where):
# Handle named axis
(_, ndim) = self._layout.minmax_depth
if named_axis := _get_named_axis(self):
where = _normalize_named_slice(
named_axis, where, self._layout.purelist_depth
)
where = _normalize_named_slice(named_axis, where, ndim)

NamedAxis.mapping = named_axis

Expand Down
24 changes: 11 additions & 13 deletions src/awkward/operations/ak_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
out_named_axis = None
if named_axis := _get_named_axis(ctx):
Expand All @@ -92,11 +94,9 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
out_named_axis = _remove_named_axis(
named_axis=out_named_axis,
axis=axis,
total=layout.purelist_depth,
total=layout.minmax_depth[1],
)

axis = regularize_axis(axis)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")

Expand All @@ -117,16 +117,14 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
allow_other=True,
)

if out_named_axis:
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped_out,
named_axis=out_named_axis,
highlevel=highlevel,
behavior=ctx.behavior,
attrs=ctx.attrs,
)
return wrapped_out
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped_out,
named_axis=out_named_axis,
highlevel=highlevel,
behavior=ctx.behavior,
attrs=ctx.attrs,
)


@ak._connect.numpy.implements("all")
Expand Down
24 changes: 11 additions & 13 deletions src/awkward/operations/ak_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
out_named_axis = None
if named_axis := _get_named_axis(ctx):
Expand All @@ -92,11 +94,9 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
out_named_axis = _remove_named_axis(
named_axis=out_named_axis,
axis=axis,
total=layout.purelist_depth,
total=layout.minmax_depth[1],
)

axis = regularize_axis(axis)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")

Expand All @@ -117,16 +117,14 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
allow_other=True,
)

if out_named_axis:
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped_out,
named_axis=out_named_axis,
highlevel=highlevel,
behavior=ctx.behavior,
attrs=ctx.attrs,
)
return wrapped_out
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped_out,
named_axis=out_named_axis,
highlevel=highlevel,
behavior=ctx.behavior,
attrs=ctx.attrs,
)


@ak._connect.numpy.implements("any")
Expand Down
24 changes: 11 additions & 13 deletions src/awkward/operations/ak_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
out_named_axis = None
if named_axis := _get_named_axis(ctx):
Expand All @@ -157,11 +159,9 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
out_named_axis = _remove_named_axis(
named_axis=out_named_axis,
axis=axis,
total=layout.purelist_depth,
total=layout.minmax_depth[1],
)

axis = regularize_axis(axis)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")

Expand All @@ -184,16 +184,14 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
allow_other=True,
)

if out_named_axis:
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped_out,
named_axis=out_named_axis,
highlevel=highlevel,
behavior=ctx.behavior,
attrs=ctx.attrs,
)
return wrapped_out
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped_out,
named_axis=out_named_axis,
highlevel=highlevel,
behavior=ctx.behavior,
attrs=ctx.attrs,
)


@ak._connect.numpy.implements("argmax")
Expand Down
24 changes: 11 additions & 13 deletions src/awkward/operations/ak_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
out_named_axis = None
if named_axis := _get_named_axis(ctx):
Expand All @@ -154,11 +156,9 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
out_named_axis = _remove_named_axis(
named_axis=out_named_axis,
axis=axis,
total=layout.purelist_depth,
total=layout.minmax_depth[1],
)

axis = regularize_axis(axis)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")

Expand All @@ -179,16 +179,14 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
allow_other=True,
)

if out_named_axis:
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped_out,
named_axis=out_named_axis,
highlevel=highlevel,
behavior=ctx.behavior,
attrs=ctx.attrs,
)
return wrapped_out
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped_out,
named_axis=out_named_axis,
highlevel=highlevel,
behavior=ctx.behavior,
attrs=ctx.attrs,
)


@ak._connect.numpy.implements("argmin")
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/operations/ak_argsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def _impl(array, axis, ascending, stable, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
out_named_axis = None
if named_axis := _get_named_axis(ctx):
Expand All @@ -90,8 +92,6 @@ def _impl(array, axis, ascending, stable, highlevel, behavior, attrs):
# use strategy "keep all" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis(named_axis, None)

axis = regularize_axis(axis)

if not is_integer(axis):
raise TypeError(f"'axis' must be an integer by now, not {axis!r}")

Expand Down
40 changes: 37 additions & 3 deletions src/awkward/operations/ak_cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@
from __future__ import annotations

from collections.abc import Mapping
from functools import reduce

import awkward as ak
from awkward._backends.numpy import NumpyBackend
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis
from awkward._namedaxis import (
_add_named_axis,
_get_named_axis,
_is_valid_named_axis,
_named_axis_to_positional_axis,
_unify_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
from awkward.errors import AxisError
Expand Down Expand Up @@ -214,8 +222,6 @@ def cartesian(


def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attrs):
axis = regularize_axis(axis)

with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
if isinstance(arrays, Mapping):
layouts = ensure_same_backend(
Expand All @@ -227,6 +233,11 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
fields = list(arrays.keys())
array_layouts = dict(zip(fields, layouts))

# propagate named axis from input to output,
# use strategy "unify" (see: awkward._namedaxis)
out_named_axis = reduce(
_unify_named_axis, map(_get_named_axis, arrays.values())
)
else:
layouts = array_layouts = ensure_same_backend(
*(
Expand All @@ -235,6 +246,17 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
)
)
fields = None
# propagate named axis from input to output,
# use strategy "unify" (see: awkward._namedaxis)
out_named_axis = reduce(_unify_named_axis, map(_get_named_axis, arrays))

axis = regularize_axis(axis)

# Handle named axis
if out_named_axis:
if _is_valid_named_axis(axis):
# Step 1: Normalize named axis to positional axis
axis = _named_axis_to_positional_axis(out_named_axis, axis)

if with_name is not None:
if parameters is None:
Expand Down Expand Up @@ -263,6 +285,7 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
if nested is None or nested is False:
nested = []
elif nested is True:
out_named_axis = _add_named_axis(out_named_axis, 0)
if fields is not None:
nested = list(fields)[:-1]
else:
Expand All @@ -288,6 +311,8 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
"the 'nested' parameter of cartesian must be integers in "
"[0, len(arrays) - 1) for an iterable of arrays"
)
for n in nested:
out_named_axis = _add_named_axis(out_named_axis, n)

backend = next((layout.backend for layout in layouts), cpu)
if posaxis == 0:
Expand Down Expand Up @@ -411,4 +436,13 @@ def apply_build_record(inputs, depth, **kwargs):
result, axis=axis_to_flatten, highlevel=False, behavior=behavior
)

return ctx.wrap(result, highlevel=highlevel)
wrapped_out = ctx.wrap(result, highlevel=highlevel)

# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped_out,
named_axis=out_named_axis,
highlevel=highlevel,
behavior=ctx.behavior,
attrs=ctx.attrs,
)
16 changes: 14 additions & 2 deletions src/awkward/operations/ak_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import (
_get_named_axis,
_is_valid_named_axis,
_named_axis_to_positional_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis

Expand Down Expand Up @@ -214,17 +219,24 @@ def _impl(
behavior,
attrs,
):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
if named_axis := _get_named_axis(ctx):
if _is_valid_named_axis(axis):
# Step 1: Normalize named axis to positional axis
axis = _named_axis_to_positional_axis(named_axis, axis)

if with_name is None:
pass
elif parameters is None:
parameters = {"__record__": with_name}
else:
parameters = {**parameters, "__record__": with_name}

with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
out = ak._do.combinations(
layout,
n,
Expand Down
Loading

0 comments on commit 8720a05

Please sign in to comment.