diff --git a/src/awkward/_layout.py b/src/awkward/_layout.py index 4dfd1216a4..f642dd6ed4 100644 --- a/src/awkward/_layout.py +++ b/src/awkward/_layout.py @@ -29,7 +29,6 @@ T = TypeVar("T") if TYPE_CHECKING: - from awkward._namedaxis import AttrsNamedAxisMapping from awkward.highlevel import Array from awkward.highlevel import Record as HighLevelRecord @@ -96,7 +95,7 @@ def _ensure_not_finalized(self): raise RuntimeError("HighLevelContext has already been finalized") @property - def attrs(self) -> Mapping | AttrsNamedAxisMapping: + def attrs(self) -> Mapping: self._ensure_finalized() if self._attrs is None: self._attrs = {} diff --git a/src/awkward/_namedaxis.py b/src/awkward/_namedaxis.py index 31cc06e12f..9f0b8d36cc 100644 --- a/src/awkward/_namedaxis.py +++ b/src/awkward/_namedaxis.py @@ -24,16 +24,6 @@ ) -class AttrsNamedAxisMapping(tp.TypedDict, total=False): - __named_axis__: AxisMapping - - -@tp.runtime_checkable -class MaybeSupportsNamedAxis(tp.Protocol): - @property - def attrs(self) -> tp.Mapping | AttrsNamedAxisMapping: ... - - # just a class for inplace mutation class NamedAxis: mapping: AxisMapping @@ -92,37 +82,31 @@ def _prettify(ax: AxisName) -> str: return delimiter.join(items) -def _get_named_axis( - ctx: MaybeSupportsNamedAxis | AttrsNamedAxisMapping | tp.Mapping | tp.Any, -) -> AxisMapping: +def _get_named_axis(ctx: tp.Any) -> AxisMapping: """ - Retrieves the named axis from the provided context. The context can either be an object that supports named axis - or a dictionary that includes a named axis mapping. + Retrieves the named axis from the provided context. Args: - ctx (MaybeSupportsNamedAxis | AttrsNamedAxisMapping | Mapping | Any): The context from which the named axis is to be retrieved. + ctx (Any): The context from which the named axis is to be retrieved. Returns: AxisMapping: The named axis retrieved from the context. If the context does not include a named axis, an empty dictionary is returned. Examples: - >>> class Test(MaybeSupportsNamedAxis): - ... @property - ... def attrs(self): - ... return {NAMED_AXIS_KEY: {"x": 0, "y": 1, "z": 2}} - ... - >>> _get_named_axis(Test()) - {"x": 0, "y": 1, "z": 2} + >>> _get_named_axis(ak.Array([1, 2, 3], named_axis={"x": 0})) + {"x": 0} + >>> _get_named_axis(np.array([1, 2, 3])) + {} >>> _get_named_axis({NAMED_AXIS_KEY: {"x": 0, "y": 1, "z": 2}}) {"x": 0, "y": 1, "z": 2} >>> _get_named_axis({"other_key": "other_value"}) {} """ - if isinstance(ctx, MaybeSupportsNamedAxis): + if hasattr(ctx, "attrs"): return _get_named_axis(ctx.attrs) elif isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx: - return ctx[NAMED_AXIS_KEY] + return dict(ctx[NAMED_AXIS_KEY]) else: return {} diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index 5c84a99f31..6d1d6649aa 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -25,7 +25,6 @@ from awkward._layout import wrap_layout from awkward._namedaxis import ( NAMED_AXIS_KEY, - AttrsNamedAxisMapping, AxisMapping, NamedAxis, _get_named_axis, @@ -380,7 +379,7 @@ def _update_class(self): self.__class__ = get_array_class(self._layout, self._behavior) @property - def attrs(self) -> Mapping | AttrsNamedAxisMapping: + def attrs(self) -> Mapping: """ The mutable mapping containing top-level metadata, which is serialised with the array during pickling.