Skip to content

Commit

Permalink
named_axis: simplify type hint for named axis in attrs mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Oct 9, 2024
1 parent df924c0 commit 595b240
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 29 deletions.
3 changes: 1 addition & 2 deletions src/awkward/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
T = TypeVar("T")

if TYPE_CHECKING:
from awkward._namedaxis import AttrsNamedAxisMapping
from awkward.highlevel import Array
from awkward.highlevel import Record as HighLevelRecord

Expand Down Expand Up @@ -96,7 +95,7 @@ def _ensure_not_finalized(self):
raise RuntimeError("HighLevelContext has already been finalized")

@property
def attrs(self) -> Mapping | AttrsNamedAxisMapping:
def attrs(self) -> Mapping:
self._ensure_finalized()
if self._attrs is None:
self._attrs = {}
Expand Down
34 changes: 9 additions & 25 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,6 @@
)


class AttrsNamedAxisMapping(tp.TypedDict, total=False):
__named_axis__: AxisMapping


@tp.runtime_checkable
class MaybeSupportsNamedAxis(tp.Protocol):
@property
def attrs(self) -> tp.Mapping | AttrsNamedAxisMapping: ...


# just a class for inplace mutation
class NamedAxis:
mapping: AxisMapping
Expand Down Expand Up @@ -92,37 +82,31 @@ def _prettify(ax: AxisName) -> str:
return delimiter.join(items)


def _get_named_axis(
ctx: MaybeSupportsNamedAxis | AttrsNamedAxisMapping | tp.Mapping | tp.Any,
) -> AxisMapping:
def _get_named_axis(ctx: tp.Any) -> AxisMapping:
"""
Retrieves the named axis from the provided context. The context can either be an object that supports named axis
or a dictionary that includes a named axis mapping.
Retrieves the named axis from the provided context.
Args:
ctx (MaybeSupportsNamedAxis | AttrsNamedAxisMapping | Mapping | Any): The context from which the named axis is to be retrieved.
ctx (Any): The context from which the named axis is to be retrieved.
Returns:
AxisMapping: The named axis retrieved from the context. If the context does not include a named axis,
an empty dictionary is returned.
Examples:
>>> class Test(MaybeSupportsNamedAxis):
... @property
... def attrs(self):
... return {NAMED_AXIS_KEY: {"x": 0, "y": 1, "z": 2}}
...
>>> _get_named_axis(Test())
{"x": 0, "y": 1, "z": 2}
>>> _get_named_axis(ak.Array([1, 2, 3], named_axis={"x": 0}))
{"x": 0}
>>> _get_named_axis(np.array([1, 2, 3]))
{}
>>> _get_named_axis({NAMED_AXIS_KEY: {"x": 0, "y": 1, "z": 2}})
{"x": 0, "y": 1, "z": 2}
>>> _get_named_axis({"other_key": "other_value"})
{}
"""
if isinstance(ctx, MaybeSupportsNamedAxis):
if hasattr(ctx, "attrs"):
return _get_named_axis(ctx.attrs)
elif isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx:
return ctx[NAMED_AXIS_KEY]
return dict(ctx[NAMED_AXIS_KEY])
else:
return {}

Expand Down
3 changes: 1 addition & 2 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from awkward._layout import wrap_layout
from awkward._namedaxis import (
NAMED_AXIS_KEY,
AttrsNamedAxisMapping,
AxisMapping,
NamedAxis,
_get_named_axis,
Expand Down Expand Up @@ -380,7 +379,7 @@ def _update_class(self):
self.__class__ = get_array_class(self._layout, self._behavior)

@property
def attrs(self) -> Mapping | AttrsNamedAxisMapping:
def attrs(self) -> Mapping:
"""
The mutable mapping containing top-level metadata, which is serialised
with the array during pickling.
Expand Down

0 comments on commit 595b240

Please sign in to comment.