From e946646c6315ff9bde0c8a2255ee6f7fc6a31539 Mon Sep 17 00:00:00 2001 From: Ianna Osborne Date: Thu, 26 Sep 2024 18:01:35 +0200 Subject: [PATCH 01/21] fix: ListArray slicing on GPU (#3248) * test: add slicing test for CPU and GPU in test_3140_cuda_slicing.py * style: pre-commit fixes * cast 'at' to int head in this case can be an array and it can be regularized to a proper backend, then the GPU kernel needs to be updated to handle a 'cp.array(0)' * style: pre-commit fixes * use ak._slicing.normalize_integer_like(head) * convert head ndarray to scalar * add item attribute to TypeTracerArray * cleanup tests * add more tests * style: pre-commit fixes * use 'ListArray-at' role * style: pre-commit fixes * use pointer awkward_ListArray_getitem_next_at.cu * use role listarray.py * revert changes * same for awkward_RegularArray_getitem_next_at.cu * remove special case for cuda * correct role for jagged size --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- dev/generate-kernel-signatures.py | 5 +- kernel-specification.yml | 6 +- .../awkward_ListArray_getitem_next_at.cu | 4 +- .../awkward_RegularArray_getitem_next_at.cu | 4 +- src/awkward/contents/listarray.py | 1 + src/awkward/contents/regulararray.py | 3 +- tests-cuda/test_3140_cuda_slicing.py | 120 ++++++++++++++++++ 7 files changed, 133 insertions(+), 10 deletions(-) diff --git a/dev/generate-kernel-signatures.py b/dev/generate-kernel-signatures.py index 74bc3c4fe1..7038d95f7f 100644 --- a/dev/generate-kernel-signatures.py +++ b/dev/generate-kernel-signatures.py @@ -429,7 +429,10 @@ def by_signature(cuda_kernel_templates): special = [repr(spec["name"])] [type_to_pytype(x["type"], special) for x in childfunc["args"]] dirlist = [repr(x["dir"]) for x in childfunc["args"]] - ispointerlist = [repr("List" in x["type"]) for x in childfunc["args"]] + ispointerlist = [ + repr("List" in x["type"] or "ListArray-at" == x.get("role", None)) + for x in childfunc["args"] + ] if spec["name"] in cuda_kernels_impl: with open( os.path.join( diff --git a/kernel-specification.yml b/kernel-specification.yml index 2838b8db5c..5c901f00e2 100644 --- a/kernel-specification.yml +++ b/kernel-specification.yml @@ -1466,7 +1466,7 @@ kernels: - {name: tocarry, type: "List[int64_t]", dir: out} - {name: fromstarts, type: "Const[List[int32_t]]", dir: in, role: ListArray-starts} - {name: fromstops, type: "Const[List[int32_t]]", dir: in, role: ListArray-stops} - - {name: jaggedsize, type: "int64_t", dir: in, role: ListArray-at} + - {name: jaggedsize, type: "int64_t", dir: in, role: ListArray-length} - {name: length, type: "int64_t", dir: in, role: default} - name: awkward_ListArray64_getitem_jagged_expand_64 args: @@ -1476,7 +1476,7 @@ kernels: - {name: tocarry, type: "List[int64_t]", dir: out} - {name: fromstarts, type: "Const[List[int64_t]]", dir: in, role: ListArray-starts} - {name: fromstops, type: "Const[List[int64_t]]", dir: in, role: ListArray-stops} - - {name: jaggedsize, type: "int64_t", dir: in, role: ListArray-at} + - {name: jaggedsize, type: "int64_t", dir: in, role: ListArray-length} - {name: length, type: "int64_t", dir: in, role: default} - name: awkward_ListArrayU32_getitem_jagged_expand_64 args: @@ -1486,7 +1486,7 @@ kernels: - {name: tocarry, type: "List[int64_t]", dir: out} - {name: fromstarts, type: "Const[List[uint32_t]]", dir: in, role: ListArray-starts} - {name: fromstops, type: "Const[List[uint32_t]]", dir: in, role: ListArray-stops} - - {name: jaggedsize, type: "int64_t", dir: in, role: ListArray-at} + - {name: jaggedsize, type: "int64_t", dir: in, role: ListArray-length} - {name: length, type: "int64_t", dir: in, role: default} description: null definition: | diff --git a/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu index 421f0d15c1..7fc13f2ae8 100644 --- a/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu +++ b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu @@ -11,7 +11,7 @@ awkward_ListArray_getitem_next_at( const C* fromstarts, const U* fromstops, int64_t lenstarts, - int64_t at, + int64_t* at, uint64_t invocation_index, uint64_t* err_code) { if (err_code[0] == NO_ERROR) { @@ -19,7 +19,7 @@ awkward_ListArray_getitem_next_at( if (thread_id < lenstarts) { int64_t length = fromstops[thread_id] - fromstarts[thread_id]; - int64_t regular_at = at; + int64_t regular_at = at[0]; if (regular_at < 0) { regular_at += length; } diff --git a/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu b/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu index 8f1282974d..1b8bd53b38 100644 --- a/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu +++ b/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu @@ -8,14 +8,14 @@ template __global__ void awkward_RegularArray_getitem_next_at( T* tocarry, - int64_t at, + int64_t* at, int64_t length, int64_t size, uint64_t invocation_index, uint64_t* err_code) { if (err_code[0] == NO_ERROR) { int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; - int64_t regular_at = at; + int64_t regular_at = at[0]; if (regular_at < 0) { regular_at += size; } diff --git a/src/awkward/contents/listarray.py b/src/awkward/contents/listarray.py index a05eeaea55..8f0a6f5e4b 100644 --- a/src/awkward/contents/listarray.py +++ b/src/awkward/contents/listarray.py @@ -712,6 +712,7 @@ def _getitem_next( nexthead, nexttail = ak._slicing.head_tail(tail) lenstarts = self._starts.length nextcarry = ak.index.Index64.empty(lenstarts, self._backend.index_nplike) + head = ak._slicing.normalize_integer_like(head) assert ( nextcarry.nplike is self._backend.index_nplike and self._starts.nplike is self._backend.index_nplike diff --git a/src/awkward/contents/regulararray.py b/src/awkward/contents/regulararray.py index a5a16fcdff..318a21bca5 100644 --- a/src/awkward/contents/regulararray.py +++ b/src/awkward/contents/regulararray.py @@ -471,8 +471,7 @@ def _getitem_next( nexthead, nexttail = ak._slicing.head_tail(tail) nextcarry = ak.index.Index64.empty(self._length, index_nplike) assert nextcarry.nplike is index_nplike - if ak.backend(head) == "cuda": - head = int(ak.to_backend(head, backend=self._backend)[0]) + head = ak._slicing.normalize_integer_like(head) self._maybe_index_error( self._backend[ "awkward_RegularArray_getitem_next_at", nextcarry.dtype.type diff --git a/tests-cuda/test_3140_cuda_slicing.py b/tests-cuda/test_3140_cuda_slicing.py index 59e2cfcb67..da1cd4bef7 100644 --- a/tests-cuda/test_3140_cuda_slicing.py +++ b/tests-cuda/test_3140_cuda_slicing.py @@ -677,3 +677,123 @@ def test_0127_tomask_operation(): [None], [6.6, None, None, 9.9], ] + + +def test_simple_slice_cpu(): + arr = ak.Array([[1, 2, 3], [0], [4, 5]]) + out = arr[:, 0] + expected = [1, 0, 4] + result = out.tolist() + cp.testing.assert_array_list_equal( + result, + expected, + err_msg=f"Slice of [[1, 2, 3], [0], [4, 5]] should be {expected}, but got {result}", + ) + + +def test_simple_slice_gpu(): + arr = ak.Array([[1, 2, 3], [0], [4, 5]], backend="cuda") + out = arr[:, 0] + expected = [1, 0, 4] + result = out.tolist() + cp.testing.assert_array_list_equal( + result, + expected, + err_msg=f"Slice of [[1, 2, 3], [0], [4, 5]] should be {expected}, but got {result}", + ) + + +def test_simple_slice_cpu1(): + arr = ak.Array([[1, 2, 3], [0], [4, 5]]) + out = arr[:, 1:] + expected = [[2, 3], [], [5]] + result = out.tolist() + cp.testing.assert_array_list_equal( + result, + expected, + err_msg=f"Slice of [[1, 2, 3], [0], [4, 5]] should be {expected}, but got {result}", + ) + + +def test_simple_slice_gpu1(): + arr = ak.Array([[1, 2, 3], [0], [4, 5]], backend="cuda") + out = arr[:, 1:] + expected = [[2, 3], [], [5]] + result = out.tolist() + cp.testing.assert_array_list_equal( + result, + expected, + err_msg=f"Slice of [[1, 2, 3], [0], [4, 5]] should be {expected}, but got {result}", + ) + + +def test_simple_slice_cpu2(): + arr = ak.Array([[1, 2, 3], [0], [4, 5]]) + out = arr[:, :1] + expected = [[1], [0], [4]] + result = out.tolist() + cp.testing.assert_array_list_equal( + result, + expected, + err_msg=f"Slice of [[1, 2, 3], [0], [4, 5]] should be {expected}, but got {result}", + ) + + +def test_simple_slice_gpu2(): + arr = ak.Array([[1, 2, 3], [0], [4, 5]], backend="cuda") + out = arr[:, :1] + expected = [[1], [0], [4]] + result = out.tolist() + cp.testing.assert_array_list_equal( + result, + expected, + err_msg=f"Slice of [[1, 2, 3], [0], [4, 5]] should be {expected}, but got {result}", + ) + + +def test_simple_slice_cpu3(): + arr = ak.Array([[1, 2, 3], [0], [4, 5]]) + out = arr[:, 1::2] + expected = [[2], [], [5]] + result = out.tolist() + cp.testing.assert_array_list_equal( + result, + expected, + err_msg=f"Slice of [[1, 2, 3], [0], [4, 5]] should be {expected}, but got {result}", + ) + + +def test_simple_slice_gpu3(): + arr = ak.Array([[1, 2, 3], [0], [4, 5]], backend="cuda") + out = arr[:, 1::2] + expected = [[2], [], [5]] + result = out.tolist() + cp.testing.assert_array_list_equal( + result, + expected, + err_msg=f"Slice of [[1, 2, 3], [0], [4, 5]] should be {expected}, but got {result}", + ) + + +def test_simple_slice_cpu4(): + arr = ak.Array([[1, 2, 3], [0], [4, 5]]) + out = arr[:, ::-1] + expected = [[3, 2, 1], [0], [5, 4]] + result = out.tolist() + cp.testing.assert_array_list_equal( + result, + expected, + err_msg=f"Slice of [[1, 2, 3], [0], [4, 5]] should be {expected}, but got {result}", + ) + + +def test_simple_slice_gpu4(): + arr = ak.Array([[1, 2, 3], [0], [4, 5]], backend="cuda") + out = arr[:, ::-1] + expected = [[3, 2, 1], [0], [5, 4]] + result = out.tolist() + cp.testing.assert_array_list_equal( + result, + expected, + err_msg=f"Slice of [[1, 2, 3], [0], [4, 5]] should be {expected}, but got {result}", + ) From c4268e087585933d92c5cc34155e238744543b1a Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Mon, 30 Sep 2024 14:38:32 -0500 Subject: [PATCH 02/21] docs: pybind11 demo project should have NumPy own the data (#3261) * step 0: remove all existing code and return None * step 1: make sure we can iterate over names_nbytes * step 2: make sure we can create a NumPy array through pybind11 * step 3: make sure we can see the raw data in the array * step 4: make sure we can fill the dict and the std::map * step 5: filling the cpp_container fills the py_container * done: we are now returning the array build by ak.from_buffers --- header-only/examples/pybind11/demo.cpp | 54 ++++++++++++-------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/header-only/examples/pybind11/demo.cpp b/header-only/examples/pybind11/demo.cpp index 94c216d0af..138e4379e8 100644 --- a/header-only/examples/pybind11/demo.cpp +++ b/header-only/examples/pybind11/demo.cpp @@ -34,42 +34,36 @@ using MyBuilder = RecordBuilder< */ template py::object snapshot_builder(const T &builder) { + // We need NumPy (to allocate arrays) and Awkward Array (ak.from_buffers). + // pybind11 will raise a ModuleNotFoundError if they aren't installed. + auto np = py::module::import("numpy"); + auto ak = py::module::import("awkward"); + + auto dtype_u1 = np.attr("dtype")("u1"); + // How much memory to allocate? - std::map names_nbytes = {}; + std::map names_nbytes; builder.buffer_nbytes(names_nbytes); - // Allocate memory - std::map buffers = {}; - for (auto it: names_nbytes) { - uint8_t *ptr = new uint8_t[it.second]; - buffers[it.first] = (void *) ptr; - } + // Ask NumPy to allocate memory and get pointers to the raw buffers. + py::dict py_container; + std::map cpp_container; + for (auto name_nbytes : names_nbytes) { + py::object array = np.attr("empty")(name_nbytes.second, dtype_u1); - // Write non-contiguous contents to memory - builder.to_buffers(buffers); - auto from_buffers = py::module::import("awkward").attr("from_buffers"); - - // Build Python dictionary containing arrays - // dtypes not important here as long as they match the underlying buffer - // as Awkward Array calls `frombuffer` to convert to the correct type - py::dict container; - for (auto it: buffers) { - - py::capsule free_when_done(it.second, [](void *data) { - uint8_t *dataPtr = reinterpret_cast(data); - delete[] dataPtr; - }); - - uint8_t *data = reinterpret_cast(it.second); - container[py::str(it.first)] = py::array_t( - {names_nbytes[it.first]}, - {sizeof(uint8_t)}, - data, - free_when_done - ); + size_t pointer = py::cast(array.attr("ctypes").attr("data")); + void* raw_data = (void*)pointer; + + py::str py_name(name_nbytes.first); + py_container[py_name] = array; + cpp_container[name_nbytes.first] = raw_data; } - return from_buffers(builder.form(), builder.length(), container); + // Write non-contiguous contents to memory. + builder.to_buffers(cpp_container); + + // Build Python dictionary containing arrays. + return ak.attr("from_buffers")(builder.form(), builder.length(), py_container); } From cfe58f3ea34c1b716508851dc8a4dc70400c434c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 2 Oct 2024 10:04:24 -0500 Subject: [PATCH 03/21] chore(deps): bump the actions group across 1 directory with 3 updates (#3260) Bumps the actions group with 3 updates in the / directory: [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel), [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) and [scientific-python/upload-nightly-action](https://github.com/scientific-python/upload-nightly-action). Updates `pypa/cibuildwheel` from 2.20 to 2.21 - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.20...v2.21) Updates `pypa/gh-action-pypi-publish` from 1.10.1 to 1.10.2 - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.10.1...v1.10.2) Updates `scientific-python/upload-nightly-action` from 0.5.0 to 0.6.1 - [Release notes](https://github.com/scientific-python/upload-nightly-action/releases) - [Commits](https://github.com/scientific-python/upload-nightly-action/compare/b67d7fcc0396e1128a474d1ab2b48aa94680f9fc...82396a2ed4269ba06c6b2988bb4fd568ef3c3d6b) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor dependency-group: actions - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-patch dependency-group: actions - dependency-name: scientific-python/upload-nightly-action dependency-type: direct:production update-type: version-update:semver-minor dependency-group: actions ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/build-wheels.yml | 4 ++-- .github/workflows/deploy-cpp.yml | 2 +- .github/workflows/deploy.yml | 2 +- .github/workflows/packaging-test.yml | 4 ++-- .github/workflows/upload-nightly-wheels.yml | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build-wheels.yml b/.github/workflows/build-wheels.yml index 361d292285..e171eb5431 100644 --- a/.github/workflows/build-wheels.yml +++ b/.github/workflows/build-wheels.yml @@ -105,7 +105,7 @@ jobs: - name: Prepare build files run: pipx run nox -s prepare - - uses: pypa/cibuildwheel@v2.20 + - uses: pypa/cibuildwheel@v2.21 env: CIBW_BUILD: "${{ matrix.build }}*" CIBW_ARCHS: ${{ matrix.arch }} @@ -157,7 +157,7 @@ jobs: - uses: docker/setup-qemu-action@v3.2.0 - - uses: pypa/cibuildwheel@v2.20 + - uses: pypa/cibuildwheel@v2.21 env: CIBW_BUILD: cp${{ matrix.python }}-* CIBW_ARCHS: ${{ matrix.arch }} diff --git a/.github/workflows/deploy-cpp.yml b/.github/workflows/deploy-cpp.yml index 67fa3624bf..c23a63179c 100644 --- a/.github/workflows/deploy-cpp.yml +++ b/.github/workflows/deploy-cpp.yml @@ -39,4 +39,4 @@ jobs: with: subject-path: "dist/awkward*cpp-*" - - uses: pypa/gh-action-pypi-publish@v1.10.1 + - uses: pypa/gh-action-pypi-publish@v1.10.2 diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index e62bca6c2b..a45da9b9b2 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -135,7 +135,7 @@ jobs: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: gh attestation verify dist/awkward-*.whl --repo ${{ github.repository }} - - uses: pypa/gh-action-pypi-publish@v1.10.1 + - uses: pypa/gh-action-pypi-publish@v1.10.2 publish-headers: name: "Publish header-only libraries alongside release" diff --git a/.github/workflows/packaging-test.yml b/.github/workflows/packaging-test.yml index 20ae0fe714..de30eecc54 100644 --- a/.github/workflows/packaging-test.yml +++ b/.github/workflows/packaging-test.yml @@ -68,7 +68,7 @@ jobs: - name: Prepare build files run: pipx run nox -s prepare - - uses: pypa/cibuildwheel@v2.20 + - uses: pypa/cibuildwheel@v2.21 env: CIBW_ARCHS_MACOS: universal2 CIBW_BUILD: cp39-win_amd64 cp310-manylinux_x86_64 cp38-macosx_universal2 @@ -76,7 +76,7 @@ jobs: config-file: cibuildwheel.toml package-dir: awkward-cpp - - uses: pypa/cibuildwheel@v2.20 + - uses: pypa/cibuildwheel@v2.21 if: matrix.os == 'ubuntu-latest' env: CIBW_BUILD: cp312-manylinux_x86_64 diff --git a/.github/workflows/upload-nightly-wheels.yml b/.github/workflows/upload-nightly-wheels.yml index bfeb6b8abd..9e69a1a4c7 100644 --- a/.github/workflows/upload-nightly-wheels.yml +++ b/.github/workflows/upload-nightly-wheels.yml @@ -58,7 +58,7 @@ jobs: ls -l dist/ - name: Upload wheels to Anaconda Cloud as nightlies - uses: scientific-python/upload-nightly-action@b67d7fcc0396e1128a474d1ab2b48aa94680f9fc # 0.5.0 + uses: scientific-python/upload-nightly-action@82396a2ed4269ba06c6b2988bb4fd568ef3c3d6b # 0.6.1 with: artifacts_path: dist anaconda_nightly_upload_token: ${{ secrets.ANACONDA_ORG_UPLOAD_TOKEN }} From ee5865a95e2a9d6ceeb1d28f3d0b9d02f83c7c8a Mon Sep 17 00:00:00 2001 From: maxymnaumchyk <70752300+maxymnaumchyk@users.noreply.github.com> Date: Thu, 3 Oct 2024 17:27:32 +0300 Subject: [PATCH 04/21] feat: to/from PyTorch Tensor (#3259) * add new to_torch function * add new from_torch function * add changes suggested by Jim * style: pre-commit fixes * fix style --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/awkward/operations/__init__.py | 2 + src/awkward/operations/ak_from_torch.py | 65 ++++++++++++++++++++++ src/awkward/operations/ak_to_torch.py | 74 +++++++++++++++++++++++++ tests/test_3259_to_torch_from_torch.py | 72 ++++++++++++++++++++++++ 4 files changed, 213 insertions(+) create mode 100644 src/awkward/operations/ak_from_torch.py create mode 100644 src/awkward/operations/ak_to_torch.py create mode 100644 tests/test_3259_to_torch_from_torch.py diff --git a/src/awkward/operations/__init__.py b/src/awkward/operations/__init__.py index e9b1a3818b..d76d8e2688 100644 --- a/src/awkward/operations/__init__.py +++ b/src/awkward/operations/__init__.py @@ -47,6 +47,7 @@ from awkward.operations.ak_from_raggedtensor import * from awkward.operations.ak_from_rdataframe import * from awkward.operations.ak_from_regular import * +from awkward.operations.ak_from_torch import * from awkward.operations.ak_full_like import * from awkward.operations.ak_imag import * from awkward.operations.ak_is_categorical import * @@ -102,6 +103,7 @@ from awkward.operations.ak_to_raggedtensor import * from awkward.operations.ak_to_rdataframe import * from awkward.operations.ak_to_regular import * +from awkward.operations.ak_to_torch import * from awkward.operations.ak_transform import * from awkward.operations.ak_type import * from awkward.operations.ak_unflatten import * diff --git a/src/awkward/operations/ak_from_torch.py b/src/awkward/operations/ak_from_torch.py new file mode 100644 index 0000000000..fc739f17b5 --- /dev/null +++ b/src/awkward/operations/ak_from_torch.py @@ -0,0 +1,65 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import awkward as ak +from awkward._dispatch import high_level_function + +__all__ = ("from_torch",) + + +@high_level_function() +def from_torch(array): + """ + Args: + array: (PyTorch Tensor): + Tensor to convert into an Awkward Array. + + Converts a PyTorch Tensor into an Awkward Array. + + If `array` contains any other data types the function raises an error. + """ + + # Dispatch + yield (array,) + + # Implementation + return _impl(array) + + +def _impl(array): + try: + import torch + except ImportError as err: + raise ImportError( + """to use ak.from_torch, you must install 'torch' package with: + + pip install torch + +or + + conda install pytorch""" + ) from err + + # check if array is a Tensor + if not isinstance(array, torch.Tensor): + raise TypeError("""only PyTorch Tensor can be converted to Awkward Array""") + + # keep the resulting array on the same device as input tensor + device = "cuda" if array.is_cuda else "cpu" + + # convert tensors to cupy if they are on cuda + if device == "cuda": + from awkward._nplikes.cupy import Cupy + + cp = Cupy.instance() + + # zero-copy data exchange through DLPack + cp_array = cp.from_dlpack(array) + ak_array = ak.from_cupy(cp_array) + + else: + np_array = array.numpy() + ak_array = ak.from_numpy(np_array) + + return ak_array diff --git a/src/awkward/operations/ak_to_torch.py b/src/awkward/operations/ak_to_torch.py new file mode 100644 index 0000000000..9f45f288bc --- /dev/null +++ b/src/awkward/operations/ak_to_torch.py @@ -0,0 +1,74 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import awkward as ak +from awkward._dispatch import high_level_function +from awkward._nplikes.numpy_like import NumpyMetadata + +__all__ = ("to_torch",) + +np = NumpyMetadata.instance() + + +@high_level_function() +def to_torch(array): + """ + Args: + array: Array-like data. May be a high level #ak.Array, + or low-level #ak.contents.ListOffsetArray, #ak.contents.ListArray, + #ak.contents.RegularArray, #ak.contents.NumpyArray + + Converts `array` (only ListOffsetArray, ListArray, RegularArray and NumpyArray data types supported) + into a PyTorch Tensor, if possible. + + If `array` contains any other data types (RecordArray for example) the function raises a TypeError. + """ + + # Dispatch + yield (array,) + + # Implementation + return _impl(array) + + +def _impl(array): + try: + import torch + except ImportError as err: + raise ImportError( + """to use ak.to_torch, you must install 'torch' package with: + + pip install torch + +or + + conda install pytorch""" + ) from err + + # useful function that handles all possible input arrays + array = ak.to_layout(array, allow_record=False) + + # get the device array is on + device = ak.backend(array) + + if device not in ["cuda", "cpu"]: + raise ValueError("Only 'cpu' and 'cuda' backend conversions are allowed") + + # convert to numpy or cupy if `array` on gpu + try: + backend_array = array.to_backend_array(allow_missing=False) + except ValueError as err: + raise TypeError( + "Only arrays containing equal-length lists of numbers can be converted into a PyTorch Tensor" + ) from err + + # check if cupy or numpy + if isinstance(backend_array, np.ndarray): + # convert numpy to a torch tensor + tensor = torch.from_numpy(backend_array) + else: + # cupy -> torch tensor + tensor = torch.utils.dlpack.from_dlpack(backend_array.toDlpack()) + + return tensor diff --git a/tests/test_3259_to_torch_from_torch.py b/tests/test_3259_to_torch_from_torch.py new file mode 100644 index 0000000000..a20567c113 --- /dev/null +++ b/tests/test_3259_to_torch_from_torch.py @@ -0,0 +1,72 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import numpy as np +import pytest + +import awkward as ak + +to_torch = ak.operations.to_torch +from_torch = ak.operations.from_torch + +torch = pytest.importorskip("torch") + +a = np.arange(2 * 2 * 2, dtype=np.float64).reshape(2, 2, 2) +b = np.arange(2 * 2 * 2).reshape(2, 2, 2) + +array = np.arange(2 * 3 * 5).reshape(2, 3, 5) +content2 = ak.contents.NumpyArray(array.reshape(-1)) +inneroffsets = ak.index.Index64(np.array([0, 5, 10, 15, 20, 25, 30])) +outeroffsets = ak.index.Index64(np.array([0, 3, 6])) + + +def test_to_torch(): + # a basic test for a 4 dimensional array + array1 = ak.Array([a, b]) + i = 0 + for sub_array in [ + [[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]], + [[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]], + ]: + assert to_torch(array1)[i].tolist() == sub_array + i += 1 + + # test that the data types are remaining the same (float64 in this case) + assert array1.layout.to_backend_array().dtype.name in str(to_torch(array1).dtype) + + # try a listoffset array inside a listoffset array + array2 = ak.contents.ListOffsetArray( + outeroffsets, ak.contents.ListOffsetArray(inneroffsets, content2) + ) + assert to_torch(array2)[0].tolist() == [ + [0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + ] + assert to_torch(array2)[1].tolist() == [ + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24], + [25, 26, 27, 28, 29], + ] + + # try just a python list + array3 = [3, 1, 4, 1, 9, 2, 6] + assert to_torch(array3).tolist() == [3, 1, 4, 1, 9, 2, 6] + + +array1 = torch.tensor([[1.0, -1.0], [1.0, -1.0]], dtype=torch.float32) +array2 = torch.tensor(np.array([[1, 2, 3], [4, 5, 6]])) + + +def test_from_torch(): + # Awkward.to_list() == Tensor.tolist() + assert from_torch(array1).to_list() == array1.tolist() + + assert from_torch(array2).to_list() == array2.tolist() + + # test that the data types are remaining the same (int64 in this case) + assert from_torch(array1).layout.dtype.name in str(array1.dtype) + + # test that the data types are remaining the same (float32 in this case) + assert from_torch(array2).layout.dtype.name in str(array2.dtype) From d7264c1250a0c948ba65a056933119befc31b8fe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Oct 2024 11:53:38 -0400 Subject: [PATCH 05/21] chore: update pre-commit hooks (#3245) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.4 → v0.6.8](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.4...v0.6.8) - [github.com/python-jsonschema/check-jsonschema: 0.29.2 → 0.29.3](https://github.com/python-jsonschema/check-jsonschema/compare/0.29.2...0.29.3) - [github.com/abravalheri/validate-pyproject: v0.19 → v0.20.2](https://github.com/abravalheri/validate-pyproject/compare/v0.19...v0.20.2) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d8dac77444..9ae65f9cd4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: additional_dependencies: [pyyaml] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.6.8 hooks: - id: ruff args: ["--fix", "--show-fixes"] @@ -62,7 +62,7 @@ repos: files: ^tests/ - repo: https://github.com/python-jsonschema/check-jsonschema - rev: 0.29.2 + rev: 0.29.3 hooks: - id: check-github-workflows args: ["--verbose"] @@ -76,6 +76,6 @@ repos: - numpy>=1.24 - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.19 + rev: v0.20.2 hooks: - id: validate-pyproject From 88ba3e5210643a37b462857c52a7fe5cd9399dd3 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Sat, 5 Oct 2024 09:36:52 -0500 Subject: [PATCH 06/21] fix: ak.typetracer.length_one_if_typetracer with option and union types (#3266) * fix: ak.typetracer.length_one_if_typetracer with option and union types * forgot to add the test * style: pre-commit fixes * no, not the Emacs backup file --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/awkward/forms/form.py | 64 ++++++++++++++++--- ...gth_one_if_typetracer_with_option_types.py | 14 ++++ 2 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 tests/test_3264_length_one_if_typetracer_with_option_types.py diff --git a/src/awkward/forms/form.py b/src/awkward/forms/form.py index 3f9bec55eb..49082970eb 100644 --- a/src/awkward/forms/form.py +++ b/src/awkward/forms/form.py @@ -558,6 +558,52 @@ def max_prefer_unknown(this: ShapeItem, that: ShapeItem) -> ShapeItem: container = {} + def prepare_empty(form): + form_key = f"node-{len(container)}" + + if isinstance(form, (ak.forms.BitMaskedForm, ak.forms.ByteMaskedForm)): + container[form_key] = b"" + return form.copy(content=prepare_empty(form.content), form_key=form_key) + + elif isinstance(form, ak.forms.IndexedOptionForm): + container[form_key] = b"" + return form.copy(content=prepare_empty(form.content), form_key=form_key) + + elif isinstance(form, ak.forms.EmptyForm): + return form + + elif isinstance(form, ak.forms.UnmaskedForm): + return form.copy(content=prepare_empty(form.content)) + + elif isinstance(form, (ak.forms.IndexedForm, ak.forms.ListForm)): + container[form_key] = b"" + return form.copy(content=prepare_empty(form.content), form_key=form_key) + + elif isinstance(form, ak.forms.ListOffsetForm): + container[form_key] = b"" + return form.copy(content=prepare_empty(form.content), form_key=form_key) + + elif isinstance(form, ak.forms.RegularForm): + return form.copy(content=prepare_empty(form.content)) + + elif isinstance(form, ak.forms.NumpyForm): + container[form_key] = b"" + return form.copy(form_key=form_key) + + elif isinstance(form, ak.forms.RecordForm): + return form.copy(contents=[prepare_empty(x) for x in form.contents]) + + elif isinstance(form, ak.forms.UnionForm): + # both tags and index will get this buffer + container[form_key] = b"" + return form.copy( + contents=[prepare_empty(x) for x in form.contents], + form_key=form_key, + ) + + else: + raise AssertionError(f"not a Form: {form!r}") + def prepare(form, multiplier): form_key = f"node-{len(container)}" @@ -566,11 +612,13 @@ def prepare(form, multiplier): container[form_key] = b"\x00" * multiplier else: container[form_key] = b"\xff" * multiplier - return form.copy(form_key=form_key) # DO NOT RECURSE + # switch from recursing down `prepare` to `prepare_empty` + return form.copy(content=prepare_empty(form.content), form_key=form_key) elif isinstance(form, ak.forms.IndexedOptionForm): container[form_key] = b"\xff\xff\xff\xff\xff\xff\xff\xff" # -1 - return form.copy(form_key=form_key) # DO NOT RECURSE + # switch from recursing down `prepare` to `prepare_empty` + return form.copy(content=prepare_empty(form.content), form_key=form_key) elif isinstance(form, ak.forms.EmptyForm): # no error if protected by non-recursing node type @@ -624,13 +672,11 @@ def prepare(form, multiplier): elif isinstance(form, ak.forms.UnionForm): # both tags and index will get this buffer, but index is 8 bytes container[form_key] = b"\x00" * (8 * multiplier) - return form.copy( - # only recurse down contents[0] because all index == 0 - contents=( - [prepare(form.contents[0], multiplier)] + form.contents[1:] - ), - form_key=form_key, - ) + # recurse down contents[0] with `prepare`, but others with `prepare_empty` + contents = [prepare(form.contents[0], multiplier)] + for x in form.contents[1:]: + contents.append(prepare_empty(x)) + return form.copy(contents=contents, form_key=form_key) else: raise AssertionError(f"not a Form: {form!r}") diff --git a/tests/test_3264_length_one_if_typetracer_with_option_types.py b/tests/test_3264_length_one_if_typetracer_with_option_types.py new file mode 100644 index 0000000000..dda733248d --- /dev/null +++ b/tests/test_3264_length_one_if_typetracer_with_option_types.py @@ -0,0 +1,14 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE +# ruff: noqa: E402 + +from __future__ import annotations + +import awkward as ak + + +def test(): + arr = ak.Array([[1], [2, 3], [1, 2, 4, 5]])[[0, None, 2]] + l1 = ak.typetracer.length_one_if_typetracer(ak.to_backend(arr, "typetracer")) + + assert l1.to_list() == [None] + assert str(l1.type) == "1 * option[var * int64]" From c75eeb43f0348c143a35138c2bf8356d7ed3dc53 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Mon, 7 Oct 2024 10:21:01 -0400 Subject: [PATCH 07/21] ci: use official GHA for uv (#3269) Signed-off-by: Henry Schreiner --- .github/workflows/build-wheels.yml | 6 +++--- .github/workflows/packaging-test.yml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build-wheels.yml b/.github/workflows/build-wheels.yml index e171eb5431..3336ee49ed 100644 --- a/.github/workflows/build-wheels.yml +++ b/.github/workflows/build-wheels.yml @@ -100,7 +100,7 @@ jobs: python-version: '3.12' - name: Setup uv - uses: yezz123/setup-uv@v4 + uses: astral-sh/setup-uv@v3 - name: Prepare build files run: pipx run nox -s prepare @@ -150,7 +150,7 @@ jobs: python-version: '3.12' - name: Setup uv - uses: yezz123/setup-uv@v4 + uses: astral-sh/setup-uv@v3 - name: Prepare build files run: pipx run nox -s prepare @@ -192,7 +192,7 @@ jobs: submodules: true - name: Setup uv - uses: yezz123/setup-uv@v4 + uses: astral-sh/setup-uv@v3 - name: Prepare build files run: pipx run nox -s prepare diff --git a/.github/workflows/packaging-test.yml b/.github/workflows/packaging-test.yml index de30eecc54..5ecbc3e513 100644 --- a/.github/workflows/packaging-test.yml +++ b/.github/workflows/packaging-test.yml @@ -63,7 +63,7 @@ jobs: submodules: true - name: Setup uv - uses: yezz123/setup-uv@v4 + uses: astral-sh/setup-uv@v3 - name: Prepare build files run: pipx run nox -s prepare From e13c5e5a1e6196761d25bf35de95a69f67b70b59 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Mon, 7 Oct 2024 10:22:26 -0400 Subject: [PATCH 08/21] ci: restore Windows 32-bit wheels (#3268) --- .github/workflows/build-wheels.yml | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/.github/workflows/build-wheels.yml b/.github/workflows/build-wheels.yml index 3336ee49ed..b4de99315b 100644 --- a/.github/workflows/build-wheels.yml +++ b/.github/workflows/build-wheels.yml @@ -87,7 +87,7 @@ jobs: - os: windows-latest arch: auto32 - build: "cp{38,39}-" + build: "cp" steps: - uses: actions/checkout@v4 @@ -114,13 +114,7 @@ jobs: package-dir: awkward-cpp - name: Check metadata - shell: python - run: | - import subprocess, glob - subprocess.run( - ["pipx", "run", "twine", "check", *glob.glob("wheelhouse/*.whl")], - check=True - ) + run: pipx run twine check wheelhouse/*.whl - name: Upload wheels uses: actions/upload-artifact@v4 @@ -166,13 +160,7 @@ jobs: package-dir: awkward-cpp - name: Check metadata - shell: python - run: | - import subprocess, glob - subprocess.run( - ["pipx", "run", "twine", "check", *glob.glob("wheelhouse/*.whl")], - check=True - ) + run: pipx run twine check wheelhouse/*.whl - name: Upload wheels uses: actions/upload-artifact@v4 From 283fa618c02f2e8406609684c07f8a9512df169d Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Mon, 7 Oct 2024 09:24:28 -0500 Subject: [PATCH 09/21] chore: bump awkward and awkward-cpp versions --- awkward-cpp/pyproject.toml | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/awkward-cpp/pyproject.toml b/awkward-cpp/pyproject.toml index 43411b9299..2df2f56fca 100644 --- a/awkward-cpp/pyproject.toml +++ b/awkward-cpp/pyproject.toml @@ -7,7 +7,7 @@ build-backend = "scikit_build_core.build" [project] name = "awkward_cpp" -version = "38" +version = "39" dependencies = [ "numpy>=1.18.0", "importlib_resources;python_version < \"3.9\"" diff --git a/pyproject.toml b/pyproject.toml index b2876972bd..f5d1424359 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ build-backend = "hatchling.build" [project] name = "awkward" -version = "2.6.8" +version = "2.6.9" description = "Manipulate JSON-like data with NumPy-like idioms." license = { text = "BSD-3-Clause" } requires-python = ">=3.8" @@ -41,7 +41,7 @@ classifiers = [ "Topic :: Utilities", ] dependencies = [ - "awkward_cpp==38", + "awkward_cpp==39", "importlib_metadata>=4.13.0;python_version < \"3.12\"", "numpy>=1.18.0", "packaging", From 99efdc401e64ee4a4ffde6fb0fb44e0695f07eba Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 10:48:37 -0500 Subject: [PATCH 10/21] docs: add tacaswell as a contributor for maintenance (#3273) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] --------- Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> --- .all-contributorsrc | 9 +++++++++ README.md | 1 + 2 files changed, 10 insertions(+) diff --git a/.all-contributorsrc b/.all-contributorsrc index ccaf3440b8..9d3fd353bc 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -483,6 +483,15 @@ "contributions": [ "code" ] + }, + { + "login": "tacaswell", + "name": "Thomas A Caswell", + "avatar_url": "https://avatars.githubusercontent.com/u/199813?v=4", + "profile": "https://tacaswell.github.io", + "contributions": [ + "maintenance" + ] } ], "contributorsPerLine": 7, diff --git a/README.md b/README.md index 65d916050d..ca4edcc01e 100644 --- a/README.md +++ b/README.md @@ -230,6 +230,7 @@ Thanks especially to the gracious help of Awkward Array contributors (including Peter Fackeldey
Peter Fackeldey

💻 Andres Rios Tascon
Andres Rios Tascon

💻 maxymnaumchyk
maxymnaumchyk

💻 + Thomas A Caswell
Thomas A Caswell

🚧 From e8f5f9a1692538bc13f459232dac8be099d2a08b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:35:55 +0100 Subject: [PATCH 11/21] chore: update pre-commit hooks (#3275) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/pre-commit-hooks: v4.6.0 → v5.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.6.0...v5.0.0) - [github.com/astral-sh/ruff-pre-commit: v0.6.8 → v0.6.9](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.8...v0.6.9) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9ae65f9cd4..4ffeb0ed6f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ ci: exclude: ^(docs|studies|tests/samples|src/awkward/_typeparser/generated_parser.py) repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-added-large-files - id: check-case-conflict @@ -27,7 +27,7 @@ repos: additional_dependencies: [pyyaml] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.8 + rev: v0.6.9 hooks: - id: ruff args: ["--fix", "--show-fixes"] From 0a9253929f1420aba0a8dfbbd1efc63e8ec6595d Mon Sep 17 00:00:00 2001 From: Ianna Osborne Date: Wed, 9 Oct 2024 12:48:34 +0200 Subject: [PATCH 12/21] make sure it's a cupy zero dim array (#3271) --- .../cuda_kernels/awkward_ListArray_getitem_next_at.cu | 7 +++++++ .../cuda_kernels/awkward_RegularArray_getitem_next_at.cu | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu index 7fc13f2ae8..1fa5333af5 100644 --- a/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu +++ b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu @@ -1,5 +1,12 @@ // BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE +// BEGIN PYTHON +// def f(grid, block, args): +// (tocarry, fromstarts, fromstops, lenstarts, at, invocation_index, err_code) = args +// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_next_at", tocarry.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tocarry, fromstarts, fromstops, lenstarts, cupy.array(at), invocation_index, err_code)) +// out["awkward_ListArray_getitem_next_at", {dtype_specializations}] = None +// END PYTHON + enum class LISTARRAY_GETITEM_NEXT_AT_ERRORS { IND_OUT_OF_RANGE, // message: "index out of range" }; diff --git a/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu b/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu index 1b8bd53b38..ddb8cd94ab 100644 --- a/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu +++ b/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu @@ -1,5 +1,13 @@ // BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE + +// BEGIN PYTHON +// def f(grid, block, args): +// (tocarry, at, length, size, invocation_index, err_code) = args +// cuda_kernel_templates.get_function(fetch_specialization(["awkward_RegularArray_getitem_next_at", tocarry.dtype]))(grid, block, (tocarry, cupy.array(at), length, size, invocation_index, err_code)) +// out["awkward_RegularArray_getitem_next_at", {dtype_specializations}] = None +// END PYTHON + enum class REGULARARRAY_GETITEM_NEXT_AT_ERRORS { IND_OUT_OF_RANGE // message: "index out of range" }; From cd7d7f6f43216b37dbf467521b8e7720f2bded91 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Fri, 11 Oct 2024 12:14:24 -0400 Subject: [PATCH 13/21] feat: named axis for `ak.Array` (#3238) * start implementing named axis for awkward array * style: pre-commit fixes * add support for named axis for first batch of highlevel functions * formatting * next batch of high-level functions * fix type hints & safer control flow when checking for named axis * style: pre-commit fixes * (hopefully) fix old (<3.10) python type annotation syntax * (hopefully) fix old (<3.10) python type annotation syntax * next batch of highlevel functions * next batch of highlevel functions * style: pre-commit fixes * update named axis implementation to not use tuples at all; start indexing named axis propagation * add named axis propagation for binary ops, some highlevel ops, and fix named axis propagation in indexing for type tracers * fix keepdims in ak.covar & ak.corr; properly propagate named axis in ak.mean; remove inplace addition of arrays from test * add ak.std & ak.var; fix bug in indexing where == comparisons against Ellipsis failed * add ak.(arg)combinations and ak.(arg)cartesian; make named axis compatible with branched structures;fix regularize_axis in all highlevel ops * avoid touching shape too much for purelist_depth, minmax_depth, and branch_depth using inner_shape property * ak.without_named_axis: allow ak.Records * ak.with_named_axis: add check to validate the given named axis mapping * fix doc strings and remove obsolete functions in _namedaxis.py module * update Slicer doc string * docs: add documentation page for named axes * propagate named axis through broadcasting; add more highlevel ops (that depended on broadcasting to work) * add test for ak.where with named axis * fix test using pyarrow * add test case for ak.broadcast_fields * streamline code * add named axis to constructor, repr and .show(...) of highlevel ak.Array & ak.Record;streamline code more;add some more tests * satisfy pylint * improve docs, comments, and add named/positional axis property to Records * remove ak.Slicer as numpy provides np.s_ and is a strict dependency * fix docs * mark right-broadcasting test with xfail for windows 32-bit * add tests & proper support for negative named axes * satisfy pylint * make xfail marker not strict * chore: enable mypy on namedaxis * fix: use attribute name in type hints * refactor: rename _NamedAxisKey * fix: define NAMED_AXIS_KEY as a literal * fix: appease mypy * highlevel getitem: less (un)wrapping * highlevel: improve repr and .show for named_axis * named_axis: fix typo and avoid dictionary copies where possible * named_axis: improve doc string of _prettify_named_axes * highlevel: fix instance check for attrs in __init__ * add named_axis to jupyter repr * fix docs * connect _neg2pos_axis with maybe_posaxis * named_axis: simplify type hint for named axis in attrs mapping * regularize axis makes sure now that its type is either int or None --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Angus Hollands --- .gitignore | 1 + docs/_toc.yml | 12 +- docs/conf.py | 2 +- .../how-to-array-properties-named-axis.md | 304 +++ docs/user-guide/how-to-array-properties.md | 23 + pyproject.toml | 1 + src/awkward/__init__.py | 1 + src/awkward/_broadcasting.py | 84 +- src/awkward/_connect/numexpr.py | 48 +- src/awkward/_connect/numpy.py | 36 +- src/awkward/_layout.py | 62 +- src/awkward/_namedaxis.py | 764 ++++++ src/awkward/_nplikes/array_like.py | 3 +- src/awkward/_nplikes/typetracer.py | 3 +- src/awkward/_operators.py | 1 + src/awkward/_regularize.py | 21 +- src/awkward/_typing.py | 4 + src/awkward/contents/content.py | 118 +- src/awkward/contents/numpyarray.py | 13 +- src/awkward/highlevel.py | 203 +- src/awkward/operations/__init__.py | 2 + src/awkward/operations/ak_all.py | 41 +- src/awkward/operations/ak_almost_equal.py | 9 + src/awkward/operations/ak_any.py | 41 +- src/awkward/operations/ak_argcartesian.py | 3 - src/awkward/operations/ak_argcombinations.py | 9 +- src/awkward/operations/ak_argmax.py | 41 +- src/awkward/operations/ak_argmin.py | 40 +- src/awkward/operations/ak_argsort.py | 19 +- src/awkward/operations/ak_array_equal.py | 3 + src/awkward/operations/ak_broadcast_arrays.py | 38 +- src/awkward/operations/ak_cartesian.py | 65 +- src/awkward/operations/ak_categories.py | 12 +- src/awkward/operations/ak_combinations.py | 16 +- src/awkward/operations/ak_concatenate.py | 46 +- src/awkward/operations/ak_corr.py | 26 +- src/awkward/operations/ak_count.py | 41 +- src/awkward/operations/ak_count_nonzero.py | 43 +- src/awkward/operations/ak_covar.py | 24 +- src/awkward/operations/ak_drop_none.py | 17 +- src/awkward/operations/ak_fill_none.py | 13 +- src/awkward/operations/ak_firsts.py | 38 +- src/awkward/operations/ak_flatten.py | 48 +- src/awkward/operations/ak_from_regular.py | 3 +- src/awkward/operations/ak_is_none.py | 34 +- src/awkward/operations/ak_isclose.py | 23 +- src/awkward/operations/ak_linear_fit.py | 14 +- src/awkward/operations/ak_local_index.py | 33 +- src/awkward/operations/ak_mask.py | 25 +- src/awkward/operations/ak_max.py | 41 +- src/awkward/operations/ak_mean.py | 29 +- .../operations/ak_merge_option_of_records.py | 11 +- .../operations/ak_merge_union_of_records.py | 11 +- src/awkward/operations/ak_min.py | 41 +- src/awkward/operations/ak_moment.py | 49 +- src/awkward/operations/ak_nan_to_none.py | 12 +- src/awkward/operations/ak_nan_to_num.py | 51 +- src/awkward/operations/ak_num.py | 51 +- src/awkward/operations/ak_pad_none.py | 12 +- src/awkward/operations/ak_prod.py | 41 +- src/awkward/operations/ak_ptp.py | 27 +- src/awkward/operations/ak_ravel.py | 11 +- src/awkward/operations/ak_real.py | 12 +- src/awkward/operations/ak_singletons.py | 36 +- src/awkward/operations/ak_softmax.py | 28 +- src/awkward/operations/ak_sort.py | 19 +- src/awkward/operations/ak_std.py | 28 +- src/awkward/operations/ak_strings_astype.py | 1 + src/awkward/operations/ak_sum.py | 41 +- src/awkward/operations/ak_to_backend.py | 2 +- src/awkward/operations/ak_to_regular.py | 2 +- src/awkward/operations/ak_transform.py | 41 +- src/awkward/operations/ak_unflatten.py | 27 +- src/awkward/operations/ak_unzip.py | 8 +- src/awkward/operations/ak_values_astype.py | 1 + src/awkward/operations/ak_var.py | 28 +- src/awkward/operations/ak_where.py | 27 +- src/awkward/operations/ak_with_field.py | 26 +- src/awkward/operations/ak_with_named_axis.py | 72 + .../operations/ak_without_named_axis.py | 54 + src/awkward/operations/ak_zip.py | 25 +- src/awkward/operations/str/akstr_join.py | 26 +- .../operations/str/akstr_join_element_wise.py | 27 +- src/awkward/operations/str/akstr_repeat.py | 26 +- tests/test_2596_named_axis.py | 2243 +++++++++++++++++ 85 files changed, 5400 insertions(+), 258 deletions(-) create mode 100644 docs/user-guide/how-to-array-properties-named-axis.md create mode 100644 docs/user-guide/how-to-array-properties.md create mode 100644 src/awkward/_namedaxis.py create mode 100644 src/awkward/operations/ak_with_named_axis.py create mode 100644 src/awkward/operations/ak_without_named_axis.py create mode 100644 tests/test_2596_named_axis.py diff --git a/.gitignore b/.gitignore index d28e4e3ea8..fd5b1bf8cf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ studies/**/sample-* +studies/named_axis.* docs/demos/countries.geojson docs/demos/test-program docs/demos/test-program.cpp diff --git a/docs/_toc.yml b/docs/_toc.yml index 2f4ff6663f..12c406ab15 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -4,12 +4,11 @@ title: "Awkward Array" defaults: titlesonly: True - subtrees: - entries: - file: getting-started/index subtrees: - - entries: + - entries: - file: getting-started/what-is-an-awkward-array - file: getting-started/10-minutes-to-awkward-array - file: getting-started/uproot-awkward-columnar-hats @@ -18,7 +17,7 @@ subtrees: - file: getting-started/papers-and-talks - file: user-guide/index subtrees: - - entries: + - entries: - file: user-guide/how-to-convert title: "Converting arrays" subtrees: @@ -74,6 +73,13 @@ subtrees: - file: user-guide/how-to-examine-checking-validity title: "Checking validity" + - file: user-guide/how-to-array-properties + title: "Array properties" + subtrees: + - entries: + - file: user-guide/how-to-array-properties-named-axis + title: "Named axes" + - file: user-guide/how-to-math title: "Numerical math" subtrees: diff --git a/docs/conf.py b/docs/conf.py index 37faac9e39..f6ea6f5e64 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -140,7 +140,7 @@ html_js_files = ["js/awkward.js"] # MyST settings -myst_enable_extensions = ["colon_fence"] +myst_enable_extensions = ["colon_fence", "deflist"] nb_execution_mode = "cache" nb_execution_raise_on_error = True diff --git a/docs/user-guide/how-to-array-properties-named-axis.md b/docs/user-guide/how-to-array-properties-named-axis.md new file mode 100644 index 0000000000..9d5321b67f --- /dev/null +++ b/docs/user-guide/how-to-array-properties-named-axis.md @@ -0,0 +1,304 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.14.1 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +Named axes +========== + +Named axes are a feature in Awkward Array that allows you to give names to the axes of an array. +This can be useful for documentation, debugging, and for writing code that is more robust to changes in the structure of the data. +As argumented at [PyHEP.dev 2023](https://indico.cern.ch/event/1234156/) and by the Harvard NLP group in their ["Tensor Considered Harmful"](https://nlp.seas.harvard.edu/NamedTensor.html) write-up, named axes can be a powerful tool to make code more readable and less error-prone. + +Awkward array ensures that named axes are properly propagated to the result. +All highlevel, indexing, and broadcasting operations in awkward array support named axes. + +Other libraries that support named axes include: +- [hist](https://hist.readthedocs.io/en/latest/) +- [haliax](https://github.com/stanford-crfm/haliax) +- [Tensor Considered Harmful](https://nlp.seas.harvard.edu/NamedTensor.html) +- [PyTorch Named Tensors](https://pytorch.org/docs/stable/name_inference.html#name-inference-reference-doc) +- [Penzai Named Axis](https://penzai.readthedocs.io/en/stable/notebooks/named_axes.html) +- [xarray Named Axis](https://docs.xarray.dev/en/stable/user-guide/indexing.html#) + +Named axes in Awkward Array are inspired primarily by `hist` and `PyTorch Named Tensors`. + ++++ + +How to (de-)attach named axes? +------------------------- + +Named axes can be attached to an array using the high-level {func}`ak.with_named_axis` function. +Awkward Array allows strings as named axes and integers as positional axes. + +The `named_axis` argument of {func}`ak.with_named_axis` accepts either a `tuple` or `dict`: +- `tuple`: + - `named axis`: item + - `positional axis`: index of the item + - _additional_: `None` represents a wildcard for not specifying a name, e.g.: `("x", None)` means that the first axis is named "x" and the second is not named. +- `dict`: + - `named axis`: key + - `positional axis`: value + - _additional_: not specifying a name is not allowed, e.g.: `{"x": 0}` means that the first axis is named "x", all other existing dimensions are unnamed. The `dict` option also allows for renaming negative axes, e.g.: `{"x": -1}` means that the last axis is named "x". + + +```{code-cell} +import awkward as ak +import numpy as np +``` + +The axis names of an array can be attached through the constructor: +```{code-cell} +named_array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y")) +# or +named_array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis={"x": 0, "y": 1}) +``` + +... or through `ak.with_named_axis`: +```{code-cell} +array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y")) +# or +named_array = ak.with_named_axis(array, named_axis={"x": 0, "y": 1}) +``` + +After attaching named axes, you can see the named axes comma-separated in the arrays representation and in `.show(named_axis=True)`: + +```{code-cell} +ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y")) +``` + +```{code-cell} +ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y")).show(named_axis=True) +``` + +Accessing the named axis mapping to positional axis can be done using the `named_axis` and `positional_axis` properties: + +```{code-cell} +named_array.named_axis +``` + +```{code-cell} +named_array.positional_axis +``` + +If you want to remove the named axes from an array, you can use the {func}`ak.without_named_axis` function: + +```{code-cell} +array = ak.without_named_axis(named_array) +array.named_axis +``` + + +Indexing with Named Axes +------------------------ + +Named axes can be used for indexing operations. +This is enabled throuhg a special syntax that allows you to index with a dictionary where keys refer to named (or positional) axes and the values to the slice or index. + +Simple examples: + +```{code-cell} +array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y", "z")) + +# named axes +named_array[{"x": 0}] # array[0, :, :] +named_array[{"z": 0}] # array[:, :, 0] + +named_array[{"x": 0, "y": 0}] # array[0, 0, :] +named_array[{"x": slice(0, 1), "y": 0}] # array[0:1, 0, :] + +named_array[named_array > 3] # array[array > 3] + + +# positional axes +named_array[{0: 0}] # array[0, :, :] +named_array[{2: 0}] # array[:, :, 0] + +named_array[{-3: 0}] # array[0, :, :] +named_array[{-1: 0}] # array[:, :, 0] +None +``` + +If multiple keys that point to the same positional axis are used, the last key will be used and all others will be ignored: + +```{code-cell} +array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y", "z")) + +assert ak.all(named_array[{0: 0, "x": slice(0, 2)}] == named_array[0:2]) +assert ak.all(named_array[{"x": slice(0, 2), 0: 0}] == named_array[0]) +``` + + +More detailed example: + +```{code-cell} +# create a Record Array that represents four events with a variable number of jets +events = ak.zip({ + "event_no": np.arange(4), + "jetpt": ak.Array([[50, 60], [45], [], [80, 30, 50]]), +}) +named_events = ak.with_named_axis(events, ("events", "jets")) + +print("classic indexing:", named_events[0, 0:1]) +print("named indexing :", named_events[{"events": 0, "jets": slice(0, 1)}]) +``` + +For syntatic suger, use `np.s_` to define slices more easily: + +```{code-cell} +array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y", "z")) + +assert ak.all(named_array[{"x": np.s_[0:2]}] == named_array[{"x": slice(0, 2)}]) +``` + +Highlevel Operations with Named Axes +------------------------------------ + +Named axes can be used for specifying the axis of a highlevel operation given that the operation is performed on an array that supports this named axis. + +For example, the `ak.sum` operation can be performed on an array with named axes: + +```{code-cell} +array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y", "z")) + +print("Sum over axis 'x':", ak.sum(named_array, axis="x")) # ak.sum(array, axis=0) +print("Sum over axis 'y':", ak.sum(named_array, axis="y")) # ak.sum(array, axis=1) +print("Sum over axis 'z':", ak.sum(named_array, axis="z")) # ak.sum(array, axis=2) +``` + + +Named Axes Propagation Strategies +--------------------------------- + + +Named axes are propagated through all operations in Awkward Array. +For this, specific strategies are defined for each operation to ensure that the named axes are properly propagated to the result. + +The possible strategies are: +- `keep all`: keep all named axes +- `keep one`: keep one named axis +- `keep up to`: keep all named axes up to a certain positional axis +- `remove all`: remove all named axis +- `remove one`: remove one named axis +- `add one`: add a new axis +- `unify`: unify named axes of two arrays. The named axes are unifiable if the have the same name (or `None`) and point to the same positional axis. + +Indexing operations +: The following table shows the strategy for indexing operations: + +| Operation | Strategy | +|----------------------|--------------| +| `array[:]` | `keep all` | +| `array[...]` | `keep all` | +| `array[()]` | `keep all` | +| `array[0:1]` | `keep all` | +| `array[[0, 1]]` | `keep all` | +| `array[array % 2]` | `keep all` | +| `array[0]` | `remove one` | +| `array[np.array(0)]` | `remove one` | +| `array[None]` | `add one` | +| `array[np.newaxis]` | `add one` | + +Universal functions (`ufuncs`) +: `ufuncs` with single argument signatures (i.e. unary operations, such as `__abs__`, `__neg__`, `__invert__`, ...) do not modify named axes (strategy: `keep all`). +: `ufuncs` with two argument signatures (i.e. binary operations, such as `__add__`, `__sub__`, `__mul__`, ...) try to merge named axis of the given arrays (strategy: `unify`). + This means that the named axes of the two arrays are merged if they have the same name (or either is `None`) and point to the same positional axis. + If there's a mismatch of named axes, e.g., the same named axis has different names or point to different positional axes, an exception is raised. + +```{code-cell} +array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y")) + +# unary operations with named axes +assert (-named_array).named_axis == {"x": 0, "y": 1} +assert (+named_array).named_axis == {"x": 0, "y": 1} +assert (~named_array).named_axis == {"x": 0, "y": 1} +assert abs(named_array).named_axis == {"x": 0, "y": 1} + +# binary operations with named axes +named_array1 = ak.with_named_axis(array, named_axis=(None, "y")) +named_array2 = ak.with_named_axis(array, named_axis=("x", None)) +named_array3 = ak.with_named_axis(array, named_axis=("x", "y")) + +assert (array + array).named_axis == {} +assert (named_array1 + array).named_axis == {"y": 1} +assert (named_array2 + array).named_axis == {"x": 0} +assert (named_array3 + array).named_axis == {"x": 0, "y": 1} + +assert (named_array1 + named_array2).named_axis == {"x": 0, "y": 1} +assert (named_array3 + named_array3).named_axis == {"x": 0, "y": 1} +``` + +Reducers (`ak.sum`, `ak.any`, ...) +: If `axis=int` and `keepdims=False` (typical use-case) removes the named axis that is reduced (strategy: `remove one`). +: If `keepdims=True` is set, the named axis is kept (strategy: `keep all`). +: If `axis=None` is set, all named axes are removed (strategy: `remove all`). + +```{code-cell} +array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) +named_array = ak.with_named_axis(array, ("x", "y")) + +assert ak.sum(named_array, axis="x", keepdims=False).named_axis == {"y": 0} +assert ak.sum(named_array, axis="x", keepdims=True).named_axis == {"x": 0, "y": 1} +``` + +--- +A full list of operations and their strategies can be found in the following table. +If an operation is not listed, the strategy is either `keep all` or automatically inferred from the below listed operations. + + +| Operation | Strategy | +|-----------------------------------------------------|--------------------| +| `ak.all(..., axis=None)` | `remove all` | +| `ak.all(..., axis=int, keepdims=False)` | `remove one` | +| `ak.all(..., axis=int, keepdims=True)` | `keep all` | +| `ak.any(..., axis=None)` | `remove all` | +| `ak.any(..., axis=int, keepdims=False)` | `remove one` | +| `ak.any(..., axis=int, keepdims=True)` | `keep all` | +| `ak.[arg]cartesian` | `unify` | +| `ak.[arg]combinations` | `keep all` | +| `ak.[arg]max(..., axis=None)` | `remove all` | +| `ak.[arg]max(..., axis=int, keepdims=False)` | `remove one` | +| `ak.[arg]max(..., axis=int, keepdims=True)` | `keep all` | +| `ak.[arg]min(..., axis=None)` | `remove all` | +| `ak.[arg]min(..., axis=int, keepdims=False)` | `remove one` | +| `ak.[arg]min(..., axis=int, keepdims=True)` | `keep all` | +| `ak.[arg]sort` | `keep all` | +| `ak.broadcast_arrays` | `unify`, `add one` | +| `ak.broadcast_fields` | `unify`, `add one` | +| `ak.categories` | `remove all` | +| `ak.concatenate` | `unify` | +| `ak.count[_nonzero](..., axis=None)` | `remove all` | +| `ak.count[_nonzero](..., axis=int, keepdims=False)` | `remove one` | +| `ak.count[_nonzero](..., axis=int, keepdims=True)` | `keep all` | +| `ak.firsts` | `remove one` | +| `ak.flatten(..., axis=None)` | `remove all` | +| `ak.flatten(..., axis=0)` | `keep all` | +| `ak.flatten(..., axis=(!=0), keepdims=True)` | `remove one` | +| `ak.local_index` | `keep up to` | +| `ak.num` | `keep one` | +| `ak.prod(..., axis=None)` | `remove all` | +| `ak.prod(..., axis=int, keepdims=False)` | `remove one` | +| `ak.prod(..., axis=int, keepdims=True)` | `keep all` | +| `ak.ravel` | `remove all` | +| `ak.singletons` | `add one` | +| `ak.sum(..., axis=None)` | `remove all` | +| `ak.sum(..., axis=int, keepdims=False)` | `remove one` | +| `ak.sum(..., axis=int, keepdims=True)` | `keep all` | +| `ak.unflatten` | `remove all` | +| `ak.where` | `unify`, `add one` | +| `ak.with_field` | `unify`, `add one` | +| `ak.zip` | `unify`, `add one` | diff --git a/docs/user-guide/how-to-array-properties.md b/docs/user-guide/how-to-array-properties.md new file mode 100644 index 0000000000..be811e888e --- /dev/null +++ b/docs/user-guide/how-to-array-properties.md @@ -0,0 +1,23 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.10.3 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +Array properties +================ + +The user guide is a collection of "how to..." guides for common tasks. See the left side-bar (or bring it into view by clicking on the upper-left `≡`) to access the guides, grouped by topic. + +If you're looking for documentation on a specific function, see the API reference instead. + +You can test any examples in a new window/tab by clicking on [![Try It! ⭷](https://img.shields.io/badge/-Try%20It%21%20%E2%86%97-orange?style=for-the-badge)](https://awkward-array.org/doc/main/_static/try-it.html). + +




diff --git a/pyproject.toml b/pyproject.toml index f5d1424359..f10e745671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -232,6 +232,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = [ 'awkward._nplikes.*', + 'awkward._namedaxis', 'awkward._behavior.*', 'awkward._backends.*', 'awkward._meta.*', diff --git a/src/awkward/__init__.py b/src/awkward/__init__.py index c82e83777f..c84b655ccb 100644 --- a/src/awkward/__init__.py +++ b/src/awkward/__init__.py @@ -24,6 +24,7 @@ import awkward._errors import awkward._lookup import awkward._ext # strictly for unpickling from Awkward 1 +import awkward._namedaxis # third-party connectors import awkward._connect.numpy diff --git a/src/awkward/_broadcasting.py b/src/awkward/_broadcasting.py index 7eb2300372..8dab0af30e 100644 --- a/src/awkward/_broadcasting.py +++ b/src/awkward/_broadcasting.py @@ -11,6 +11,11 @@ import awkward as ak from awkward._backends.backend import Backend from awkward._backends.dispatch import backend_of +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + _add_named_axis, + _unify_named_axis, +) from awkward._nplikes.numpy import Numpy from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import ShapeItem, unknown_length @@ -319,10 +324,18 @@ def is_string_like(obj) -> bool: } -def left_broadcast_to(content: Content, depth: int) -> Content: - for _ in range(content.purelist_depth, depth): - content = RegularArray(content, 1, content.length) - return content +def _export_named_axis_from_depth_to_lateral( + idx: int, + depth_context: dict[str, Any], + lateral_context: dict[str, Any], +) -> None: + # set adjusted named axes to lateral (inplace) + named_axis, ndim = depth_context[NAMED_AXIS_KEY][idx] + seen_named_axis, _ = lateral_context[NAMED_AXIS_KEY][idx] + lateral_context[NAMED_AXIS_KEY][idx] = ( + _unify_named_axis(named_axis, seen_named_axis), + ndim, + ) def broadcast_regular_dim_size(contents: Sequence[ak.contents.Content]) -> ShapeItem: @@ -433,10 +446,32 @@ def apply_step( max_depth = max(x.purelist_depth for x in contents) if max_depth > 0 and all(x.purelist_isregular for x in contents): - nextinputs = [ - left_broadcast_to(o, max_depth) if isinstance(o, Content) else o - for o in inputs - ] + nextinputs = [] + + named_axes_with_ndims = depth_context[NAMED_AXIS_KEY] + seen_named_axes = lateral_context[NAMED_AXIS_KEY] + for i, ((named_axis, ndim), o) in enumerate( + zip(named_axes_with_ndims, inputs) + ): + if isinstance(o, Content): + # rightbroadcast + for _ in range(o.purelist_depth, max_depth): + o = RegularArray(o, 1, o.length) + # track new dimensions for named axis + # rightbroadcasting adds a new first(!) dimension at depth + seen_named_axis, seen_ndim = seen_named_axes[i] + named_axis = _add_named_axis(named_axis, depth, ndim) + depth_context[NAMED_AXIS_KEY][i] = ( + _unify_named_axis(named_axis, seen_named_axis), + ndim + 1, + ) + if o.is_leaf: + _export_named_axis_from_depth_to_lateral( + i, depth_context, lateral_context + ) + nextinputs.append(o) + else: + nextinputs.append(o) # Did a broadcast take place? if any(x is not y for x, y in zip(inputs, nextinputs)): return apply_step( @@ -538,6 +573,7 @@ def broadcast_any_list(): # Under the category of "is_list", we have both strings and non-strings # The strings should behave like non-lists within these routines. + named_axes_with_ndims = depth_context[NAMED_AXIS_KEY] # Are the non-string list types exclusively regular? if all(x.is_regular or (is_string_like(x) or not x.is_list) for x in contents): # Compute the expected dim size @@ -586,7 +622,9 @@ def broadcast_any_list(): # we don't left-broadcast nextinputs = [] nextparameters = [] - for x, x_is_string in zip(inputs, inputs_are_strings): + for i, ((named_axis, ndim), x, x_is_string) in enumerate( + zip(named_axes_with_ndims, inputs, inputs_are_strings) + ): if isinstance(x, RegularArray) and not x_is_string: content_size_maybe_one = ( x.size is not unknown_length and x.size == 1 @@ -603,6 +641,16 @@ def broadcast_any_list(): ) ) nextparameters.append(x._parameters) + # track new dimensions for named axis + # rightbroadcasting adds a new first(!) dimension as depth + depth_context[NAMED_AXIS_KEY][i] = ( + _add_named_axis(named_axis, depth, ndim), + ndim + 1, + ) + if x.is_leaf: + _export_named_axis_from_depth_to_lateral( + i, depth_context, lateral_context + ) # Any unknown values or sizes are assumed to be correct as-is elif ( dim_size is unknown_length @@ -667,7 +715,9 @@ def broadcast_any_list(): nextinputs = [] nextparameters = [] - for x, x_is_string in zip(inputs, input_is_string): + for i, ((named_axis, ndim), x, x_is_string) in enumerate( + zip(named_axes_with_ndims, inputs, input_is_string) + ): if isinstance(x, listtypes) and not x_is_string: next_content = broadcast_to_offsets_avoiding_carry(x, offsets) nextinputs.append(next_content) @@ -680,6 +730,16 @@ def broadcast_any_list(): .content ) nextparameters.append(NO_PARAMETERS) + # track new dimensions for named axis + # leftbroadcasting adds a new last dimension at depth + 1 + depth_context[NAMED_AXIS_KEY][i] = ( + _add_named_axis(named_axis, depth + 1, ndim), + ndim + 1, + ) + if x.is_leaf: + _export_named_axis_from_depth_to_lateral( + i, depth_context, lateral_context + ) else: nextinputs.append(x) nextparameters.append(NO_PARAMETERS) @@ -889,7 +949,7 @@ def action_logical_or(inputs, backend, **kwargs): (xy_mask, cond_mask), action_logical_or, 0, - None, + depth_context, lateral_context, simple_options, )[0] @@ -917,7 +977,7 @@ def apply_mask_action(inputs, backend, **kwargs): (xy_unmasked, mask), apply_mask_action, 0, - None, + depth_context, lateral_context, simple_options, ) diff --git a/src/awkward/_connect/numexpr.py b/src/awkward/_connect/numexpr.py index 85ab566c8c..50c3e0485f 100644 --- a/src/awkward/_connect/numexpr.py +++ b/src/awkward/_connect/numexpr.py @@ -4,12 +4,15 @@ import sys import warnings +from functools import reduce from packaging.version import parse as parse_version import awkward as ak -from awkward._behavior import behavior_of +from awkward._attrs import attrs_of_obj +from awkward._behavior import behavior_of, behavior_of_obj from awkward._layout import wrap_layout +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis _has_checked_version = False @@ -110,9 +113,26 @@ def action(inputs, **ignore): return None behavior = behavior_of(*arrays) - out = ak._broadcasting.broadcast_and_apply(arrays, action, allow_records=False) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arguments) + out = ak._broadcasting.broadcast_and_apply( + arrays, + action, + depth_context=depth_context, + lateral_context=lateral_context, + allow_records=False, + ) assert isinstance(out, tuple) and len(out) == 1 - return wrap_layout(out[0], behavior) + wrapped = wrap_layout(out[0], behavior) + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=True, + behavior=behavior_of_obj(wrapped), + attrs=attrs_of_obj(wrapped), + ) evaluate.evaluate = evaluate @@ -148,6 +168,24 @@ def action(inputs, **ignore): return None behavior = behavior_of(*arrays) - out = ak._broadcasting.broadcast_and_apply(arrays, action, allow_records=False) + + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arguments) + out = ak._broadcasting.broadcast_and_apply( + arrays, + action, + depth_context=depth_context, + lateral_context=lateral_context, + allow_records=False, + ) assert isinstance(out, tuple) and len(out) == 1 - return wrap_layout(out[0], behavior) + wrapped = wrap_layout(out[0], behavior) + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=True, + behavior=behavior_of_obj(wrapped), + attrs=attrs_of_obj(wrapped), + ) diff --git a/src/awkward/_connect/numpy.py b/src/awkward/_connect/numpy.py index f17ee98b36..7f5a7cdb08 100644 --- a/src/awkward/_connect/numpy.py +++ b/src/awkward/_connect/numpy.py @@ -22,6 +22,7 @@ ) from awkward._categorical import as_hashable from awkward._layout import wrap_layout +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes import to_nplike from awkward._parameters import parameters_intersect from awkward._regularize import is_non_string_like_iterable @@ -363,6 +364,8 @@ def array_ufunc(ufunc, method: str, inputs, kwargs: dict[str, Any]): attrs = attrs_of(*inputs) backend = backend_of(*inputs, coerce_to_common=True) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(inputs) + inputs = _array_ufunc_custom_cast(inputs, behavior, backend) def action(inputs, **ignore): @@ -464,13 +467,40 @@ def action(inputs, **ignore): return None out = ak._broadcasting.broadcast_and_apply( - inputs, action, allow_records=False, function_name=ufunc.__name__ + inputs, + action, + depth_context=depth_context, + lateral_context=lateral_context, + allow_records=False, + function_name=ufunc.__name__, ) + out_named_axis = functools.reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) if len(out) == 1: - return wrap_layout(out[0], behavior=behavior, attrs=attrs) + wrapped = wrap_layout(out[0], behavior=behavior, attrs=attrs) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=True, + behavior=None, + attrs=None, + ) else: - return tuple(wrap_layout(o, behavior=behavior, attrs=attrs) for o in out) + wrapped_out = [] + for o in out: + wrapped = wrap_layout(o, behavior=behavior, attrs=attrs) + wrapped_out.append( + ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=True, + behavior=None, + attrs=None, + ) + ) + return tuple(wrapped_out) def action_for_matmul(inputs): diff --git a/src/awkward/_layout.py b/src/awkward/_layout.py index 11cb4bcbe5..f642dd6ed4 100644 --- a/src/awkward/_layout.py +++ b/src/awkward/_layout.py @@ -56,9 +56,7 @@ def merge_mappings( class HighLevelContext: - def __init__( - self, behavior: Mapping | None = None, attrs: Mapping[str, Any] | None = None - ): + def __init__(self, behavior: Mapping | None = None, attrs: Mapping | None = None): self._behavior = behavior self._attrs = attrs self._is_finalized = False @@ -66,6 +64,22 @@ def __init__( self._attrs_from_objects = [] self._behavior_from_objects = [] + def with_attr(self, key, value) -> Self: + self._ensure_finalized() + return type(self)( + behavior=self.behavior, + attrs={**self.attrs, key: value}, + ).finalize() + + def without_attr(self, key) -> Self: + self._ensure_finalized() + attrs = dict(self.attrs) + attrs.pop(key, None) + return type(self)( + behavior=self.behavior, + attrs=attrs, + ).finalize() + def __enter__(self): return self @@ -81,8 +95,10 @@ def _ensure_not_finalized(self): raise RuntimeError("HighLevelContext has already been finalized") @property - def attrs(self) -> Mapping[str, Any] | None: + def attrs(self) -> Mapping: self._ensure_finalized() + if self._attrs is None: + self._attrs = {} return self._attrs @property @@ -154,7 +170,11 @@ def unwrap( ) def wrap( - self, obj: Any, *, highlevel: bool = True, allow_other: bool = False + self, + obj: Any, + *, + highlevel: bool = True, + allow_other: bool = False, ) -> Any: self._ensure_finalized() @@ -230,7 +250,7 @@ def maybe_highlevel_to_lowlevel(obj): Args: obj: an object - Calls #ak.to_layout and returns the result iff. the object is a high-level + Calls #ak.to_layout and returns the result if the object is a high-level Awkward object, otherwise the object is returned as-is. This function should be removed once scalars are properly handled by `to_layout`. @@ -372,6 +392,34 @@ def attach(x): return layout +def _neg2pos_axis( + axis: int, + total: int, +) -> int: + """ + Converts a negative axis index to a positive one. + + This function takes a negative axis index and the total number of axes and returns the corresponding positive axis index. + If the input axis index is already positive, it is returned as is. + + Args: + axis (int): The axis index to convert. Can be negative. + total (int): The total number of axes. + + Returns: + int: The positive axis index corresponding to the input axis index. + + Examples: + >>> _neg2pos_axis(-1, 3) + 2 + >>> _neg2pos_axis(1, 3) + 1 + """ + if axis < 0: + return total + axis + return axis + + def maybe_posaxis(layout: Content, axis: int, depth: int) -> int | None: from awkward.record import Record @@ -386,6 +434,6 @@ def maybe_posaxis(layout: Content, axis: int, depth: int) -> int | None: else: is_branching, additional_depth = layout.branch_depth if not is_branching: - return axis + depth + additional_depth - 1 + return _neg2pos_axis(axis, additional_depth) + depth - 1 else: return None diff --git a/src/awkward/_namedaxis.py b/src/awkward/_namedaxis.py new file mode 100644 index 0000000000..9f0b8d36cc --- /dev/null +++ b/src/awkward/_namedaxis.py @@ -0,0 +1,764 @@ +from __future__ import annotations + +import json +import re +from dataclasses import dataclass + +import awkward._typing as tp +from awkward._layout import _neg2pos_axis +from awkward._regularize import is_integer + +# axis names are hashables, mostly strings, +# except for integers, which are reserved for positional axis. +AxisName: tp.TypeAlias = tp.Hashable + +# e.g.: {"x": 0, "y": 1, "z": 2} +AxisMapping: tp.TypeAlias = tp.Mapping[AxisName, int] + +# e.g.: ("x", "y", None) where None is a wildcard +AxisTuple: tp.TypeAlias = tp.Tuple[AxisName, ...] + + +NAMED_AXIS_KEY: tp.Literal["__named_axis__"] = ( + "__named_axis__" # reserved for named axis +) + + +# just a class for inplace mutation +class NamedAxis: + mapping: AxisMapping + + +NamedAxis.mapping = {} + + +def _prettify_named_axes( + named_axis: AxisMapping, + delimiter: str = ", ", + maxlen: None | int = None, +) -> str: + """ + This function takes a named axis mapping and returns a string representation of the mapping. + The axis names are sorted in ascending order of their corresponding integer values. + If the axis name is a valid Python identifier, it is represented as is. + Otherwise, it is represented as a JSON string. + + Args: + named_axis (AxisMapping): The named axis mapping to prettify. + delimiter (str, optional): The delimiter to use between items in the output string. Defaults to ", ". + maxlen (None | int, optional): The maximum length of the output string. If the string exceeds this length, it is truncated and ends with "...". Defaults to None. + + Returns: + str: The prettified string representation of the named axis mapping. + + Examples: + >>> _prettify_named_axes({"x": 0, "y": 1, "z": 2}) + 'x:0, y:1, z:2' + >>> _prettify_named_axes({"x": 0, "y": 1, "$": 2}) + 'x:0, y:1, "$":2' + >>> _prettify_named_axes({"x": 0, "y": 1, "z": 2}, delimiter="; ") + 'x:0; y:1; z:2' + >>> _prettify_named_axes({"foo": 0, "bar": 1, "baz": 2}, maxlen=17) + 'foo:0, bar:1, ...' + """ + + def _prettify(ax: AxisName) -> str: + repr_ax = str(ax) + if re.match("[A-Za-z_][A-Za-z_0-9]*", repr_ax): + return repr_ax + return json.dumps(repr_ax) + + sorted_named_axis = sorted(named_axis.items(), key=lambda x: x[1]) + items = [ + f"{_prettify(named_ax)}:{pos_ax}" for named_ax, pos_ax in sorted_named_axis + ] + if maxlen is not None: + if len(delimiter.join(items)) > maxlen: + while ( + len(delimiter.join(items)) > maxlen - len(delimiter + "...") + ) and items: + items.pop(-1) + items.append("...") + return delimiter.join(items) + + +def _get_named_axis(ctx: tp.Any) -> AxisMapping: + """ + Retrieves the named axis from the provided context. + + Args: + ctx (Any): The context from which the named axis is to be retrieved. + + Returns: + AxisMapping: The named axis retrieved from the context. If the context does not include a named axis, + an empty dictionary is returned. + + Examples: + >>> _get_named_axis(ak.Array([1, 2, 3], named_axis={"x": 0})) + {"x": 0} + >>> _get_named_axis(np.array([1, 2, 3])) + {} + >>> _get_named_axis({NAMED_AXIS_KEY: {"x": 0, "y": 1, "z": 2}}) + {"x": 0, "y": 1, "z": 2} + >>> _get_named_axis({"other_key": "other_value"}) + {} + """ + if hasattr(ctx, "attrs"): + return _get_named_axis(ctx.attrs) + elif isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx: + return dict(ctx[NAMED_AXIS_KEY]) + else: + return {} + + +def _make_positional_axis_tuple(n: int) -> tuple[int, ...]: + """ + Generates a positional axis tuple of length n. + + Args: + n (int): The length of the positional axis tuple to generate. + + Returns: + tuple[int, ...]: The generated positional axis tuple. + + Examples: + >>> _make_positional_axis_tuple(3) + (0, 1, 2) + """ + return tuple(range(n)) + + +def _is_valid_named_axis(axis: AxisName) -> bool: + """ + Checks if the given axis is a valid named axis. A valid named axis is a hashable object that is not an integer or None. Currently it is restricted to strings. + + Args: + axis (AxisName): The axis to check. + + Returns: + bool: True if the axis is a valid named axis, False otherwise. + + Examples: + >>> _is_valid_named_axis("x") + True + >>> _is_valid_named_axis(1) + False + """ + return ( + # axis must be hashable + isinstance(axis, AxisName) + # ... but not an integer, otherwise we would confuse it with positional axis + and not is_integer(axis) + # we also prohibit None, which is reserved for wildcard + and axis is not None + # Let's only allow strings for now, in the future we can open up to more types + # by removing the isinstance(axis, str) check. + and isinstance(axis, str) + ) + + +def _check_valid_axis(axis: AxisName) -> AxisName: + """ + Checks if the given axis is a valid named axis. If not, raises a ValueError. + + Args: + axis (AxisName): The axis to check. + + Returns: + AxisName: The axis if it is a valid named axis. + + Raises: + ValueError: If the axis is not a valid named axis. + + Examples: + >>> _check_valid_axis("x") + "x" + >>> _check_valid_axis(1) + Traceback (most recent call last): + ... + ValueError: Axis names must be hashable and not int, got 1 [type(axis)=] + """ + if not _is_valid_named_axis(axis): + raise ValueError( + f"Axis names must be hashable and not int, got {axis!r} [{type(axis)=}]" + ) + return axis + + +def _check_valid_named_axis_mapping(named_axis: AxisMapping) -> AxisMapping: + """ + Checks if the given named axis mapping is valid. A valid named axis mapping is a dictionary where the keys are valid named axes + (hashable objects that are not integers) and the values are integers. + + Args: + named_axis (AxisMapping): The named axis mapping to check. + + Raises: + ValueError: If any of the keys in the named axis mapping is not a valid named axis or if any of the values is not an integer. + + Examples: + >>> _check_valid_named_axis_mapping({"x": 0, "y": 1, "z": 2}) # No exception is raised + >>> _check_valid_named_axis_mapping({"x": 0, "y": 1, "z": "2"}) + Traceback (most recent call last): + ... + ValueError: Named axis mapping values must be integers, got '2' [type(axis)=] + >>> _check_valid_named_axis_mapping({"x": 0, 1: 1, "z": 2}) + Traceback (most recent call last): + ... + ValueError: Axis names must be hashable and not int, got 1 [type(axis)=] + """ + for name, axis in named_axis.items(): + _check_valid_axis(name) + if not is_integer(axis): + raise ValueError( + f"Named axis mapping values must be integers, got {axis!r} [{type(axis)=}]" + ) + return named_axis + + +def _axis_tuple_to_mapping(axis_tuple: AxisTuple) -> AxisMapping: + """ + Converts a tuple of axis names to a dictionary mapping axis names to their positions. + + Args: + axis_tuple (AxisTuple): A tuple of axis names. Can include None as a wildcard. + + Returns: + AxisMapping: A dictionary mapping axis names to their positions. + + Examples: + >>> _axis_tuple_to_mapping(("x", None, "y")) + {"x": 0, "y": 2} + """ + return {axis: i for i, axis in enumerate(axis_tuple) if axis is not None} + + +def _prepare_named_axis_for_attrs( + named_axis: AxisMapping | AxisTuple, + ndim: int, +) -> AxisMapping: + """ + Prepares the named axis for attribute assignment. + + This function takes a named axis, which can either be a mapping or a tuple, and returns a dictionary mapping axis names to their positions. + The function checks if the named axis is valid and if the positional axes match the number of dimensions. If not, an error is raised. + + Args: + named_axis (AxisMapping | AxisTuple): The named axis to prepare. Can either be a mapping or a tuple. + ndim (int): The number of dimensions. + + Returns: + AxisMapping: The prepared named axis. + + Raises: + TypeError: If the named axis is not a mapping or a tuple. + ValueError: If the named axes do not point to positional axes matching the number of dimensions. + + Examples: + >>> _prepare_named_axis_for_attrs({"x": 0, "y": 1, "z": 2}, 3) + {"x": 0, "y": 1, "z": 2} + >>> _prepare_named_axis_for_attrs(("x", "y", "z"), 3) + {"x": 0, "y": 1, "z": 2} + >>> _prepare_named_axis_for_attrs({"x": 0, "y": 1, "z": 2}, 2) + Traceback (most recent call last): + ... + ValueError: Named axes must point to positional axes matching 2 dimensions, got named_axis={"x": 0, "y": 1, "z": 2}, ndim=2 + """ + if isinstance(named_axis, tuple): + _named_axis = _axis_tuple_to_mapping(named_axis) + elif isinstance(named_axis, dict): + _named_axis = named_axis + else: + raise TypeError( + f"named_axis must be a mapping or a tuple, got {named_axis=} [{type(named_axis)=}]" + ) + _check_valid_named_axis_mapping(_named_axis) + pos_axes = set(_named_axis.values()) + if max(pos_axes, default=0) >= ndim or min(pos_axes, default=0) < -ndim: + raise ValueError( + f"Named axes must point to positional axes matching {ndim} dimensions, got {named_axis=}, {ndim=}" + ) + return _named_axis + + +def _make_named_int_class(name: tp.Any) -> type[int]: + class NamedInt(int): + def __repr__(self): + value_repr = super().__repr__() + return f"{name!r} (named axis) -> {value_repr} (pos. axis)" + + __str__ = __repr__ + + return NamedInt + + +def _named_axis_to_positional_axis( + named_axis: AxisMapping, + axis: AxisName, +) -> int | None: + """ + Converts a single named axis to a positional axis. + + Args: + axis (AxisName): The named axis to convert. + named_axis (AxisMapping): The mapping from named axes to positional axes. + + Returns: + int | None: The positional axis corresponding to the given named axis. If the named axis is not found in the mapping, returns None. + + Raises: + ValueError: If the named axis is not found in the named axis mapping. + + Examples: + >>> _named_axis_to_positional_axis({"x": 0, "y": 1, "z": 2}, "x") + 0 + """ + if _is_valid_named_axis(axis): + if axis not in named_axis: + raise ValueError(f"{axis=} not found in {named_axis=} mapping.") + + # we wrap it to preserve the original name in its __repr__ and __str__ + # in order to properly display it in error messages. This is useful for cases + # where the positional axis is pointing to a non-existing axis. The error message + # will then show the original (named) axis together with the positional axis. + cls = _make_named_int_class(axis) + return cls(named_axis[axis]) + + if is_integer(axis): + # TODO: is_integer is an external helper function that doesn't specify types + return int(tp.cast(tp.Any, axis)) + elif axis is None: + return None + else: + raise ValueError(f"Invalid {axis=} [{type(axis)=}]") + + +# These are the strategies to handle named axis for the +# output array when performing operations along an axis. +# See studies/named_axis.md#named-axis-in-high-level-functions and +# https://pytorch.org/docs/stable/name_inference.html. +# +# The possible strategies are: +# - "keep all" (_keep_named_axis(..., None)): Keep all named axes in the output array, e.g.: `ak.drop_none` +# - "keep one" (_keep_named_axis(..., int)): Keep one named axes in the output array, e.g.: `ak.firsts` +# - "keep up to" (_keep_named_axis_up_to(..., int)): Keep all named axes up to a certain positional axis in the output array, e.g.: `ak.local_index` +# - "remove all" (_remove_all_named_axis): Removes all named axis, e.g.: `ak.categories` +# - "remove one" (_remove_named_axis): Remove the named axis from the output array, e.g.: `ak.sum` +# - "add one" (_add_named_axis): Add a new named axis to the output array, e.g.: `ak.concatenate` +# - "unify" (_unify_named_axis): Unify the named axis in the output array given two input arrays, e.g.: `ak.broadcast_arrays` + + +def _keep_named_axis( + named_axis: AxisMapping, + axis: int | None = None, +) -> AxisMapping: + """ + Determines the new named axis after keeping the specified axis. This function is useful when an operation + is applied that retains only one axis. + + Args: + named_axis (AxisMapping): The current named axis. + axis (int | None, optional): The index of the axis to keep. If None, all axes are kept. Default is None. + + Returns: + AxisMapping: The new named axis after keeping the specified axis. + + Examples: + >>> _keep_named_axis({"x": 0, "y": 1, "z": 2}, 1) + {"y": 0} + >>> _keep_named_axis({"x": 0, "y": 1, "z": 2}, None) + {"x": 0, "y": 1, "z": 2} + """ + if axis is None: + return named_axis + return {k: 0 for k, v in named_axis.items() if v == axis} + + +def _keep_named_axis_up_to( + named_axis: AxisMapping, + axis: int, + total: int, +) -> AxisMapping: + """ + Determines the new named axis after keeping all axes up to the specified axis. This function is useful when an operation + is applied that retains all axes up to a certain axis. + + Args: + named_axis (AxisMapping): The current named axis. + axis (int): The index of the axis up to which to keep. + total (int): The total number of axes. + + Returns: + AxisMapping: The new named axis after keeping all axes up to the specified axis. + + Examples: + >>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, 1, 3) + {"x": 0, "y": 1} + >>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, -1, 3) + {"x": 0, "y": 1, "z": 2} + >>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, 0, 3) + {"x": 0} + """ + axis = _neg2pos_axis(axis, total) + out = {} + for k, v in named_axis.items(): + if v >= 0 and v <= axis: + out[k] = v + elif v < 0 and v >= -axis - 1: + out[k] = v + return out + + +def _remove_all_named_axis( + named_axis: AxisMapping, +) -> AxisMapping: + """ + Returns an empty named axis mapping after removing all axes from the given named axis mapping. + This function is typically used when an operation that eliminates all axes is applied. + + Args: + named_axis (AxisMapping): The current named axis mapping. + + Returns: + AxisMapping: An empty named axis mapping. + + Examples: + >>> _remove_all_named_axis({"x": 0, "y": 1, "z": 2}) + {} + """ + return _remove_named_axis(named_axis=named_axis, axis=None) + + +def _remove_named_axis( + named_axis: AxisMapping, + axis: int | None = None, + total: int | None = None, +) -> AxisMapping: + """ + Determines the new named axis after removing the specified axis. This is useful, for example, + when applying an operation that removes one axis. + + Args: + named_axis (AxisMapping): The current named axis. + axis (int | None, optional): The index of the axis to remove. If None, no axes are removed. Default is None. + total (int | None, optional): The total number of axes. If None, it is calculated as the length of the named axis. Default is None. + + Returns: + AxisMapping: The new named axis after removing the specified axis. + + Examples: + >>> _remove_named_axis({"x": 0, "y": 1}, None) + {} + >>> _remove_named_axis({"x": 0, "y": 1}, 0) + {"y": 0} + >>> _remove_named_axis({"x": 0, "y": 1, "z": 2}, 1) + {"x": 0, "z": 1} + >>> _remove_named_axis({"x": 0, "y": 1, "z": -1}, 1) + {"x": 0, "z": -1} + >>> _remove_named_axis({"x": 0, "y": 1, "z": -3}, 1) + {"x": 0, "z": -2} + """ + if axis is None: + return {} + + if total is None: + total = len(named_axis) + + # remove the specified axis + out = { + ax: pos + for ax, pos in named_axis.items() + if _neg2pos_axis(pos, total) != _neg2pos_axis(axis, total) + } + + return _adjust_pos_axis(out, axis, total, direction=-1) + + +def _adjust_pos_axis( + named_axis: AxisMapping, + axis: int, + total: int, + direction: int, +) -> AxisMapping: + """ + Adjusts the positions of the axes in the named axis mapping after an axis has been removed or added. + + Args: + named_axis (AxisMapping): The current named axis mapping. + axis (int): The position of the removed/added axis. + total (int): The total number of axes. + direction (int): The direction of the adjustment. -1 means axis is removed; +1 means axis is added. Default is +1. + + Returns: + AxisMapping: The adjusted named axis mapping. + + Examples: + # axis=1 removed + >>> _adjust_pos_axis({"x": 0, "z": 2}, 1, 3, -1) + {"x": 0, "z": 1} + # axis=1 added + >>> _adjust_pos_axis({"x": 0, "z": 2}, 1, 3, +1) + {"x": 0, "z": 3} + # axis=1 removed + >>> _adjust_pos_axis({"x": 0, "z": -1}, 1, 3, -1) + {"x": 0, "z": -1} + # axis=1 added + >>> _adjust_pos_axis({"x": 0, "z": -1}, 1, 3, +1) + {"x": 0, "z": -1} + """ + assert direction in (-1, +1), f"Invalid direction: {direction}" + + def _adjust(pos: int, axis: int, direction: int) -> int: + # positive axis + if axis >= 0: + # positive axis and position greater than or equal to the removed/added (positive) axis + # -> change position by direction + if pos >= axis: + return pos + direction + # positive axis and negative position + # -> change position by direction + elif pos < 0 and pos + total < axis: + return pos - direction + # positive axis and position smaller than the removed/added (positive) axis, but greater than 0 + # -> keep position + else: + return pos + # negative axis + else: + # negative axis and position smaller than the removed/added (negative) axis + # -> change position by inverse direction + if pos <= axis: + return pos - direction + # negative axis and positive position + # -> change position by inverse direction + elif pos > 0 and pos > axis + total: + return pos + direction + # negative axis and position greater than the removed/added (negative) axis, but smaller than 0 + # -> keep position + else: + return pos + + return {k: _adjust(v, axis, direction) for k, v in named_axis.items()} + + +def _add_named_axis( + named_axis: AxisMapping, + axis: int, + total: int | None = None, +) -> AxisMapping: + """ + Adds a new axis to the named_axis at the specified position. + + Args: + named_axis (AxisMapping): The current named axis mapping. + axis (int): The position at which to add the new axis. + total (int | None): The total number of axes. + + Returns: + AxisMapping: The updated named axis mapping after adding the new axis. + + Examples: + >>> _add_named_axis({"x": 0, "y": 1, "z": 2}, 0) + {"x": 1, "y": 2, "z": 3} + >>> _add_named_axis({"x": 0, "y": 1, "z": 2}, 1) + {"x": 0, "y": 2, "z": 3} + """ + if total is None: + total = len(named_axis) + + return _adjust_pos_axis(named_axis, axis, total, direction=+1) + + +def _unify_named_axis( + named_axis1: AxisMapping, + named_axis2: AxisMapping, +) -> AxisMapping: + """ + Unifies two named axes into a single named axis. The function iterates over all positional axes present in either of the input named axes. + For each positional axis, it checks the corresponding axis names in both input named axes. If the axis names are the same or if one of them is None, + the unified axis will be the non-None axis. If the axis names are different and neither of them is None, a ValueError is raised. + + Args: + named_axis1 (AxisMapping): The first named axis to unify. + named_axis2 (AxisMapping): The second named axis to unify. + + Returns: + AxisMapping: The unified named axis. + + Raises: + ValueError: If the axes are different and neither of them is None. + + Examples: + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": 2}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({"x": 0, "y": 1}, {"x": 0, "y": 1, "z": 2}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"a": 0, "b": 1, "c": 2}) + Traceback (most recent call last): + ... + ValueError: The named axes are different. Got: 'x' and 'a' for positional axis 0 + + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": 3}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({}, {"x": 0, "y": 1, "z": 2}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({}, {}) + {} + """ + + def _get_axis_name( + axis_mapping: AxisMapping, positional_axis: int + ) -> AxisName | None: + for name, position in axis_mapping.items(): + if position == positional_axis: + return name + return None + + unified_named_axis = {} + all_positional_axes = set(named_axis1.values()) | set(named_axis2.values()) + for position in all_positional_axes: + axis_name1 = _get_axis_name(named_axis1, position) + axis_name2 = _get_axis_name(named_axis2, position) + if axis_name1 is not None and axis_name2 is not None: + if axis_name1 != axis_name2: + raise ValueError( + f"The named axes are incompatible. Got: {axis_name1} and {axis_name2} for positional axis {position}" + ) + unified_named_axis[axis_name1] = position + elif axis_name1 is not None: # axis_name2 is None + unified_named_axis[axis_name1] = position + elif axis_name2 is not None: # axis_name1 is None + unified_named_axis[axis_name2] = position + return unified_named_axis + + +@dataclass +class NamedAxesWithDims: + """ + A dataclass that stores the named axis and their corresponding dimensions. + This is a helper class to store the named axis mapping and the number of + dimensions of each named axis, which is useful for broadcasting. + + Attributes: + named_axis (AxisMapping): The named axis mapping. + ndims (Tuple[int]): The number of dimensions of the named axis. + """ + + named_axis: list[AxisMapping] + ndims: list[int] + + def __post_init__(self): + if len(self.named_axis) != len(self.ndims): + raise ValueError( + "The number of dimensions must match the number of named axis mappings." + ) + + def __iter__(self) -> tp.Iterator[tuple[AxisMapping, int]]: + yield from zip(self.named_axis, self.ndims) + + @classmethod + def prepare_contexts( + cls, arrays: tp.Sequence, unwrap_kwargs: dict | None = None + ) -> tuple[dict, dict]: + from awkward._layout import HighLevelContext + from awkward._typetracer import MaybeNone + + # unwrap options + arrays = [x.content if isinstance(x, MaybeNone) else x for x in arrays] + + _unwrap_kwargs = {"allow_unknown": True} + if unwrap_kwargs is not None: + _unwrap_kwargs.update(unwrap_kwargs) + + _named_axes = [] + _ndims = [] + for array in arrays: + with HighLevelContext() as ctx: + layout = ctx.unwrap(array, **_unwrap_kwargs) + _named_axes.append(_get_named_axis(array)) + _ndims.append(layout.minmax_depth[1]) + + depth_context = {NAMED_AXIS_KEY: cls(_named_axes, _ndims)} + lateral_context = {NAMED_AXIS_KEY: cls(_named_axes, _ndims)} + return depth_context, lateral_context + + def __setitem__(self, index: int, named_axis_with_ndim: tuple[AxisMapping, int]): + named_axis, ndim = named_axis_with_ndim + self.named_axis[index] = named_axis + self.ndims[index] = ndim + + def __getitem__(self, index: int) -> tuple[AxisMapping, int]: + return self.named_axis[index], self.ndims[index] + + def __len__(self) -> int: + return len(self.named_axis) + + +# Define a type alias for a slice or int (can be a single axis or a sequence of axes) +AxisSlice: tp.TypeAlias = tp.Union[tuple, slice, int, tp.EllipsisType, None] +NamedAxisSlice: tp.TypeAlias = tp.Dict[AxisName, AxisSlice] + + +def _normalize_named_slice( + named_axis: AxisMapping, + where: AxisSlice | NamedAxisSlice, + total: int, +) -> AxisSlice: + """ + Normalizes a named slice into a positional slice. + + This function takes a named slice (a dictionary mapping axis names to slices) and converts it into a positional slice + (a tuple of slices). The positional slice can then be used to index an array. + + Args: + named_axis (AxisMapping): The current named axis mapping. + where (AxisSlice | NamedAxisSlice): The slice to normalize. Can be a single slice, a tuple of slices, or a dictionary mapping axis names to slices. + total (int): The total number of axes. + + Returns: + AxisSlice: The normalized slice. + + Raises: + ValueError: If an invalid axis name is provided in the slice. + + Examples: + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {0: 0}, 3) + (0, slice(None), slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {-1: 0}, 3) + (slice(None), slice(None), 0) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0}, 3) + (0, slice(None), slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1}, 3) + (0, 1, slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": ...}, 3) + (0, 1, ...) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": slice(0, 1)}, 3) + (0, 1, slice(0, 1)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": (0, 1)}, 3) + ((0, 1), slice(None), slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": [0, 1]}, 3) + ([0, 1], slice(None), slice(None)) + """ + if isinstance(where, dict): + out_where: list[AxisSlice] = [slice(None)] * total + for ax_name, ax_where in where.items(): + slice_ = ax_where if ax_where is not ... else slice(None) + if is_integer(ax_name): + # it's an integer, pyright doesn't get this + idx = tp.cast(int, ax_name) + out_where[idx] = slice_ + elif _is_valid_named_axis(ax_name): + # it's an integer, pyright doesn't get this + idx = tp.cast(int, _named_axis_to_positional_axis(named_axis, ax_name)) + out_where[idx] = slice_ + else: + raise ValueError(f"Invalid axis name: {ax_name} in slice {where}") + where = tuple(out_where) + return where diff --git a/src/awkward/_nplikes/array_like.py b/src/awkward/_nplikes/array_like.py index d82611ae5d..d75fe6cfcb 100644 --- a/src/awkward/_nplikes/array_like.py +++ b/src/awkward/_nplikes/array_like.py @@ -8,6 +8,7 @@ from awkward._typing import ( TYPE_CHECKING, DType, + EllipsisType, Protocol, Self, SupportsIndex, @@ -15,8 +16,6 @@ ) if TYPE_CHECKING: - from types import EllipsisType - from numpy.typing import DTypeLike diff --git a/src/awkward/_nplikes/typetracer.py b/src/awkward/_nplikes/typetracer.py index 64b99abf01..548071c0ee 100644 --- a/src/awkward/_nplikes/typetracer.py +++ b/src/awkward/_nplikes/typetracer.py @@ -26,6 +26,7 @@ TYPE_CHECKING, Any, DType, + EllipsisType, Final, Literal, Self, @@ -36,8 +37,6 @@ ) if TYPE_CHECKING: - from types import EllipsisType - from numpy.typing import DTypeLike from awkward.contents.content import Content diff --git a/src/awkward/_operators.py b/src/awkward/_operators.py index 2c58330492..f5a2cf90da 100644 --- a/src/awkward/_operators.py +++ b/src/awkward/_operators.py @@ -50,6 +50,7 @@ def _binary_method(ufunc, name): def func(self, other): if _disables_array_ufunc(other): return NotImplemented + return ufunc(self, other) func.__name__ = f"__{name}__" diff --git a/src/awkward/_regularize.py b/src/awkward/_regularize.py index 663d9eb01a..6f78a18409 100644 --- a/src/awkward/_regularize.py +++ b/src/awkward/_regularize.py @@ -7,7 +7,7 @@ from collections.abc import Iterable, Sequence, Sized from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._typing import AxisMaybeNone, SupportsInt +from awkward._typing import Any np = NumpyMetadata.instance() @@ -51,8 +51,19 @@ def is_non_string_like_sequence(obj) -> bool: return not isinstance(obj, (str, bytes)) and isinstance(obj, Sequence) -def regularize_axis(axis: SupportsInt | None) -> AxisMaybeNone: - if axis is None: - return None +def regularize_axis(axis: Any, none_allowed: bool = True) -> int | None: + """ + This function's main purpose is to convert [np,cp,...].array(0) to 0. + """ + if is_integer_like(axis): + regularized_axis = int(axis) else: - return int(axis) + regularized_axis = axis + cond = is_integer(regularized_axis) + msg = f"'axis' must be an integer, not {axis!r}" + if none_allowed: + cond = cond or regularized_axis is None + msg = f"'axis' must be an integer or None, not {axis!r}" + if not cond: + raise TypeError(msg) + return regularized_axis diff --git a/src/awkward/_typing.py b/src/awkward/_typing.py index 0e987b4399..be474a37d9 100644 --- a/src/awkward/_typing.py +++ b/src/awkward/_typing.py @@ -26,6 +26,7 @@ "Literal", "SupportsIndex", "ParamSpec", + "EllipsisType", *typing.__all__, } ) @@ -46,7 +47,10 @@ TypeGuard, Unpack, ) + + EllipsisType = type(...) else: + from types import EllipsisType from typing import ( ClassVar, Final, diff --git a/src/awkward/contents/content.py b/src/awkward/contents/content.py index d0169ee2eb..f324d9dac5 100644 --- a/src/awkward/contents/content.py +++ b/src/awkward/contents/content.py @@ -16,8 +16,14 @@ ) from awkward._behavior import get_array_class, get_record_class from awkward._kernels import KernelError -from awkward._layout import wrap_layout +from awkward._layout import maybe_posaxis, wrap_layout from awkward._meta.meta import Meta +from awkward._namedaxis import ( + NamedAxis, + _add_named_axis, + _keep_named_axis, + _remove_named_axis, +) from awkward._nplikes import to_nplike from awkward._nplikes.dispatch import nplike_of_obj from awkward._nplikes.numpy import Numpy @@ -27,7 +33,12 @@ parameters_are_equal, type_parameters_equal, ) -from awkward._regularize import is_integer_like, is_sized_iterable +from awkward._regularize import ( + is_array_like, + is_integer, + is_integer_like, + is_sized_iterable, +) from awkward._slicing import normalize_slice from awkward._typing import ( TYPE_CHECKING, @@ -38,6 +49,7 @@ Protocol, Self, SupportsIndex, + Type, TypeAlias, TypedDict, ) @@ -509,10 +521,14 @@ def _getitem_next_missing( ) def __getitem__(self, where): - return self._getitem(where) + return self._getitem(where, NamedAxis) - def _getitem(self, where): + def _getitem(self, where, named_axis: Type[NamedAxis] = NamedAxis): if is_integer_like(where): + # propagate named_axis to output + named_axis.mapping = _remove_named_axis( + named_axis.mapping, where, self.purelist_depth + ) return self._getitem_at(ak._slicing.normalize_integer_like(where)) elif isinstance(where, slice) and where.step is None: @@ -523,21 +539,35 @@ def _getitem(self, where): return self._getitem_range(start, stop) elif isinstance(where, slice): - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif isinstance(where, str): return self._getitem_field(where) elif where is np.newaxis: - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif where is Ellipsis: - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif isinstance(where, tuple): if len(where) == 0: return self + # count number of ellipsis + # Need to use a little trick here: + # where.count(Ellipsis) does not work, because it will do a == comparison against Ellipsis, + # and this will fail in the case of typetracers where == is dispatched to np.equal ufunc. + # In this dispatch we encounter an assertion that the type of the Ellipsis is not allowed. + # ...but luckily we can use the fact that Ellipsis is a singleton and use the 'is' operator + n_ellipsis = 0 + for w in where: + if w is ...: + n_ellipsis += 1 + + if n_ellipsis > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Backend may change if index contains typetracers backend = backend_of(self, *where, coerce_to_common=True) this = self.to_backend(backend) @@ -547,6 +577,62 @@ def _getitem(self, where): # Prepare items for advanced indexing (e.g. via broadcasting) nextwhere = ak._slicing.prepare_advanced_indexing(items, backend) + # Handle named axis + # first expand the ellipsis to colons in nextwhere, + # copy nextwhere to not pollute the original + _nextwhere = tuple(nextwhere) + if n_ellipsis == 1: + # collect the ellipsis index + # same little trick as above for `nextwhere.index(...)` + (ellipsis_at,) = tuple(i for i, x in enumerate(nextwhere) if x is ...) + # calculate how many slice(None) we need to add + # same little trick as above for `nextwhere.count(None)` + n_newaxis = 0 + for x in nextwhere: + if x is np.newaxis or x is None: + n_newaxis += 1 + n_total = self.minmax_depth[1] + n_slice_none = n_total - (len(_nextwhere) - n_newaxis - 1) + # expand `[...]` to `[:]*n_slice_none` + _nextwhere = ( + _nextwhere[:ellipsis_at] + + (slice(None),) * n_slice_none + + _nextwhere[ellipsis_at + 1 :] + ) + + # now propagate named axis + _named_axis = _keep_named_axis(named_axis.mapping, None) + _adjust_dim = 0 + # this loop does the following: + # - remove a named axis for integer indices, e.g. `a[1, 2]` + # - add a named axis for None (or np.newaxis) indices, e.g. `a[..., None]` + # - keep named axis for any other index, e.g. `a[:]`, `a[0:1]`, or `a[a>0]` + # (these may only remove elements, but not dimensions) + for dim, nw in enumerate(_nextwhere): + dim_adjusted = dim + _adjust_dim + total_adjusted = self.minmax_depth[1] + _adjust_dim + for _, pos in _named_axis.items(): + if maybe_posaxis(self, pos, 0) == dim_adjusted: + break + + if is_integer(nw) or (is_array_like(nw) and nw.ndim == 0): + _named_axis = _remove_named_axis( + named_axis=_named_axis, + axis=dim_adjusted, + total=total_adjusted, + ) + _adjust_dim -= 1 + elif nw is None: + _named_axis = _add_named_axis( + named_axis=_named_axis, + axis=dim_adjusted, + total=total_adjusted, + ) + _adjust_dim += 1 + + # set propagated named axis + named_axis.mapping = _named_axis + next = ak.contents.RegularArray( this, this.length, @@ -562,7 +648,7 @@ def _getitem(self, where): return out._getitem_at(0) elif isinstance(where, ak.highlevel.Array): - return self._getitem(where.layout) + return self._getitem(where.layout, named_axis) # Convert between nplikes of different backends elif ( @@ -570,7 +656,9 @@ def _getitem(self, where): and where.backend is not self._backend ): backend = backend_of(self, where, coerce_to_common=True) - return self.to_backend(backend)._getitem(where.to_backend(backend)) + return self.to_backend(backend)._getitem( + where.to_backend(backend), named_axis + ) elif isinstance(where, ak.contents.NumpyArray): data_as_index = to_nplike( @@ -602,7 +690,7 @@ def _getitem(self, where): allow_lazy = "copied" # True, but also can be modified in-place else: wheres = self._backend.index_nplike.nonzero(data_as_index) - return self._getitem(wheres) + return self._getitem(wheres, named_axis) else: raise TypeError( "array slice must be an array of integers or booleans, not\n\n {}".format( @@ -621,9 +709,9 @@ def _getitem(self, where): elif isinstance(where, ak.contents.RegularArray): maybe_numpy = where.maybe_to_NumpyArray() if maybe_numpy is None: - return self._getitem((where,)) + return self._getitem((where,), named_axis) else: - return self._getitem(maybe_numpy) + return self._getitem(maybe_numpy, named_axis) # Awkward Array of strings elif ( @@ -637,7 +725,7 @@ def _getitem(self, where): return where.to_NumpyArray(np.int64) elif isinstance(where, Content): - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif is_sized_iterable(where): # Do we have an array @@ -654,7 +742,7 @@ def _getitem(self, where): primitive_policy="error", string_policy="as-characters", ) - return self._getitem(layout) + return self._getitem(layout, named_axis) elif len(where) == 0: return self._carry( @@ -682,7 +770,7 @@ def _getitem(self, where): ), self._backend, ) - return self._getitem(layout) + return self._getitem(layout, named_axis) else: raise TypeError( diff --git a/src/awkward/contents/numpyarray.py b/src/awkward/contents/numpyarray.py index 315d9383b7..5c90ca0141 100644 --- a/src/awkward/contents/numpyarray.py +++ b/src/awkward/contents/numpyarray.py @@ -175,7 +175,11 @@ def shape(self) -> tuple[ShapeItem, ...]: @property def inner_shape(self) -> tuple[ShapeItem, ...]: - return self._data.shape[1:] + if hasattr(self._data, "inner_shape"): + inner_shape = self._data.inner_shape + else: + inner_shape = self._data.shape[1:] + return inner_shape @property def strides(self) -> tuple[ShapeItem, ...]: @@ -189,14 +193,9 @@ def _raw(self, nplike=None): return to_nplike(self.data, nplike, from_nplike=self._backend.nplike) def _form_with_key(self, getkey: Callable[[Content], str | None]) -> NumpyForm: - if hasattr(self._data, "inner_shape"): - inner_shape = self._data.inner_shape - else: - inner_shape = self._data.shape[1:] - return self.form_cls( ak.types.numpytype.dtype_to_primitive(self._data.dtype), - inner_shape, + self.inner_shape, parameters=self._parameters, form_key=getkey(self), ) diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index f315945511..6d1d6649aa 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -23,6 +23,16 @@ from awkward._backends.numpy import NumpyBackend from awkward._behavior import behavior_of, get_array_class, get_record_class from awkward._layout import wrap_layout +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + AxisMapping, + NamedAxis, + _get_named_axis, + _make_positional_axis_tuple, + _normalize_named_slice, + _prepare_named_axis_for_attrs, + _prettify_named_axes, +) from awkward._nplikes.numpy import Numpy from awkward._nplikes.numpy_like import NumpyMetadata from awkward._operators import NDArrayOperatorsMixin @@ -32,7 +42,7 @@ unpickle_record_schema_1, ) from awkward._regularize import is_non_string_like_iterable -from awkward._typing import Any, TypeVar +from awkward._typing import Any, MutableMapping, TypeVar from awkward._util import STDOUT from awkward.prettyprint import Formatter from awkward.prettyprint import valuestr as prettyprint_valuestr @@ -278,6 +288,7 @@ def __init__( check_valid=False, backend=None, attrs=None, + named_axis=None, ): self._cpp_type = None if isinstance(data, ak.contents.Content): @@ -326,9 +337,20 @@ def __init__( if behavior is not None and not isinstance(behavior, Mapping): raise TypeError("behavior must be None or a mapping") - if attrs is not None and not isinstance(attrs, Mapping): + if attrs is not None and not isinstance(attrs, MutableMapping): raise TypeError("attrs must be None or a mapping") + if named_axis: + _named_axis = _prepare_named_axis_for_attrs( + named_axis=named_axis, + ndim=layout.minmax_depth[1], + ) + # now we're good, set the named axis + if attrs is None: + attrs = {} + # if NAMED_AXIS_KEY is already in attrs, it will be overwritten + attrs[NAMED_AXIS_KEY] = _named_axis + self._layout = layout self._behavior = behavior self._attrs = attrs @@ -357,7 +379,7 @@ def _update_class(self): self.__class__ = get_array_class(self._layout, self._behavior) @property - def attrs(self) -> Mapping[str, Any]: + def attrs(self) -> Mapping: """ The mutable mapping containing top-level metadata, which is serialised with the array during pickling. @@ -455,6 +477,15 @@ def behavior(self, behavior): else: raise TypeError("behavior must be None or a dict") + @property + def positional_axis(self) -> tuple[int, ...]: + (_, ndim) = self._layout.minmax_depth + return _make_positional_axis_tuple(ndim) + + @property + def named_axis(self) -> AxisMapping: + return _get_named_axis(self) + class Mask: def __init__(self, array): self._array = array @@ -1062,12 +1093,30 @@ def __getitem__(self, where): have the same dimension as the array being indexed. """ with ak._errors.SlicingErrorContext(self, where): - return wrap_layout( - prepare_layout(self._layout[where]), - self._behavior, - allow_other=True, - attrs=self._attrs, - ) + # Handle named axis + (_, ndim) = self._layout.minmax_depth + named_axis = _get_named_axis(self) + where = _normalize_named_slice(named_axis, where, ndim) + + NamedAxis.mapping = named_axis + + indexed_layout = prepare_layout(self._layout._getitem(where, NamedAxis)) + + if NamedAxis.mapping: + return ak.operations.ak_with_named_axis._impl( + indexed_layout, + named_axis=NamedAxis.mapping, + highlevel=True, + behavior=self._behavior, + attrs=self._attrs, + ) + else: + return wrap_layout( + indexed_layout, + self._behavior, + allow_other=True, + attrs=self._attrs, + ) def __bytes__(self) -> bytes: if isinstance(self._layout, ak.contents.NumpyArray) and self._layout.parameter( @@ -1309,6 +1358,15 @@ def _repr(self, limit_cols): else: valuestr = "-typetracer" + # prepare named_axis str for repr + axisstr = "" + if self.named_axis: + # we reserve at maximum 20 characters for the named axis string + axisstr = _prettify_named_axes(self.named_axis, delimiter=",", maxlen=20) + axisstr = f" {axisstr}" + # subtract the reserved space from the limit_cols + limit_cols -= len(axisstr) + if len(typestr) + len(pytype) + len(" type=''") + 3 < limit_cols // 2: strwidth = limit_cols - (len(typestr) + len(pytype) + len(" type=''") + 3) else: @@ -1327,13 +1385,14 @@ def _repr(self, limit_cols): else: typestr = "'" + typestr + "'" - return f"<{pytype}{valuestr} type={typestr}>" + return f"<{pytype}{valuestr}{axisstr} type={typestr}>" def show( self, limit_rows=20, limit_cols=80, type=False, + named_axis=False, stream=STDOUT, *, formatter=None, @@ -1365,25 +1424,41 @@ def show( valuestr = prettyprint_valuestr( self, limit_rows, limit_cols, formatter=formatter_impl ) + + out_io = io.StringIO() if type: - tmp = io.StringIO() - self.type.show(stream=tmp) - out = "type: " + tmp.getvalue() + valuestr - else: - out = valuestr + out_io.write("type: ") + self.type.show(stream=out_io) + if named_axis and self.named_axis: + out_io.write("axes: ") + out_io.write( + _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) + ) + out_io.write("\n") + out_io.write(valuestr) if stream is None: - return out + return out_io else: if stream is STDOUT: stream = STDOUT.stream - stream.write(out + "\n") + stream.write(out_io.getvalue() + "\n") def _repr_mimebundle_(self, include=None, exclude=None): + # order: 1. array, 2. named_axis, 3. type value_buff = io.StringIO() - self.show(type=False, stream=value_buff) + self.show(type=False, named_axis=False, stream=value_buff) header_lines = value_buff.getvalue().splitlines() + named_axis_line = "" + if self.named_axis: + named_axis_buff = io.StringIO() + named_axis_buff.write("axes: ") + named_axis_buff.write( + _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) + ) + named_axis_line = named_axis_buff.getvalue() + type_buff = io.StringIO() self.type.show(stream=type_buff) footer_lines = type_buff.getvalue().splitlines() @@ -1393,8 +1468,16 @@ def _repr_mimebundle_(self, include=None, exclude=None): if header_lines[-1] == "": del header_lines[-1] - n_cols = max(len(line) for line in itertools.chain(header_lines, footer_lines)) - body = "\n".join([*header_lines, "-" * n_cols, *footer_lines]) + n_cols = max( + len(line) + for line in itertools.chain(header_lines, [named_axis_line], footer_lines) + ) + body_lines = header_lines + body_lines.append("-" * n_cols) + if named_axis_line: + body_lines.append(named_axis_line) + body_lines.extend(footer_lines) + body = "\n".join(body_lines) return { "text/html": f"
{html.escape(body)}
", @@ -1719,6 +1802,7 @@ def __init__( check_valid=False, backend=None, attrs=None, + named_axis=None, ): if isinstance(data, ak.record.Record): layout = data @@ -1762,6 +1846,20 @@ def __init__( if behavior is not None and not isinstance(behavior, Mapping): raise TypeError("behavior must be None or mapping") + if attrs is not None and not isinstance(attrs, MutableMapping): + raise TypeError("attrs must be None or a mapping") + + if named_axis: + _named_axis = _prepare_named_axis_for_attrs( + named_axis=named_axis, + ndim=layout.minmax_depth[1], + ) + # now we're good, set the named axis + if attrs is None: + attrs = {} + # if NAMED_AXIS_KEY is already in attrs, it will be overwritten + attrs[NAMED_AXIS_KEY] = _named_axis + self._layout = layout self._behavior = behavior self._attrs = attrs @@ -1877,6 +1975,15 @@ def behavior(self, behavior): else: raise TypeError("behavior must be None or a dict") + @property + def positional_axis(self) -> tuple[int, ...]: + (_, ndim) = self._layout.minmax_depth + return _make_positional_axis_tuple(ndim) + + @property + def named_axis(self) -> AxisMapping: + return _get_named_axis(self) + def tolist(self): """ Converts this Record into Python objects; same as #ak.to_list @@ -2170,6 +2277,15 @@ def _repr(self, limit_cols): else: valuestr = "-typetracer" + # prepare named_axis str for repr + axisstr = "" + if self.named_axis: + # we reserve at maximum 20 characters for the named axis string + axisstr = _prettify_named_axes(self.named_axis, delimiter=",", maxlen=20) + axisstr = f" {axisstr}" + # subtract the reserved space from the limit_cols + limit_cols -= len(axisstr) + if len(typestr) + len(pytype) + len(" type=''") + 3 < limit_cols // 2: strwidth = limit_cols - (len(typestr) + len(pytype) + len(" type=''") + 3) else: @@ -2188,13 +2304,14 @@ def _repr(self, limit_cols): else: typestr = "'" + typestr + "'" - return f"<{pytype}{valuestr} type={typestr}>" + return f"<{pytype}{valuestr}{axisstr} type={typestr}>" def show( self, limit_rows=20, limit_cols=80, type=False, + named_axis=False, stream=STDOUT, *, formatter=None, @@ -2224,25 +2341,41 @@ def show( valuestr = prettyprint_valuestr( self, limit_rows, limit_cols, formatter=formatter_impl ) + + out_io = io.StringIO() if type: - tmp = io.StringIO() - self.type.show(stream=tmp) - out = "type: " + tmp.getvalue() + valuestr - else: - out = valuestr + out_io.write("type: ") + self.type.show(stream=out_io) + if named_axis and self.named_axis: + out_io.write("axes: ") + out_io.write( + _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) + ) + out_io.write("\n") + out_io.write(valuestr) if stream is None: - return out + return out_io.getvalue() else: if stream is STDOUT: stream = STDOUT.stream - stream.write(out + "\n") + stream.write(out_io.getvalue() + "\n") def _repr_mimebundle_(self, include=None, exclude=None): + # order: 1. array, 2. named_axis, 3. type value_buff = io.StringIO() - self.show(type=False, stream=value_buff) + self.show(type=False, named_axis=False, stream=value_buff) header_lines = value_buff.getvalue().splitlines() + named_axis_line = "" + if self.named_axis: + named_axis_buff = io.StringIO() + named_axis_buff.write("axes: ") + named_axis_buff.write( + _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) + ) + named_axis_line = named_axis_buff.getvalue() + type_buff = io.StringIO() self.type.show(stream=type_buff) footer_lines = type_buff.getvalue().splitlines() @@ -2252,8 +2385,16 @@ def _repr_mimebundle_(self, include=None, exclude=None): if header_lines[-1] == "": del header_lines[-1] - n_cols = max(len(line) for line in itertools.chain(header_lines, footer_lines)) - body = "\n".join([*header_lines, "-" * n_cols, *footer_lines]) + n_cols = max( + len(line) + for line in itertools.chain(header_lines, [named_axis_line], footer_lines) + ) + body_lines = header_lines + body_lines.append("-" * n_cols) + if named_axis_line: + body_lines.append(named_axis_line) + body_lines.extend(footer_lines) + body = "\n".join(body_lines) return { "text/html": f"
{html.escape(body)}
", diff --git a/src/awkward/operations/__init__.py b/src/awkward/operations/__init__.py index d76d8e2688..94dbd9ffac 100644 --- a/src/awkward/operations/__init__.py +++ b/src/awkward/operations/__init__.py @@ -114,8 +114,10 @@ from awkward.operations.ak_where import * from awkward.operations.ak_with_field import * from awkward.operations.ak_with_name import * +from awkward.operations.ak_with_named_axis import * from awkward.operations.ak_with_parameter import * from awkward.operations.ak_without_field import * +from awkward.operations.ak_without_named_axis import * from awkward.operations.ak_without_parameters import * from awkward.operations.ak_zeros_like import * from awkward.operations.ak_zip import * diff --git a/src/awkward/operations/ak_all.py b/src/awkward/operations/ak_all.py index 859bfd98cb..98a22520ba 100644 --- a/src/awkward/operations/ak_all.py +++ b/src/awkward/operations/ak_all.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -67,9 +73,26 @@ def all( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.All() out = ak._do.reduce( @@ -80,7 +103,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("all") diff --git a/src/awkward/operations/ak_almost_equal.py b/src/awkward/operations/ak_almost_equal.py index 66f67e4d8a..949f955a45 100644 --- a/src/awkward/operations/ak_almost_equal.py +++ b/src/awkward/operations/ak_almost_equal.py @@ -7,6 +7,7 @@ from awkward._behavior import behavior_of, get_array_class, get_record_class from awkward._dispatch import high_level_function from awkward._layout import ensure_same_backend +from awkward._namedaxis import _get_named_axis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._parameters import parameters_are_equal from awkward.operations.ak_to_layout import to_layout @@ -27,6 +28,7 @@ def almost_equal( dtype_exact: bool = True, check_parameters: bool = True, check_regular: bool = True, + check_named_axis: bool = True, ): """ Args: @@ -39,6 +41,7 @@ def almost_equal( check_parameters: whether to compare parameters. check_regular: whether to consider ragged and regular dimensions as unequal. + check_named_axis: bool (default=True) whether to consider named axes as unequal. Return True if the two array-like arguments are considered equal for the given options. Otherwise, return False. @@ -61,6 +64,7 @@ def almost_equal( dtype_exact=dtype_exact, check_parameters=check_parameters, check_regular=check_regular, + check_named_axis=check_named_axis, exact_eq=False, same_content_types=False, equal_nan=False, @@ -75,6 +79,7 @@ def _impl( dtype_exact: bool, check_parameters: bool, check_regular: bool, + check_named_axis: bool, exact_eq: bool, same_content_types: bool, equal_nan: bool, @@ -91,6 +96,10 @@ def _impl( right_layout = layouts[1].to_packed() backend = backend_of(left_layout) + if check_named_axis and _get_named_axis(left) and _get_named_axis(right): + if left.named_axis != right.named_axis: + return False + if not backend.nplike.known_data: raise NotImplementedError( "Awkward Arrays with typetracer backends cannot yet be compared with `ak.almost_equal`." diff --git a/src/awkward/operations/ak_any.py b/src/awkward/operations/ak_any.py index 79c9cc6b83..e99065d97c 100644 --- a/src/awkward/operations/ak_any.py +++ b/src/awkward/operations/ak_any.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -67,9 +73,26 @@ def any( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Any() out = ak._do.reduce( @@ -80,7 +103,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("any") diff --git a/src/awkward/operations/ak_argcartesian.py b/src/awkward/operations/ak_argcartesian.py index 12deed5749..f012290cbe 100644 --- a/src/awkward/operations/ak_argcartesian.py +++ b/src/awkward/operations/ak_argcartesian.py @@ -7,7 +7,6 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import regularize_axis __all__ = ("argcartesian",) @@ -107,8 +106,6 @@ def argcartesian( def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attrs): - axis = regularize_axis(axis) - if isinstance(arrays, Mapping): index_arrays = {n: ak.local_index(x, axis) for n, x in arrays.items()} else: diff --git a/src/awkward/operations/ak_argcombinations.py b/src/awkward/operations/ak_argcombinations.py index 98a2643855..337f77cec1 100644 --- a/src/awkward/operations/ak_argcombinations.py +++ b/src/awkward/operations/ak_argcombinations.py @@ -5,6 +5,7 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import _get_named_axis, _named_axis_to_positional_axis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -93,7 +94,6 @@ def _impl( behavior, attrs, ): - axis = regularize_axis(axis) if parameters is None: parameters = {} else: @@ -101,6 +101,13 @@ def _impl( if with_name is not None: parameters["__record__"] = with_name + # Handle named axis + named_axis = _get_named_axis(array) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + if axis < 0: raise ValueError("the 'axis' for argcombinations must be non-negative") else: diff --git a/src/awkward/operations/ak_argmax.py b/src/awkward/operations/ak_argmax.py index a4dbe947bd..ef9b37e57c 100644 --- a/src/awkward/operations/ak_argmax.py +++ b/src/awkward/operations/ak_argmax.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -132,9 +138,26 @@ def nanargmax( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.ArgMax() out = ak._do.reduce( @@ -145,7 +168,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("argmax") diff --git a/src/awkward/operations/ak_argmin.py b/src/awkward/operations/ak_argmin.py index 7f21fb3aa8..6982a4d407 100644 --- a/src/awkward/operations/ak_argmin.py +++ b/src/awkward/operations/ak_argmin.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -129,10 +135,26 @@ def nanargmin( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.ArgMin() out = ak._do.reduce( @@ -143,7 +165,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("argmin") diff --git a/src/awkward/operations/ak_argsort.py b/src/awkward/operations/ak_argsort.py index bade378b20..7c92d6a645 100644 --- a/src/awkward/operations/ak_argsort.py +++ b/src/awkward/operations/ak_argsort.py @@ -6,6 +6,10 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -70,11 +74,22 @@ def argsort( def _impl(array, axis, ascending, stable, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + out = ak._do.argsort(layout, axis, ascending, stable) - return ctx.wrap(out, highlevel=highlevel) + + return ctx.wrap( + out, + highlevel=highlevel, + ) @ak._connect.numpy.implements("argsort") diff --git a/src/awkward/operations/ak_array_equal.py b/src/awkward/operations/ak_array_equal.py index 398db6b2a6..1dabd60f31 100644 --- a/src/awkward/operations/ak_array_equal.py +++ b/src/awkward/operations/ak_array_equal.py @@ -18,6 +18,7 @@ def array_equal( same_content_types: bool = True, check_parameters: bool = True, check_regular: bool = True, + check_named_axis: bool = True, ): """ True if two arrays have the same shape and elements, False otherwise. @@ -34,6 +35,7 @@ def array_equal( check_parameters: bool (default=True) whether to compare parameters. check_regular: bool (default=True) whether to consider ragged and regular dimensions as unequal. + check_named_axis: bool (default=True) whether to consider named axes as unequal. TypeTracer arrays are not supported, as there is very little information to be compared. @@ -49,6 +51,7 @@ def array_equal( dtype_exact=dtype_exact, check_parameters=check_parameters, check_regular=check_regular, + check_named_axis=check_named_axis, exact_eq=True, same_content_types=same_content_types and check_regular, equal_nan=equal_nan, diff --git a/src/awkward/operations/ak_broadcast_arrays.py b/src/awkward/operations/ak_broadcast_arrays.py index 877c69f9c0..feef9b5138 100644 --- a/src/awkward/operations/ak_broadcast_arrays.py +++ b/src/awkward/operations/ak_broadcast_arrays.py @@ -2,6 +2,8 @@ from __future__ import annotations +from functools import reduce + import awkward as ak from awkward._attrs import attrs_of_obj from awkward._backends.dispatch import backend_of @@ -10,6 +12,11 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import wrap_layout +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + NamedAxesWithDims, + _unify_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("broadcast_arrays",) @@ -243,24 +250,43 @@ def action(inputs, depth, **kwargs): else: return None + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arrays) out = ak._broadcasting.broadcast_and_apply( inputs, action, + depth_context=depth_context, + lateral_context=lateral_context, left_broadcast=left_broadcast, right_broadcast=right_broadcast, broadcast_parameters_rule=broadcast_parameters_rule, numpy_to_regular=True, ) assert isinstance(out, tuple) - return [ - wrap_layout( + + # unify named axes + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = [] + for layout_out, array_in in zip(out, arrays): + _behavior = behavior_of_obj(array_in, behavior=behavior) + _attrs = attrs_of_obj(array_in, attrs=attrs) + wrapped = wrap_layout( layout_out, - behavior=behavior_of_obj(array_in, behavior=behavior), + behavior=_behavior, highlevel=highlevel, - attrs=attrs_of_obj(array_in, attrs=attrs), + attrs=_attrs, ) - for layout_out, array_in in zip(out, arrays) - ] + wrapped_out.append( + ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=_behavior, + attrs=_attrs, + ) + ) + return wrapped_out @ak._connect.numpy.implements("broadcast_arrays") diff --git a/src/awkward/operations/ak_cartesian.py b/src/awkward/operations/ak_cartesian.py index 91767d27d8..0f46f449c9 100644 --- a/src/awkward/operations/ak_cartesian.py +++ b/src/awkward/operations/ak_cartesian.py @@ -3,11 +3,20 @@ from __future__ import annotations from collections.abc import Mapping +from functools import reduce import awkward as ak from awkward._backends.numpy import NumpyBackend from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + NamedAxesWithDims, + _add_named_axis, + _get_named_axis, + _named_axis_to_positional_axis, + _unify_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis from awkward.errors import AxisError @@ -214,7 +223,6 @@ def cartesian( def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: if isinstance(arrays, Mapping): layouts = ensure_same_backend( @@ -226,6 +234,11 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr fields = list(arrays.keys()) array_layouts = dict(zip(fields, layouts)) + # propagate named axis from input to output, + # use strategy "unify" (see: awkward._namedaxis) + out_named_axis = reduce( + _unify_named_axis, map(_get_named_axis, arrays.values()) + ) else: layouts = array_layouts = ensure_same_backend( *( @@ -234,6 +247,15 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr ) ) fields = None + # propagate named axis from input to output, + # use strategy "unify" (see: awkward._namedaxis) + out_named_axis = reduce(_unify_named_axis, map(_get_named_axis, arrays)) + + # Handle named axis + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(out_named_axis, axis) + axis = regularize_axis(axis, none_allowed=False) + max_ndim = max(layout.minmax_depth[1] for layout in layouts) if with_name is not None: if parameters is None: @@ -262,6 +284,7 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr if nested is None or nested is False: nested = [] elif nested is True: + out_named_axis = _add_named_axis(out_named_axis, 0, max_ndim) if fields is not None: nested = list(fields)[:-1] else: @@ -287,6 +310,8 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr "the 'nested' parameter of cartesian must be integers in " "[0, len(arrays) - 1) for an iterable of arrays" ) + for n in nested: + out_named_axis = _add_named_axis(out_named_axis, n, max_ndim) backend = next((layout.backend for layout in layouts), cpu) if posaxis == 0: @@ -398,16 +423,48 @@ def apply_build_record(inputs, depth, **kwargs): else: return None + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + list(arrays.values()) if isinstance(arrays, Mapping) else list(arrays) + ) out = ak._broadcasting.broadcast_and_apply( - new_layouts, apply_build_record, right_broadcast=False + new_layouts, + apply_build_record, + depth_context=depth_context, + lateral_context=lateral_context, + right_broadcast=False, ) assert isinstance(out, tuple) and len(out) == 1 result = out[0] + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(result, highlevel=highlevel) + # propagate named axis to output + result = ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) + # Remove surplus dimensions, iterating from smallest to greatest for axis_to_flatten in axes_to_flatten: result = ak.operations.flatten( - result, axis=axis_to_flatten, highlevel=False, behavior=behavior + result, axis=axis_to_flatten, highlevel=highlevel, behavior=behavior ) - return ctx.wrap(result, highlevel=highlevel) + return result + + wrapped_out = ctx.wrap(result, highlevel=highlevel) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_categories.py b/src/awkward/operations/ak_categories.py index cd7f6ccf4c..e723d098da 100644 --- a/src/awkward/operations/ak_categories.py +++ b/src/awkward/operations/ak_categories.py @@ -49,6 +49,16 @@ def action(layout, **kwargs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + ak._do.recursively_apply(layout, action) - return ctx.wrap(output, highlevel=highlevel) + wrapped_out = ctx.wrap(output, highlevel=highlevel) + + # propagate named axis from input to output, + # use strategy "drop all" (see: awkward._namedaxis) + return ak.operations.ak_without_named_axis._impl( + wrapped_out, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_combinations.py b/src/awkward/operations/ak_combinations.py index d22708cb4a..284023f2cd 100644 --- a/src/awkward/operations/ak_combinations.py +++ b/src/awkward/operations/ak_combinations.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -214,7 +218,15 @@ def _impl( behavior, attrs, ): - axis = regularize_axis(axis) + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: + layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) if with_name is None: pass @@ -223,8 +235,6 @@ def _impl( else: parameters = {**parameters, "__record__": with_name} - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: - layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") out = ak._do.combinations( layout, n, diff --git a/src/awkward/operations/ak_concatenate.py b/src/awkward/operations/ak_concatenate.py index fb8fcf94ae..3e086f7e8c 100644 --- a/src/awkward/operations/ak_concatenate.py +++ b/src/awkward/operations/ak_concatenate.py @@ -2,6 +2,7 @@ from __future__ import annotations +from functools import reduce from itertools import permutations import awkward as ak @@ -9,6 +10,13 @@ from awkward._dispatch import high_level_function from awkward._do import mergeable from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + NamedAxesWithDims, + _get_named_axis, + _named_axis_to_positional_axis, + _unify_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import unknown_length from awkward._parameters import type_parameters_equal @@ -92,7 +100,6 @@ def _merge_as_union( def _impl(arrays, axis, mergebool, highlevel, behavior, attrs): - axis = regularize_axis(axis) # Simple single-array, axis=0 fast-path if ( # Is an array with a known backend @@ -121,6 +128,15 @@ def _impl(arrays, axis, mergebool, highlevel, behavior, attrs): ) ) + # Handle named axis + merged_named_axis = reduce(_unify_named_axis, map(_get_named_axis, arrays)) + # Step 1: normalize named axis to positional axis + axis = _named_axis_to_positional_axis(merged_named_axis, axis) + axis = regularize_axis(axis, none_allowed=False) + # Step 2: propagate named axis from input to output, + # use strategy "unify" (see: awkward._namedaxis) + out_named_axis = merged_named_axis + contents = [x for x in content_or_others if isinstance(x, ak.contents.Content)] if len(contents) == 0: raise ValueError("need at least one array to concatenate") @@ -342,11 +358,35 @@ def action(inputs, depth, backend, **kwargs): else: return None + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + list(arrays) + ) out = ak._broadcasting.broadcast_and_apply( - content_or_others, action, allow_records=True, right_broadcast=False + content_or_others, + action, + depth_context=depth_context, + lateral_context=lateral_context, + allow_records=True, + right_broadcast=False, )[0] + # Unify named axes + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) - return ctx.wrap(out, highlevel=highlevel) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) def _form_has_type(form, type_): diff --git a/src/awkward/operations/ak_corr.py b/src/awkward/operations/ak_corr.py index 74d148831d..e646a43b0f 100644 --- a/src/awkward/operations/ak_corr.py +++ b/src/awkward/operations/ak_corr.py @@ -3,12 +3,14 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, ensure_same_backend, maybe_highlevel_to_lowlevel, ) +from awkward._namedaxis import _get_named_axis, _is_valid_named_axis from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -86,7 +88,10 @@ def corr( def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) + if _is_valid_named_axis(axis): + raise NotImplementedError("named axis not yet supported for ak.corr") + + axis = regularize_axis(axis, none_allowed=True) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, y_layout, weight_layout = ensure_same_backend( @@ -110,7 +115,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr x, weight, axis, - False, + True, mask_identity, highlevel=True, behavior=ctx.behavior, @@ -120,7 +125,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr y, weight, axis, - False, + True, mask_identity, highlevel=True, behavior=ctx.behavior, @@ -184,8 +189,19 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr behavior=ctx.behavior, attrs=ctx.attrs, ) - return ctx.wrap( - maybe_highlevel_to_lowlevel(sumwxy / ufuncs.sqrt(sumwxx * sumwyy)), + + out = sumwxy / ufuncs.sqrt(sumwxx * sumwyy) + + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, ) + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, + ) diff --git a/src/awkward/operations/ak_count.py b/src/awkward/operations/ak_count.py index 85f43a27ee..f9b8c48481 100644 --- a/src/awkward/operations/ak_count.py +++ b/src/awkward/operations/ak_count.py @@ -5,6 +5,12 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -109,9 +115,26 @@ def count( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Count() out = ak._do.reduce( @@ -122,4 +145,18 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_count_nonzero.py b/src/awkward/operations/ak_count_nonzero.py index 919a6abf22..74a8b23033 100644 --- a/src/awkward/operations/ak_count_nonzero.py +++ b/src/awkward/operations/ak_count_nonzero.py @@ -5,6 +5,12 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -68,7 +74,26 @@ def count_nonzero( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: + layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") reducer = ak._reducers.CountNonzero() @@ -81,7 +106,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("count_nonzero") diff --git a/src/awkward/operations/ak_covar.py b/src/awkward/operations/ak_covar.py index a070ac6895..7c8fe930fe 100644 --- a/src/awkward/operations/ak_covar.py +++ b/src/awkward/operations/ak_covar.py @@ -3,12 +3,14 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, ensure_same_backend, maybe_highlevel_to_lowlevel, ) +from awkward._namedaxis import _get_named_axis, _is_valid_named_axis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -83,7 +85,9 @@ def covar( def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) + if _is_valid_named_axis(axis): + raise NotImplementedError("named axis not yet supported for ak.covar") + axis = regularize_axis(axis, none_allowed=True) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, y_layout, weight_layout = ensure_same_backend( @@ -107,7 +111,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr x, weight, axis, - False, + True, mask_identity, highlevel=True, behavior=None, @@ -117,7 +121,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr y, weight, axis, - False, + True, mask_identity, highlevel=True, behavior=None, @@ -161,8 +165,18 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr behavior=None, attrs=None, ) - return ctx.wrap( - maybe_highlevel_to_lowlevel(sumwxy / sumw), + + out = sumwxy / sumw + + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, ) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, + ) diff --git a/src/awkward/operations/ak_drop_none.py b/src/awkward/operations/ak_drop_none.py index c6c06014db..d81770f78f 100644 --- a/src/awkward/operations/ak_drop_none.py +++ b/src/awkward/operations/ak_drop_none.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis from awkward.errors import AxisError @@ -65,10 +69,16 @@ def _drop_none_if_list(layout): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=True) + if axis is None: # if the outer layout is_option, drop_nones without affecting offsets if layout.is_option: @@ -120,4 +130,7 @@ def action(layout, depth, **kwargs): if len(options["none_indexes"]) > 0: out = ak._do.recursively_apply(out, recompute_offsets, depth_context=options) - return ctx.wrap(out, highlevel=highlevel) + return ctx.wrap( + out, + highlevel=highlevel, + ) diff --git a/src/awkward/operations/ak_fill_none.py b/src/awkward/operations/ak_fill_none.py index 89834689cd..fb3dbfd019 100644 --- a/src/awkward/operations/ak_fill_none.py +++ b/src/awkward/operations/ak_fill_none.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis from awkward.errors import AxisError @@ -69,8 +73,6 @@ def fill_none(array, value, axis=-1, *, highlevel=True, behavior=None, attrs=Non def _impl(array, value, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: array_layout, value_layout = ensure_same_backend( ctx.unwrap(array, allow_record=True, allow_unknown=False), @@ -84,6 +86,13 @@ def _impl(array, value, axis, highlevel, behavior, attrs): ), ) + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=True) + if isinstance(value_layout, ak.record.Record): value_layout = value_layout.array[value_layout.at : value_layout.at + 1] elif isinstance(value_layout, ak.contents.Content): diff --git a/src/awkward/operations/ak_firsts.py b/src/awkward/operations/ak_firsts.py index f67da6dde1..79fba6eb51 100644 --- a/src/awkward/operations/ak_firsts.py +++ b/src/awkward/operations/ak_firsts.py @@ -5,8 +5,13 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward.errors import AxisError __all__ = ("firsts",) @@ -58,10 +63,20 @@ def firsts(array, axis=1, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False) - axis = regularize_axis(axis) - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer, not {axis!r}") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _remove_named_axis( + named_axis=named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=False) if maybe_posaxis(layout, axis, 1) == 0: # specialized logic; it's tested in test_0582-propagate-context-in-broadcast_and_apply.py @@ -103,4 +118,17 @@ def action(layout, depth, backend, **kwargs): out = ak._do.recursively_apply(layout, action, numpy_to_regular=True) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_flatten.py b/src/awkward/operations/ak_flatten.py index b246870463..3805d28e71 100644 --- a/src/awkward/operations/ak_flatten.py +++ b/src/awkward/operations/ak_flatten.py @@ -5,6 +5,12 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -173,10 +179,25 @@ def flatten(array, axis=1, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=True) + # Step 2: propagate named axis from input to output, + # if axis == None: use strategy "remove all" (see: awkward._namedaxis) + # if axis == 0: use strategy "keep all" (see: awkward._namedaxis) + # if axis != 0: use strategy "remove one" (see: awkward._namedaxis) + if axis is None: + pass + elif axis == 0 or maybe_posaxis(layout, axis, 1) == 0: + out_named_axis = _keep_named_axis(named_axis, None) + else: + out_named_axis = _remove_named_axis(named_axis, axis, layout.minmax_depth[1]) + if axis is None: out = ak._do.remove_structure(layout, function_name="ak.flatten") assert isinstance(out, tuple) and all( @@ -234,4 +255,27 @@ def apply(layout): out = apply(layout) else: out = ak._do.flatten(layout, axis) - return ctx.wrap(out, highlevel=highlevel) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + # if axis == None: use strategy "remove all" (see: awkward._namedaxis) + if axis is None: + return ak.operations.ak_without_named_axis._impl( + wrapped_out, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) + # if axis == 0: use strategy "keep all" (see: awkward._namedaxis) + # if axis != 0: use strategy "remove one" (see: awkward._namedaxis) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_from_regular.py b/src/awkward/operations/ak_from_regular.py index b3f840ef31..9fe2800a2b 100644 --- a/src/awkward/operations/ak_from_regular.py +++ b/src/awkward/operations/ak_from_regular.py @@ -55,7 +55,8 @@ def from_regular(array, axis=1, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False) - axis = regularize_axis(axis) + + axis = regularize_axis(axis, none_allowed=True) if axis is None: diff --git a/src/awkward/operations/ak_is_none.py b/src/awkward/operations/ak_is_none.py index 078c86bde6..d9a58a5478 100644 --- a/src/awkward/operations/ak_is_none.py +++ b/src/awkward/operations/ak_is_none.py @@ -5,8 +5,13 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis_up_to, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward.errors import AxisError __all__ = ("is_none",) @@ -41,12 +46,19 @@ def is_none(array, axis=0, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer, not {axis!r}") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + + # Step 2: propagate named axis from input to output, + # use strategy "keep up to" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis_up_to(named_axis, axis, layout.minmax_depth[1]) def action(layout, depth, backend, lateral_context, **kwargs): posaxis = maybe_posaxis(layout, axis, depth) @@ -68,4 +80,16 @@ def action(layout, depth, backend, lateral_context, **kwargs): out = ak._do.recursively_apply(layout, action, numpy_to_regular=True) - return ctx.wrap(out, highlevel=highlevel) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_isclose.py b/src/awkward/operations/ak_isclose.py index 8797c36752..d5ff825c61 100644 --- a/src/awkward/operations/ak_isclose.py +++ b/src/awkward/operations/ak_isclose.py @@ -2,9 +2,12 @@ from __future__ import annotations +from functools import reduce + import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("isclose",) @@ -70,10 +73,26 @@ def action(inputs, backend, **kwargs): ), ) - out = ak._broadcasting.broadcast_and_apply(layouts, action) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts([a, b]) + out = ak._broadcasting.broadcast_and_apply( + layouts, + action, + depth_context=depth_context, + lateral_context=lateral_context, + ) assert isinstance(out, tuple) and len(out) == 1 - return ctx.wrap(out[0], highlevel=highlevel) + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("isclose") diff --git a/src/awkward/operations/ak_linear_fit.py b/src/awkward/operations/ak_linear_fit.py index 971fea64fe..01ac0f3297 100644 --- a/src/awkward/operations/ak_linear_fit.py +++ b/src/awkward/operations/ak_linear_fit.py @@ -7,7 +7,6 @@ from awkward._layout import HighLevelContext, ensure_same_backend from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import regularize_axis __all__ = ("linear_fit",) @@ -95,8 +94,6 @@ def linear_fit( def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, y_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), @@ -231,4 +228,13 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr if is_scalar: out = out[0] - return ctx.wrap(out, highlevel=highlevel, allow_other=is_scalar) + wrapped_out = ctx.wrap(out, highlevel=highlevel, allow_other=is_scalar) + + # propagate named axis + # use strategy "remove all" (see: awkward._namedaxis) + return ak.operations.ak_without_named_axis._impl( + wrapped_out, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_local_index.py b/src/awkward/operations/ak_local_index.py index 2231ac229f..d5e7089dbc 100644 --- a/src/awkward/operations/ak_local_index.py +++ b/src/awkward/operations/ak_local_index.py @@ -5,6 +5,11 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis_up_to, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -88,8 +93,32 @@ def local_index(array, axis=-1, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + + # Step 2: propagate named axis from input to output, + # use strategy "keep up to" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis_up_to(named_axis, axis, layout.minmax_depth[1]) + out = ak._do.local_index(layout, axis) - return ctx.wrap(out, highlevel=highlevel) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_mask.py b/src/awkward/operations/ak_mask.py index 54d9a5e04b..b18273047a 100644 --- a/src/awkward/operations/ak_mask.py +++ b/src/awkward/operations/ak_mask.py @@ -2,9 +2,12 @@ from __future__ import annotations +from functools import reduce + import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("mask",) @@ -124,8 +127,26 @@ def action(inputs, backend, **kwargs): ctx.unwrap(mask, allow_record=False, primitive_policy="error"), ) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts([array, mask]) out = ak._broadcasting.broadcast_and_apply( - layouts, action, numpy_to_regular=True, right_broadcast=False + layouts, + action, + depth_context=depth_context, + lateral_context=lateral_context, + numpy_to_regular=True, + right_broadcast=False, ) assert isinstance(out, tuple) and len(out) == 1 - return ctx.wrap(out[0], highlevel=highlevel) + + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_max.py b/src/awkward/operations/ak_max.py index a01a0d64c5..319b2c7bed 100644 --- a/src/awkward/operations/ak_max.py +++ b/src/awkward/operations/ak_max.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -142,9 +148,26 @@ def nanmax( def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Max(initial) out = ak._do.reduce( @@ -155,7 +178,21 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("amax") diff --git a/src/awkward/operations/ak_mean.py b/src/awkward/operations/ak_mean.py index fa74a89b61..a9b38ce1f0 100644 --- a/src/awkward/operations/ak_mean.py +++ b/src/awkward/operations/ak_mean.py @@ -3,6 +3,7 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import ( @@ -11,6 +12,10 @@ maybe_highlevel_to_lowlevel, maybe_posaxis, ) +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -174,8 +179,6 @@ def nanmean( def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), @@ -191,6 +194,13 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): x = ctx.wrap(x_layout) weight = ctx.wrap(weight_layout, allow_other=True) + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=True) + with np.errstate(invalid="ignore", divide="ignore"): if weight is None: sumw = ak.operations.ak_count._impl( @@ -245,14 +255,25 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): if axis is None: if not keepdims: + # remove all dimensions out = out[(0,) * out.ndim] else: if not keepdims: + # remove reduced dimension posaxis = maybe_posaxis(out.layout, axis, 1) out = out[(slice(None, None),) * posaxis + (0,)] - return ctx.wrap( - maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), + highlevel=highlevel, + allow_other=True, + ) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out) or {}), + highlevel=highlevel, + behavior=None, + attrs=None, ) diff --git a/src/awkward/operations/ak_merge_option_of_records.py b/src/awkward/operations/ak_merge_option_of_records.py index c3e1095ba4..17402e77a6 100644 --- a/src/awkward/operations/ak_merge_option_of_records.py +++ b/src/awkward/operations/ak_merge_option_of_records.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis from awkward.errors import AxisError @@ -49,10 +53,15 @@ def merge_option_of_records( def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + # First, normalise type-invsible "index-of-records" to "record-of-index" def apply_displace_index(layout, backend, **kwargs): if (layout.is_indexed and not layout.is_option) and layout.content.is_record: diff --git a/src/awkward/operations/ak_merge_union_of_records.py b/src/awkward/operations/ak_merge_union_of_records.py index d523c0b5f8..0094203947 100644 --- a/src/awkward/operations/ak_merge_union_of_records.py +++ b/src/awkward/operations/ak_merge_union_of_records.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import ArrayLike, NumpyMetadata from awkward._regularize import regularize_axis from awkward.errors import AxisError @@ -59,10 +63,15 @@ def merge_union_of_records( def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + def invert_record_union( tags: ArrayLike, index: ArrayLike, contents ) -> ak.contents.RecordArray: diff --git a/src/awkward/operations/ak_min.py b/src/awkward/operations/ak_min.py index 05e583d430..1b9189f740 100644 --- a/src/awkward/operations/ak_min.py +++ b/src/awkward/operations/ak_min.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -142,9 +148,26 @@ def nanmin( def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Min(initial) out = ak._do.reduce( @@ -155,7 +178,21 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("amin") diff --git a/src/awkward/operations/ak_moment.py b/src/awkward/operations/ak_moment.py index 7cac2498ee..2c8e29adb1 100644 --- a/src/awkward/operations/ak_moment.py +++ b/src/awkward/operations/ak_moment.py @@ -3,14 +3,19 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, ensure_same_backend, maybe_highlevel_to_lowlevel, ) +from awkward._namedaxis import ( + AxisName, + _get_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import regularize_axis +from awkward._typing import Mapping __all__ = ("moment",) @@ -22,13 +27,13 @@ def moment( x, n, weight=None, - axis=None, + axis: AxisName = None, *, - keepdims=False, - mask_identity=False, - highlevel=True, - behavior=None, - attrs=None, + keepdims: bool = False, + mask_identity: bool = False, + highlevel: bool = True, + behavior: Mapping | None = None, + attrs: Mapping | None = None, ): """ Args: @@ -86,9 +91,17 @@ def moment( ) -def _impl(x, n, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) - +def _impl( + x, + n, + weight, + axis: AxisName, + keepdims: bool, + mask_identity: bool, + highlevel: bool, + behavior: Mapping | None, + attrs: Mapping | None, +): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), @@ -143,8 +156,20 @@ def _impl(x, n, weight, axis, keepdims, mask_identity, highlevel, behavior, attr behavior=ctx.behavior, attrs=ctx.attrs, ) - return ctx.wrap( - maybe_highlevel_to_lowlevel(sumwxn / sumw), + + out = sumwxn / sumw + + # propagate named axis to output + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, ) + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, + ) diff --git a/src/awkward/operations/ak_nan_to_none.py b/src/awkward/operations/ak_nan_to_none.py index 7dabbfe828..23ef938dbe 100644 --- a/src/awkward/operations/ak_nan_to_none.py +++ b/src/awkward/operations/ak_nan_to_none.py @@ -6,6 +6,7 @@ from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext from awkward._nplikes.numpy_like import NumpyMetadata +from awkward._typing import Mapping __all__ = ("nan_to_none",) @@ -13,7 +14,13 @@ @high_level_function() -def nan_to_none(array, *, highlevel=True, behavior=None, attrs=None): +def nan_to_none( + array, + *, + highlevel: bool = True, + behavior: Mapping | None = None, + attrs: Mapping | None = None, +): """ Args: array: Array-like data (anything #ak.to_layout recognizes). @@ -35,7 +42,7 @@ def nan_to_none(array, *, highlevel=True, behavior=None, attrs=None): return _impl(array, highlevel, behavior, attrs) -def _impl(array, highlevel, behavior, attrs): +def _impl(array, highlevel: bool, behavior: Mapping | None, attrs: Mapping | None): def action(layout, continuation, backend, **kwargs): if layout.is_numpy and np.issubdtype(layout.dtype, np.floating): mask = backend.nplike.isnan(layout.data) @@ -55,5 +62,6 @@ def action(layout, continuation, backend, **kwargs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + out = ak._do.recursively_apply(layout, action) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_nan_to_num.py b/src/awkward/operations/ak_nan_to_num.py index 4c7472a06f..69e2617c00 100644 --- a/src/awkward/operations/ak_nan_to_num.py +++ b/src/awkward/operations/ak_nan_to_num.py @@ -2,10 +2,14 @@ from __future__ import annotations +from functools import reduce + import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata +from awkward._typing import Mapping __all__ = ("nan_to_num",) @@ -15,14 +19,14 @@ @high_level_function() def nan_to_num( array, - copy=True, + copy: bool = True, nan=0.0, posinf=None, neginf=None, *, - highlevel=True, - behavior=None, - attrs=None, + highlevel: bool = True, + behavior: Mapping | None = None, + attrs: Mapping | None = None, ): """ Args: @@ -52,7 +56,16 @@ def nan_to_num( return _impl(array, copy, nan, posinf, neginf, highlevel, behavior, attrs) -def _impl(array, copy, nan, posinf, neginf, highlevel, behavior, attrs): +def _impl( + array, + copy: bool, + nan, + posinf, + neginf, + highlevel: bool, + behavior: Mapping | None, + attrs: Mapping | None, +): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout, nan_layout, posinf_layout, neginf_layout = ensure_same_backend( ctx.unwrap(array), @@ -81,15 +94,19 @@ def _impl(array, copy, nan, posinf, neginf, highlevel, behavior, attrs): broadcasting_ids = {} broadcasting = [layout] + arrays_to_broadcast = [array] if isinstance(nan_layout, ak.contents.Content): broadcasting_ids[id(nan)] = len(broadcasting) broadcasting.append(nan_layout) + arrays_to_broadcast.append(nan) if isinstance(posinf_layout, ak.contents.Content): broadcasting_ids[id(posinf)] = len(broadcasting) broadcasting.append(posinf_layout) + arrays_to_broadcast.append(posinf) if isinstance(neginf_layout, ak.contents.Content): broadcasting_ids[id(neginf)] = len(broadcasting) broadcasting.append(neginf_layout) + arrays_to_broadcast.append(neginf) if len(broadcasting) == 1: @@ -138,9 +155,29 @@ def action(inputs, backend, **kwargs): else: return None - out = ak._broadcasting.broadcast_and_apply(broadcasting, action) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + arrays_to_broadcast + ) + out = ak._broadcasting.broadcast_and_apply( + broadcasting, + action, + depth_context=depth_context, + lateral_context=lateral_context, + ) assert isinstance(out, tuple) and len(out) == 1 - out = out[0] + + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_num.py b/src/awkward/operations/ak_num.py index ad9b4e746c..705a1e1c63 100644 --- a/src/awkward/operations/ak_num.py +++ b/src/awkward/operations/ak_num.py @@ -5,8 +5,14 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis +from awkward._typing import Mapping from awkward.errors import AxisError __all__ = ("num",) @@ -15,7 +21,14 @@ @high_level_function() -def num(array, axis=1, *, highlevel=True, behavior=None, attrs=None): +def num( + array, + axis=1, + *, + highlevel: bool = True, + behavior: Mapping | None = None, + attrs: Mapping | None = None, +): """ Args: array: Array-like data (anything #ak.to_layout recognizes). @@ -83,13 +96,25 @@ def num(array, axis=1, *, highlevel=True, behavior=None, attrs=None): return _impl(array, axis, highlevel, behavior, attrs) -def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) +def _impl( + array, + axis, + highlevel: bool, + behavior: Mapping | None, + attrs: Mapping | None, +): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer, not {axis!r}") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # use strategy "keep one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) if maybe_posaxis(layout, axis, 1) == 0: index_nplike = layout.backend.index_nplike @@ -109,4 +134,16 @@ def action(layout, depth, **kwargs): out = ak._do.recursively_apply(layout, action, numpy_to_regular=True) - return ctx.wrap(out, highlevel=highlevel) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_pad_none.py b/src/awkward/operations/ak_pad_none.py index 34355a8546..17bb3035ac 100644 --- a/src/awkward/operations/ak_pad_none.py +++ b/src/awkward/operations/ak_pad_none.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -113,9 +117,15 @@ def pad_none( def _impl(array, target, axis, clip, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + out = ak._do.pad_none(layout, target, axis, clip=clip) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_prod.py b/src/awkward/operations/ak_prod.py index cde898f174..d3d1a050c3 100644 --- a/src/awkward/operations/ak_prod.py +++ b/src/awkward/operations/ak_prod.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -119,9 +125,26 @@ def nanprod( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Prod() out = ak._do.reduce( @@ -132,7 +155,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("prod") diff --git a/src/awkward/operations/ak_ptp.py b/src/awkward/operations/ak_ptp.py index 56daaa6980..6d4beafbd5 100644 --- a/src/awkward/operations/ak_ptp.py +++ b/src/awkward/operations/ak_ptp.py @@ -3,6 +3,7 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import ( @@ -10,6 +11,10 @@ maybe_highlevel_to_lowlevel, maybe_posaxis, ) +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -83,10 +88,16 @@ def ptp( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=True) + with np.errstate(invalid="ignore", divide="ignore"): maxi = ak.operations.ak_max._impl( layout, @@ -126,8 +137,18 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): posaxis = maybe_posaxis(out.layout, axis, 1) out = out[(slice(None, None),) * posaxis + (0,)] - return ctx.wrap( - maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), + highlevel=highlevel, + allow_other=True, + ) + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, ) diff --git a/src/awkward/operations/ak_ravel.py b/src/awkward/operations/ak_ravel.py index 66a3e3a55d..062601eff4 100644 --- a/src/awkward/operations/ak_ravel.py +++ b/src/awkward/operations/ak_ravel.py @@ -75,7 +75,16 @@ def _impl(array, highlevel, behavior, attrs): result = ak._do.mergemany(out) - return ctx.wrap(result, highlevel=highlevel) + wrapped_out = ctx.wrap(result, highlevel=highlevel) + + # propagate named axis to output + # use strategy "remove all" (see: awkward._namedaxis) + return ak.operations.ak_without_named_axis._impl( + wrapped_out, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("ravel") diff --git a/src/awkward/operations/ak_real.py b/src/awkward/operations/ak_real.py index 6d52971dab..655e4e8007 100644 --- a/src/awkward/operations/ak_real.py +++ b/src/awkward/operations/ak_real.py @@ -14,10 +14,10 @@ @ak._connect.numpy.implements("real") @high_level_function() -def real(val, highlevel=True, behavior=None, attrs=None): +def real(array, highlevel=True, behavior=None, attrs=None): """ Args: - val : array_like + array : array_like Input array. highlevel (bool, default is True): If True, return an #ak.Array; otherwise, return a low-level #ak.contents.Content subclass. @@ -30,15 +30,15 @@ def real(val, highlevel=True, behavior=None, attrs=None): If the arrays have complex elements, the returned arrays are floats. """ # Dispatch - yield (val,) + yield (array,) # Implementation - return _impl_real(val, highlevel, behavior, attrs) + return _impl(array, highlevel, behavior, attrs) -def _impl_real(val, highlevel, behavior, attrs): +def _impl(array, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: - layout = ctx.unwrap(val, allow_record=False, primitive_policy="error") + layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") out = ak._do.recursively_apply(layout, _action_real) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_singletons.py b/src/awkward/operations/ak_singletons.py index 35f60d5c97..4de6a59151 100644 --- a/src/awkward/operations/ak_singletons.py +++ b/src/awkward/operations/ak_singletons.py @@ -5,8 +5,13 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _add_named_axis, + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward.errors import AxisError __all__ = ("singletons",) @@ -56,12 +61,21 @@ def singletons(array, axis=0, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer, not {axis!r}") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + + # Step 2: propagate named axis from input to output, + # use strategy "add one" (see: awkward._namedaxis) + out_named_axis = _add_named_axis( + named_axis, (axis + 1) if axis >= 0 else axis, layout.minmax_depth[1] + ) def action(layout, depth, backend, **kwargs): posaxis = maybe_posaxis(layout, axis, depth) @@ -90,4 +104,16 @@ def action(layout, depth, backend, **kwargs): out = ak._do.recursively_apply(layout, action, numpy_to_regular=True) - return ctx.wrap(out, highlevel=highlevel) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_softmax.py b/src/awkward/operations/ak_softmax.py index e86cbe9cf0..b2cb11bff0 100644 --- a/src/awkward/operations/ak_softmax.py +++ b/src/awkward/operations/ak_softmax.py @@ -3,12 +3,17 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, maybe_highlevel_to_lowlevel, maybe_posaxis, ) +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -75,10 +80,16 @@ def softmax( def _impl(x, axis, keepdims, mask_identity, highlevel, behavior, attrs): original_axis = axis - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout = ctx.unwrap(x, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=True) + x = ctx.wrap(x_layout) if maybe_posaxis(x_layout, axis, 1) != maybe_posaxis(x_layout, -1, 1): @@ -97,8 +108,19 @@ def _impl(x, axis, keepdims, mask_identity, highlevel, behavior, attrs): behavior=ctx.behavior, attrs=ctx.attrs, ) - return ctx.wrap( - maybe_highlevel_to_lowlevel(expx / denom), + + out = expx / denom + + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, ) + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, + ) diff --git a/src/awkward/operations/ak_sort.py b/src/awkward/operations/ak_sort.py index 5e82e91604..0864fc5d98 100644 --- a/src/awkward/operations/ak_sort.py +++ b/src/awkward/operations/ak_sort.py @@ -6,6 +6,10 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -59,11 +63,22 @@ def sort( def _impl(array, axis, ascending, stable, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + out = ak._do.sort(layout, axis, ascending, stable) - return ctx.wrap(out, highlevel=highlevel) + + return ctx.wrap( + out, + highlevel=highlevel, + ) @ak._connect.numpy.implements("sort") diff --git a/src/awkward/operations/ak_std.py b/src/awkward/operations/ak_std.py index 0385032440..7926b341fe 100644 --- a/src/awkward/operations/ak_std.py +++ b/src/awkward/operations/ak_std.py @@ -3,6 +3,7 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import ( @@ -11,6 +12,10 @@ maybe_highlevel_to_lowlevel, maybe_posaxis, ) +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -165,8 +170,6 @@ def nanstd( def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), @@ -182,6 +185,13 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a x = ctx.wrap(x_layout) weight = ctx.wrap(weight_layout, allow_other=True) + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=True) + with np.errstate(invalid="ignore", divide="ignore"): out = ufuncs.sqrt( ak.operations.ak_var._impl( @@ -215,8 +225,18 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a posaxis = maybe_posaxis(out.layout, axis, 1) out = out[(slice(None, None),) * posaxis + (0,)] - return ctx.wrap( - maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), + highlevel=highlevel, + allow_other=True, + ) + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, ) diff --git a/src/awkward/operations/ak_strings_astype.py b/src/awkward/operations/ak_strings_astype.py index b0834db3a6..479232cf01 100644 --- a/src/awkward/operations/ak_strings_astype.py +++ b/src/awkward/operations/ak_strings_astype.py @@ -82,5 +82,6 @@ def action(layout, **kwargs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + out = ak._do.recursively_apply(layout, action) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_sum.py b/src/awkward/operations/ak_sum.py index f00434083e..ae6a40aef8 100644 --- a/src/awkward/operations/ak_sum.py +++ b/src/awkward/operations/ak_sum.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -269,9 +275,26 @@ def nansum( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Sum() out = ak._do.reduce( @@ -282,7 +305,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("sum") diff --git a/src/awkward/operations/ak_to_backend.py b/src/awkward/operations/ak_to_backend.py index f65a2c0a81..8d93e2de94 100644 --- a/src/awkward/operations/ak_to_backend.py +++ b/src/awkward/operations/ak_to_backend.py @@ -17,7 +17,7 @@ def to_backend(array, backend, *, highlevel=True, behavior=None, attrs=None): """ Args: array: Array-like data (anything #ak.to_layout recognizes). - backend (`"cpu"`, `"cuda"`, or `"jax"`): If `"cpu"`, the array structure is + backend (`"cpu"`, `"cuda"`, `"jax"`, or `"typetracer"`): If `"cpu"`, the array structure is recursively copied (if need be) to main memory for use with the default Numpy backend; if `"cuda"`, the structure is copied to the GPU(s) for use with CuPy. If `"jax"`, the structure is diff --git a/src/awkward/operations/ak_to_regular.py b/src/awkward/operations/ak_to_regular.py index b72e48d7c5..ae9f9cc3da 100644 --- a/src/awkward/operations/ak_to_regular.py +++ b/src/awkward/operations/ak_to_regular.py @@ -66,7 +66,7 @@ def to_regular(array, axis=1, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) + axis = regularize_axis(axis, none_allowed=True) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") diff --git a/src/awkward/operations/ak_transform.py b/src/awkward/operations/ak_transform.py index 23b4dbfd4e..93a4911914 100644 --- a/src/awkward/operations/ak_transform.py +++ b/src/awkward/operations/ak_transform.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +from functools import reduce import awkward as ak from awkward._backends.numpy import NumpyBackend @@ -15,6 +16,7 @@ ) from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis __all__ = ("transform",) @@ -580,6 +582,17 @@ def action(inputs, **kwargs): f"transformation must return a Content, tuple of Contents, or None, not {type(out)}\n\n{out!r}" ) + if depth_context is None: + depth_context = {} + if lateral_context is None: + lateral_context = {} + assert NAMED_AXIS_KEY not in depth_context + assert NAMED_AXIS_KEY not in lateral_context + _depth_context, _lateral_context = NamedAxesWithDims.prepare_contexts( + [array, *more_arrays] + ) + depth_context.update(_depth_context) + lateral_context.update(_lateral_context) backend = next((layout.backend for layout in layouts), cpu) isscalar = [] out = apply_broadcasting_step( @@ -594,6 +607,11 @@ def action(inputs, **kwargs): assert isinstance(out, tuple) out = [broadcast_unpack(x, isscalar) for x in out] + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + if return_value == "none": return elif expect_return_value and not transformer_did_terminate: @@ -602,6 +620,25 @@ def action(inputs, **kwargs): "or tuple of Contents, but instead only returned None." ) elif len(out) == 1: - return ctx.wrap(out[0], highlevel=highlevel) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) else: - return tuple(ctx.wrap(x, highlevel=highlevel) for x in out) + wrapped_out = [] + for x in out: + wrapped = ctx.wrap(x, highlevel=highlevel) + wrapped_out.append( + ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) + ) + return tuple(wrapped_out) diff --git a/src/awkward/operations/ak_unflatten.py b/src/awkward/operations/ak_unflatten.py index 78c2631e31..83a3b8f2b4 100644 --- a/src/awkward/operations/ak_unflatten.py +++ b/src/awkward/operations/ak_unflatten.py @@ -6,6 +6,10 @@ from awkward._backends.numpy import NumpyBackend from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import unknown_length from awkward._nplikes.typetracer import is_unknown_scalar @@ -91,8 +95,6 @@ def unflatten(array, counts, axis=0, *, highlevel=True, behavior=None, attrs=Non def _impl(array, counts, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout, maybe_counts_layout = ensure_same_backend( ctx.unwrap(array, allow_record=False, primitive_policy="error"), @@ -105,6 +107,13 @@ def _impl(array, counts, axis, highlevel, behavior, attrs): ), ) + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + if is_integer_like(maybe_counts_layout): # Regularize unknown values to unknown lengths if ( @@ -292,4 +301,16 @@ def apply(layout, depth, backend, **kwargs): f"at axis={axis}" ) - return ctx.wrap(out, highlevel=highlevel) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # Step 2: propagate named axis from input to output, + # use strategy "remove all" (see: awkward._namedaxis) + return ak.operations.ak_without_named_axis._impl( + wrapped_out, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_unzip.py b/src/awkward/operations/ak_unzip.py index 8d0bfc229a..8c19380133 100644 --- a/src/awkward/operations/ak_unzip.py +++ b/src/awkward/operations/ak_unzip.py @@ -51,6 +51,7 @@ def unzip(array, *, highlevel=True, behavior=None, attrs=None): def _impl(array, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=True, primitive_policy="error") + fields = ak.operations.fields(layout) def check_for_union(layout, **kwargs): @@ -70,5 +71,10 @@ def check_for_union(layout, **kwargs): return (ctx.wrap(layout, highlevel=highlevel, allow_other=True),) else: return tuple( - ctx.wrap(layout[n], highlevel=highlevel, allow_other=True) for n in fields + ctx.wrap( + layout[n], + highlevel=highlevel, + allow_other=True, + ) + for n in fields ) diff --git a/src/awkward/operations/ak_values_astype.py b/src/awkward/operations/ak_values_astype.py index 714a4320d9..fa25ca5a35 100644 --- a/src/awkward/operations/ak_values_astype.py +++ b/src/awkward/operations/ak_values_astype.py @@ -72,6 +72,7 @@ def values_astype( def _impl(array, to, including_unknown, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + to_str = ak.types.numpytype.dtype_to_primitive(np.dtype(to)) out = ak._do.numbers_to_type(layout, to_str, including_unknown) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_var.py b/src/awkward/operations/ak_var.py index d1139d8b4c..759f5edf1c 100644 --- a/src/awkward/operations/ak_var.py +++ b/src/awkward/operations/ak_var.py @@ -3,6 +3,7 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import ( @@ -11,6 +12,10 @@ maybe_highlevel_to_lowlevel, maybe_posaxis, ) +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -170,8 +175,6 @@ def nanvar( def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), @@ -187,6 +190,12 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a x = ctx.wrap(x_layout) weight = ctx.wrap(weight_layout, allow_other=True) + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=True) + with np.errstate(invalid="ignore", divide="ignore"): if weight is None: sumw = ak.operations.ak_count._impl( @@ -267,8 +276,19 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a posaxis = maybe_posaxis(out.layout, axis, 1) out = out[(slice(None, None),) * posaxis + (0,)] - return ctx.wrap( - maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, ) diff --git a/src/awkward/operations/ak_where.py b/src/awkward/operations/ak_where.py index dda7d99f42..07f5f2f7bd 100644 --- a/src/awkward/operations/ak_where.py +++ b/src/awkward/operations/ak_where.py @@ -2,9 +2,12 @@ from __future__ import annotations +from functools import reduce + import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("where",) @@ -121,8 +124,26 @@ def action(inputs, backend, **kwargs): else: return None + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + [x, y, condition] + ) out = ak._broadcasting.broadcast_and_apply( - layouts, action, numpy_to_regular=True, function_name="ak.where" + layouts, + action, + depth_context=depth_context, + lateral_context=lateral_context, + numpy_to_regular=True, + function_name="ak.where", + ) + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, ) - - return ctx.wrap(out[0], highlevel=highlevel) diff --git a/src/awkward/operations/ak_with_field.py b/src/awkward/operations/ak_with_field.py index 671a061978..3adb5c33a1 100644 --- a/src/awkward/operations/ak_with_field.py +++ b/src/awkward/operations/ak_with_field.py @@ -3,10 +3,12 @@ from __future__ import annotations import copy +from functools import reduce import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_non_string_like_sequence @@ -76,6 +78,11 @@ def _impl(base, what, where, highlevel, behavior, attrs): if is_non_string_like_sequence(where): where = where[0] + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + [base, what], + unwrap_kwargs={"none_policy": "promote"}, + ) + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: base, what = ensure_same_backend( ctx.unwrap(base, allow_record=True, primitive_policy="error"), @@ -156,9 +163,24 @@ def action(inputs, **kwargs): return None out = ak._broadcasting.broadcast_and_apply( - [base, what], action, right_broadcast=False + [base, what], + action, + depth_context=depth_context, + lateral_context=lateral_context, + right_broadcast=False, ) assert isinstance(out, tuple) and len(out) == 1 - return ctx.wrap(out[0], highlevel=highlevel) + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_with_named_axis.py b/src/awkward/operations/ak_with_named_axis.py new file mode 100644 index 0000000000..507acc485c --- /dev/null +++ b/src/awkward/operations/ak_with_named_axis.py @@ -0,0 +1,72 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +from awkward._dispatch import high_level_function +from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + AxisMapping, + AxisTuple, + _prepare_named_axis_for_attrs, +) +from awkward._nplikes.numpy_like import NumpyMetadata + +__all__ = ("with_named_axis",) + +np = NumpyMetadata.instance() + + +@high_level_function() +def with_named_axis( + array, + named_axis: AxisTuple | AxisMapping, + *, + highlevel=True, + behavior=None, + attrs=None, +): + """ + Args: + array: Array-like data (anything #ak.to_layout recognizes). + named_axis: AxisTuple | AxisMapping: Names to give to the array axis; this assigns + the `"__named_axis__"` attr. If None, any existing name is unset. + highlevel (bool): If True, return an #ak.Array; otherwise, return + a low-level #ak.contents.Content subclass. + behavior (None or dict): Custom #ak.behavior for the output array, if + high-level. + attrs (None or dict): Custom attributes for the output array, if + high-level. + + Returns an #ak.Array or #ak.Record (or low-level equivalent, if + `highlevel=False`) with a new name. This function does not change the + array in-place. If the new name is None, then the array is returned as it is. + """ + # Dispatch + yield (array,) + + # Implementation + return _impl(array, named_axis, highlevel, behavior, attrs) + + +def _impl(array, named_axis, highlevel, behavior, attrs): + # Named axis handling + if not named_axis: # no-op, e.g. named_axis is None, (), {} + return array + + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: + layout = ctx.unwrap(array, allow_record=True) + + _named_axis = _prepare_named_axis_for_attrs( + named_axis=named_axis, + ndim=layout.minmax_depth[1], + ) + # now we're good, set the named axis + return ctx.with_attr( + key=NAMED_AXIS_KEY, + value=_named_axis, + ).wrap( + layout, + highlevel=highlevel, + allow_other=True, + ) diff --git a/src/awkward/operations/ak_without_named_axis.py b/src/awkward/operations/ak_without_named_axis.py new file mode 100644 index 0000000000..3697344a4b --- /dev/null +++ b/src/awkward/operations/ak_without_named_axis.py @@ -0,0 +1,54 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +from awkward._dispatch import high_level_function +from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + NAMED_AXIS_KEY, +) +from awkward._nplikes.numpy_like import NumpyMetadata + +__all__ = ("without_named_axis",) + +np = NumpyMetadata.instance() + + +@high_level_function() +def without_named_axis( + array, + *, + highlevel=True, + behavior=None, + attrs=None, +): + """ + Args: + array: Array-like data (anything #ak.to_layout recognizes). + highlevel (bool): If True, return an #ak.Array; otherwise, return + a low-level #ak.contents.Content subclass. + behavior (None or dict): Custom #ak.behavior for the output array, if + high-level. + attrs (None or dict): Custom attributes for the output array, if + high-level. + + Returns an #ak.Array or #ak.Record (or low-level equivalent, if + `highlevel=False`) without named axes. This function does not change the + array in-place. + """ + # Dispatch + yield (array,) + + # Implementation + return _impl(array, highlevel, behavior, attrs) + + +def _impl(array, highlevel, behavior, attrs): + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: + layout = ctx.unwrap(array, allow_record=True) + + return ctx.without_attr(key=NAMED_AXIS_KEY).wrap( + layout, + highlevel=highlevel, + allow_other=True, + ) diff --git a/src/awkward/operations/ak_zip.py b/src/awkward/operations/ak_zip.py index bed5c233e5..5ce58f8b1a 100644 --- a/src/awkward/operations/ak_zip.py +++ b/src/awkward/operations/ak_zip.py @@ -3,10 +3,12 @@ from __future__ import annotations from collections.abc import Mapping +from functools import reduce import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("zip",) @@ -174,6 +176,7 @@ def _impl( ): if depth_limit is not None and depth_limit <= 0: raise ValueError("depth_limit must be None or at least 1") + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: if isinstance(arrays, Mapping): layouts = ensure_same_backend( @@ -238,8 +241,15 @@ def action(inputs, depth, backend, **ignore): else: return None + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + list(arrays.values()) if isinstance(arrays, Mapping) else list(arrays) + ) out = ak._broadcasting.broadcast_and_apply( - layouts, action, right_broadcast=right_broadcast + layouts, + action, + depth_context=depth_context, + lateral_context=lateral_context, + right_broadcast=right_broadcast, ) assert isinstance(out, tuple) and len(out) == 1 out = out[0] @@ -248,4 +258,15 @@ def action(inputs, depth, backend, **ignore): out = out[0] assert isinstance(out, ak.record.Record) - return ctx.wrap(out, highlevel=highlevel) + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out, highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/str/akstr_join.py b/src/awkward/operations/str/akstr_join.py index a5ab638ba5..d18dc0174e 100644 --- a/src/awkward/operations/str/akstr_join.py +++ b/src/awkward/operations/str/akstr_join.py @@ -2,10 +2,15 @@ from __future__ import annotations +from functools import reduce + import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._backends.typetracer import TypeTracerBackend +from awkward._behavior import behavior_of_obj from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis __all__ = ("join",) @@ -95,6 +100,7 @@ def apply_unary(layout, **kwargs): ) out = ak._do.recursively_apply(layout, apply_unary) + return ctx.wrap(out, highlevel=highlevel) else: def apply_binary(layouts, **kwargs): @@ -123,8 +129,24 @@ def apply_binary(layouts, **kwargs): ), ) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + [array, separator] + ) (out,) = ak._broadcasting.broadcast_and_apply( - (layout, maybe_separator_layout), apply_binary + (layout, maybe_separator_layout), + apply_binary, + depth_context=depth_context, + lateral_context=lateral_context, ) - return ctx.wrap(out, highlevel=highlevel) + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped = ctx.wrap(out, highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=behavior_of_obj(wrapped), + attrs=attrs_of_obj(wrapped), + ) diff --git a/src/awkward/operations/str/akstr_join_element_wise.py b/src/awkward/operations/str/akstr_join_element_wise.py index 98f4e42f91..cd2ed0184f 100644 --- a/src/awkward/operations/str/akstr_join_element_wise.py +++ b/src/awkward/operations/str/akstr_join_element_wise.py @@ -2,10 +2,15 @@ from __future__ import annotations +from functools import reduce + import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._backends.typetracer import TypeTracerBackend +from awkward._behavior import behavior_of_obj from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis __all__ = ("join_element_wise",) @@ -66,6 +71,22 @@ def action(layouts, **kwargs): ): return (_apply_through_arrow(pc.binary_join_element_wise, *layouts),) - (out,) = ak._broadcasting.broadcast_and_apply(layouts, action) - - return ctx.wrap(out, highlevel=highlevel) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arrays) + (out,) = ak._broadcasting.broadcast_and_apply( + layouts, + action, + depth_context=depth_context, + lateral_context=lateral_context, + ) + + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped = ctx.wrap(out, highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=behavior_of_obj(wrapped), + attrs=attrs_of_obj(wrapped), + ) diff --git a/src/awkward/operations/str/akstr_repeat.py b/src/awkward/operations/str/akstr_repeat.py index de929c57b7..49bec96569 100644 --- a/src/awkward/operations/str/akstr_repeat.py +++ b/src/awkward/operations/str/akstr_repeat.py @@ -3,11 +3,15 @@ from __future__ import annotations import numbers +from functools import reduce import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._backends.typetracer import TypeTracerBackend +from awkward._behavior import behavior_of_obj from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("repeat",) @@ -79,8 +83,26 @@ def action(inputs, **kwargs): return (_apply_through_arrow(pc.binary_repeat, *inputs),) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + [array, num_repeats] + ) (out,) = ak._broadcasting.broadcast_and_apply( - (layout, num_repeats_layout), action + (layout, num_repeats_layout), + action, + depth_context=depth_context, + lateral_context=lateral_context, + ) + + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped = ctx.wrap(out, highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=behavior_of_obj(wrapped), + attrs=attrs_of_obj(wrapped), ) else: @@ -98,4 +120,4 @@ def action(layout, **kwargs): out = ak._do.recursively_apply(layout, action) - return ctx.wrap(out, highlevel=highlevel) + return ctx.wrap(out, highlevel=highlevel) diff --git a/tests/test_2596_named_axis.py b/tests/test_2596_named_axis.py new file mode 100644 index 0000000000..acfa1e9e34 --- /dev/null +++ b/tests/test_2596_named_axis.py @@ -0,0 +1,2243 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import sys + +import numpy as np +import pytest + +import awkward as ak +from awkward._namedaxis import _get_named_axis + + +def test_constructor(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y")) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0, "y": 1} + assert array.positional_axis == (0, 1) + + array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis={"x": 0, "y": 1}) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0, "y": 1} + assert array.positional_axis == (0, 1) + + +def test_with_named_axis(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + assert not _get_named_axis(array) + assert array.named_axis == {} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis=("x", "y")) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0, "y": 1} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis=("x", None)) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis=(None, "x")) + assert _get_named_axis(array) + assert array.named_axis == {"x": 1} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis={"x": 0, "y": 1}) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0, "y": 1} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis={"x": 1}) + assert _get_named_axis(array) + assert array.named_axis == {"x": 1} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis={"y": -1}) + assert _get_named_axis(array) + assert array.named_axis == {"y": -1} + assert array.positional_axis == (0, 1) + + # This is possible in a future version of named axis, but currently only strings are supported + # from dataclasses import dataclass + + # @dataclass(frozen=True) + # class exotic_axis: + # attr: str + + # ax1 = exotic_axis(attr="I'm not the type of axis that you're used to") + # ax2 = exotic_axis(attr="...me neither!") + + # array = ak.with_named_axis(array, named_axis=(ax1, ax2)) + # assert array.named_axis == (ax1, ax2) + # assert array.positional_axis == (0, 1) + + +def test_named_axis_indexing(): + array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y", "z")) + + # test indexing + assert ak.all(array[...] == named_array[...]) + assert ak.all(array[()] == named_array[()]) + + assert ak.all(array[None, :, :, :] == named_array[None, :, :, :]) + assert ak.all(array[:, None, :, :] == named_array[:, None, :, :]) + assert ak.all(array[:, :, None, :] == named_array[:, :, None, :]) + assert ak.all(array[:, :, :, None] == named_array[:, :, :, None]) + + assert ak.all(array[0, :, :] == named_array[{"x": 0}]) + assert ak.all(array[:, 0, :] == named_array[{"y": 0}]) + assert ak.all(array[:, :, 0] == named_array[{"z": 0}]) + + assert ak.all(array[0, :, :] == named_array[{0: 0}]) + assert ak.all(array[:, 0, :] == named_array[{1: 0}]) + assert ak.all(array[:, :, 0] == named_array[{2: 0}]) + + assert ak.all(array[0, :, :] == named_array[{-3: 0}]) + assert ak.all(array[:, 0, :] == named_array[{-2: 0}]) + assert ak.all(array[:, :, 0] == named_array[{-1: 0}]) + + assert ak.all(array[0, 0, :] == named_array[{"x": 0, "y": 0}]) + assert ak.all(array[0, :, 0] == named_array[{"x": 0, "z": 0}]) + assert ak.all(array[:, 0, 0] == named_array[{"y": 0, "z": 0}]) + assert array[0, 0, 0] == named_array[{"x": 0, "y": 0, "z": 0}] + + assert ak.all(array[slice(0, 1), :, :] == named_array[{"x": slice(0, 1)}]) + assert ak.all(array[:, slice(0, 1), :] == named_array[{"y": slice(0, 1)}]) + assert ak.all(array[:, :, slice(0, 1)] == named_array[{"z": slice(0, 1)}]) + + assert ak.all(array[0, :, slice(0, 1)] == named_array[{"x": 0, "z": slice(0, 1)}]) + assert ak.all(array[:, 0, slice(0, 1)] == named_array[{"y": 0, "z": slice(0, 1)}]) + assert ak.all(array[slice(0, 1), 0, :] == named_array[{"x": slice(0, 1), "y": 0}]) + + assert ak.all(array[array > 3] == named_array[named_array > 3]) + + # test naming propagation + assert ( + named_array[...].named_axis + == named_array.named_axis + == {"x": 0, "y": 1, "z": 2} + ) + assert ( + named_array[()].named_axis == named_array.named_axis == {"x": 0, "y": 1, "z": 2} + ) + + assert named_array[None, :, :, :].named_axis == {"x": 1, "y": 2, "z": 3} + assert named_array[:, None, :, :].named_axis == {"x": 0, "y": 2, "z": 3} + assert named_array[:, :, None, :].named_axis == {"x": 0, "y": 1, "z": 3} + assert named_array[:, :, :, None].named_axis == {"x": 0, "y": 1, "z": 2} + + assert named_array[None, ...].named_axis == {"x": 1, "y": 2, "z": 3} + assert named_array[:, None, ...].named_axis == {"x": 0, "y": 2, "z": 3} + assert named_array[..., None, :].named_axis == {"x": 0, "y": 1, "z": 3} + assert named_array[..., None].named_axis == {"x": 0, "y": 1, "z": 2} + + assert ( + named_array[0, :, :].named_axis + == named_array[{"x": 0}].named_axis + == {"y": 0, "z": 1} + ) + assert ( + named_array[:, 0, :].named_axis + == named_array[{"y": 0}].named_axis + == {"x": 0, "z": 1} + ) + assert ( + named_array[:, :, 0].named_axis + == named_array[{"z": 0}].named_axis + == {"x": 0, "y": 1} + ) + + assert ( + named_array[0, ...].named_axis + == named_array[{"x": 0}].named_axis + == {"y": 0, "z": 1} + ) + assert ( + named_array[:, 0, :].named_axis + == named_array[{"y": 0}].named_axis + == {"x": 0, "z": 1} + ) + assert ( + named_array[..., 0].named_axis + == named_array[{"z": 0}].named_axis + == {"x": 0, "y": 1} + ) + + assert named_array[{0: 0}].named_axis == {"y": 0, "z": 1} + assert named_array[{1: 0}].named_axis == {"x": 0, "z": 1} + assert named_array[{2: 0}].named_axis == {"x": 0, "y": 1} + + assert named_array[{-3: 0}].named_axis == {"y": 0, "z": 1} + assert named_array[{-2: 0}].named_axis == {"x": 0, "z": 1} + assert named_array[{-1: 0}].named_axis == {"x": 0, "y": 1} + + assert ( + named_array[0, 0, :].named_axis + == named_array[{"x": 0, "y": 0}].named_axis + == {"z": 0} + ) + assert ( + named_array[0, :, 0].named_axis + == named_array[{"x": 0, "z": 0}].named_axis + == {"y": 0} + ) + assert ( + named_array[:, 0, 0].named_axis + == named_array[{"y": 0, "z": 0}].named_axis + == {"x": 0} + ) + assert not _get_named_axis(named_array[0, 0, 0]) + assert not _get_named_axis(named_array[{"x": 0, "y": 0, "z": 0}]) + + assert ( + named_array[slice(0, 1), :, :].named_axis + == named_array[{"x": slice(0, 1)}].named_axis + == {"x": 0, "y": 1, "z": 2} + ) + assert ( + named_array[:, slice(0, 1), :].named_axis + == named_array[{"y": slice(0, 1)}].named_axis + == {"x": 0, "y": 1, "z": 2} + ) + assert ( + named_array[:, :, slice(0, 1)].named_axis + == named_array[{"z": slice(0, 1)}].named_axis + == {"x": 0, "y": 1, "z": 2} + ) + + assert ( + named_array[0, :, slice(0, 1)].named_axis + == named_array[{"x": 0, "z": slice(0, 1)}].named_axis + == {"y": 0, "z": 1} + ) + assert ( + named_array[:, 0, slice(0, 1)].named_axis + == named_array[{"y": 0, "z": slice(0, 1)}].named_axis + == {"x": 0, "z": 1} + ) + assert ( + named_array[slice(0, 1), 0, :].named_axis + == named_array[{"x": slice(0, 1), "y": 0}].named_axis + == {"x": 0, "z": 1} + ) + + +def test_negative_named_axis_indexing(): + array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -3, "y": -2, "z": -1}) + + # test indexing + assert ak.all(array[...] == named_array[...]) + assert ak.all(array[()] == named_array[()]) + + assert ak.all(array[None, :, :, :] == named_array[None, :, :, :]) + assert ak.all(array[:, None, :, :] == named_array[:, None, :, :]) + assert ak.all(array[:, :, None, :] == named_array[:, :, None, :]) + assert ak.all(array[:, :, :, None] == named_array[:, :, :, None]) + + assert ak.all(array[0, :, :] == named_array[{"x": 0}]) + assert ak.all(array[:, 0, :] == named_array[{"y": 0}]) + assert ak.all(array[:, :, 0] == named_array[{"z": 0}]) + + assert ak.all(array[0, :, :] == named_array[{0: 0}]) + assert ak.all(array[:, 0, :] == named_array[{1: 0}]) + assert ak.all(array[:, :, 0] == named_array[{2: 0}]) + + assert ak.all(array[0, :, :] == named_array[{-3: 0}]) + assert ak.all(array[:, 0, :] == named_array[{-2: 0}]) + assert ak.all(array[:, :, 0] == named_array[{-1: 0}]) + + assert ak.all(array[0, 0, :] == named_array[{"x": 0, "y": 0}]) + assert ak.all(array[0, :, 0] == named_array[{"x": 0, "z": 0}]) + assert ak.all(array[:, 0, 0] == named_array[{"y": 0, "z": 0}]) + assert array[0, 0, 0] == named_array[{"x": 0, "y": 0, "z": 0}] + + assert ak.all(array[slice(0, 1), :, :] == named_array[{"x": slice(0, 1)}]) + assert ak.all(array[:, slice(0, 1), :] == named_array[{"y": slice(0, 1)}]) + assert ak.all(array[:, :, slice(0, 1)] == named_array[{"z": slice(0, 1)}]) + + assert ak.all(array[0, :, slice(0, 1)] == named_array[{"x": 0, "z": slice(0, 1)}]) + assert ak.all(array[:, 0, slice(0, 1)] == named_array[{"y": 0, "z": slice(0, 1)}]) + assert ak.all(array[slice(0, 1), 0, :] == named_array[{"x": slice(0, 1), "y": 0}]) + + assert ak.all(array[array > 3] == named_array[named_array > 3]) + + # test naming propagation + assert ( + named_array[...].named_axis + == named_array.named_axis + == {"x": -3, "y": -2, "z": -1} + ) + assert ( + named_array[()].named_axis + == named_array.named_axis + == {"x": -3, "y": -2, "z": -1} + ) + + assert named_array[None, :, :, :].named_axis == {"x": -3, "y": -2, "z": -1} + assert named_array[:, None, :, :].named_axis == {"x": -4, "y": -2, "z": -1} + assert named_array[:, :, None, :].named_axis == {"x": -4, "y": -3, "z": -1} + assert named_array[:, :, :, None].named_axis == {"x": -4, "y": -3, "z": -2} + + assert named_array[None, ...].named_axis == {"x": -3, "y": -2, "z": -1} + assert named_array[:, None, ...].named_axis == {"x": -4, "y": -2, "z": -1} + assert named_array[..., None, :].named_axis == {"x": -4, "y": -3, "z": -1} + assert named_array[..., None].named_axis == {"x": -4, "y": -3, "z": -2} + + assert ( + named_array[0, :, :].named_axis + == named_array[{"x": 0}].named_axis + == {"y": -2, "z": -1} + ) + assert ( + named_array[:, 0, :].named_axis + == named_array[{"y": 0}].named_axis + == {"x": -2, "z": -1} + ) + assert ( + named_array[:, :, 0].named_axis + == named_array[{"z": 0}].named_axis + == {"x": -2, "y": -1} + ) + + assert ( + named_array[0, ...].named_axis + == named_array[{"x": 0}].named_axis + == {"y": -2, "z": -1} + ) + assert ( + named_array[..., 0].named_axis + == named_array[{"z": 0}].named_axis + == {"x": -2, "y": -1} + ) + + assert named_array[{0: 0}].named_axis == {"y": -2, "z": -1} + assert named_array[{1: 0}].named_axis == {"x": -2, "z": -1} + assert named_array[{2: 0}].named_axis == {"x": -2, "y": -1} + + assert named_array[{-3: 0}].named_axis == {"y": -2, "z": -1} + assert named_array[{-2: 0}].named_axis == {"x": -2, "z": -1} + assert named_array[{-1: 0}].named_axis == {"x": -2, "y": -1} + + assert ( + named_array[0, 0, :].named_axis + == named_array[{"x": 0, "y": 0}].named_axis + == {"z": -1} + ) + assert ( + named_array[0, :, 0].named_axis + == named_array[{"x": 0, "z": 0}].named_axis + == {"y": -1} + ) + assert ( + named_array[:, 0, 0].named_axis + == named_array[{"y": 0, "z": 0}].named_axis + == {"x": -1} + ) + assert not _get_named_axis(named_array[0, 0, 0]) + assert not _get_named_axis(named_array[{"x": 0, "y": 0, "z": 0}]) + + assert ( + named_array[slice(0, 1), :, :].named_axis + == named_array[{"x": slice(0, 1)}].named_axis + == {"x": -3, "y": -2, "z": -1} + ) + assert ( + named_array[:, slice(0, 1), :].named_axis + == named_array[{"y": slice(0, 1)}].named_axis + == {"x": -3, "y": -2, "z": -1} + ) + assert ( + named_array[:, :, slice(0, 1)].named_axis + == named_array[{"z": slice(0, 1)}].named_axis + == {"x": -3, "y": -2, "z": -1} + ) + + assert ( + named_array[0, :, slice(0, 1)].named_axis + == named_array[{"x": 0, "z": slice(0, 1)}].named_axis + == {"y": -2, "z": -1} + ) + assert ( + named_array[:, 0, slice(0, 1)].named_axis + == named_array[{"y": 0, "z": slice(0, 1)}].named_axis + == {"x": -2, "z": -1} + ) + assert ( + named_array[slice(0, 1), 0, :].named_axis + == named_array[{"x": slice(0, 1), "y": 0}].named_axis + == {"x": -2, "z": -1} + ) + + +@pytest.mark.xfail( + sys.platform == "win32", + reason="right-broadcasting (NumPy-style) behaves differently for 32-bit windows", + strict=False, +) +def test_named_axis_right_broadcasting(): + # [NumPy-style] rightbroadcasting: (n, m) -> (1, n, m) + a = ak.Array([1]) # (1,) + b = ak.Array([[10, 20], [30, 40], [50, 60]]) # (3, 2) + + na = ak.with_named_axis(a, named_axis={"y": 0}) + nb = ak.with_named_axis(b, named_axis={"x": 0, "y": 1}) + + naa, nbb = ak.broadcast_arrays(na, nb) + + assert naa.named_axis == nbb.named_axis == {"x": 0, "y": 1} + + na = ak.with_named_axis(a, named_axis={"y": -1}) + nb = ak.with_named_axis(b, named_axis={"y": -2, "x": -1}) + + naa, nbb = ak.broadcast_arrays(na, nb) + + assert naa.named_axis == nbb.named_axis == {"y": -2, "x": -1} + + +def test_named_axis_left_broadcasting(): + # [Awkward-style] leftbroadcasting: (n, m) -> (n, m, 1) + a = ak.Array([[[0, 1, 2], [], [3, 4]], [], [[5], [6, 7, 8, 9]]]) # (3, var, var) + b = ak.Array([[10, 20, 30], [], [40, 50]]) # (3, var) + + na = ak.with_named_axis(a, named_axis=("x", "y", "z")) + nb = ak.with_named_axis(b, named_axis=("x", "y")) + + naa, nbb = ak.broadcast_arrays(na, nb) + + assert naa.named_axis == nbb.named_axis == {"x": 0, "y": 1, "z": 2} + + na = ak.with_named_axis(a, named_axis={"x": -3, "y": -2, "z": -1}) + nb = ak.with_named_axis(b, named_axis={"x": -2, "y": -1}) + + naa, nbb = ak.broadcast_arrays(na, nb) + + assert naa.named_axis == nbb.named_axis == {"x": -3, "y": -2, "z": -1} + + # this is not allowed! + a = ak.with_named_axis(ak.Array([[1, 2], [3, 4]]), ("x", "y")) # {"x": 0, "y": 1} + asum = ak.sum(a, axis="x") # {"y": 0} + + with pytest.raises(ValueError): + _ = a + asum + + # this is allowed! + a = ak.with_named_axis(ak.Array([[1, 2], [3, 4]]), ("x", "y")) # {"x": 0, "y": 1} + asum = ak.sum(a, axis="y") # {"x": 0} + + assert (a + asum).named_axis == {"x": 0, "y": 1} + + +def test_named_axis_unary_ufuncs(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + assert (-named_array).named_axis == named_array.named_axis + assert (+named_array).named_axis == named_array.named_axis + assert (~named_array).named_axis == named_array.named_axis + assert abs(named_array).named_axis == named_array.named_axis + + +def test_named_axis_binary_ufuncs(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array1 = ak.with_named_axis(array, named_axis=(None, "y")) + named_array2 = ak.with_named_axis(array, named_axis=("x", None)) + named_array3 = ak.with_named_axis(array, named_axis=("x", "y")) + + # just for addition, the rest is the same + # __add__ + assert (array + array).named_axis == {} + assert (named_array1 + array).named_axis == {"y": 1} + assert (named_array2 + array).named_axis == {"x": 0} + assert (named_array3 + array).named_axis == {"x": 0, "y": 1} + + assert (named_array1 + named_array2).named_axis == {"x": 0, "y": 1} + assert (named_array3 + named_array3).named_axis == {"x": 0, "y": 1} + + # __radd__ + assert (array + named_array1).named_axis == {"y": 1} + assert (array + named_array2).named_axis == {"x": 0} + assert (array + named_array3).named_axis == {"x": 0, "y": 1} + + a = ak.with_named_axis(array, named_axis=("x", None)) + b = ak.with_named_axis(array, named_axis=("y", None)) + with pytest.raises( + ValueError, + match="The named axes are incompatible. Got: x and y for positional axis 0", + ): + _ = a + b + + a = ak.with_named_axis(array, named_axis=(None, "x")) + b = ak.with_named_axis(array, named_axis=(None, "y")) + with pytest.raises( + ValueError, + match="The named axes are incompatible. Got: x and y for positional axis 1", + ): + _ = a + b + + +def test_named_axis_ak_all(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.all(array < 4, axis=0) == ak.all(named_array < 4, axis="x")) + assert ak.all(ak.all(array < 4, axis=1) == ak.all(named_array < 4, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.all(named_array < 4, axis=0).named_axis + == ak.all(named_array < 4, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.all(named_array < 4, axis=1).named_axis + == ak.all(named_array < 4, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.all(named_array < 4, axis=0, keepdims=True).named_axis + == ak.all(named_array < 4, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.all(named_array < 4, axis=1, keepdims=True).named_axis + == ak.all(named_array < 4, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.all(named_array < 4, axis=None)) + + +def test_negative_named_axis_ak_all(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.all(array < 4, axis=-2) == ak.all(named_array < 4, axis="x")) + assert ak.all(ak.all(array < 4, axis=-1) == ak.all(named_array < 4, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.all(named_array < 4, axis=-2).named_axis + == ak.all(named_array < 4, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.all(named_array < 4, axis=-1).named_axis + == ak.all(named_array < 4, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.all(named_array < 4, axis=-2, keepdims=True).named_axis + == ak.all(named_array < 4, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.all(named_array < 4, axis=-1, keepdims=True).named_axis + == ak.all(named_array < 4, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.all(named_array < 4, axis=None)) + + +def test_named_axis_ak_almost_equal(): + array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array1 = named_array2 = ak.with_named_axis(array1, named_axis=("x", "y")) + + assert ak.almost_equal(array1, array2, check_named_axis=False) == ak.almost_equal( + named_array1, named_array2, check_named_axis=False + ) + assert ak.almost_equal(array1, array2, check_named_axis=True) == ak.almost_equal( + named_array1, named_array2, check_named_axis=True + ) + + assert ak.almost_equal(named_array1, array1, check_named_axis=False) + assert ak.almost_equal(named_array1, array1, check_named_axis=True) + + named_array3 = ak.with_named_axis(array1, named_axis=("x", "muons")) + assert ak.almost_equal(named_array1, named_array3, check_named_axis=False) + assert not ak.almost_equal(named_array1, named_array3, check_named_axis=True) + + +def test_negative_named_axis_ak_almost_equal(): + array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array1 = named_array2 = ak.with_named_axis( + array1, named_axis={"x": -2, "y": -1} + ) + + assert ak.almost_equal(array1, array2, check_named_axis=False) == ak.almost_equal( + named_array1, named_array2, check_named_axis=False + ) + assert ak.almost_equal(array1, array2, check_named_axis=True) == ak.almost_equal( + named_array1, named_array2, check_named_axis=True + ) + + assert ak.almost_equal(named_array1, array1, check_named_axis=False) + assert ak.almost_equal(named_array1, array1, check_named_axis=True) + + named_array3 = ak.with_named_axis(array1, named_axis={"x": -2, "z": -1}) + assert ak.almost_equal(named_array1, named_array3, check_named_axis=False) + assert not ak.almost_equal(named_array1, named_array3, check_named_axis=True) + + +def test_named_axis_ak_angle(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.angle(array) == ak.angle(named_array)) + + # check that result axis names are correctly propagated + assert ak.angle(named_array).named_axis == {"x": 0, "y": 1} + + +def test_negative_named_axis_ak_angle(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.angle(array) == ak.angle(named_array)) + + # check that result axis names are correctly propagated + assert ak.angle(named_array).named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_any(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.any(array < 4, axis=0) == ak.any(named_array < 4, axis="x")) + assert ak.all(ak.any(array < 4, axis=1) == ak.any(named_array < 4, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.any(named_array < 4, axis=0).named_axis + == ak.any(named_array < 4, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.any(named_array < 4, axis=1).named_axis + == ak.any(named_array < 4, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.any(named_array < 4, axis=0, keepdims=True).named_axis + == ak.any(named_array < 4, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.any(named_array < 4, axis=1, keepdims=True).named_axis + == ak.any(named_array < 4, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.all(named_array < 4, axis=None)) + + +def test_negative_named_axis_ak_any(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.any(array < 4, axis=-2) == ak.any(named_array < 4, axis="x")) + assert ak.all(ak.any(array < 4, axis=-1) == ak.any(named_array < 4, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.any(named_array < 4, axis=-2).named_axis + == ak.any(named_array < 4, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.any(named_array < 4, axis=-1).named_axis + == ak.any(named_array < 4, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.any(named_array < 4, axis=-2, keepdims=True).named_axis + == ak.any(named_array < 4, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.any(named_array < 4, axis=-1, keepdims=True).named_axis + == ak.any(named_array < 4, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.all(named_array < 4, axis=None)) + + +def test_named_axis_ak_argcartesian(): + one = ak.Array([[1], [2], [3]]) + two = ak.Array([[4, 5]]) + three = ak.Array([[6, 7]]) + + named_one = ak.with_named_axis(one, named_axis=("x", "y")) + named_two = ak.with_named_axis(two, named_axis=("x", "y")) + named_three = ak.with_named_axis(three, named_axis=("x", "y")) + + assert ak.argcartesian( + [named_one, named_two, named_three], axis="x", nested=False + ).named_axis == {"x": 0, "y": 1} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="x", nested=True + ).named_axis == {"x": 1, "y": 2} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="x", nested=[0] + ).named_axis == {"x": 1, "y": 2} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="x", nested=[1] + ).named_axis == {"x": 0, "y": 2} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="x", nested=[0, 1] + ).named_axis == {"x": 2, "y": 3} + + +def test_negative_named_axis_ak_argcartesian(): + one = ak.Array([[1], [2], [3]]) + two = ak.Array([[4, 5]]) + three = ak.Array([[6, 7]]) + + named_one = ak.with_named_axis(one, named_axis={"x": -2, "y": -1}) + named_two = ak.with_named_axis(two, named_axis={"x": -2, "y": -1}) + named_three = ak.with_named_axis(three, named_axis={"x": -2, "y": -1}) + + assert ak.argcartesian( + [named_one, named_two, named_three], axis="y", nested=False + ).named_axis == {"x": -1} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="y", nested=True + ).named_axis == {"x": -2, "y": -1} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="y", nested=[0] + ).named_axis == {"x": -1} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="y", nested=[1] + ).named_axis == {"y": -1} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="y", nested=[0, 1] + ).named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_argcombinations(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + assert ( + ak.argcombinations(named_array, 2, axis=0).named_axis == named_array.named_axis + ) + assert ( + ak.argcombinations(named_array, 2, axis=1).named_axis == named_array.named_axis + ) + + +def test_negative_named_axis_ak_argcombinations(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + assert ( + ak.argcombinations(named_array, 2, axis=0).named_axis == named_array.named_axis + ) + assert ( + ak.argcombinations(named_array, 2, axis=1).named_axis == named_array.named_axis + ) + + +def test_named_axis_ak_argmax(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.argmax(array, axis=0) == ak.argmax(named_array, axis="x")) + assert ak.all(ak.argmax(array, axis=1) == ak.argmax(named_array, axis="y")) + assert ak.all( + ak.argmax(array, axis=0, keepdims=True) + == ak.argmax(named_array, axis="x", keepdims=True) + ) + assert ak.all( + ak.argmax(array, axis=1, keepdims=True) + == ak.argmax(named_array, axis="y", keepdims=True) + ) + assert ak.argmax(array, axis=None) == ak.argmax(named_array, axis=None) + + # check that result axis names are correctly propagated + assert ( + ak.argmax(named_array, axis=0).named_axis + == ak.argmax(named_array, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.argmax(named_array, axis=1).named_axis + == ak.argmax(named_array, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.argmax(named_array, axis=0, keepdims=True).named_axis + == ak.argmax(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.argmax(named_array, axis=1, keepdims=True).named_axis + == ak.argmax(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.argmax(named_array, axis=None)) + + +def test_negative_named_axis_ak_argmax(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.argmax(array, axis=-2) == ak.argmax(named_array, axis="x")) + assert ak.all(ak.argmax(array, axis=-1) == ak.argmax(named_array, axis="y")) + assert ak.all( + ak.argmax(array, axis=-2, keepdims=True) + == ak.argmax(named_array, axis="x", keepdims=True) + ) + assert ak.all( + ak.argmax(array, axis=-1, keepdims=True) + == ak.argmax(named_array, axis="y", keepdims=True) + ) + assert ak.argmax(array, axis=None) == ak.argmax(named_array, axis=None) + + # check that result axis names are correctly propagated + assert ( + ak.argmax(named_array, axis=-2).named_axis + == ak.argmax(named_array, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.argmax(named_array, axis=-1).named_axis + == ak.argmax(named_array, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.argmax(named_array, axis=-2, keepdims=True).named_axis + == ak.argmax(named_array, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.argmax(named_array, axis=-1, keepdims=True).named_axis + == ak.argmax(named_array, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.argmax(named_array, axis=None)) + + +def test_named_axis_ak_argmin(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.argmin(array, axis=0) == ak.argmin(named_array, axis="x")) + assert ak.all(ak.argmin(array, axis=1) == ak.argmin(named_array, axis="y")) + assert ak.all( + ak.argmin(array, axis=0, keepdims=True) + == ak.argmin(named_array, axis="x", keepdims=True) + ) + assert ak.all( + ak.argmin(array, axis=1, keepdims=True) + == ak.argmin(named_array, axis="y", keepdims=True) + ) + assert ak.argmin(array, axis=None) == ak.argmin(named_array, axis=None) + + # check that result axis names are correctly propagated + assert ( + ak.argmin(named_array, axis=0).named_axis + == ak.argmin(named_array, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.argmin(named_array, axis=1).named_axis + == ak.argmin(named_array, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.argmin(named_array, axis=0, keepdims=True).named_axis + == ak.argmin(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.argmin(named_array, axis=1, keepdims=True).named_axis + == ak.argmin(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.argmin(named_array, axis=None)) + + +def test_negative_named_axis_ak_argmin(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.argmin(array, axis=-2) == ak.argmin(named_array, axis="x")) + assert ak.all(ak.argmin(array, axis=-1) == ak.argmin(named_array, axis="y")) + assert ak.all( + ak.argmin(array, axis=-2, keepdims=True) + == ak.argmin(named_array, axis="x", keepdims=True) + ) + assert ak.all( + ak.argmin(array, axis=-1, keepdims=True) + == ak.argmin(named_array, axis="y", keepdims=True) + ) + assert ak.argmin(array, axis=None) == ak.argmin(named_array, axis=None) + + # check that result axis names are correctly propagated + assert ( + ak.argmin(named_array, axis=-2).named_axis + == ak.argmin(named_array, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.argmin(named_array, axis=-1).named_axis + == ak.argmin(named_array, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.argmin(named_array, axis=-2, keepdims=True).named_axis + == ak.argmin(named_array, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.argmin(named_array, axis=-1, keepdims=True).named_axis + == ak.argmin(named_array, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.argmin(named_array, axis=None)) + + +def test_named_axis_ak_argsort(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.argsort(array, axis=0) == ak.argsort(named_array, axis="x")) + assert ak.all(ak.argsort(array, axis=1) == ak.argsort(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.argsort(named_array, axis=0).named_axis + == ak.argsort(named_array, axis="x").named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.argsort(named_array, axis=1).named_axis + == ak.argsort(named_array, axis="y").named_axis + == {"x": 0, "y": 1} + ) + + +def test_negative_named_axis_ak_argsort(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.argsort(array, axis=-2) == ak.argsort(named_array, axis="x")) + assert ak.all(ak.argsort(array, axis=-1) == ak.argsort(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.argsort(named_array, axis=-2).named_axis + == ak.argsort(named_array, axis="x").named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.argsort(named_array, axis=-1).named_axis + == ak.argsort(named_array, axis="y").named_axis + == {"x": -2, "y": -1} + ) + + +def test_named_axis_ak_array_equal(): + array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array1 = named_array2 = ak.with_named_axis(array1, named_axis=("x", "y")) + + assert ak.array_equal(array1, array2, check_named_axis=False) == ak.array_equal( + named_array1, named_array2, check_named_axis=False + ) + assert ak.array_equal(array1, array2, check_named_axis=True) == ak.array_equal( + named_array1, named_array2, check_named_axis=True + ) + + assert ak.array_equal(named_array1, array1, check_named_axis=False) + assert ak.array_equal(named_array1, array1, check_named_axis=True) + + named_array3 = ak.with_named_axis(array1, named_axis=("x", "z")) + assert ak.array_equal(named_array1, named_array3, check_named_axis=False) + assert not ak.array_equal(named_array1, named_array3, check_named_axis=True) + + +def test_negative_named_axis_ak_array_equal(): + array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array1 = named_array2 = ak.with_named_axis( + array1, named_axis={"x": -2, "y": -1} + ) + + assert ak.array_equal(array1, array2, check_named_axis=False) == ak.array_equal( + named_array1, named_array2, check_named_axis=False + ) + assert ak.array_equal(array1, array2, check_named_axis=True) == ak.array_equal( + named_array1, named_array2, check_named_axis=True + ) + + assert ak.array_equal(named_array1, array1, check_named_axis=False) + assert ak.array_equal(named_array1, array1, check_named_axis=True) + + named_array3 = ak.with_named_axis(array1, named_axis={"x": -2, "z": -1}) + assert ak.array_equal(named_array1, named_array3, check_named_axis=False) + assert not ak.array_equal(named_array1, named_array3, check_named_axis=True) + + +def test_named_axis_ak_backend(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + assert ak.backend(array) == ak.backend(named_array) + + +def test_named_axis_ak_broadcast_fields(): + x = ak.Array([{"x": {"y": 1, "z": 2, "w": [1]}}]) + y = ak.Array([{"x": [{"y": 1}]}]) + + nx = ak.with_named_axis(x, named_axis=("x", "y")) + ny = ak.with_named_axis(y, named_axis=("a", "b")) + + na, nb = ak.broadcast_fields(nx, ny) + assert na.named_axis == {"x": 0, "y": 1} + assert nb.named_axis == {"a": 0, "b": 1} + + +def test_named_axis_ak_cartesian(): + one = ak.Array([[1], [2], [3]]) + two = ak.Array([[4, 5]]) + three = ak.Array([[6, 7]]) + + named_one = ak.with_named_axis(one, named_axis=("x", "y")) + named_two = ak.with_named_axis(two, named_axis=("x", "y")) + named_three = ak.with_named_axis(three, named_axis=("x", "y")) + + assert ak.cartesian( + [named_one, named_two, named_three], axis="x", nested=False + ).named_axis == {"x": 0, "y": 1} + assert ak.cartesian( + [named_one, named_two, named_three], axis="x", nested=True + ).named_axis == {"x": 1, "y": 2} + assert ak.cartesian( + [named_one, named_two, named_three], axis="x", nested=[0] + ).named_axis == {"x": 1, "y": 2} + assert ak.cartesian( + [named_one, named_two, named_three], axis="x", nested=[1] + ).named_axis == {"x": 0, "y": 2} + assert ak.cartesian( + [named_one, named_two, named_three], axis="x", nested=[0, 1] + ).named_axis == {"x": 2, "y": 3} + + +def test_negative_named_axis_ak_cartesian(): + one = ak.Array([[1], [2], [3]]) + two = ak.Array([[4, 5]]) + three = ak.Array([[6, 7]]) + + named_one = ak.with_named_axis(one, named_axis={"x": -2, "y": -1}) + named_two = ak.with_named_axis(two, named_axis={"x": -2, "y": -1}) + named_three = ak.with_named_axis(three, named_axis={"x": -2, "y": -1}) + + assert ak.cartesian( + [named_one, named_two, named_three], axis="y", nested=False + ).named_axis == {"x": -1} + assert ak.cartesian( + [named_one, named_two, named_three], axis="y", nested=True + ).named_axis == {"x": -2, "y": -1} + assert ak.cartesian( + [named_one, named_two, named_three], axis="y", nested=[0] + ).named_axis == {"x": -1} + assert ak.cartesian( + [named_one, named_two, named_three], axis="y", nested=[1] + ).named_axis == {"y": -1} + assert ak.cartesian( + [named_one, named_two, named_three], axis="y", nested=[0, 1] + ).named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_categories(): + pyarrow = pytest.importorskip("pyarrow") # noqa: F841 + + array = ak.str.to_categorical([["one", "two"], ["one", "three"], ["one", "four"]]) + + named_array = ak.with_named_axis(array, named_axis=("a", "b")) + + assert ak.all(ak.categories(array) == ak.categories(named_array)) # FIX: ufuncs + assert ( + ak.categories(array).named_axis == ak.categories(named_array).named_axis == {} + ) + + +def test_named_axis_ak_combinations(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + assert ak.combinations(named_array, 2, axis=0).named_axis == named_array.named_axis + assert ak.combinations(named_array, 2, axis=1).named_axis == named_array.named_axis + + +def test_negative_named_axis_ak_combinations(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + assert ak.combinations(named_array, 2, axis=-2).named_axis == named_array.named_axis + assert ak.combinations(named_array, 2, axis=-1).named_axis == named_array.named_axis + + +def test_named_axis_ak_concatenate(): + array1 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array3 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array4 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + all_arrays = [array1, array2, array3, array4] + + named_array1 = ak.with_named_axis(array1, named_axis=(None, None)) + named_array2 = ak.with_named_axis(array1, named_axis=(None, "y")) + named_array3 = ak.with_named_axis(array1, named_axis=("x", None)) + named_array4 = ak.with_named_axis(array1, named_axis=("x", "y")) + + all_named_arrays = [named_array1, named_array2, named_array3, named_array4] + + assert ak.all( + ak.concatenate(all_arrays, axis=0) == ak.concatenate(all_named_arrays, axis="x") + ) + assert ak.all( + ak.concatenate(all_arrays, axis=1) == ak.concatenate(all_named_arrays, axis="y") + ) + + assert ak.concatenate(all_named_arrays, axis="x").named_axis == {"x": 0, "y": 1} + assert ak.concatenate(all_named_arrays, axis="y").named_axis == {"x": 0, "y": 1} + + with pytest.raises( + ValueError, + match="The named axes are incompatible. Got: x and y for positional axis 0", + ): + ak.concatenate( + [ + ak.with_named_axis(array1, named_axis=("x", None)), + ak.with_named_axis(array2, named_axis=("y", None)), + ], + axis=0, + ) + + with pytest.raises( + ValueError, + match="The named axes are incompatible. Got: x and y for positional axis 1", + ): + ak.concatenate( + [ + ak.with_named_axis(array1, named_axis=(None, "x")), + ak.with_named_axis(array2, named_axis=(None, "y")), + ], + axis=1, + ) + + +def test_negative_named_axis_ak_concatenate(): + array1 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array3 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array4 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + all_arrays = [array1, array2, array3, array4] + + named_array1 = ak.with_named_axis(array1, named_axis={}) + named_array2 = ak.with_named_axis(array1, named_axis={"y": -1}) + named_array3 = ak.with_named_axis(array1, named_axis={"x": -2}) + named_array4 = ak.with_named_axis(array1, named_axis={"x": -2, "y": -1}) + + all_named_arrays = [named_array1, named_array2, named_array3, named_array4] + + assert ak.all( + ak.concatenate(all_arrays, axis=-2) + == ak.concatenate(all_named_arrays, axis="x") + ) + assert ak.all( + ak.concatenate(all_arrays, axis=-1) + == ak.concatenate(all_named_arrays, axis="y") + ) + + assert ak.concatenate(all_named_arrays, axis="x").named_axis == {"x": -2, "y": -1} + assert ak.concatenate(all_named_arrays, axis="y").named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_copy(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + assert ak.copy(named_array).named_axis == {"x": 0, "y": 1} + + +# def test_named_axis_ak_corr(): +# array_x = ak.Array([[0, 1.1], [3.3, 4.4]]) +# array_y = ak.Array([[0, 1], [3, 4]]) + +# named_array_x = ak.with_named_axis(array_x, ("x", "y")) +# named_array_y = ak.with_named_axis(array_y, ("x", "y")) + +# assert ak.all( +# ak.corr(array_x, array_y, axis=0) +# == ak.corr(named_array_x, named_array_y, axis="x") +# ) +# assert ak.all( +# ak.corr(array_x, array_y, axis=1) +# == ak.corr(named_array_x, named_array_y, axis="y") +# ) +# assert ak.corr(array_x, array_y, axis=None) == ak.corr( +# named_array_x, named_array_y, axis=None +# ) + +# assert ak.corr(named_array_x, named_array_y, axis="x").named_axis == {"y": 0} +# assert ak.corr(named_array_x, named_array_y, axis="y").named_axis == {"x": 0} +# assert not _get_named_axis(ak.corr(named_array_x, named_array_y, axis=None)) + + +def test_named_axis_ak_count(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.count(array, axis=0) == ak.count(named_array, axis="x")) + assert ak.all(ak.count(array, axis=1) == ak.count(named_array, axis="y")) + assert ak.count(array, axis=None) == ak.count(named_array, axis=None) + + assert ak.count(named_array, axis="x").named_axis == {"y": 0} + assert ak.count(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.count(named_array, axis=None)) + + +def test_negative_named_axis_ak_count(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.count(array, axis=-2) == ak.count(named_array, axis="x")) + assert ak.all(ak.count(array, axis=-1) == ak.count(named_array, axis="y")) + assert ak.count(array, axis=None) == ak.count(named_array, axis=None) + + assert ak.count(named_array, axis="x").named_axis == {"y": -1} + assert ak.count(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.count(named_array, axis=None)) + + +def test_named_axis_ak_count_nonzero(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all( + ak.count_nonzero(array, axis=0) == ak.count_nonzero(named_array, axis="x") + ) + assert ak.all( + ak.count_nonzero(array, axis=1) == ak.count_nonzero(named_array, axis="y") + ) + assert ak.count_nonzero(array, axis=None) == ak.count_nonzero( + named_array, axis=None + ) + + assert ak.count_nonzero(named_array, axis="x").named_axis == {"y": 0} + assert ak.count_nonzero(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.count_nonzero(named_array, axis=None)) + + +def test_negative_named_axis_ak_count_nonzero(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all( + ak.count_nonzero(array, axis=-2) == ak.count_nonzero(named_array, axis="x") + ) + assert ak.all( + ak.count_nonzero(array, axis=-1) == ak.count_nonzero(named_array, axis="y") + ) + assert ak.count_nonzero(array, axis=None) == ak.count_nonzero( + named_array, axis=None + ) + + assert ak.count_nonzero(named_array, axis="x").named_axis == {"y": -1} + assert ak.count_nonzero(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.count_nonzero(named_array, axis=None)) + + +# def test_named_axis_ak_covar(): +# array_x = ak.Array([[0, 1.1], [3.3, 4.4]]) +# array_y = ak.Array([[0, 1], [3, 4]]) + +# named_array_x = ak.with_named_axis(array_x, ("x", "y")) +# named_array_y = ak.with_named_axis(array_y, ("x", "y")) + +# assert ak.all( +# ak.covar(array_x, array_y, axis=0) +# == ak.covar(named_array_x, named_array_y, axis="x") +# ) +# assert ak.all( +# ak.covar(array_x, array_y, axis=1) +# == ak.covar(named_array_x, named_array_y, axis="y") +# ) +# assert ak.covar(array_x, array_y, axis=None) == ak.covar( +# named_array_x, named_array_y, axis=None +# ) + +# assert ak.covar(named_array_x, named_array_y, axis="x").named_axis == {"y": 0} +# assert ak.covar(named_array_x, named_array_y, axis="y").named_axis == {"x": 0} +# assert not _get_named_axis(ak.covar(named_array_x, named_array_y, axis=None)) + + +def test_named_axis_ak_drop_none(): + array = ak.Array([[1, None], [3], [None], [4, None, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.drop_none(array, axis=0) == ak.drop_none(named_array, axis="x")) + assert ak.all(ak.drop_none(array, axis=1) == ak.drop_none(named_array, axis="y")) + assert ak.all( + ak.drop_none(array, axis=None) == ak.drop_none(named_array, axis=None) + ) + + assert ak.drop_none(named_array, axis="x").named_axis == {"x": 0, "y": 1} + assert ak.drop_none(named_array, axis="y").named_axis == {"x": 0, "y": 1} + assert ak.drop_none(named_array, axis=None).named_axis == {"x": 0, "y": 1} + + +def test_negative_named_axis_ak_drop_none(): + array = ak.Array([[1, None], [3], [None], [4, None, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.drop_none(array, axis=-2) == ak.drop_none(named_array, axis="x")) + assert ak.all(ak.drop_none(array, axis=-1) == ak.drop_none(named_array, axis="y")) + assert ak.all( + ak.drop_none(array, axis=None) == ak.drop_none(named_array, axis=None) + ) + + assert ak.drop_none(named_array, axis="x").named_axis == {"x": -2, "y": -1} + assert ak.drop_none(named_array, axis="y").named_axis == {"x": -2, "y": -1} + assert ak.drop_none(named_array, axis=None).named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_enforce_type(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.enforce_type(named_array, "var * ?int64").named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_fill_none(): + array = ak.Array([[1.1, None, 2.2], [], [None, 3.3, 4.4]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all( + ak.fill_none(array, 0, axis=0) == ak.fill_none(named_array, 0, axis="x") + ) + assert ak.all( + ak.fill_none(array, 0, axis=1) == ak.fill_none(named_array, 0, axis="y") + ) + assert ak.all( + ak.fill_none(array, 0, axis=None) == ak.fill_none(named_array, 0, axis=None) + ) + + assert ak.fill_none(named_array, 0, axis="x").named_axis == {"x": 0, "y": 1} + assert ak.fill_none(named_array, 0, axis="y").named_axis == {"x": 0, "y": 1} + assert ak.fill_none(named_array, 0, axis=None).named_axis == {"x": 0, "y": 1} + + +def test_negative_named_axis_ak_fill_none(): + array = ak.Array([[1.1, None, 2.2], [], [None, 3.3, 4.4]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all( + ak.fill_none(array, 0, axis=-2) == ak.fill_none(named_array, 0, axis="x") + ) + assert ak.all( + ak.fill_none(array, 0, axis=-1) == ak.fill_none(named_array, 0, axis="y") + ) + assert ak.all( + ak.fill_none(array, 0, axis=None) == ak.fill_none(named_array, 0, axis=None) + ) + + assert ak.fill_none(named_array, 0, axis="x").named_axis == {"x": -2, "y": -1} + assert ak.fill_none(named_array, 0, axis="y").named_axis == {"x": -2, "y": -1} + assert ak.fill_none(named_array, 0, axis=None).named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_firsts(): + array = ak.Array([[1.1], [2.2], [], [3.3], [], [], [4.4], [5.5]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.firsts(array, axis=0) == ak.firsts(named_array, axis="x")) + assert ak.all(ak.firsts(array, axis=1) == ak.firsts(named_array, axis="y")) + + assert ak.firsts(named_array, axis="x").named_axis == {"y": 0} + assert ak.firsts(named_array, axis="y").named_axis == {"x": 0} + + +def test_negative_named_axis_ak_firsts(): + array = ak.Array([[1.1], [2.2], [], [3.3], [], [], [4.4], [5.5]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.firsts(array, axis=-2) == ak.firsts(named_array, axis="x")) + assert ak.all(ak.firsts(array, axis=-1) == ak.firsts(named_array, axis="y")) + + assert ak.firsts(named_array, axis="x").named_axis == {"y": -1} + assert ak.firsts(named_array, axis="y").named_axis == {"x": -1} + + +def test_named_axis_ak_flatten(): + array = ak.Array([[[1.1, 2.2]], [[]], [[3.3]], [[]], [[]], [[4.4, 5.5]]]) + + named_array = ak.with_named_axis(array, ("x", "y", "z")) + + assert ak.all(ak.flatten(array, axis=0) == ak.flatten(named_array, axis="x")) + assert ak.all(ak.flatten(array, axis=1) == ak.flatten(named_array, axis="y")) + assert ak.all(ak.flatten(array, axis=2) == ak.flatten(named_array, axis="z")) + assert ak.all(ak.flatten(array, axis=None) == ak.flatten(named_array, axis=None)) + + assert ak.flatten(named_array, axis="x").named_axis == {"x": 0, "y": 1, "z": 2} + assert ak.flatten(named_array, axis="y").named_axis == {"x": 0, "z": 1} + assert ak.flatten(named_array, axis="z").named_axis == {"x": 0, "y": 1} + assert not _get_named_axis(ak.flatten(named_array, axis=None)) + + +def test_negative_named_axis_ak_flatten(): + array = ak.Array([[[1.1, 2.2]], [[]], [[3.3]], [[]], [[]], [[4.4, 5.5]]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -3, "y": -2, "z": -1}) + + assert ak.all(ak.flatten(array, axis=-3) == ak.flatten(named_array, axis="x")) + assert ak.all(ak.flatten(array, axis=-2) == ak.flatten(named_array, axis="y")) + assert ak.all(ak.flatten(array, axis=-1) == ak.flatten(named_array, axis="z")) + assert ak.all(ak.flatten(array, axis=None) == ak.flatten(named_array, axis=None)) + + assert ak.flatten(named_array, axis="x").named_axis == {"x": -3, "y": -2, "z": -1} + assert ak.flatten(named_array, axis="y").named_axis == {"x": -2, "z": -1} + assert ak.flatten(named_array, axis="z").named_axis == {"x": -2, "y": -1} + assert not _get_named_axis(ak.flatten(named_array, axis=None)) + + +def test_named_axis_ak_imag(): + array = ak.Array([[1 + 2j], [2 + 1j], []]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.imag(array) == ak.imag(named_array)) + assert ak.imag(named_array).named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_is_none(): + array = ak.Array([[[1, None]], [[3]], [[None]], [[4, None, 6]]]) + + named_array = ak.with_named_axis(array, ("x", "y", "z")) + + assert ak.all(ak.is_none(array, axis=0) == ak.is_none(named_array, axis="x")) + assert ak.all(ak.is_none(array, axis=1) == ak.is_none(named_array, axis="y")) + assert ak.all(ak.is_none(array, axis=2) == ak.is_none(named_array, axis="z")) + + assert ak.is_none(named_array, axis="x").named_axis == {"x": 0} + assert ak.is_none(named_array, axis="y").named_axis == {"x": 0, "y": 1} + assert ak.is_none(named_array, axis="z").named_axis == {"x": 0, "y": 1, "z": 2} + + +def test_negative_named_axis_ak_is_none(): + array = ak.Array([[[1, None]], [[3]], [[None]], [[4, None, 6]]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -3, "y": -2, "z": -1}) + + assert ak.all(ak.is_none(array, axis=-3) == ak.is_none(named_array, axis="x")) + assert ak.all(ak.is_none(array, axis=-2) == ak.is_none(named_array, axis="y")) + assert ak.all(ak.is_none(array, axis=-1) == ak.is_none(named_array, axis="z")) + + assert ak.is_none(named_array, axis="x").named_axis == {"z": -1} + assert ak.is_none(named_array, axis="y").named_axis == {"y": -2, "z": -1} + assert ak.is_none(named_array, axis="z").named_axis == {"x": -3, "y": -2, "z": -1} + + +def test_named_axis_ak_isclose(): + a = b = ak.Array( + [[[0.0, 1.1, 2.2], []], [[3.3, 4.4]], [], [[5.5], [], [6.6, 7.7, 8.8, 9.9]]] + ) + + na = ak.with_named_axis(a, ("x", "y", "z")) + nb = ak.with_named_axis(b, ("x", "y", "z")) + assert ak.all(ak.isclose(a, b) == ak.isclose(na, nb)) + + na = ak.with_named_axis(a, (None, "y", "z")) + nb = ak.with_named_axis(b, ("x", "y", None)) + assert ak.isclose(na, nb).named_axis == {"x": 0, "y": 1, "z": 2} + + +def test_named_axis_ak_local_index(): + array = ak.Array( + [[[0.0, 1.1, 2.2], []], [[3.3, 4.4]], [], [[5.5], [], [6.6, 7.7, 8.8, 9.9]]] + ) + + named_array = ak.with_named_axis(array, ("x", "y", "z")) + + assert ak.all( + ak.local_index(array, axis=0) == ak.local_index(named_array, axis="x") + ) + assert ak.all( + ak.local_index(array, axis=1) == ak.local_index(named_array, axis="y") + ) + assert ak.all( + ak.local_index(array, axis=2) == ak.local_index(named_array, axis="z") + ) + + assert ak.local_index(named_array, axis="x").named_axis == {"x": 0} + assert ak.local_index(named_array, axis="y").named_axis == {"x": 0, "y": 1} + assert ak.local_index(named_array, axis="z").named_axis == {"x": 0, "y": 1, "z": 2} + + +def test_negative_named_axis_ak_local_index(): + array = ak.Array( + [[[0.0, 1.1, 2.2], []], [[3.3, 4.4]], [], [[5.5], [], [6.6, 7.7, 8.8, 9.9]]] + ) + named_array = ak.with_named_axis(array, {"x": -3, "y": -2, "z": -1}) + + assert ak.local_index(named_array, axis="x").named_axis == {"z": -1} + assert ak.local_index(named_array, axis="y").named_axis == {"y": -2, "z": -1} + assert ak.local_index(named_array, axis="z").named_axis == { + "x": -3, + "y": -2, + "z": -1, + } + + +def test_named_axis_ak_mask(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + mask = array > 3 + + named_array = ak.with_named_axis(array, ("x", "y")) + named_mask = named_array > 3 + + assert ak.all(ak.mask(array, mask) == ak.mask(named_array, mask)) + assert ak.all(ak.mask(array, mask) == ak.mask(named_array, named_mask)) + + assert ak.mask(named_array, mask).named_axis == named_array.named_axis + assert ak.mask(named_array, named_mask).named_axis == named_array.named_axis + + +def test_named_axis_ak_max(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.max(array, axis=0) == ak.max(named_array, axis="x")) + assert ak.all(ak.max(array, axis=1) == ak.max(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.max(named_array, axis=0).named_axis + == ak.max(named_array, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.max(named_array, axis=1).named_axis + == ak.max(named_array, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.max(named_array, axis=0, keepdims=True).named_axis + == ak.max(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.max(named_array, axis=1, keepdims=True).named_axis + == ak.max(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.max(named_array, axis=None)) + + +def test_negative_named_axis_ak_max(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.max(array, axis=-2) == ak.max(named_array, axis="x")) + assert ak.all(ak.max(array, axis=-1) == ak.max(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.max(named_array, axis=-2).named_axis + == ak.max(named_array, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.max(named_array, axis=-1).named_axis + == ak.max(named_array, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.max(named_array, axis=-2, keepdims=True).named_axis + == ak.max(named_array, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.max(named_array, axis=-1, keepdims=True).named_axis + == ak.max(named_array, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.max(named_array, axis=None)) + + +def test_named_axis_ak_mean(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.mean(array, axis=0) == ak.mean(named_array, axis="x")) + assert ak.all(ak.mean(array, axis=1) == ak.mean(named_array, axis="y")) + assert ak.mean(array, axis=None) == ak.mean(named_array, axis=None) + + assert ak.mean(named_array, axis="x").named_axis == {"y": 0} + assert ak.mean(named_array, axis="y").named_axis == {"x": 0} + assert ak.mean(named_array, axis="x", keepdims=True).named_axis == {"x": 0, "y": 1} + assert ak.mean(named_array, axis="y", keepdims=True).named_axis == {"x": 0, "y": 1} + assert not _get_named_axis(ak.mean(named_array, axis=None)) + + +def test_negative_named_axis_ak_mean(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + assert ak.all(ak.mean(array, axis=-2) == ak.mean(named_array, axis="x")) + assert ak.all(ak.mean(array, axis=-1) == ak.mean(named_array, axis="y")) + assert ak.mean(array, axis=None) == ak.mean(named_array, axis=None) + + assert ak.mean(named_array, axis="x").named_axis == {"y": -1} + assert ak.mean(named_array, axis="y").named_axis == {"x": -1} + assert ak.mean(named_array, axis="x", keepdims=True).named_axis == { + "x": -2, + "y": -1, + } + assert ak.mean(named_array, axis="y", keepdims=True).named_axis == { + "x": -2, + "y": -1, + } + assert not _get_named_axis(ak.mean(named_array, axis=None)) + + +def test_named_axis_ak_merge_option_of_records(): + array = ak.Array([None, {"a": 1}, {"a": 2}]) + + named_array = ak.with_named_axis(array, named_axis=("x",)) + + assert ( + ak.merge_option_of_records(named_array, axis="x").named_axis + == named_array.named_axis + ) + + +def test_named_axis_ak_merge_union_of_records(): + array = ak.concatenate(([{"a": 1}], [{"b": 2}])) + + named_array = ak.with_named_axis(array, named_axis=("x",)) + + assert ( + ak.merge_union_of_records(named_array, axis="x").named_axis + == named_array.named_axis + ) + + +def test_named_axis_ak_min(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.min(array, axis=0) == ak.min(named_array, axis="x")) + assert ak.all(ak.min(array, axis=1) == ak.min(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.min(named_array, axis=0).named_axis + == ak.min(named_array, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.min(named_array, axis=1).named_axis + == ak.min(named_array, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.min(named_array, axis=0, keepdims=True).named_axis + == ak.min(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.min(named_array, axis=1, keepdims=True).named_axis + == ak.min(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.min(named_array, axis=None)) + + +def test_negative_named_axis_ak_min(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.min(array, axis=-2) == ak.min(named_array, axis="x")) + assert ak.all(ak.min(array, axis=-1) == ak.min(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.min(named_array, axis=-2).named_axis + == ak.min(named_array, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.min(named_array, axis=-1).named_axis + == ak.min(named_array, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.min(named_array, axis=-2, keepdims=True).named_axis + == ak.min(named_array, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.min(named_array, axis=-1, keepdims=True).named_axis + == ak.min(named_array, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.min(named_array, axis=None)) + + +def test_named_axis_ak_moment(): + array = ak.Array([[0, 1.1], [3.3, 4.4]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.moment(array, 0, axis=0) == ak.moment(named_array, 0, axis="x")) + assert ak.all(ak.moment(array, 0, axis=1) == ak.moment(named_array, 0, axis="y")) + assert ak.moment(array, 0, axis=None) == ak.moment(named_array, 0, axis=None) + + assert ak.moment(named_array, 0, axis="x").named_axis == {"y": 0} + assert ak.moment(named_array, 0, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.moment(named_array, 0, axis=None)) + + +def test_negative_named_axis_ak_moment(): + array = ak.Array([[0, 1.1], [3.3, 4.4]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.moment(array, 0, axis=-2) == ak.moment(named_array, 0, axis="x")) + assert ak.all(ak.moment(array, 0, axis=-1) == ak.moment(named_array, 0, axis="y")) + assert ak.moment(array, 0, axis=None) == ak.moment(named_array, 0, axis=None) + + assert ak.moment(named_array, 0, axis="x").named_axis == {"y": -1} + assert ak.moment(named_array, 0, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.moment(named_array, 0, axis=None)) + + +def test_named_axis_ak_nan_to_none(): + array = ak.Array([[0, np.nan], [np.nan], [3.3, 4.4]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.nan_to_none(array) == ak.nan_to_none(named_array)) + assert ak.nan_to_none(named_array).named_axis == named_array.named_axis + + +def test_named_axis_ak_nan_to_num(): + array = ak.Array([[0, np.nan], [np.nan], [3.3, 4.4]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.nan_to_num(array, nan=0.0) == ak.nan_to_num(named_array, nan=0.0)) + assert ak.nan_to_num(named_array, nan=0.0).named_axis == named_array.named_axis + + +def test_named_axis_ak_num(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.num(array, axis=0) == ak.num(named_array, axis="x") + assert ak.all(ak.num(array, axis=1) == ak.num(named_array, axis="y")) + + assert ak.num(named_array, axis="y").named_axis == {"y": 0} + assert not _get_named_axis(ak.num(named_array, axis="x")) + + +def test_negative_named_axis_ak_num(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.num(array, axis=-2) == ak.num(named_array, axis="x") + assert ak.all(ak.num(array, axis=-1) == ak.num(named_array, axis="y")) + + assert ak.num(named_array, axis="y").named_axis == {"y": 0} + assert not _get_named_axis(ak.num(named_array, axis="x")) + + +def test_named_axis_ak_ones_like(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.ones_like(array) == ak.ones_like(named_array)) + + assert ak.ones_like(named_array).named_axis == named_array.named_axis + + +def test_named_axis_ak_pad_none(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.pad_none(array, 3, axis=0) == ak.pad_none(named_array, 3, axis=0)) + assert ak.all(ak.pad_none(array, 3, axis=1) == ak.pad_none(named_array, 3, axis=1)) + + assert ak.pad_none(named_array, 3, axis=0).named_axis == named_array.named_axis + assert ak.pad_none(named_array, 3, axis=1).named_axis == named_array.named_axis + + +def test_named_axis_ak_prod(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.prod(array, axis=0) == ak.prod(named_array, axis="x")) + assert ak.all(ak.prod(array, axis=1) == ak.prod(named_array, axis="y")) + assert ak.prod(array, axis=None) == ak.prod(named_array, axis=None) + + assert ak.prod(named_array, axis="x").named_axis == {"y": 0} + assert ak.prod(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.prod(named_array, axis=None)) + + +def test_negative_named_axis_ak_prod(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.prod(array, axis=-2) == ak.prod(named_array, axis="x")) + assert ak.all(ak.prod(array, axis=-1) == ak.prod(named_array, axis="y")) + assert ak.prod(array, axis=None) == ak.prod(named_array, axis=None) + + assert ak.prod(named_array, axis="x").named_axis == {"y": -1} + assert ak.prod(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.prod(named_array, axis=None)) + + +def test_named_axis_ak_ptp(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.ptp(array, axis=0) == ak.ptp(named_array, axis="x")) + assert ak.all(ak.ptp(array, axis=1) == ak.ptp(named_array, axis="y")) + assert ak.ptp(array, axis=None) == ak.ptp(named_array, axis=None) + + assert ak.ptp(named_array, axis="x").named_axis == {"y": 0} + assert ak.ptp(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.ptp(named_array, axis=None)) + + +def test_negative_named_axis_ak_ptp(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.ptp(array, axis=-2) == ak.ptp(named_array, axis="x")) + assert ak.all(ak.ptp(array, axis=-1) == ak.ptp(named_array, axis="y")) + assert ak.ptp(array, axis=None) == ak.ptp(named_array, axis=None) + + assert ak.ptp(named_array, axis="x").named_axis == {"y": -1} + assert ak.ptp(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.ptp(named_array, axis=None)) + + +def test_named_axis_ak_ravel(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.ravel(array) == ak.ravel(named_array)) + + assert not _get_named_axis(ak.ravel(named_array)) + + +def test_named_axis_ak_real(): + array = ak.Array([[1 + 2j], [2 + 1j], []]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.real(array) == ak.real(named_array)) + assert ak.real(named_array).named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_round(): + array = ak.Array([[1.234], [2.345, 3.456], []]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.round(array) == ak.round(named_array)) + assert ak.round(named_array).named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_run_lengths(): + array = ak.Array([[1.1, 1.1, 1.1, 2.2, 3.3], [3.3, 4.4], [4.4, 5.5]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.run_lengths(array) == ak.run_lengths(named_array)) + + assert ak.run_lengths(named_array).named_axis == named_array.named_axis + + +def test_named_axis_ak_singletons(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.singletons(array, axis=0) == ak.singletons(named_array, axis="x")) + assert ak.all(ak.singletons(array, axis=1) == ak.singletons(named_array, axis="y")) + + assert ak.singletons(named_array, axis=0).named_axis == {"x": 0, "y": 2} + assert ak.singletons(named_array, axis=1).named_axis == {"x": 0, "y": 1} + + +def test_negative_named_axis_ak_singletons(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.singletons(array, axis=-2) == ak.singletons(named_array, axis="x")) + assert ak.all(ak.singletons(array, axis=-1) == ak.singletons(named_array, axis="y")) + + assert ak.singletons(named_array, axis=-2).named_axis == {"x": -3, "y": -1} + assert ak.singletons(named_array, axis=-1).named_axis == {"x": -3, "y": -2} + + +def test_named_axis_ak_softmax(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.softmax(array, axis=-1) == ak.softmax(named_array, axis="y")) + + assert ak.softmax(named_array, axis="y").named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_sort(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.sort(array, axis=0) == ak.sort(named_array, axis="x")) + assert ak.all(ak.sort(array, axis=1) == ak.sort(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.sort(named_array, axis=0).named_axis + == ak.sort(named_array, axis="x").named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.sort(named_array, axis=1).named_axis + == ak.sort(named_array, axis="y").named_axis + == {"x": 0, "y": 1} + ) + + +def test_named_axis_ak_std(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.std(array, axis=0) == ak.std(named_array, axis="x")) + assert ak.all(ak.std(array, axis=1) == ak.std(named_array, axis="y")) + assert ak.std(array, axis=None) == ak.std(named_array, axis=None) + + assert ak.std(named_array, axis="x").named_axis == {"y": 0} + assert ak.std(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.std(named_array, axis=None)) + + +def test_negative_named_axis_ak_std(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.std(array, axis=-2) == ak.std(named_array, axis="x")) + assert ak.all(ak.std(array, axis=-1) == ak.std(named_array, axis="y")) + assert ak.std(array, axis=None) == ak.std(named_array, axis=None) + + assert ak.std(named_array, axis="x").named_axis == {"y": -1} + assert ak.std(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.std(named_array, axis=None)) + + +def test_named_axis_ak_strings_astype(): + array = ak.Array([["1", "2"], ["3"], ["4", "5", "6"]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all( + ak.strings_astype(array, np.int32) == ak.strings_astype(named_array, np.int32) + ) + + assert ak.strings_astype(named_array, np.int32).named_axis == named_array.named_axis + + +def test_named_axis_ak_sum(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.sum(array, axis=0) == ak.sum(named_array, axis="x")) + assert ak.all(ak.sum(array, axis=1) == ak.sum(named_array, axis="y")) + assert ak.sum(array, axis=None) == ak.sum(named_array, axis=None) + + assert ak.sum(named_array, axis="x").named_axis == {"y": 0} + assert ak.sum(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.sum(named_array, axis=None)) + + +def test_negative_named_axis_ak_sum(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.sum(array, axis=-2) == ak.sum(named_array, axis="x")) + assert ak.all(ak.sum(array, axis=-1) == ak.sum(named_array, axis="y")) + assert ak.sum(array, axis=None) == ak.sum(named_array, axis=None) + + assert ak.sum(named_array, axis="x").named_axis == {"y": -1} + assert ak.sum(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.sum(named_array, axis=None)) + + +def test_named_axis_ak_to_backend(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.to_backend(named_array, "typetracer").named_axis == named_array.named_axis + + +def test_named_axis_ak_to_packed(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.to_packed(array) == ak.to_packed(named_array)) + + assert ak.to_packed(named_array).named_axis == named_array.named_axis + + +def test_named_axis_ak_unflatten(): + array = ak.Array([[1, 2, 3, 4], [], [5, 6, 7], [8, 9]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + counts = ak.Array([2, 2, 1, 2, 1, 1]) + + assert ak.all( + ak.unflatten(array, counts, axis=1) + == ak.unflatten(named_array, counts, axis="y") + ) + assert not _get_named_axis(ak.unflatten(named_array, counts, axis="y")) + + +def test_named_axis_ak_unzip(): + array = ak.Array( + [ + {"x": 1.1, "y": [1]}, + {"x": 2.2, "y": [2, 2]}, + {"x": 3.3, "y": [3, 3, 3]}, + ] + ) + named_array = ak.with_named_axis(array, ("x", "y")) + x, y = ak.unzip(named_array) + assert x.named_axis == y.named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_values_astype(): + array = ak.Array([[1, 2, 3, 4], [], [5, 6, 7], [8, 9]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all( + ak.values_astype(array, np.float32) == ak.values_astype(named_array, np.float32) + ) + + assert ( + ak.values_astype(named_array, np.float32).named_axis == named_array.named_axis + ) + + +def test_named_axis_ak_var(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.var(array, axis=0) == ak.var(named_array, axis="x")) + assert ak.all(ak.var(array, axis=1) == ak.var(named_array, axis="y")) + assert ak.var(array, axis=None) == ak.var(named_array, axis=None) + + assert ak.var(named_array, axis="x").named_axis == {"y": 0} + assert ak.var(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.var(named_array, axis=None)) + + +def test_negative_named_axis_ak_var(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.var(array, axis=-2) == ak.var(named_array, axis="x")) + assert ak.all(ak.var(array, axis=-1) == ak.var(named_array, axis="y")) + assert ak.var(array, axis=None) == ak.var(named_array, axis=None) + + assert ak.var(named_array, axis="x").named_axis == {"y": -1} + assert ak.var(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.var(named_array, axis=None)) + + +def test_named_axis_ak_where(): + a = ak.Array([[1, 2], [3, 4]]) + na = ak.with_named_axis(a, ("x", "y")) + + assert ak.all(ak.where(a > 2, 0, 1) == ak.where(na > 2, 0, 1)) + assert ak.where(na > 2, 0, 1).named_axis == {"x": 0, "y": 1} + assert ak.where(na > 2, na, 1).named_axis == {"x": 0, "y": 1} + + nb = ak.with_named_axis(a, ("a", "b")) + with pytest.raises(ValueError): + _ = ak.where(na > 2, nb, 1) + + +def test_named_axis_ak_with_field(): + array = ak.Array( + [ + {"x": 1.1, "y": [1]}, + {"x": 2.2, "y": [2, 2]}, + {"x": 3.3, "y": [3, 3, 3]}, + ] + ) + named_array = ak.with_named_axis(array, ("x", "y")) + xyz = ak.with_field(named_array, ak.Array([[1], [2], [3]]), "z") + x, y, z = ak.unzip(xyz) + assert x.named_axis == y.named_axis == z.named_axis == {"x": 0, "y": 1} + + named_z = ak.with_named_axis(ak.Array([[1], [2], [3]]), ("a", "b")) + with pytest.raises(ValueError): + ak.with_field(named_array, named_z, "z") + + +def test_named_axis_ak_with_name(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.with_name(named_array, "new_name").named_axis == named_array.named_axis + + +def test_named_axis_ak_with_named_axis(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + # tuple + named_array = ak.with_named_axis(array, ("x", "y")) + assert named_array.named_axis == {"x": 0, "y": 1} + + # dict + named_array = ak.with_named_axis(array, {"x": 0, "y": -1}) + assert named_array.named_axis == {"x": 0, "y": -1} + + +def test_named_axis_ak_with_parameter(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ( + ak.with_parameter(named_array, "param", 1.0).named_axis + == named_array.named_axis + ) + + +def test_named_axis_ak_without_parameters(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + named_array_with_parameteter = ak.with_parameter(named_array, "param", 1.0) + + assert ( + ak.without_parameters(named_array_with_parameteter).named_axis + == named_array.named_axis + ) + + +def test_named_axis_ak_zeros_like(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.zeros_like(array) == ak.zeros_like(named_array)) + + assert ak.zeros_like(named_array).named_axis == named_array.named_axis + + +def test_named_axis_ak_zip(): + named_array1 = ak.with_named_axis(ak.Array([1, 2, 3]), ("x",)) + named_array2 = ak.with_named_axis(ak.Array([[4, 5, 6], [], [7]]), ("x", "y")) + + assert ak.zip({"x": named_array1, "y": named_array2}).named_axis == {"x": 0, "y": 1} + + named_array1 = ak.with_named_axis(ak.Array([1, 2, 3]), ("a",)) + named_array2 = ak.with_named_axis(ak.Array([[4, 5, 6], [], [7]]), ("x", "y")) + + with pytest.raises(ValueError): + _ = ak.zip({"x": named_array1, "y": named_array2}) From 46509eff9dffb4d940fa2c81f1dbd6ba6cc707de Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Fri, 25 Oct 2024 02:09:32 -0500 Subject: [PATCH 14/21] docs: add basnijholt as a contributor for maintenance (#3287) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] --------- Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> --- .all-contributorsrc | 9 +++++++++ README.md | 1 + 2 files changed, 10 insertions(+) diff --git a/.all-contributorsrc b/.all-contributorsrc index 9d3fd353bc..4f6f90c52b 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -492,6 +492,15 @@ "contributions": [ "maintenance" ] + }, + { + "login": "basnijholt", + "name": "Bas Nijholt", + "avatar_url": "https://avatars.githubusercontent.com/u/6897215?v=4", + "profile": "http://www.nijho.lt", + "contributions": [ + "maintenance" + ] } ], "contributorsPerLine": 7, diff --git a/README.md b/README.md index ca4edcc01e..e4ec709a8d 100644 --- a/README.md +++ b/README.md @@ -231,6 +231,7 @@ Thanks especially to the gracious help of Awkward Array contributors (including Andres Rios Tascon
Andres Rios Tascon

💻 maxymnaumchyk
maxymnaumchyk

💻 Thomas A Caswell
Thomas A Caswell

🚧 + Bas Nijholt
Bas Nijholt

🚧 From 3bb2661acde33907dc6ffb1758fa6e054f2d435e Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Fri, 25 Oct 2024 00:20:26 -0700 Subject: [PATCH 15/21] chore: delete `.DS_Store` and add it to .gitignore (#3286) * Delete .DS_Store * Add .DS_Store to .gitignore --- .DS_Store | Bin 8196 -> 0 bytes .gitignore | 3 +++ 2 files changed, 3 insertions(+) delete mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index f0c6683eef7b709a38f347afd2bb5a398c6cbc52..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHM&2G~`5S~q1;xrTrqM|}7vc$D5rThtT327r#35wJNH~`QcnocU1@hV>-lDOznNWkT_O^tZsQKo0ugDbEawVn#uR?e zb);mXk(-bLc%nKPsJmoQZA4oQCK7szDuU&@#|%YNt@&&W36>{JQCR4jaA#LpGVzMeD$DPkvCa&|jf| zJZh1XbRrg0PA<|qZBi9eu0p<`Eg4`3 zyp`ys^twe)fp1YG=`*vFLT9DFgOt{f@mg`3rU3LXQ9lWlarzpIlQz^tgltS`C#BC3 zit867-G-c22N;iBP~T6YB;|3`EbSp+4m=phf_>V-;`hjBbPCC&$B_%rwva(CS_l4o z;5h&`0%k)R2)Zq2F*zJ@d-ExK0j&A3v5Xb5Eo0KmWN6}+a`X%VU8gdFTcL7t9T^T$ zTzdv9@aRy^+JVm%-cd;m$3|R=#pV+0pmu=VldBGZ8=K^WI3=)oq+1ylH|RO!m{j8! zB@j2@$D$8%V{xyA>cX23ZcbTh$~cCb?T~`x3Iw8;u(^09n=M4pMQwroNtUs#5|+<4 z+mOG$WNl=&9MOc$EY{kgdU8!1$%Ix;=1Uj(VuPVE`%8R`#c=wByR;@pY=WlQ^V`wc z`rT|hwjZpn{uG())bxd%meUH_yUMQDs{}^S^-D(Q4gad*d^CrB!+d31z1rNhb5U?yzqe>M{C(1XMlZJX;mO46vEhSC$8|iH4?TV|bIKKRu8vH@ zo=iqVW^J6ctK%?0$4`wE^D;~0HsD&75eBm1$O7O0%agzVkKojRVn8wQ|1cn?HYyt> ztY>)Er Date: Mon, 28 Oct 2024 14:08:36 +0200 Subject: [PATCH 16/21] fix: add cuda backend support for `to_raggedtensor` and `from_raggedtensor` functions (#3263) * add cuda backend support * style changes * keep gpu id the same * style changes * fix device id selection * add new functions to the documentation * add cuda backend support for ak.from_raggedtensor * add suggestions from Jim * add suggestions from Jim --------- Co-authored-by: Ianna Osborne --- docs/reference/toctree.txt | 8 +++ .../operations/ak_from_raggedtensor.py | 42 +++++++++++++- src/awkward/operations/ak_to_raggedtensor.py | 56 ++++++++++++++++--- 3 files changed, 96 insertions(+), 10 deletions(-) diff --git a/docs/reference/toctree.txt b/docs/reference/toctree.txt index 4b6ae1154d..0175459441 100644 --- a/docs/reference/toctree.txt +++ b/docs/reference/toctree.txt @@ -39,6 +39,14 @@ generated/ak.to_feather generated/ak.from_avro_file +.. toctree:: + :caption: Conversions for machine learning + + generated/ak.from_raggedtensor + generated/ak.to_raggedtensor + generated/ak.from_torch + generated/ak.to_torch + .. toctree:: :caption: Converting to Pandas DataFrames diff --git a/src/awkward/operations/ak_from_raggedtensor.py b/src/awkward/operations/ak_from_raggedtensor.py index 1c895506c2..cadf1edb12 100644 --- a/src/awkward/operations/ak_from_raggedtensor.py +++ b/src/awkward/operations/ak_from_raggedtensor.py @@ -2,6 +2,8 @@ from __future__ import annotations +import re + import awkward as ak from awkward._dispatch import high_level_function @@ -30,18 +32,25 @@ def from_raggedtensor(array): def _impl(array): try: # get the flat values - content = array.flat_values.numpy() + content = array.flat_values except AttributeError as err: raise TypeError( """only RaggedTensor can be converted to awkward array""" ) from err - # convert them to ak.contents right away + + # handle gpu and cpu instances separately + device = content.backing_device + + content = _tensor_to_np_or_cp(content, device) + + # convert flat_values to ak.contents right away content = ak.contents.NumpyArray(content) # get the offsets offsets_arr = [] for splits in array.nested_row_splits: - split = splits.numpy() + # handle gpu and cpu instances separately + split = _tensor_to_np_or_cp(splits, device) # convert to ak.index offset = ak.index.Index64(split) offsets_arr.append(offset) @@ -55,6 +64,33 @@ def _impl(array): return ak.Array(_recursive_call(content, offsets_arr, 0)) +def _tensor_to_np_or_cp(array, device): + matched_device = re.match(".*:(CPU|GPU):[0-9]+", device) + + if matched_device is None: + raise NotImplementedError( + f"TensorFlow device has an unexpected format: {device!r}" + ) + elif matched_device.groups()[0] == "GPU": + try: + import tensorflow as tf + except ImportError as err: + raise ImportError( + """to use ak.from_raggedtensor, you must install the 'tensorflow' package with: + + pip install tensorflow + or + conda install tensorflow""" + ) from err + + from awkward._nplikes.cupy import Cupy + + cp = Cupy.instance() + return cp.from_dlpack(tf.experimental.dlpack.to_dlpack(array)) + elif matched_device.groups()[0] == "CPU": + return array.numpy() + + def _recursive_call(content, offsets_arr, count): if count == len(offsets_arr) - 2: return ak.contents.ListOffsetArray( diff --git a/src/awkward/operations/ak_to_raggedtensor.py b/src/awkward/operations/ak_to_raggedtensor.py index 5fcb2e2d5f..0a8c797c63 100644 --- a/src/awkward/operations/ak_to_raggedtensor.py +++ b/src/awkward/operations/ak_to_raggedtensor.py @@ -4,9 +4,12 @@ import awkward as ak from awkward._dispatch import high_level_function +from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("to_raggedtensor",) +np = NumpyMetadata.instance() + @high_level_function() def to_raggedtensor(array): @@ -45,14 +48,49 @@ def _impl(array): # also transforms a python list to awkward array array = ak.to_layout(array, allow_record=False) + # keep the same device + ak_device = ak.backend(array) + if ak_device not in ["cuda", "cpu"]: + raise ValueError("""Only 'cpu' and 'cuda' backend conversions are allowed""") + + if ak_device == "cpu": + device = "CPU:0" + else: + id = _find_innermost_content(array).data.device.id + device = "GPU:" + str(id) + + with tf.device(device): + if isinstance(array, ak.contents.numpyarray.NumpyArray): + values = array.data + # handle cupy separately + values = _convert_to_tensor_if_cupy(values) + return tf.RaggedTensor.from_row_splits( + values=values, row_splits=[0, array.__len__()] + ) + + else: + flat_values, nested_row_splits = _recursive_call(array, ()) + return tf.RaggedTensor.from_nested_row_splits( + flat_values, nested_row_splits + ) + + +def _find_innermost_content(array): if isinstance(array, ak.contents.numpyarray.NumpyArray): - return tf.RaggedTensor.from_row_splits( - values=array.data, row_splits=[0, array.__len__()] - ) + return array + else: + return _find_innermost_content(array.content) + + +def _convert_to_tensor_if_cupy(array): + if isinstance(array, np.ndarray): + return array else: - flat_values, nested_row_splits = _recursive_call(array, ()) + # converts cupy directly to tensor, + # since `tf.RaggedTensor.from_nested_row_splits` can not work with Cupy arrays + import tensorflow as tf - return tf.RaggedTensor.from_nested_row_splits(flat_values, nested_row_splits) + return tf.experimental.dlpack.from_dlpack(array.toDlpack()) def _recursive_call(layout, offsets_arr): @@ -75,10 +113,14 @@ def _recursive_call(layout, offsets_arr): ) # recursively gather all of the offsets of an array - offsets_arr += (layout.offsets.data,) + offset = layout.offsets.data + offset = _convert_to_tensor_if_cupy(offset) + offsets_arr += (offset,) except AttributeError: # at the last iteration form a ragged tensor from the # accumulated offsets and flattened values of the array - return layout.data, offsets_arr + data = layout.data + data = _convert_to_tensor_if_cupy(data) + return data, offsets_arr return _recursive_call(layout.content, offsets_arr) From 3fd561a125ca74c98b3aefd200112acabcf71227 Mon Sep 17 00:00:00 2001 From: Ianna Osborne Date: Thu, 31 Oct 2024 16:56:13 +0100 Subject: [PATCH 17/21] fix: test `from_raggedtensor` on GPU (#3288) * fix: test from_raggedtensor on GPU * style: pre-commit fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/test_3210_to_raggedtensor_from_raggedtensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_3210_to_raggedtensor_from_raggedtensor.py b/tests/test_3210_to_raggedtensor_from_raggedtensor.py index d250910c09..3546279146 100644 --- a/tests/test_3210_to_raggedtensor_from_raggedtensor.py +++ b/tests/test_3210_to_raggedtensor_from_raggedtensor.py @@ -107,7 +107,9 @@ def test_convert_from_raggedtensor(): ak_array1 = ak.contents.ListOffsetArray(offsets1, content1) result1 = ak.to_layout(from_raggedtensor(tf_array1), allow_record=False) - assert (result1.content.data == np_array1).all() + assert ( + result1.content.data == ak.to_backend(np_array1, result1.backend).layout.data + ).all() assert (result1.offsets.data == [0, 2, 3, 3, 5]).all() assert from_raggedtensor(tf_array1).to_list() == ak_array1.to_list() From 8a0bcb5d5e687c35e18196c0fe1dfa4c62e58740 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Wed, 6 Nov 2024 10:42:19 -0600 Subject: [PATCH 18/21] docs: add nj-vs-vh as a contributor for code (#3293) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] --------- Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> --- .all-contributorsrc | 9 +++++++++ README.md | 1 + 2 files changed, 10 insertions(+) diff --git a/.all-contributorsrc b/.all-contributorsrc index 4f6f90c52b..9d4caf4356 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -501,6 +501,15 @@ "contributions": [ "maintenance" ] + }, + { + "login": "nj-vs-vh", + "name": "Igor Vaiman", + "avatar_url": "https://avatars.githubusercontent.com/u/30616208?v=4", + "profile": "https://nj-vs-vh.name/", + "contributions": [ + "code" + ] } ], "contributorsPerLine": 7, diff --git a/README.md b/README.md index e4ec709a8d..7cc935cc3e 100644 --- a/README.md +++ b/README.md @@ -232,6 +232,7 @@ Thanks especially to the gracious help of Awkward Array contributors (including maxymnaumchyk
maxymnaumchyk

💻 Thomas A Caswell
Thomas A Caswell

🚧 Bas Nijholt
Bas Nijholt

🚧 + Igor Vaiman
Igor Vaiman

💻 From 5145bcee480ba748e22fc198ec5971db7c5800be Mon Sep 17 00:00:00 2001 From: Igor Vaiman Date: Wed, 6 Nov 2024 17:42:27 +0100 Subject: [PATCH 19/21] fix: correct handling of keepdims and mask_identity for weighted mean (#3291) * fix + tests * style: pre-commit fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/awkward/operations/ak_mean.py | 4 +- tests/test_3285_ak_mean_weighted_row_wise.py | 65 ++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tests/test_3285_ak_mean_weighted_row_wise.py diff --git a/src/awkward/operations/ak_mean.py b/src/awkward/operations/ak_mean.py index a9b38ce1f0..3b6552c521 100644 --- a/src/awkward/operations/ak_mean.py +++ b/src/awkward/operations/ak_mean.py @@ -225,8 +225,8 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): sumw = ak.operations.ak_sum._impl( x * 0 + weight, axis, - keepdims, - mask_identity, + keepdims=True, + mask_identity=True, highlevel=True, behavior=ctx.behavior, attrs=ctx.attrs, diff --git a/tests/test_3285_ak_mean_weighted_row_wise.py b/tests/test_3285_ak_mean_weighted_row_wise.py new file mode 100644 index 0000000000..2c1bc1db9b --- /dev/null +++ b/tests/test_3285_ak_mean_weighted_row_wise.py @@ -0,0 +1,65 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import math + +import pytest + +import awkward as ak + + +@pytest.mark.parametrize( + "keepdims, expected_result", + [ + pytest.param(False, ak.Array([2.25, 6.5])), + pytest.param(True, ak.Array([[2.25], [6.5]])), + ], +) +def test_keepdims(keepdims: bool, expected_result: ak.Array): + data = ak.Array( + [ + [1, 2, 3], + [4, 7], + ] + ) + weight = ak.Array( + [ + [1, 1, 2], + [1, 5], + ] + ) + assert ak.all( + ak.mean(data, weight=weight, axis=1, keepdims=keepdims) == expected_result + ) + + +@pytest.mark.parametrize( + "mask_identity, expected_result", + [ + pytest.param(False, ak.Array([1.5, math.nan, 8])), + pytest.param(True, ak.Array([1.5, None, 8])), + ], +) +def test_mask_identity(mask_identity: bool, expected_result: ak.Array): + data = ak.Array( + [ + [1, 2], + [], + [6, 9], + ] + ) + weight = ak.Array( + [ + [1, 1], + [], + [1, 2], + ] + ) + result = ak.mean(data, weight=weight, axis=1, mask_identity=mask_identity) + assert result[0] == expected_result[0] + if mask_identity: + assert result[1] is None + else: + assert math.isnan(result[1]) # NaN is not equal to itself per IEEE! + assert result[2] == expected_result[2] From 00822da9dd6d40e07451d58c085f4cc3016349e3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:46:18 -0600 Subject: [PATCH 20/21] chore(deps): bump the actions group across 1 directory with 2 updates (#3290) Bumps the actions group with 2 updates in the / directory: [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) and [mamba-org/setup-micromamba](https://github.com/mamba-org/setup-micromamba). Updates `pypa/gh-action-pypi-publish` from 1.10.2 to 1.11.0 - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.10.2...v1.11.0) Updates `mamba-org/setup-micromamba` from 1 to 2 - [Release notes](https://github.com/mamba-org/setup-micromamba/releases) - [Commits](https://github.com/mamba-org/setup-micromamba/compare/v1...v2) --- updated-dependencies: - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-minor dependency-group: actions - dependency-name: mamba-org/setup-micromamba dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jim Pivarski --- .github/workflows/deploy-cpp.yml | 2 +- .github/workflows/deploy.yml | 2 +- .github/workflows/docs.yml | 4 ++-- .github/workflows/test.yml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/deploy-cpp.yml b/.github/workflows/deploy-cpp.yml index c23a63179c..c5011af995 100644 --- a/.github/workflows/deploy-cpp.yml +++ b/.github/workflows/deploy-cpp.yml @@ -39,4 +39,4 @@ jobs: with: subject-path: "dist/awkward*cpp-*" - - uses: pypa/gh-action-pypi-publish@v1.10.2 + - uses: pypa/gh-action-pypi-publish@v1.11.0 diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index a45da9b9b2..931eb47c63 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -135,7 +135,7 @@ jobs: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: gh attestation verify dist/awkward-*.whl --repo ${{ github.repository }} - - uses: pypa/gh-action-pypi-publish@v1.10.2 + - uses: pypa/gh-action-pypi-publish@v1.11.0 publish-headers: name: "Publish header-only libraries alongside release" diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 4e199a87c2..8305708b10 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -141,7 +141,7 @@ jobs: # solve with different external library versions. By default, # ROOT uses cxx-compiler too, so hopefully this won't be an issue - name: Setup Python via Conda - uses: mamba-org/setup-micromamba@v1 + uses: mamba-org/setup-micromamba@v2 with: # Cache invalidates daily by default cache-environment: true @@ -264,7 +264,7 @@ jobs: # solve with different external library versions. By default, # ROOT uses cxx-compiler too, so hopefully this won't be an issue - name: Setup Python via Conda - uses: mamba-org/setup-micromamba@v1 + uses: mamba-org/setup-micromamba@v2 with: # Cache invalidates daily by default cache-environment: true diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2c58f66746..f896831328 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -175,7 +175,7 @@ jobs: submodules: true - name: Setup Python via Conda - uses: mamba-org/setup-micromamba@v1 + uses: mamba-org/setup-micromamba@v2 with: # Cache invalidates daily by default cache-environment: true From 9bab1177066833cca50b79667fa5057cd77de0ca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:47:07 -0600 Subject: [PATCH 21/21] chore: update pre-commit hooks (#3284) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.9 → v0.7.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.9...v0.7.2) - [github.com/python-jsonschema/check-jsonschema: 0.29.3 → 0.29.4](https://github.com/python-jsonschema/check-jsonschema/compare/0.29.3...0.29.4) - [github.com/pre-commit/mirrors-mypy: v1.11.2 → v1.13.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.11.2...v1.13.0) - [github.com/abravalheri/validate-pyproject: v0.20.2 → v0.22](https://github.com/abravalheri/validate-pyproject/compare/v0.20.2...v0.22) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jim Pivarski --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4ffeb0ed6f..9c06c85b80 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: additional_dependencies: [pyyaml] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.2 hooks: - id: ruff args: ["--fix", "--show-fixes"] @@ -62,13 +62,13 @@ repos: files: ^tests/ - repo: https://github.com/python-jsonschema/check-jsonschema - rev: 0.29.3 + rev: 0.29.4 hooks: - id: check-github-workflows args: ["--verbose"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy files: src @@ -76,6 +76,6 @@ repos: - numpy>=1.24 - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.20.2 + rev: v0.22 hooks: - id: validate-pyproject