Skip to content

Commit

Permalink
fix: restrict named_axis inferring to `ak.Arrays/ak.Records/ak.High…
Browse files Browse the repository at this point in the history
…LevelContexts` by default (#3304)

* fix: use None when ndims can't be inferred from a layout-like obj

* restrict _get_named_axis by default to awkward-arrays only

* fix boolean condition
  • Loading branch information
pfackeldey authored Nov 13, 2024
1 parent ab484af commit 13ec1d6
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 12 deletions.
6 changes: 3 additions & 3 deletions src/awkward/_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def apply_step(
named_axis = _add_named_axis(named_axis, depth, ndim)
depth_context[NAMED_AXIS_KEY][i] = (
_unify_named_axis(named_axis, seen_named_axis),
ndim + 1,
ndim + 1 if ndim is not None else ndim,
)
if o.is_leaf:
_export_named_axis_from_depth_to_lateral(
Expand Down Expand Up @@ -645,7 +645,7 @@ def broadcast_any_list():
# rightbroadcasting adds a new first(!) dimension as depth
depth_context[NAMED_AXIS_KEY][i] = (
_add_named_axis(named_axis, depth, ndim),
ndim + 1,
ndim + 1 if ndim is not None else ndim,
)
if x.is_leaf:
_export_named_axis_from_depth_to_lateral(
Expand Down Expand Up @@ -734,7 +734,7 @@ def broadcast_any_list():
# leftbroadcasting adds a new last dimension at depth + 1
depth_context[NAMED_AXIS_KEY][i] = (
_add_named_axis(named_axis, depth + 1, ndim),
ndim + 1,
ndim + 1 if ndim is not None else ndim,
)
if x.is_leaf:
_export_named_axis_from_depth_to_lateral(
Expand Down
13 changes: 9 additions & 4 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _prettify(ax: AxisName) -> str:
return delimiter.join(items)


def _get_named_axis(ctx: tp.Any) -> AxisMapping:
def _get_named_axis(ctx: tp.Any, allow_any: bool = False) -> AxisMapping:
"""
Retrieves the named axis from the provided context.
Expand All @@ -103,9 +103,14 @@ def _get_named_axis(ctx: tp.Any) -> AxisMapping:
>>> _get_named_axis({"other_key": "other_value"})
{}
"""
if hasattr(ctx, "attrs"):
return _get_named_axis(ctx.attrs)
elif isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx:
from awkward._layout import HighLevelContext
from awkward.highlevel import Array, Record

if hasattr(ctx, "attrs") and (
isinstance(ctx, (HighLevelContext, Array, Record)) or allow_any
):
return _get_named_axis(ctx.attrs, allow_any=True)
elif allow_any and isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx:
return dict(ctx[NAMED_AXIS_KEY])
else:
return {}
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
)
return ak.operations.ak_with_named_axis._impl(
wrapped,
named_axis=_get_named_axis(attrs_of_obj(out) or {}),
named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True),
highlevel=highlevel,
behavior=None,
attrs=None,
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _impl(
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped,
named_axis=_get_named_axis(attrs_of_obj(out)),
named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True),
highlevel=highlevel,
behavior=None,
attrs=None,
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_ptp.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped,
named_axis=_get_named_axis(attrs_of_obj(out)),
named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True),
highlevel=highlevel,
behavior=None,
attrs=None,
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped,
named_axis=_get_named_axis(attrs_of_obj(out)),
named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True),
highlevel=highlevel,
behavior=None,
attrs=None,
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped,
named_axis=_get_named_axis(attrs_of_obj(out)),
named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True),
highlevel=highlevel,
behavior=None,
attrs=None,
Expand Down

0 comments on commit 13ec1d6

Please sign in to comment.