Skip to content

Commit

Permalink
add tests & proper support for negative named axes
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Oct 7, 2024
1 parent 13700d8 commit d3dc258
Show file tree
Hide file tree
Showing 8 changed files with 900 additions and 84 deletions.
67 changes: 47 additions & 20 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,34 @@ def _prettify(ax: AxisName) -> str:
)


def _neg2pos_axis(
axis: int,
total: int,
) -> int:
"""
Converts a negative axis index to a positive one.
This function takes a negative axis index and the total number of axes and returns the corresponding positive axis index.
If the input axis index is already positive, it is returned as is.
Args:
axis (int): The axis index to convert. Can be negative.
total (int): The total number of axes.
Returns:
int: The positive axis index corresponding to the input axis index.
Examples:
>>> _neg2pos_axis(-1, 3)
2
>>> _neg2pos_axis(1, 3)
1
"""
if axis < 0:
return total + axis
return axis


def _get_named_axis(
ctx: MaybeSupportsNamedAxis | AttrsNamedAxisMapping | tp.Mapping | tp.Any,
) -> AxisMapping:
Expand Down Expand Up @@ -365,6 +393,7 @@ def _keep_named_axis(
def _keep_named_axis_up_to(
named_axis: AxisMapping,
axis: int,
total: int,
) -> AxisMapping:
"""
Determines the new named axis after keeping all axes up to the specified axis. This function is useful when an operation
Expand All @@ -373,26 +402,20 @@ def _keep_named_axis_up_to(
Args:
named_axis (AxisMapping): The current named axis.
axis (int): The index of the axis up to which to keep.
total (int): The total number of axes.
Returns:
AxisMapping: The new named axis after keeping all axes up to the specified axis.
Examples:
>>> _keep_named_axis_up_to({"x": 0, "y": 2}, 0)
{"x": 0}
>>> _keep_named_axis_up_to({"x": 0, "y": 2}, 1)
{"x": 0}
>>> _keep_named_axis_up_to({"x": 0, "y": 2}, 2)
{"x": 0, "y": 2}
>>> _keep_named_axis_up_to({"x": 0, "y": -2}, 0)
>>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, 1, 3)
{"x": 0, "y": 1}
>>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, -1, 3)
{"x": 0, "y": 1, "z": 2}
>>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, 0, 3)
{"x": 0}
>>> _keep_named_axis_up_to({"x": 0, "y": -2}, 1)
{"x": 0, "y": -2}
>>> _keep_named_axis_up_to({"x": 0, "y": -2}, 2)
{"x": 0, "y": -2}
"""
if axis < 0:
raise ValueError("The axis must be a positive integer.")
axis = _neg2pos_axis(axis, total)
out = {}
for k, v in named_axis.items():
if v >= 0 and v <= axis:
Expand Down Expand Up @@ -458,7 +481,11 @@ def _remove_named_axis(
total = len(named_axis)

# remove the specified axis
out = {ax: pos for ax, pos in named_axis.items() if pos != axis}
out = {
ax: pos
for ax, pos in named_axis.items()
if _neg2pos_axis(pos, total) != _neg2pos_axis(axis, total)
}

return _adjust_pos_axis(out, axis, total, direction=-1)

Expand Down Expand Up @@ -504,14 +531,14 @@ def _adjust(pos: int, axis: int, direction: int) -> int:
# -> change position by direction
if pos >= axis:
return pos + direction
# positive axis and position smaller than the removed/added (positive) axis, but greater than 0
# -> keep position
elif pos >= 0:
return pos
# positive axis and negative position
# -> change position by direction
elif pos < 0 and pos + total < axis:
return pos - direction
# positive axis and position smaller than the removed/added (positive) axis, but greater than 0
# -> keep position
else:
return _adjust(pos, axis - total, direction)
return pos
# negative axis
else:
# negative axis and position smaller than the removed/added (negative) axis
Expand All @@ -520,7 +547,7 @@ def _adjust(pos: int, axis: int, direction: int) -> int:
return pos - direction
# negative axis and positive position
# -> change position by inverse direction
elif pos > axis + total:
elif pos > 0 and pos > axis + total:
return pos + direction
# negative axis and position greater than the removed/added (negative) axis, but smaller than 0
# -> keep position
Expand Down
31 changes: 19 additions & 12 deletions src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from awkward._behavior import get_array_class, get_record_class
from awkward._kernels import KernelError
from awkward._layout import wrap_layout
from awkward._layout import maybe_posaxis, wrap_layout
from awkward._meta.meta import Meta
from awkward._namedaxis import (
NamedAxis,
Expand Down Expand Up @@ -591,7 +591,7 @@ def _getitem(self, where, named_axis: Type[NamedAxis] = NamedAxis):
for x in nextwhere:
if x is np.newaxis or x is None:
n_newaxis += 1
n_total = self.purelist_depth
n_total = self.minmax_depth[1]
n_slice_none = n_total - (len(_nextwhere) - n_newaxis - 1)
# expand `[...]` to `[:]*n_slice_none`
_nextwhere = (
Expand All @@ -602,26 +602,33 @@ def _getitem(self, where, named_axis: Type[NamedAxis] = NamedAxis):

# now propagate named axis
_named_axis = _keep_named_axis(named_axis.mapping, None)
iter_named_axis = iter(dict(_named_axis).items())
_adjust_dims = 0
_adjust_dim = 0
# this loop does the following:
# - remove a named axis for integer indices, e.g. `a[1, 2]`
# - add a named axis for None (or np.newaxis) indices, e.g. `a[..., None]`
# - keep named axis for any other index, e.g. `a[:]`, `a[0:1]`, or `a[a>0]`
# (these may only remove elements, but not dimensions)
for i, nw in enumerate(_nextwhere):
i += _adjust_dims
for _, pos in iter_named_axis:
if pos == i:
for dim, nw in enumerate(_nextwhere):
dim_adjusted = dim + _adjust_dim
total_adjusted = self.minmax_depth[1] + _adjust_dim
for _, pos in _named_axis.items():
if maybe_posaxis(self, pos, 0) == dim_adjusted:
break

if is_integer(nw) or (is_array_like(nw) and nw.ndim == 0):
_named_axis = _remove_named_axis(
_named_axis, i, self.purelist_depth
named_axis=_named_axis,
axis=dim_adjusted,
total=total_adjusted,
)
_adjust_dims -= 1
_adjust_dim -= 1
elif nw is None:
_named_axis = _add_named_axis(_named_axis, i, self.purelist_depth)
_adjust_dims += 1
_named_axis = _add_named_axis(
named_axis=_named_axis,
axis=dim_adjusted,
total=total_adjusted,
)
_adjust_dim += 1

# set propagated named axis
named_axis.mapping = _named_axis
Expand Down
6 changes: 3 additions & 3 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,12 +1109,12 @@ def __getitem__(self, where):
)

if NamedAxis.mapping:
out = ak.operations.ak_with_named_axis._impl(
return ak.operations.ak_with_named_axis._impl(
out,
named_axis=NamedAxis.mapping,
highlevel=True,
behavior=self._behavior,
attrs=self._attrs,
behavior=None,
attrs=None,
)
return out

Expand Down
5 changes: 3 additions & 2 deletions src/awkward/operations/ak_cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _named_axis_to_positional_axis(out_named_axis, axis)
max_ndim = max([layout.minmax_depth[1] for layout in layouts])

Check warning on line 259 in src/awkward/operations/ak_cartesian.py

View workflow job for this annotation

GitHub Actions / Run PyLint

Consider using a generator instead 'max(layout.minmax_depth[1] for layout in layouts)'

if with_name is not None:
if parameters is None:
Expand Down Expand Up @@ -284,7 +285,7 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
if nested is None or nested is False:
nested = []
elif nested is True:
out_named_axis = _add_named_axis(out_named_axis, 0)
out_named_axis = _add_named_axis(out_named_axis, 0, max_ndim)
if fields is not None:
nested = list(fields)[:-1]
else:
Expand All @@ -311,7 +312,7 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
"[0, len(arrays) - 1) for an iterable of arrays"
)
for n in nested:
out_named_axis = _add_named_axis(out_named_axis, n)
out_named_axis = _add_named_axis(out_named_axis, n, max_ndim)

backend = next((layout.backend for layout in layouts), cpu)
if posaxis == 0:
Expand Down
4 changes: 1 addition & 3 deletions src/awkward/operations/ak_is_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def _impl(array, axis, highlevel, behavior, attrs):

# Step 2: propagate named axis from input to output,
# use strategy "keep up to" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis_up_to(
named_axis, axis + layout.minmax_depth[1] if axis < 0 else axis
)
out_named_axis = _keep_named_axis_up_to(named_axis, axis, layout.minmax_depth[1])

def action(layout, depth, backend, lateral_context, **kwargs):
posaxis = maybe_posaxis(layout, axis, depth)
Expand Down
4 changes: 1 addition & 3 deletions src/awkward/operations/ak_local_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ def _impl(array, axis, highlevel, behavior, attrs):

# Step 2: propagate named axis from input to output,
# use strategy "keep up to" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis_up_to(
named_axis, axis + layout.minmax_depth[1] if axis < 0 else axis
)
out_named_axis = _keep_named_axis_up_to(named_axis, axis, layout.minmax_depth[1])

out = ak._do.local_index(layout, axis)

Expand Down
4 changes: 3 additions & 1 deletion src/awkward/operations/ak_singletons.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def _impl(array, axis, highlevel, behavior, attrs):

# Step 2: propagate named axis from input to output,
# use strategy "add one" (see: awkward._namedaxis)
out_named_axis = _add_named_axis(named_axis, axis + 1)
out_named_axis = _add_named_axis(
named_axis, (axis + 1) if axis >= 0 else axis, layout.minmax_depth[1]
)

def action(layout, depth, backend, **kwargs):
posaxis = maybe_posaxis(layout, axis, depth)
Expand Down
Loading

0 comments on commit d3dc258

Please sign in to comment.