Skip to content

Commit

Permalink
add ak.std & ak.var; fix bug in indexing where == comparisons against…
Browse files Browse the repository at this point in the history
… Ellipsis failed
  • Loading branch information
pfackeldey committed Sep 20, 2024
1 parent 00669a3 commit bb97999
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 20 deletions.
16 changes: 11 additions & 5 deletions src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,15 @@ def _getitem(self, where, named_axis: Type[NamedAxis] = NamedAxis):
if len(where) == 0:
return self

n_ellipsis = where.count(...)
# count number of ellipsis
# fun fact:
# where.count(Ellipsis) does not work, because it will do a == comparison against Ellipsis,
# but luckily we can use the fact that Ellipsis is a singleton and use the 'is' operator
n_ellipsis = 0
for w in where:
if w is ...:
n_ellipsis += 1

if n_ellipsis > 1:
raise IndexError("an index can only have a single ellipsis ('...')")

Expand All @@ -574,12 +582,10 @@ def _getitem(self, where, named_axis: Type[NamedAxis] = NamedAxis):
_nextwhere = tuple(nextwhere)
if n_ellipsis == 1:
# collect the ellipsis index
# fun fact:
# nextwhere.index(...) does not work, because it will do a == comparison and that fails with typetracers that compare against ...,
# but luckily we can use the fact that ... is a singleton and use the 'is' operator
# same fun fact as above for `nextwhere.index(...)`
(ellipsis_at,) = tuple(i for i, x in enumerate(nextwhere) if x is ...)
# calculate how many slice(None) we need to add
# same fun fact as above...
# same fun fact as above for `nextwhere.count(None)`
n_newaxis = 0
for x in nextwhere:
if x is np.newaxis or x is None:
Expand Down
10 changes: 9 additions & 1 deletion src/awkward/operations/ak_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import awkward as ak
from awkward._attrs import attrs_of_obj
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import (
Expand All @@ -15,6 +16,7 @@
_get_named_axis,
_is_valid_named_axis,
_named_axis_to_positional_axis,
_NamedAxisKey,
)
from awkward._nplikes import ufuncs
from awkward._nplikes.numpy_like import NumpyMetadata
Expand Down Expand Up @@ -229,7 +231,13 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a
posaxis = maybe_posaxis(out.layout, axis, 1)
out = out[(slice(None, None),) * posaxis + (0,)]

# TODO: propagate named axis once slicing is implemented!
# propagate named axis to output
if out_named_axis := _get_named_axis(attrs_of_obj(out) or {}):
ctx = ctx.with_attr(
key=_NamedAxisKey,
value=out_named_axis,
)

return ctx.wrap(
maybe_highlevel_to_lowlevel(out),
highlevel=highlevel,
Expand Down
23 changes: 20 additions & 3 deletions src/awkward/operations/ak_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import awkward as ak
from awkward._attrs import attrs_of_obj
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import (
Expand All @@ -11,8 +12,13 @@
maybe_highlevel_to_lowlevel,
maybe_posaxis,
)
from awkward._namedaxis import (
_get_named_axis,
_is_valid_named_axis,
_named_axis_to_positional_axis,
_NamedAxisKey,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis

__all__ = ("var", "nanvar")

Expand Down Expand Up @@ -170,8 +176,6 @@ def nanvar(


def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, attrs):
axis = regularize_axis(axis)

with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout, weight_layout = ensure_same_backend(
ctx.unwrap(x, allow_record=False, primitive_policy="error"),
Expand All @@ -187,6 +191,12 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a
x = ctx.wrap(x_layout)
weight = ctx.wrap(weight_layout, allow_other=True)

# Handle named axis
if named_axis := _get_named_axis(ctx):
if _is_valid_named_axis(axis):
# Step 1: Normalize named axis to positional axis
axis = _named_axis_to_positional_axis(named_axis, axis)

with np.errstate(invalid="ignore", divide="ignore"):
if weight is None:
sumw = ak.operations.ak_count._impl(
Expand Down Expand Up @@ -267,6 +277,13 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a
posaxis = maybe_posaxis(out.layout, axis, 1)
out = out[(slice(None, None),) * posaxis + (0,)]

# propagate named axis to output
if out_named_axis := _get_named_axis(attrs_of_obj(out) or {}):
ctx = ctx.with_attr(
key=_NamedAxisKey,
value=out_named_axis,
)

return ctx.wrap(
maybe_highlevel_to_lowlevel(out),
highlevel=highlevel,
Expand Down
30 changes: 19 additions & 11 deletions tests/test_2596_named_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,19 +1212,17 @@ def test_named_axis_ak_sort():


def test_named_axis_ak_std():
# TODO: once slicing is implemented
# array = ak.Array([[1, 2], [3], [4, 5, 6]])
array = ak.Array([[1, 2], [3], [4, 5, 6]])

# named_array = ak.with_named_axis(array, ("x", "y"))
named_array = ak.with_named_axis(array, ("x", "y"))

# assert ak.all(ak.std(array, axis=0) == ak.std(named_array, axis="x"))
# assert ak.all(ak.std(array, axis=1) == ak.std(named_array, axis="y"))
# assert ak.std(array, axis=None) == ak.std(named_array, axis=None)
assert ak.all(ak.std(array, axis=0) == ak.std(named_array, axis="x"))
assert ak.all(ak.std(array, axis=1) == ak.std(named_array, axis="y"))
assert ak.std(array, axis=None) == ak.std(named_array, axis=None)

# assert ak.std(named_array, axis="x").named_axis == ("y",)
# assert ak.std(named_array, axis="y").named_axis == ("x",)
# assert ak.std(named_array, axis=None).named_axis == (None,)
assert True
assert ak.std(named_array, axis="x").named_axis == {"y": 0}
assert ak.std(named_array, axis="y").named_axis == {"x": 0}
assert not _get_named_axis(ak.std(named_array, axis=None))


def test_named_axis_ak_strings_astype():
Expand Down Expand Up @@ -1405,7 +1403,17 @@ def test_named_axis_ak_values_astype():


def test_named_axis_ak_var():
assert True
array = ak.Array([[1, 2], [3], [4, 5, 6]])

named_array = ak.with_named_axis(array, ("x", "y"))

assert ak.all(ak.var(array, axis=0) == ak.var(named_array, axis="x"))
assert ak.all(ak.var(array, axis=1) == ak.var(named_array, axis="y"))
assert ak.var(array, axis=None) == ak.var(named_array, axis=None)

assert ak.var(named_array, axis="x").named_axis == {"y": 0}
assert ak.var(named_array, axis="y").named_axis == {"x": 0}
assert not _get_named_axis(ak.var(named_array, axis=None))


def test_named_axis_ak_where():
Expand Down

0 comments on commit bb97999

Please sign in to comment.